# Training of Encoder-Decoder network with attention mechanism
## Data preparation
* Create our input files from the Microsoft COCO dataset.
* By default COCO come with 5 captions per image. 
* We are going to ignore the captions that are longer than 50 words. As we have seen in the Image-Captioning-Project, they are really rare.
* The output files are stored under the "output_folder" directory

In [4]:
from utils import create_input_files

# This will create big files in your computer.
create_input_files(dataset='coco',
                   karpathy_json_path='dataset_coco.json',
                   image_folder='../cocoapi-master/images',
                   captions_per_image=5,
                   min_word_freq=5,
                   output_folder='../cocoapi-master/images',
                   max_len=50
                  )


Reading TRAIN images and captions, storing to file...



100%|██████████████████████████████████████████████████████████████████████████| 113287/113287 [48:08<00:00, 39.21it/s]



Reading VAL images and captions, storing to file...



100%|██████████████████████████████████████████████████████████████████████████████| 5000/5000 [03:26<00:00, 32.00it/s]



Reading TEST images and captions, storing to file...



100%|██████████████████████████████████████████████████████████████████████████████| 5000/5000 [02:57<00:00, 28.14it/s]


In [1]:
data_folder = '../cocoapi-master/images'         # the output_folder of created input_files
data_name = 'coco_5_cap_per_img_5_min_word_freq' # the same way we created them

## Training setup

In [2]:
# first import necessary packages
import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from models import Encoder, DecoderWithAttention
from datasets import *
from utils import *
from nltk.translate.bleu_score import corpus_bleu

### specify the model parameters

In [3]:
emb_dim = 512         # dimension of word embeddings
attention_dim = 512   # dimension of attention linear layers
decoder_dim = 512     # dimention of LSTM decoder
dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # GPU set for PyTorch
cudnn.benchmark = True # set to "True" only if inputs to model are fixed size; otherwise lot of computational overhead

### specify the training parameters

In [4]:
start_epoch = 0
epochs = 7                    # Max number of training epochs if early-stopping is not triggered
epochs_since_improvement = 0  # This tracks the number of epochs since the last improvement was made
batch_size = 32
encoder_lr = 1e-4             # This is used with encoder if 'fine-tunning' is enabled
decoder_lr = 4e-4             # LSTM learning rate
grad_clip = 5.0               # Clip the gradients at this value
alpha_c = 1.0                 # Regularization for attention; implementation from the original paper
best_bleu4 = 0.0              # This tracks the BLUE-4 score
print_freq = 200    
fine_tune_encoder = True     # Do we train the encoder?
checkpoint = 'BEST_checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar'             # Path to checkpoint, "None" if none
torch.cuda.get_device_name(0) # check if GPU is available

'GeForce GTX 1080 Ti'

## A walk through of what will happen later

In [5]:
# take the next gourpd of training data from loader
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
train_loader = torch.utils.data.DataLoader(
    CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
    batch_size=batch_size, shuffle=True, pin_memory=True)
imgs, caps, caplens = next(iter(train_loader))

In [6]:
imgs.size()

torch.Size([32, 3, 256, 256])

In [18]:
caps.size()

torch.Size([32, 52])

In [19]:
caps.size()

torch.Size([32, 52])

In [20]:
# load data into GPU device
imgs = imgs.to(device)
caps = caps.to(device)
caplens = caplens.to(device)

# Pass through CNN encoder
imgs = encoder(imgs)   # (batch_size, encoded_image_size, encoded_image_size, 2048)

# Pass through RNN decoder
scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

In [21]:
scores.size()

torch.Size([32, 16, 9490])

In [22]:
caps_sorted.size()

torch.Size([32, 52])

In [23]:
len(decode_lengths)

32

In [24]:
# pad the sequence so that we can calculate the loss later
scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)
targets = caps_sorted[:, 1:]
targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)

In [25]:
loss = criterion(scores.data, targets.data)

## Now we have the loss to perform back propogate
All the rest can be handled by Pytorch

In [26]:
loss

tensor(2.2221, device='cuda:0', grad_fn=<NllLossBackward>)

## DEFINE Training function

