In [8]:
import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.tensorboard import SummaryWriter
from models import Encoder, DecoderWithAttention
from datasets import *
from utils import *
from nltk.translate.bleu_score import corpus_bleu

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# Data parameters
data_folder = './data/images/'  # folder with data files saved by create_input_files.py
data_name = 'coco_5_cap_per_img_5_min_word_freq'  # base name shared by data files

# Model parameters
emb_dim = 512  # dimension of word embeddings
attention_dim = 512  # dimension of attention linear layers
decoder_dim = 512  # dimension of decoder RNN
dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# Training parameters
start_epoch = 0
epochs = 6 #120  # number of epochs to train for (if early stopping is not triggered)
epochs_since_improvement = 0  # keeps track of number of epochs since there's been an improvement in validation BLEU
batch_size = 160
workers = 1  # for data-loading; right now, only 1 works with h5py
encoder_lr = 1e-4  # learning rate for encoder if fine-tuning
decoder_lr = 4e-4  # learning rate for decoder
grad_clip = 5.  # clip gradients at an absolute value of
alpha_c = 1.  # regularization parameter for 'doubly stochastic attention', as in the paper
best_bleu4 = 0.  # BLEU-4 score right now
print_freq = 5  # print training/validation stats every __ batches
fine_tune_encoder = False  # fine-tune encoder?
checkpoint = None  # path to checkpoint, None if none


def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map
    timestr = time.strftime("%Y%m%d_%H%M%S")
    writer = SummaryWriter('./logs/' + timestr)
    print('Tensorboard is recording into folder: ' + './logs/' + timestr)

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)
    print(f'word map len: {len(word_map)}')
    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        # here you connect the Optimizer to the network parameters
        decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        # enable/disable finetuning
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                             lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    
    # the caption DataLoader method getitem() returns one (image, caption). The images are repeated because for the caption i the i//captions_per image image is taken 
    train_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch, writer=writer)
        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)
        
        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,
                        decoder_optimizer, recent_bleu4, is_best)
        
        writer.close()
    print('Tensorboard is recording into folder: ' + PATH_to_log_dir + timestr)



def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch, writer):
    """
    Performs one epoch's training.
    :param train_loader: DataLoader for training data
    :param encoder: encoder model
    :param decoder: decoder model
    :param criterion: loss layer
    :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning)
    :param decoder_optimizer: optimizer to update decoder's weights
    :param epoch: epoch number
    """
    # just sets the modules in training mode
    decoder.train()  # train mode (dropout and batchnorm is used)
    encoder.train()

    batch_time = AverageMeter()  # forward prop. + back prop. time
    data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss (per word decoded)
    top5accs = AverageMeter()  # top5 accuracy

    start = time.time()

    # Batches
    for i, (imgs, caps, caplens) in enumerate(train_loader):
        data_time.update(time.time() - start)
        #print(imgs.shape)
        #print(torch.all(imgs.data[0] == imgs.data[1]))
        #print(caps.shape)
        #print(caps.data[0])
        #print(caplens.data[0])
        
        # Move to GPU, if available
        imgs = imgs.to(device)
        caps = caps.to(device)
        caplens = caplens.to(device)

        # Forward prop.
        imgs = encoder(imgs)
        #print(f'batch image shape after encoding: {imgs.shape}')
        scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

        # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
        targets = caps_sorted[:, 1:]

        # Remove timesteps that we didn't decode at, or are pads (do not compute loss over padded regions)
        # pack_padded_sequence is an easy trick to do this
        '''
        This function is actually used to perform the same dynamic batching (i.e., processing only the effective batch size at each timestep) we performed in our Decoder, 
        when using an RNN or LSTM in PyTorch. In this case, PyTorch handles the dynamic variable-length graphs internally.
        ''' 
        #print(f'scores before padding: {scores.shape}') #(bs, max(decoded_lenghts), vs)
        scores  = pack_padded_sequence(scores, decode_lengths, batch_first=True).data
        #print(f'scores after padding: {scores.shape}') #(bs*sum(decoded_lengths), vs)
        
        
        
        targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data
        #print(f'targest: {targets.shape}')

        # Calculate loss
        loss = criterion(scores, targets)

        # Add doubly stochastic attention regularization
        # We want the model to attend to every pixel over the course of generating the entire sequence. 
        # Therefore, we try to minimize the difference between 1 and the sum of a pixel's weights across all timesteps.
        # sum is to sum across all timesteps
        loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

        # Back prop.
        # zeroing the grads at each iteration
        decoder_optimizer.zero_grad()
        if encoder_optimizer is not None:
            encoder_optimizer.zero_grad()
        # compute gradients for every (trainable) parameter    
        loss.backward()

        # Clip gradients
        if grad_clip is not None:
            clip_gradient(decoder_optimizer, grad_clip)
            if encoder_optimizer is not None:
                clip_gradient(encoder_optimizer, grad_clip)

        # Update weights
        # step updates the parameters with the gradient
        decoder_optimizer.step()
        if encoder_optimizer is not None:
            encoder_optimizer.step()

        # Keep track of metrics
        top5 = accuracy(scores, targets, 5)
        losses.update(loss.item(), sum(decode_lengths))
        top5accs.update(top5, sum(decode_lengths))
        batch_time.update(time.time() - start)

        start = time.time()

        # Print status
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),
                                                                          batch_time=batch_time,
                                                                          data_time=data_time, loss=losses,
                                                                          top5=top5accs))

    writer.add_scalar('Train/Loss', loss.item(), epoch)
    writer.flush()