In [7]:
def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch):
    """
    tain_loader: DataLoader for the training data
    encoder: encoder model
    decoder: decoder model
    criterion: loss function
    encoder_optimizer: optimizer for the encoder(if fine-tuning is on)
    decoder_optimizer: optimizer for the LSTM
    epoch: epoch number
    """
    
    decoder.train()    # Training mode
    encoder.train()
    
    # AverageMeter() is imported from utils.py to track the most recent, average, sum, and count of a metric.
    batch_time = AverageMeter()   # forward prop + back prop time
    data_time = AverageMeter()    # data loading time
    losses = AverageMeter()       # loss (per word decoded)
    top5accs = AverageMeter()     # top5 accuracy
    
    start = time.time()
    
    # Batches
    for i, (imgs, caps, caplens) in enumerate(train_loader):
        data_time.update(time.time() - start)
        
        # Move data to GPU
        imgs = imgs.to(device)
        caps = caps.to(device)
        caplens = caplens.to(device)
        
        # Forward prop.
        imgs = encoder(imgs)   # (batch_size, encoded_image_size, encoded_image_size, 2048)
        scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens) 
        # scores-(batch_size_t, vocab_size) 
        # caps_sorted-(batch_size, max_caption_length)
        # decode_lengths-[caption_lengths -1]
        # alphas, pixel-wise attention, (batch_size, max(decode_lengths), num_pixels)
        # sort_ind: this index sort the lengths of captions in descending order
        
        # Since we decoded staring with <start>, the targets are all words after <start>, up to <end>
        targets = caps_sorted[:, 1:]
        
        # Remove timesteps that we didn't encode at, or are pads
        # pack_padded_sentence to fo this trick
        scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)
        targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)
        
        
        # Compute the loss
        loss = criterion(scores.data, targets.data)
        
        # Add doubly stochastic attention regularization
        loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()
        
        # BackProp.
        decoder_optimizer.zero_grad()
        if encoder_optimizer is not None:
            encoder_optimizer.zero_grad()
        loss.backward()
        
        # Clip gradients
        if grad_clip is not None:
            clip_gradient(decoder_optimizer, grad_clip)
            if encoder_optimizer is not None:
                clip_gradient(encoder_optimizer, grad_clip)
        
        # Update weights
        decoder_optimizer.step()
        if encoder_optimizer is not None:
            encoder_optimizer.step()
            
        # Keep track of metrics
        top5 = accuracy(scores.data, targets.data, 5)
        losses.update(loss.item(), sum(decode_lengths))
        top5accs.update(top5, sum(decode_lengths))
        batch_time.update(time.time() - start)

        start = time.time()

        # Print status
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),
                                                                          batch_time=batch_time,
                                                                          data_time=data_time, loss=losses,
                                                                          top5=top5accs))

## DEFINE validation function

In [8]:
def validate(val_loader, encoder, decoder, criterion):
    """
    val_loader: DataLoader for validation data
    encoder: trained encoder model
    decoder: trained decoder model
    criterion: loss function
    
    return: BLEU-4 score
    """
    decoder.eval() # eval mode (no dropout or batchnorm)
    if encoder is not None:
        encoder.eval()
    batch_time = AverageMeter()
    losses = AverageMeter()
    top5accs = AverageMeter()
    start = time.time()
    
    references = list()  # references (true captions) for calculating BLEU-4 score
    hypotheses = list()  # references (predictions)
    
    # explicitly disable gradient calculation to avoid CUDA memory error
    with torch.no_grad():
        # Batches
        for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader):
            # move to GPU
            imgs = imgs.to(device)
            caps = caps.to(device)
            caplens = caplens.to(device)
            
            # Forward prop.
            if encoder is not None:
                imgs = encoder(imgs)
            scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)
            
            # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
            # "caps_sorted" is the "caps" tensor sorted with descending caption_length order. (batch_size, max_caption_length)
            targets = caps_sorted[:, 1:]   # this matchs the position of target to predictions
            
            # Remove timesteps that we didn't decode at, or are pads
            # pack_padded_sequence can do this trick
            scores_copy = scores.clone()
            scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)
            targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)
            
            # Calculate loss
            loss = criterion(scores.data, targets.data)
            
            # Add doubly stochastic attention regularization
            loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()
            
            # Keep track of metrics
            losses.update(loss.item(), sum(decode_lengths))
            top5 = accuracy(scores.data, targets.data, 5)
            top5accs.update(top5, sum(decode_lengths))
            batch_time.update(time.time() - start)
            
            start = time.time()
            
            if i % print_freq == 0:
                print('Validation: [{0}/{1}]\t'
                      'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,
                                                                                loss=losses, top5=top5accs))
                
            # Store references (true captions), and hypothesis (prediction) for each image
            # If for n images, we have n hypotheses, and references a, b, c... for each image, we need
            # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b],...], hypotheses = [hyp1, hyp2, ...]
            
            # References
            allcaps = allcaps[sort_ind] # the captions are sorted in the decoder
            for j in range(allcaps.shape[0]):
                img_caps = allcaps[j].tolist()
                img_captions = list(
                    map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}], 
                        img_caps))   # remove <start> and pads
                references.append(img_captions)
            
            # Hypotheses
            _, preds = torch.max(scores_copy, dim=2)
            preds = preds.tolist()
            temp_preds = list()
            for j,p in enumerate(preds):
                temp_preds.append(preds[j][:decode_lengths[j]])   # remove pads
            preds = temp_preds
            hypotheses.extend(preds)
            
            assert len(references) == len(hypotheses)
            
        # Calculate BLEU-4 scores
        bleu4 = corpus_bleu(references, hypotheses)
        
        print(
            '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(
                loss=losses,
                top5=top5accs,
                bleu=bleu4))

    return bleu4

# Start training and validation

## load our word_map

In [9]:
global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
with open(word_map_file, 'r') as j:
    word_map = json.load(j)
print(f"The total size of the vocabulary is :{len(word_map)}")

The total size of the vocabulary is :9490


## Initialize / load checkpoint

In [10]:
if checkpoint is None:
    decoder = DecoderWithAttention(attention_dim=attention_dim,
                                   embed_dim=emb_dim,
                                   decoder_dim=decoder_dim,
                                   vocab_size=len(word_map),
                                   dropout=dropout
                                  )
    decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
    encoder = Encoder()
    encoder.fine_tune(fine_tune_encoder)
    encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                             lr=encoder_lr) if fine_tune_encoder else None
else:
    checkpoint = torch.load(checkpoint)
    start_epoch = checkpoint['epoch'] + 1
    epochs_since_improvement = checkpoint['epochs_since_improvement']
    best_bleu4 = checkpoint['bleu-4']
    decoder = checkpoint['decoder']
    decoder_optimizer = checkpoint['decoder_optimizer']
    encoder = checkpoint['encoder']
    encoder_optimizer = checkpoint['encoder_optimizer']
    if fine_tune_encoder is True and encoder_optimizer is None:
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

## The actual training

In [11]:
# Move to GPU, if available
decoder = decoder.to(device)
encoder = encoder.to(device)

# Loss function
criterion = nn.CrossEntropyLoss().to(device)

# Custom dataloaders
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
train_loader = torch.utils.data.DataLoader(
    CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
    batch_size=batch_size, shuffle=True, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
    CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
    batch_size=batch_size, shuffle=True, pin_memory=True)

In [10]:
# Move to GPU, if available
decoder = decoder.to(device)
encoder = encoder.to(device)

# Loss function
criterion = nn.CrossEntropyLoss().to(device)

# Custom dataloaders
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
train_loader = torch.utils.data.DataLoader(
    CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
    batch_size=batch_size, shuffle=True, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
    CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
    batch_size=batch_size, shuffle=True, pin_memory=True)

In [12]:
# Epochs
for epoch in range(start_epoch, epochs):

    # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
    if epochs_since_improvement == 20:
        break
    if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
        adjust_learning_rate(decoder_optimizer, 0.8)
        if fine_tune_encoder:
            adjust_learning_rate(encoder_optimizer, 0.8)

    # One epoch's training
    train(train_loader=train_loader,
          encoder=encoder,
          decoder=decoder,
          criterion=criterion,
          encoder_optimizer=encoder_optimizer,
          decoder_optimizer=decoder_optimizer,
          epoch=epoch)

    # One epoch's validation
    recent_bleu4 = validate(val_loader=val_loader,
                            encoder=encoder,
                            decoder=decoder,
                            criterion=criterion)

    # Check if there was an improvement
    is_best = recent_bleu4 > best_bleu4
    best_bleu4 = max(recent_bleu4, best_bleu4)
    if not is_best:
        epochs_since_improvement += 1
        print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
    else:
        epochs_since_improvement = 0

    # Save checkpoint
    save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,
                    decoder_optimizer, recent_bleu4, is_best)