def validate(val_loader, encoder, decoder, criterion):
    """
    Performs one epoch's validation.
    :param val_loader: DataLoader for validation data.
    :param encoder: encoder model
    :param decoder: decoder model
    :param criterion: loss layer
    :return: BLEU-4 score
    """
    decoder.eval()  # eval mode (no dropout or batchnorm)
    if encoder is not None:
        encoder.eval()

    batch_time = AverageMeter()
    losses = AverageMeter()
    top5accs = AverageMeter()

    start = time.time()

    references = list()  # references (true captions) for calculating BLEU-4 score
    hypotheses = list()  # hypotheses (predictions)

    # explicitly disable gradient calculation to avoid CUDA memory error
    # solves the issue #57
    with torch.no_grad():
        # Batches
        for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader):

            # Move to device, if available
            imgs = imgs.to(device)
            caps = caps.to(device)
            caplens = caplens.to(device)

            # Forward prop.
            if encoder is not None:
                imgs = encoder(imgs)
            scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

            # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
            targets = caps_sorted[:, 1:]

            # Remove timesteps that we didn't decode at, or are pads
            # pack_padded_sequence is an easy trick to do this
            scores_copy = scores.clone()
            scores = pack_padded_sequence(scores, decode_lengths, batch_first=True).data
            targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data

            # Calculate loss
            loss = criterion(scores, targets)

            # Add doubly stochastic attention regularization
            loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

            # Keep track of metrics
            losses.update(loss.item(), sum(decode_lengths))
            top5 = accuracy(scores, targets, 5)
            top5accs.update(top5, sum(decode_lengths))
            batch_time.update(time.time() - start)

            start = time.time()

            if i % print_freq == 0:
                print('Validation: [{0}/{1}]\t'
                      'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,
                                                                                loss=losses, top5=top5accs))

            # Store references (true captions), and hypothesis (prediction) for each image
            # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
            # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...] (one hypotesis per image, k true captions per image)

            # References
            allcaps = allcaps[sort_ind]  # because images were sorted in the decoder
            for j in range(allcaps.shape[0]):
                img_caps = allcaps[j].tolist()
                img_captions = list(
                    map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}],
                        img_caps))  # remove <start> and pads
                references.append(img_captions)
            print('references')
            print(references)
            # Hypotheses
            _, preds = torch.max(scores_copy, dim=2)
            preds = preds.tolist()
            temp_preds = list()
            for j, p in enumerate(preds):
                temp_preds.append(preds[j][:decode_lengths[j]])  # remove pads
            preds = temp_preds
            hypotheses.extend(preds[0])
            print('hypotesis')
            print(hypotheses)
            assert len(references) == len(hypotheses)

        # Calculate BLEU-4 scores
        bleu4 = corpus_bleu(references, hypotheses)

        print(
            '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(
                loss=losses,
                top5=top5accs,
                bleu=bleu4))

    return bleu4

In [9]:
main()