Epoch: [4][0/17702]	Batch Time 6.032 (6.032)	Data Load Time 0.731 (0.731)	Loss 3.0988 (3.0988)	Top-5 Accuracy 78.723 (78.723)
Epoch: [4][200/17702]	Batch Time 1.219 (1.284)	Data Load Time 0.562 (0.540)	Loss 3.0774 (3.0194)	Top-5 Accuracy 77.233 (78.635)
Epoch: [4][400/17702]	Batch Time 1.266 (1.270)	Data Load Time 0.500 (0.537)	Loss 3.0700 (3.0262)	Top-5 Accuracy 77.929 (78.618)
Epoch: [4][600/17702]	Batch Time 1.172 (1.254)	Data Load Time 0.469 (0.528)	Loss 3.0951 (3.0341)	Top-5 Accuracy 78.462 (78.564)
Epoch: [4][800/17702]	Batch Time 1.234 (1.243)	Data Load Time 0.562 (0.518)	Loss 2.6472 (3.0410)	Top-5 Accuracy 82.778 (78.485)
Epoch: [4][1000/17702]	Batch Time 1.125 (1.231)	Data Load Time 0.484 (0.509)	Loss 3.3399 (3.0441)	Top-5 Accuracy 74.419 (78.434)
Epoch: [4][1200/17702]	Batch Time 1.094 (1.220)	Data Load Time 0.422 (0.499)	Loss 3.0285 (3.0448)	Top-5 Accuracy 81.429 (78.457)
Epoch: [4][1400/17702]	Batch Time 1.203 (1.210)	Data Load Time 0.500 (0.490)	Loss 2.9988 (3.0431)	Top-5 

Epoch: [4][12600/17702]	Batch Time 1.062 (1.097)	Data Load Time 0.359 (0.380)	Loss 3.0545 (3.0689)	Top-5 Accuracy 77.008 (78.259)
Epoch: [4][12800/17702]	Batch Time 1.031 (1.096)	Data Load Time 0.328 (0.380)	Loss 3.0507 (3.0686)	Top-5 Accuracy 76.944 (78.263)
Epoch: [4][13000/17702]	Batch Time 1.203 (1.096)	Data Load Time 0.469 (0.379)	Loss 3.0377 (3.0689)	Top-5 Accuracy 78.670 (78.257)
Epoch: [4][13200/17702]	Batch Time 1.156 (1.096)	Data Load Time 0.391 (0.379)	Loss 3.0287 (3.0690)	Top-5 Accuracy 78.713 (78.259)
Epoch: [4][13400/17702]	Batch Time 1.000 (1.095)	Data Load Time 0.328 (0.379)	Loss 3.4252 (3.0692)	Top-5 Accuracy 73.458 (78.258)
Epoch: [4][13600/17702]	Batch Time 1.516 (1.095)	Data Load Time 0.422 (0.378)	Loss 3.4462 (3.0695)	Top-5 Accuracy 77.078 (78.254)
Epoch: [4][13800/17702]	Batch Time 0.953 (1.094)	Data Load Time 0.297 (0.378)	Loss 3.1200 (3.0694)	Top-5 Accuracy 75.931 (78.256)
Epoch: [4][14000/17702]	Batch Time 1.062 (1.094)	Data Load Time 0.391 (0.378)	Loss 3.1111 