Tensorboard is recording into folder: ./logs/20200402_111016
word map len: 9490
Epoch: [0][0/3541]	Batch Time 4.959 (4.959)	Data Load Time 0.559 (0.559)	Loss 10.1236 (10.1236)	Top-5 Accuracy 0.000 (0.000)
Epoch: [0][5/3541]	Batch Time 1.584 (2.062)	Data Load Time 0.000 (0.093)	Loss 8.1364 (8.9190)	Top-5 Accuracy 29.608 (21.844)
Epoch: [0][10/3541]	Batch Time 1.452 (1.781)	Data Load Time 0.000 (0.051)	Loss 7.4783 (8.3537)	Top-5 Accuracy 32.406 (26.210)
Epoch: [0][15/3541]	Batch Time 1.403 (1.672)	Data Load Time 0.000 (0.035)	Loss 7.0378 (7.9795)	Top-5 Accuracy 33.442 (28.276)
Epoch: [0][20/3541]	Batch Time 1.390 (1.611)	Data Load Time 0.000 (0.027)	Loss 6.7381 (7.7105)	Top-5 Accuracy 34.330 (29.573)
Epoch: [0][25/3541]	Batch Time 1.402 (1.573)	Data Load Time 0.000 (0.022)	Loss 6.5660 (7.5078)	Top-5 Accuracy 33.333 (30.257)
Epoch: [0][30/3541]	Batch Time 1.431 (1.553)	Data Load Time 0.000 (0.018)	Loss 6.3480 (7.3321)	Top-5 Accuracy 35.522 (30.932)
Epoch: [0][35/3541]	Batch Time 1.405 (1.

Epoch: [0][325/3541]	Batch Time 1.431 (1.448)	Data Load Time 0.000 (0.002)	Loss 5.0022 (5.6496)	Top-5 Accuracy 50.756 (43.669)
Epoch: [0][330/3541]	Batch Time 1.415 (1.447)	Data Load Time 0.000 (0.002)	Loss 4.9494 (5.6395)	Top-5 Accuracy 51.054 (43.781)
Epoch: [0][335/3541]	Batch Time 1.421 (1.447)	Data Load Time 0.000 (0.002)	Loss 4.9637 (5.6299)	Top-5 Accuracy 51.961 (43.884)
Epoch: [0][340/3541]	Batch Time 1.411 (1.447)	Data Load Time 0.000 (0.002)	Loss 5.0225 (5.6205)	Top-5 Accuracy 50.513 (43.984)
Epoch: [0][345/3541]	Batch Time 1.418 (1.447)	Data Load Time 0.000 (0.002)	Loss 4.8610 (5.6110)	Top-5 Accuracy 52.533 (44.095)
Epoch: [0][350/3541]	Batch Time 1.402 (1.447)	Data Load Time 0.000 (0.002)	Loss 4.8056 (5.6016)	Top-5 Accuracy 53.333 (44.198)
Epoch: [0][355/3541]	Batch Time 1.442 (1.447)	Data Load Time 0.000 (0.002)	Loss 5.1522 (5.5936)	Top-5 Accuracy 49.172 (44.281)
Epoch: [0][360/3541]	Batch Time 1.412 (1.447)	Data Load Time 0.000 (0.002)	Loss 4.8962 (5.5857)	Top-5 Accuracy 

Epoch: [0][650/3541]	Batch Time 1.486 (1.450)	Data Load Time 0.000 (0.001)	Loss 4.8542 (5.2411)	Top-5 Accuracy 53.142 (48.311)
Epoch: [0][655/3541]	Batch Time 1.480 (1.450)	Data Load Time 0.000 (0.001)	Loss 4.5626 (5.2368)	Top-5 Accuracy 56.472 (48.359)
Epoch: [0][660/3541]	Batch Time 1.411 (1.450)	Data Load Time 0.000 (0.001)	Loss 4.5602 (5.2326)	Top-5 Accuracy 56.578 (48.411)
Epoch: [0][665/3541]	Batch Time 1.449 (1.451)	Data Load Time 0.000 (0.001)	Loss 4.6850 (5.2284)	Top-5 Accuracy 54.886 (48.460)
Epoch: [0][670/3541]	Batch Time 1.470 (1.451)	Data Load Time 0.000 (0.001)	Loss 4.7643 (5.2246)	Top-5 Accuracy 54.841 (48.508)
Epoch: [0][675/3541]	Batch Time 1.466 (1.452)	Data Load Time 0.000 (0.001)	Loss 4.7690 (5.2212)	Top-5 Accuracy 53.947 (48.552)
Epoch: [0][680/3541]	Batch Time 1.403 (1.452)	Data Load Time 0.000 (0.001)	Loss 4.6395 (5.2173)	Top-5 Accuracy 54.741 (48.598)
Epoch: [0][685/3541]	Batch Time 1.438 (1.451)	Data Load Time 0.000 (0.001)	Loss 4.6240 (5.2130)	Top-5 Accuracy 

Epoch: [0][975/3541]	Batch Time 1.427 (1.450)	Data Load Time 0.000 (0.001)	Loss 4.4313 (5.0211)	Top-5 Accuracy 57.413 (50.953)
Epoch: [0][980/3541]	Batch Time 1.613 (1.450)	Data Load Time 0.000 (0.001)	Loss 4.5598 (5.0184)	Top-5 Accuracy 56.816 (50.988)
Epoch: [0][985/3541]	Batch Time 1.410 (1.450)	Data Load Time 0.000 (0.001)	Loss 4.3819 (5.0154)	Top-5 Accuracy 58.122 (51.026)
Epoch: [0][990/3541]	Batch Time 1.486 (1.450)	Data Load Time 0.000 (0.001)	Loss 4.5697 (5.0126)	Top-5 Accuracy 57.428 (51.063)
Epoch: [0][995/3541]	Batch Time 1.499 (1.450)	Data Load Time 0.000 (0.001)	Loss 4.5343 (5.0098)	Top-5 Accuracy 57.221 (51.096)
Epoch: [0][1000/3541]	Batch Time 1.409 (1.449)	Data Load Time 0.000 (0.001)	Loss 4.5826 (5.0073)	Top-5 Accuracy 56.999 (51.128)
Epoch: [0][1005/3541]	Batch Time 1.411 (1.449)	Data Load Time 0.000 (0.001)	Loss 4.5005 (5.0048)	Top-5 Accuracy 56.474 (51.157)
Epoch: [0][1010/3541]	Batch Time 1.409 (1.449)	Data Load Time 0.000 (0.001)	Loss 4.3733 (5.0018)	Top-5 Accura

Epoch: [0][1300/3541]	Batch Time 1.413 (1.448)	Data Load Time 0.000 (0.001)	Loss 4.4612 (4.8688)	Top-5 Accuracy 57.319 (52.851)
Epoch: [0][1305/3541]	Batch Time 1.435 (1.448)	Data Load Time 0.000 (0.001)	Loss 4.2229 (4.8670)	Top-5 Accuracy 60.708 (52.876)
Epoch: [0][1310/3541]	Batch Time 1.382 (1.448)	Data Load Time 0.000 (0.001)	Loss 4.3336 (4.8652)	Top-5 Accuracy 60.976 (52.899)
Epoch: [0][1315/3541]	Batch Time 1.416 (1.448)	Data Load Time 0.000 (0.001)	Loss 4.3955 (4.8631)	Top-5 Accuracy 58.734 (52.922)
Epoch: [0][1320/3541]	Batch Time 1.452 (1.448)	Data Load Time 0.000 (0.001)	Loss 4.4088 (4.8613)	Top-5 Accuracy 58.410 (52.947)
Epoch: [0][1325/3541]	Batch Time 1.391 (1.447)	Data Load Time 0.000 (0.001)	Loss 4.1129 (4.8591)	Top-5 Accuracy 62.358 (52.974)
Epoch: [0][1330/3541]	Batch Time 1.436 (1.448)	Data Load Time 0.000 (0.001)	Loss 4.4303 (4.8570)	Top-5 Accuracy 58.470 (53.000)
Epoch: [0][1335/3541]	Batch Time 1.410 (1.448)	Data Load Time 0.000 (0.001)	Loss 4.2943 (4.8551)	Top-5 A

Epoch: [0][1625/3541]	Batch Time 1.445 (1.447)	Data Load Time 0.000 (0.001)	Loss 4.3002 (4.7566)	Top-5 Accuracy 59.784 (54.269)
Epoch: [0][1630/3541]	Batch Time 1.396 (1.447)	Data Load Time 0.000 (0.001)	Loss 4.2857 (4.7551)	Top-5 Accuracy 60.123 (54.288)
Epoch: [0][1635/3541]	Batch Time 1.399 (1.447)	Data Load Time 0.000 (0.001)	Loss 4.1101 (4.7537)	Top-5 Accuracy 61.669 (54.305)
Epoch: [0][1640/3541]	Batch Time 1.413 (1.447)	Data Load Time 0.000 (0.001)	Loss 4.3050 (4.7521)	Top-5 Accuracy 60.891 (54.326)
Epoch: [0][1645/3541]	Batch Time 1.589 (1.447)	Data Load Time 0.000 (0.001)	Loss 4.1734 (4.7507)	Top-5 Accuracy 62.003 (54.344)
Epoch: [0][1650/3541]	Batch Time 1.435 (1.447)	Data Load Time 0.000 (0.001)	Loss 4.3973 (4.7493)	Top-5 Accuracy 58.547 (54.362)
Epoch: [0][1655/3541]	Batch Time 1.436 (1.447)	Data Load Time 0.000 (0.001)	Loss 4.3154 (4.7479)	Top-5 Accuracy 60.454 (54.380)
Epoch: [0][1660/3541]	Batch Time 1.411 (1.447)	Data Load Time 0.000 (0.001)	Loss 4.2897 (4.7463)	Top-5 A

Epoch: [0][1950/3541]	Batch Time 1.429 (1.448)	Data Load Time 0.000 (0.001)	Loss 4.1048 (4.6689)	Top-5 Accuracy 62.411 (55.392)
Epoch: [0][1955/3541]	Batch Time 1.411 (1.448)	Data Load Time 0.000 (0.001)	Loss 4.2538 (4.6678)	Top-5 Accuracy 61.915 (55.406)
Epoch: [0][1960/3541]	Batch Time 1.482 (1.448)	Data Load Time 0.000 (0.001)	Loss 4.2106 (4.6666)	Top-5 Accuracy 60.524 (55.420)
Epoch: [0][1965/3541]	Batch Time 1.465 (1.448)	Data Load Time 0.000 (0.001)	Loss 4.2510 (4.6654)	Top-5 Accuracy 61.334 (55.436)
Epoch: [0][1970/3541]	Batch Time 1.405 (1.448)	Data Load Time 0.000 (0.001)	Loss 4.1469 (4.6643)	Top-5 Accuracy 61.102 (55.448)
Epoch: [0][1975/3541]	Batch Time 1.394 (1.448)	Data Load Time 0.000 (0.001)	Loss 4.1487 (4.6632)	Top-5 Accuracy 61.822 (55.460)
Epoch: [0][1980/3541]	Batch Time 1.399 (1.448)	Data Load Time 0.000 (0.001)	Loss 4.0911 (4.6619)	Top-5 Accuracy 62.575 (55.475)
Epoch: [0][1985/3541]	Batch Time 1.397 (1.448)	Data Load Time 0.000 (0.001)	Loss 4.1996 (4.6606)	Top-5 A

Epoch: [0][2275/3541]	Batch Time 1.423 (1.448)	Data Load Time 0.000 (0.000)	Loss 4.2210 (4.5974)	Top-5 Accuracy 62.089 (56.298)
Epoch: [0][2280/3541]	Batch Time 1.428 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.0741 (4.5964)	Top-5 Accuracy 61.886 (56.311)
Epoch: [0][2285/3541]	Batch Time 1.419 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.1994 (4.5954)	Top-5 Accuracy 61.028 (56.325)
Epoch: [0][2290/3541]	Batch Time 1.416 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.1276 (4.5945)	Top-5 Accuracy 62.943 (56.337)
Epoch: [0][2295/3541]	Batch Time 1.416 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.2654 (4.5936)	Top-5 Accuracy 60.099 (56.348)
Epoch: [0][2300/3541]	Batch Time 1.399 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.0067 (4.5925)	Top-5 Accuracy 63.361 (56.362)
Epoch: [0][2305/3541]	Batch Time 1.394 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.1886 (4.5915)	Top-5 Accuracy 61.035 (56.374)
Epoch: [0][2310/3541]	Batch Time 1.423 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.1107 (4.5904)	Top-5 A

Epoch: [0][2600/3541]	Batch Time 1.498 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.0183 (4.5384)	Top-5 Accuracy 62.767 (57.054)
Epoch: [0][2605/3541]	Batch Time 1.428 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.0186 (4.5376)	Top-5 Accuracy 64.009 (57.065)
Epoch: [0][2610/3541]	Batch Time 1.493 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.0696 (4.5369)	Top-5 Accuracy 63.077 (57.075)
Epoch: [0][2615/3541]	Batch Time 1.422 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.1302 (4.5361)	Top-5 Accuracy 62.114 (57.085)
Epoch: [0][2620/3541]	Batch Time 1.463 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.2090 (4.5354)	Top-5 Accuracy 61.958 (57.094)
Epoch: [0][2625/3541]	Batch Time 1.388 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.0593 (4.5348)	Top-5 Accuracy 63.205 (57.103)
Epoch: [0][2630/3541]	Batch Time 1.426 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.0829 (4.5340)	Top-5 Accuracy 62.225 (57.113)
Epoch: [0][2635/3541]	Batch Time 1.500 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.0879 (4.5332)	Top-5 A

Epoch: [0][2925/3541]	Batch Time 1.453 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.0784 (4.4885)	Top-5 Accuracy 61.661 (57.685)
Epoch: [0][2930/3541]	Batch Time 1.446 (1.447)	Data Load Time 0.000 (0.000)	Loss 3.9381 (4.4877)	Top-5 Accuracy 65.523 (57.695)
Epoch: [0][2935/3541]	Batch Time 1.453 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.2678 (4.4870)	Top-5 Accuracy 60.670 (57.705)
Epoch: [0][2940/3541]	Batch Time 1.473 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.1516 (4.4862)	Top-5 Accuracy 62.843 (57.714)
Epoch: [0][2945/3541]	Batch Time 1.479 (1.447)	Data Load Time 0.000 (0.000)	Loss 3.9942 (4.4856)	Top-5 Accuracy 63.547 (57.723)
Epoch: [0][2950/3541]	Batch Time 1.574 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.0063 (4.4849)	Top-5 Accuracy 63.791 (57.732)
Epoch: [0][2955/3541]	Batch Time 1.466 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.0170 (4.4842)	Top-5 Accuracy 63.454 (57.741)
Epoch: [0][2960/3541]	Batch Time 1.477 (1.447)	Data Load Time 0.000 (0.000)	Loss 4.0411 (4.4835)	Top-5 A

Epoch: [0][3250/3541]	Batch Time 1.438 (1.446)	Data Load Time 0.000 (0.000)	Loss 4.0283 (4.4444)	Top-5 Accuracy 63.596 (58.249)
Epoch: [0][3255/3541]	Batch Time 1.416 (1.446)	Data Load Time 0.000 (0.000)	Loss 4.1291 (4.4438)	Top-5 Accuracy 62.670 (58.257)
Epoch: [0][3260/3541]	Batch Time 1.492 (1.446)	Data Load Time 0.000 (0.000)	Loss 4.1914 (4.4432)	Top-5 Accuracy 62.228 (58.264)
Epoch: [0][3265/3541]	Batch Time 1.401 (1.446)	Data Load Time 0.000 (0.000)	Loss 4.0384 (4.4426)	Top-5 Accuracy 63.382 (58.272)
Epoch: [0][3270/3541]	Batch Time 1.453 (1.446)	Data Load Time 0.000 (0.000)	Loss 4.0452 (4.4420)	Top-5 Accuracy 62.352 (58.280)
Epoch: [0][3275/3541]	Batch Time 1.438 (1.446)	Data Load Time 0.000 (0.000)	Loss 3.9902 (4.4414)	Top-5 Accuracy 64.244 (58.287)
Epoch: [0][3280/3541]	Batch Time 1.427 (1.446)	Data Load Time 0.000 (0.000)	Loss 4.0363 (4.4408)	Top-5 Accuracy 64.348 (58.295)
Epoch: [0][3285/3541]	Batch Time 1.455 (1.446)	Data Load Time 0.000 (0.000)	Loss 4.0352 (4.4401)	Top-5 A

AssertionError: 