Epoch: [5][6600/17702]	Batch Time 1.156 (1.108)	Data Load Time 0.406 (0.391)	Loss 3.0694 (3.0064)	Top-5 Accuracy 76.336 (79.078)
Epoch: [5][6800/17702]	Batch Time 1.047 (1.108)	Data Load Time 0.391 (0.391)	Loss 3.1284 (3.0075)	Top-5 Accuracy 77.557 (79.063)
Epoch: [5][7000/17702]	Batch Time 1.125 (1.107)	Data Load Time 0.359 (0.391)	Loss 2.7860 (3.0073)	Top-5 Accuracy 83.511 (79.067)
Epoch: [5][7200/17702]	Batch Time 1.125 (1.107)	Data Load Time 0.375 (0.390)	Loss 3.1335 (3.0076)	Top-5 Accuracy 77.394 (79.062)
Epoch: [5][7400/17702]	Batch Time 1.219 (1.107)	Data Load Time 0.531 (0.390)	Loss 2.9187 (3.0078)	Top-5 Accuracy 80.172 (79.060)
Epoch: [5][7600/17702]	Batch Time 1.031 (1.107)	Data Load Time 0.344 (0.390)	Loss 3.1439 (3.0081)	Top-5 Accuracy 75.871 (79.059)
Epoch: [5][7800/17702]	Batch Time 1.109 (1.107)	Data Load Time 0.375 (0.390)	Loss 2.9092 (3.0086)	Top-5 Accuracy 81.267 (79.053)
Epoch: [5][8000/17702]	Batch Time 1.297 (1.107)	Data Load Time 0.484 (0.390)	Loss 3.1913 (3.0093)

Epoch: [6][600/17702]	Batch Time 1.078 (1.113)	Data Load Time 0.375 (0.401)	Loss 2.7387 (2.9278)	Top-5 Accuracy 83.161 (80.097)
Epoch: [6][800/17702]	Batch Time 1.109 (1.112)	Data Load Time 0.422 (0.401)	Loss 2.7227 (2.9276)	Top-5 Accuracy 84.426 (80.129)
Epoch: [6][1000/17702]	Batch Time 1.266 (1.112)	Data Load Time 0.484 (0.401)	Loss 2.9361 (2.9319)	Top-5 Accuracy 79.082 (80.070)
Epoch: [6][1200/17702]	Batch Time 1.359 (1.118)	Data Load Time 0.687 (0.406)	Loss 2.9906 (2.9369)	Top-5 Accuracy 79.006 (80.023)
Epoch: [6][1400/17702]	Batch Time 1.094 (1.118)	Data Load Time 0.422 (0.406)	Loss 2.9327 (2.9392)	Top-5 Accuracy 78.223 (79.981)
Epoch: [6][1600/17702]	Batch Time 1.078 (1.118)	Data Load Time 0.375 (0.405)	Loss 2.9290 (2.9403)	Top-5 Accuracy 79.144 (79.958)
Epoch: [6][1800/17702]	Batch Time 1.000 (1.116)	Data Load Time 0.328 (0.403)	Loss 2.7221 (2.9404)	Top-5 Accuracy 84.718 (79.954)
Epoch: [6][2000/17702]	Batch Time 1.125 (1.113)	Data Load Time 0.500 (0.400)	Loss 2.8470 (2.9418)	T

Epoch: [6][13200/17702]	Batch Time 1.094 (1.091)	Data Load Time 0.406 (0.378)	Loss 3.3548 (2.9706)	Top-5 Accuracy 74.359 (79.614)
Epoch: [6][13400/17702]	Batch Time 1.094 (1.091)	Data Load Time 0.312 (0.378)	Loss 2.9509 (2.9706)	Top-5 Accuracy 80.637 (79.614)
Epoch: [6][13600/17702]	Batch Time 1.109 (1.091)	Data Load Time 0.406 (0.378)	Loss 3.1058 (2.9709)	Top-5 Accuracy 76.519 (79.609)
Epoch: [6][13800/17702]	Batch Time 1.078 (1.091)	Data Load Time 0.375 (0.378)	Loss 2.9133 (2.9713)	Top-5 Accuracy 80.541 (79.603)
Epoch: [6][14000/17702]	Batch Time 1.078 (1.091)	Data Load Time 0.375 (0.378)	Loss 3.0685 (2.9718)	Top-5 Accuracy 78.992 (79.598)
Epoch: [6][14200/17702]	Batch Time 1.047 (1.091)	Data Load Time 0.328 (0.378)	Loss 3.0349 (2.9720)	Top-5 Accuracy 75.806 (79.596)
Epoch: [6][14400/17702]	Batch Time 1.078 (1.091)	Data Load Time 0.375 (0.378)	Loss 3.3803 (2.9725)	Top-5 Accuracy 72.576 (79.590)
Epoch: [6][14600/17702]	Batch Time 1.062 (1.091)	Data Load Time 0.359 (0.378)	Loss 2.8036 