In [1]:
import time
import torch
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from nltk.translate.bleu_score import corpus_bleu
from torch.utils.data import Dataset
import h5py
import json
import os



In [2]:
from model import Encoder,Attention,DecoderWithAttention
from utils import clip_gradient,adjust_learning_rate,accuracy,save_checkpoint,AverageMeter,CaptionDataset

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
!nvidia-smi

Sun Apr 17 18:19:37 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.91.03    Driver Version: 460.91.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla M40 24GB      Off  | 00000000:04:00.0 Off |                    0 |
| N/A   37C    P8    15W / 250W |      3MiB / 22945MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla M40 24GB      Off  | 00000000:82:00.0 Off |                    0 |
| N/A   39C    P8    17W / 250W |      3MiB / 22945MiB |      0%      Defaul

TRAINING AND VALIDATION

In [4]:
data_folder = 'dataset/'  # folder with data files saved by create_input_files.py
data_name = 'coco_5_cap_per_img_5_min_word_freq'  # base name shared by data files

In [5]:
# Model parameters
emb_dim = 512  # dimension of word embeddings
attention_dim = 512  # dimension of attention linear layers
decoder_dim = 512  # dimension of decoder RNN
dropout = 0.5
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

In [6]:
# Training parameters
start_epoch = 0
epochs = 120  # number of epochs to train for (if early stopping is not triggered)
epochs_since_improvement = 0  # keeps track of number of epochs since there's been an improvement in validation BLEU
batch_size = 32
workers = 1  # for data-loading; right now, only 1 works with h5py
encoder_lr = 1e-4  # learning rate for encoder if fine-tuning
decoder_lr = 4e-4  # learning rate for decoder
grad_clip = 5.  # clip gradients at an absolute value of
alpha_c = 1.  # regularization parameter for 'doubly stochastic attention', as in the paper
best_bleu4 = 0.  # BLEU-4 score right now
print_freq = 100  # print training/validation stats every __ batches
fine_tune_encoder = False  # fine-tune encoder?
checkpoint = None  # path to checkpoint, None if none

In [7]:
def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch):
    decoder.train()  # train mode (dropout and batchnorm is used)
    encoder.train()

    batch_time = AverageMeter()  # forward prop. + back prop. time
    data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss (per word decoded)
    top5accs = AverageMeter()  # top5 accuracy

    start = time.time()

    # Batches
    for i, (imgs, caps, caplens) in enumerate(train_loader):
        data_time.update(time.time() - start)

        # Move to GPU, if available
        imgs = imgs.to(device)
        caps = caps.to(device)
        caplens = caplens.to(device)

        # Forward prop.
        imgs = encoder(imgs)
        scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

        # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
        targets = caps_sorted[:, 1:]

        # Remove timesteps that we didn't decode at, or are pads
        # pack_padded_sequence is an easy trick to do this
        
        scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)[0]
        targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)[0]

        # Calculate loss
        loss = criterion(scores, targets)

        # Add doubly stochastic attention regularization
        loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

        # Back prop.
        decoder_optimizer.zero_grad()
        if encoder_optimizer is not None:
            encoder_optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        if grad_clip is not None:
            clip_gradient(decoder_optimizer, grad_clip)
            if encoder_optimizer is not None:
                clip_gradient(encoder_optimizer, grad_clip)

        # Update weights
        decoder_optimizer.step()
        if encoder_optimizer is not None:
            encoder_optimizer.step()

        # Keep track of metrics
        top5 = accuracy(scores, targets, 5)
        losses.update(loss.item(), sum(decode_lengths))
        top5accs.update(top5, sum(decode_lengths))
        batch_time.update(time.time() - start)

        start = time.time()

        # Print status
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),
                                                                          batch_time=batch_time,
                                                                          data_time=data_time, loss=losses,
                                                                          top5=top5accs))

In [8]:
def validate(val_loader, encoder, decoder, criterion):
    """
    Performs one epoch's validation.

    :param val_loader: DataLoader for validation data.
    :param encoder: encoder model
    :param decoder: decoder model
    :param criterion: loss layer
    :return: BLEU-4 score
    """
    decoder.eval()  # eval mode (no dropout or batchnorm)
    if encoder is not None:
        encoder.eval()

    batch_time = AverageMeter()
    losses = AverageMeter()
    top5accs = AverageMeter()

    start = time.time()

    references = list()  # references (true captions) for calculating BLEU-4 score
    hypotheses = list()  # hypotheses (predictions)

    # explicitly disable gradient calculation to avoid CUDA memory error
    # solves the issue #57
    with torch.no_grad():
        # Batches
        for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader):

            # Move to device, if available
            imgs = imgs.to(device)
            caps = caps.to(device)
            caplens = caplens.to(device)

            # Forward prop.
            if encoder is not None:
                imgs = encoder(imgs)
            scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

            # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
            targets = caps_sorted[:, 1:]

            # Remove timesteps that we didn't decode at, or are pads
            # pack_padded_sequence is an easy trick to do this
            scores_copy = scores.clone()
            scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)[0]
            targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)[0]

            # Calculate loss
            loss = criterion(scores, targets)

            # Add doubly stochastic attention regularization
            loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

            # Keep track of metrics
            losses.update(loss.item(), sum(decode_lengths))
            top5 = accuracy(scores, targets, 5)
            top5accs.update(top5, sum(decode_lengths))
            batch_time.update(time.time() - start)

            start = time.time()

            if i % print_freq == 0:
                print('Validation: [{0}/{1}]\t'
                      'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,
                                                                                loss=losses, top5=top5accs))

            # Store references (true captions), and hypothesis (prediction) for each image
            # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
            # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]

            # References
            allcaps = allcaps[sort_ind]  # because images were sorted in the decoder
            for j in range(allcaps.shape[0]):
                img_caps = allcaps[j].tolist()
                img_captions = list(
                    map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}],
                        img_caps))  # remove <start> and pads
                references.append(img_captions)

            # Hypotheses
            _, preds = torch.max(scores_copy, dim=2)
            preds = preds.tolist()
            temp_preds = list()
            for j, p in enumerate(preds):
                temp_preds.append(preds[j][:decode_lengths[j]])  # remove pads
            preds = temp_preds
            hypotheses.extend(preds)

            assert len(references) == len(hypotheses)

        # Calculate BLEU-4 scores
        bleu4 = corpus_bleu(references, hypotheses)

        print(
            '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(
                loss=losses,
                top5=top5accs,
                bleu=bleu4))

    return bleu4

In [9]:
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)
        
    checkpoint = "checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar"

    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                             lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        state = {'epoch': epoch,
             'epochs_since_improvement': epochs_since_improvement,
             'bleu-4': recent_bleu4,
             'encoder': encoder,
             'decoder': decoder,
             'encoder_optimizer': encoder_optimizer,
             'decoder_optimizer': decoder_optimizer}
        filename = 'checkpoint_' + data_name + '.pth.tar'
        torch.save(state, filename)
        # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint
        if is_best:
            torch.save(state, 'BEST_' + filename)

In [None]:
main()

Epoch: [28][0/17702]	Batch Time 2.455 (2.455)	Data Load Time 0.965 (0.965)	Loss 3.4412 (3.4412)	Top-5 Accuracy 71.316 (71.316)
Epoch: [28][100/17702]	Batch Time 0.483 (0.476)	Data Load Time 0.001 (0.014)	Loss 3.6596 (3.3973)	Top-5 Accuracy 67.989 (71.704)
Epoch: [28][200/17702]	Batch Time 0.432 (0.468)	Data Load Time 0.001 (0.009)	Loss 3.2963 (3.3828)	Top-5 Accuracy 77.067 (71.959)
Epoch: [28][300/17702]	Batch Time 0.459 (0.467)	Data Load Time 0.026 (0.008)	Loss 3.1688 (3.3789)	Top-5 Accuracy 75.798 (71.979)
Epoch: [28][400/17702]	Batch Time 0.469 (0.464)	Data Load Time 0.001 (0.007)	Loss 3.3081 (3.3772)	Top-5 Accuracy 73.278 (71.949)
Epoch: [28][500/17702]	Batch Time 0.461 (0.464)	Data Load Time 0.001 (0.006)	Loss 3.4517 (3.3759)	Top-5 Accuracy 69.672 (71.923)
Epoch: [28][600/17702]	Batch Time 0.439 (0.462)	Data Load Time 0.001 (0.006)	Loss 3.2826 (3.3765)	Top-5 Accuracy 73.879 (71.939)
Epoch: [28][700/17702]	Batch Time 0.453 (0.462)	Data Load Time 0.001 (0.006)	Loss 3.5785 (3.3771)	T

Epoch: [28][6400/17702]	Batch Time 0.412 (0.465)	Data Load Time 0.001 (0.005)	Loss 3.7172 (3.3933)	Top-5 Accuracy 68.232 (71.763)
Epoch: [28][6500/17702]	Batch Time 0.438 (0.464)	Data Load Time 0.001 (0.005)	Loss 3.3531 (3.3937)	Top-5 Accuracy 71.467 (71.755)
Epoch: [28][6600/17702]	Batch Time 0.454 (0.465)	Data Load Time 0.001 (0.005)	Loss 3.5464 (3.3939)	Top-5 Accuracy 69.565 (71.753)
Epoch: [28][6700/17702]	Batch Time 0.591 (0.465)	Data Load Time 0.001 (0.005)	Loss 3.4613 (3.3944)	Top-5 Accuracy 72.368 (71.752)
Epoch: [28][6800/17702]	Batch Time 0.412 (0.465)	Data Load Time 0.001 (0.005)	Loss 3.3663 (3.3947)	Top-5 Accuracy 70.492 (71.744)
Epoch: [28][6900/17702]	Batch Time 0.473 (0.465)	Data Load Time 0.023 (0.005)	Loss 3.6332 (3.3946)	Top-5 Accuracy 70.670 (71.745)
Epoch: [28][7000/17702]	Batch Time 0.440 (0.465)	Data Load Time 0.014 (0.005)	Loss 3.3661 (3.3948)	Top-5 Accuracy 72.141 (71.746)
Epoch: [28][7100/17702]	Batch Time 0.427 (0.465)	Data Load Time 0.014 (0.005)	Loss 3.5393 

Epoch: [28][12700/17702]	Batch Time 0.477 (0.466)	Data Load Time 0.001 (0.005)	Loss 3.4801 (3.4080)	Top-5 Accuracy 71.875 (71.605)
Epoch: [28][12800/17702]	Batch Time 0.502 (0.466)	Data Load Time 0.001 (0.005)	Loss 3.3503 (3.4080)	Top-5 Accuracy 71.979 (71.605)
Epoch: [28][12900/17702]	Batch Time 0.424 (0.466)	Data Load Time 0.001 (0.005)	Loss 3.3928 (3.4081)	Top-5 Accuracy 70.604 (71.603)
Epoch: [28][13000/17702]	Batch Time 0.609 (0.466)	Data Load Time 0.013 (0.005)	Loss 3.8377 (3.4083)	Top-5 Accuracy 64.894 (71.602)
Epoch: [28][13100/17702]	Batch Time 0.567 (0.466)	Data Load Time 0.018 (0.005)	Loss 3.5959 (3.4085)	Top-5 Accuracy 68.564 (71.599)
Epoch: [28][13200/17702]	Batch Time 0.467 (0.466)	Data Load Time 0.001 (0.005)	Loss 3.3311 (3.4087)	Top-5 Accuracy 70.115 (71.595)
Epoch: [28][13300/17702]	Batch Time 0.432 (0.466)	Data Load Time 0.001 (0.005)	Loss 3.8168 (3.4089)	Top-5 Accuracy 61.408 (71.593)
Epoch: [28][13400/17702]	Batch Time 0.485 (0.466)	Data Load Time 0.001 (0.005)	Loss

Epoch: [29][500/17702]	Batch Time 0.466 (0.470)	Data Load Time 0.022 (0.007)	Loss 3.3475 (3.3554)	Top-5 Accuracy 71.925 (72.142)
Epoch: [29][600/17702]	Batch Time 0.699 (0.469)	Data Load Time 0.014 (0.006)	Loss 3.4817 (3.3533)	Top-5 Accuracy 73.250 (72.209)
Epoch: [29][700/17702]	Batch Time 0.441 (0.470)	Data Load Time 0.014 (0.006)	Loss 3.2073 (3.3600)	Top-5 Accuracy 72.581 (72.114)
Epoch: [29][800/17702]	Batch Time 0.520 (0.470)	Data Load Time 0.015 (0.006)	Loss 3.3677 (3.3596)	Top-5 Accuracy 71.676 (72.149)
Epoch: [29][900/17702]	Batch Time 0.407 (0.469)	Data Load Time 0.001 (0.006)	Loss 3.3352 (3.3601)	Top-5 Accuracy 70.933 (72.116)
Epoch: [29][1000/17702]	Batch Time 0.488 (0.470)	Data Load Time 0.001 (0.006)	Loss 3.3963 (3.3608)	Top-5 Accuracy 73.280 (72.111)
Epoch: [29][1100/17702]	Batch Time 0.429 (0.468)	Data Load Time 0.025 (0.005)	Loss 3.2554 (3.3622)	Top-5 Accuracy 74.521 (72.085)
Epoch: [29][1200/17702]	Batch Time 0.491 (0.468)	Data Load Time 0.001 (0.005)	Loss 3.3243 (3.36

Epoch: [29][6900/17702]	Batch Time 0.400 (0.468)	Data Load Time 0.001 (0.005)	Loss 3.3572 (3.3830)	Top-5 Accuracy 72.500 (71.833)
Epoch: [29][7000/17702]	Batch Time 0.438 (0.467)	Data Load Time 0.001 (0.005)	Loss 3.5519 (3.3831)	Top-5 Accuracy 70.170 (71.831)
Epoch: [29][7100/17702]	Batch Time 0.546 (0.467)	Data Load Time 0.001 (0.005)	Loss 3.3567 (3.3833)	Top-5 Accuracy 72.358 (71.830)
Epoch: [29][7200/17702]	Batch Time 0.419 (0.467)	Data Load Time 0.008 (0.005)	Loss 3.4106 (3.3837)	Top-5 Accuracy 73.425 (71.824)
Epoch: [29][7300/17702]	Batch Time 0.441 (0.468)	Data Load Time 0.001 (0.005)	Loss 3.5039 (3.3840)	Top-5 Accuracy 71.698 (71.823)
Epoch: [29][7400/17702]	Batch Time 0.515 (0.468)	Data Load Time 0.013 (0.005)	Loss 3.5752 (3.3842)	Top-5 Accuracy 69.663 (71.821)
Epoch: [29][7500/17702]	Batch Time 0.420 (0.468)	Data Load Time 0.001 (0.005)	Loss 3.3751 (3.3845)	Top-5 Accuracy 73.407 (71.820)
Epoch: [29][7600/17702]	Batch Time 0.578 (0.468)	Data Load Time 0.001 (0.005)	Loss 3.2489 

Epoch: [29][13400/17702]	Batch Time 0.450 (0.468)	Data Load Time 0.001 (0.005)	Loss 3.6869 (3.3970)	Top-5 Accuracy 68.824 (71.700)
Epoch: [29][13500/17702]	Batch Time 0.484 (0.468)	Data Load Time 0.001 (0.005)	Loss 3.1979 (3.3970)	Top-5 Accuracy 75.130 (71.702)
Epoch: [29][13600/17702]	Batch Time 0.415 (0.468)	Data Load Time 0.001 (0.005)	Loss 3.4920 (3.3972)	Top-5 Accuracy 70.868 (71.700)
Epoch: [29][13700/17702]	Batch Time 0.560 (0.468)	Data Load Time 0.014 (0.005)	Loss 3.6799 (3.3973)	Top-5 Accuracy 71.348 (71.700)
Epoch: [29][13800/17702]	Batch Time 0.406 (0.468)	Data Load Time 0.001 (0.005)	Loss 3.3975 (3.3974)	Top-5 Accuracy 72.394 (71.699)
Epoch: [29][13900/17702]	Batch Time 0.489 (0.468)	Data Load Time 0.001 (0.005)	Loss 3.7040 (3.3975)	Top-5 Accuracy 67.095 (71.696)
Epoch: [29][14000/17702]	Batch Time 0.489 (0.468)	Data Load Time 0.001 (0.005)	Loss 3.5084 (3.3976)	Top-5 Accuracy 68.338 (71.695)
Epoch: [29][14100/17702]	Batch Time 0.425 (0.468)	Data Load Time 0.013 (0.005)	Loss

Epoch: [30][1200/17702]	Batch Time 0.429 (0.473)	Data Load Time 0.001 (0.006)	Loss 3.1872 (3.3655)	Top-5 Accuracy 71.468 (72.004)
Epoch: [30][1300/17702]	Batch Time 0.508 (0.472)	Data Load Time 0.001 (0.006)	Loss 3.5292 (3.3647)	Top-5 Accuracy 69.210 (72.020)
Epoch: [30][1400/17702]	Batch Time 0.570 (0.472)	Data Load Time 0.001 (0.005)	Loss 3.3737 (3.3654)	Top-5 Accuracy 70.604 (72.019)
Epoch: [30][1500/17702]	Batch Time 0.415 (0.472)	Data Load Time 0.007 (0.005)	Loss 3.3214 (3.3629)	Top-5 Accuracy 74.432 (72.072)
Epoch: [30][1600/17702]	Batch Time 0.476 (0.472)	Data Load Time 0.014 (0.005)	Loss 3.4909 (3.3628)	Top-5 Accuracy 68.588 (72.086)
Epoch: [30][1700/17702]	Batch Time 0.511 (0.471)	Data Load Time 0.001 (0.005)	Loss 3.3533 (3.3635)	Top-5 Accuracy 68.994 (72.076)
Epoch: [30][1800/17702]	Batch Time 0.649 (0.471)	Data Load Time 0.018 (0.005)	Loss 3.4844 (3.3645)	Top-5 Accuracy 72.928 (72.060)
Epoch: [30][1900/17702]	Batch Time 0.400 (0.472)	Data Load Time 0.001 (0.005)	Loss 3.3639 

Epoch: [30][7600/17702]	Batch Time 0.461 (0.471)	Data Load Time 0.001 (0.005)	Loss 3.3378 (3.3796)	Top-5 Accuracy 72.647 (71.870)
Epoch: [30][7700/17702]	Batch Time 0.420 (0.471)	Data Load Time 0.001 (0.005)	Loss 3.2899 (3.3795)	Top-5 Accuracy 73.278 (71.875)
Epoch: [30][7800/17702]	Batch Time 0.474 (0.471)	Data Load Time 0.001 (0.005)	Loss 3.3092 (3.3798)	Top-5 Accuracy 74.413 (71.871)
Epoch: [30][7900/17702]	Batch Time 0.556 (0.471)	Data Load Time 0.010 (0.005)	Loss 3.5068 (3.3801)	Top-5 Accuracy 70.732 (71.872)
Epoch: [30][8000/17702]	Batch Time 0.418 (0.471)	Data Load Time 0.001 (0.005)	Loss 3.5846 (3.3802)	Top-5 Accuracy 71.429 (71.869)
Epoch: [30][8100/17702]	Batch Time 0.457 (0.470)	Data Load Time 0.001 (0.005)	Loss 3.3096 (3.3804)	Top-5 Accuracy 72.394 (71.868)
Epoch: [30][8200/17702]	Batch Time 0.483 (0.470)	Data Load Time 0.001 (0.005)	Loss 3.2419 (3.3807)	Top-5 Accuracy 74.859 (71.865)
Epoch: [30][8300/17702]	Batch Time 0.536 (0.470)	Data Load Time 0.014 (0.005)	Loss 3.4761 

Epoch: [30][13900/17702]	Batch Time 0.411 (0.470)	Data Load Time 0.010 (0.005)	Loss 3.6831 (3.3918)	Top-5 Accuracy 67.908 (71.739)
Epoch: [30][14000/17702]	Batch Time 0.500 (0.470)	Data Load Time 0.013 (0.005)	Loss 3.5739 (3.3919)	Top-5 Accuracy 70.213 (71.739)
Epoch: [30][14100/17702]	Batch Time 0.485 (0.470)	Data Load Time 0.001 (0.005)	Loss 3.3730 (3.3920)	Top-5 Accuracy 73.315 (71.738)
Epoch: [30][14200/17702]	Batch Time 0.546 (0.470)	Data Load Time 0.022 (0.005)	Loss 3.3212 (3.3921)	Top-5 Accuracy 70.380 (71.737)
Epoch: [30][14300/17702]	Batch Time 0.513 (0.470)	Data Load Time 0.001 (0.005)	Loss 3.2613 (3.3924)	Top-5 Accuracy 73.351 (71.734)
Epoch: [30][14400/17702]	Batch Time 0.488 (0.470)	Data Load Time 0.001 (0.005)	Loss 3.5354 (3.3927)	Top-5 Accuracy 70.141 (71.730)
Epoch: [30][14500/17702]	Batch Time 0.461 (0.470)	Data Load Time 0.001 (0.005)	Loss 3.7891 (3.3930)	Top-5 Accuracy 67.806 (71.728)
Epoch: [30][14600/17702]	Batch Time 0.556 (0.470)	Data Load Time 0.020 (0.005)	Loss

Epoch: [31][1700/17702]	Batch Time 0.600 (0.471)	Data Load Time 0.001 (0.006)	Loss 3.2987 (3.3577)	Top-5 Accuracy 71.169 (72.104)
Epoch: [31][1800/17702]	Batch Time 0.480 (0.471)	Data Load Time 0.001 (0.006)	Loss 3.4231 (3.3578)	Top-5 Accuracy 71.159 (72.100)
Epoch: [31][1900/17702]	Batch Time 0.446 (0.470)	Data Load Time 0.001 (0.006)	Loss 3.6757 (3.3574)	Top-5 Accuracy 68.169 (72.099)
Epoch: [31][2000/17702]	Batch Time 0.498 (0.471)	Data Load Time 0.023 (0.006)	Loss 3.4365 (3.3578)	Top-5 Accuracy 69.022 (72.102)
Epoch: [31][2100/17702]	Batch Time 0.465 (0.470)	Data Load Time 0.001 (0.006)	Loss 3.4707 (3.3574)	Top-5 Accuracy 70.426 (72.100)
Epoch: [31][2200/17702]	Batch Time 0.470 (0.470)	Data Load Time 0.013 (0.005)	Loss 3.5047 (3.3558)	Top-5 Accuracy 70.241 (72.126)
Epoch: [31][2300/17702]	Batch Time 0.516 (0.470)	Data Load Time 0.014 (0.005)	Loss 3.4097 (3.3550)	Top-5 Accuracy 70.543 (72.131)
Epoch: [31][2400/17702]	Batch Time 0.457 (0.470)	Data Load Time 0.019 (0.005)	Loss 3.6751 

Epoch: [31][8100/17702]	Batch Time 0.455 (0.472)	Data Load Time 0.001 (0.005)	Loss 3.3284 (3.3741)	Top-5 Accuracy 73.138 (71.881)
Epoch: [31][8200/17702]	Batch Time 0.454 (0.472)	Data Load Time 0.014 (0.005)	Loss 3.3993 (3.3741)	Top-5 Accuracy 71.745 (71.886)
Epoch: [31][8300/17702]	Batch Time 0.425 (0.472)	Data Load Time 0.014 (0.005)	Loss 3.4974 (3.3743)	Top-5 Accuracy 71.038 (71.886)
Epoch: [31][8400/17702]	Batch Time 0.391 (0.472)	Data Load Time 0.001 (0.005)	Loss 3.2055 (3.3745)	Top-5 Accuracy 74.493 (71.880)
Epoch: [31][8500/17702]	Batch Time 0.456 (0.472)	Data Load Time 0.001 (0.005)	Loss 3.4613 (3.3749)	Top-5 Accuracy 70.647 (71.872)
Epoch: [31][8600/17702]	Batch Time 0.425 (0.472)	Data Load Time 0.001 (0.005)	Loss 3.4944 (3.3755)	Top-5 Accuracy 71.538 (71.867)
Epoch: [31][8700/17702]	Batch Time 0.684 (0.472)	Data Load Time 0.001 (0.005)	Loss 3.4421 (3.3758)	Top-5 Accuracy 72.771 (71.863)
Epoch: [31][8800/17702]	Batch Time 0.451 (0.472)	Data Load Time 0.001 (0.005)	Loss 3.2434 

Epoch: [31][14400/17702]	Batch Time 0.458 (0.472)	Data Load Time 0.001 (0.005)	Loss 3.4326 (3.3876)	Top-5 Accuracy 71.186 (71.744)
Epoch: [31][14500/17702]	Batch Time 0.679 (0.472)	Data Load Time 0.014 (0.005)	Loss 3.6251 (3.3878)	Top-5 Accuracy 68.557 (71.743)
Epoch: [31][14600/17702]	Batch Time 0.490 (0.472)	Data Load Time 0.001 (0.005)	Loss 3.3675 (3.3880)	Top-5 Accuracy 71.159 (71.740)
Epoch: [31][14700/17702]	Batch Time 0.379 (0.472)	Data Load Time 0.001 (0.005)	Loss 3.4325 (3.3881)	Top-5 Accuracy 71.225 (71.740)
Epoch: [31][14800/17702]	Batch Time 0.399 (0.472)	Data Load Time 0.003 (0.005)	Loss 3.4848 (3.3883)	Top-5 Accuracy 74.221 (71.738)
Epoch: [31][14900/17702]	Batch Time 0.562 (0.472)	Data Load Time 0.001 (0.005)	Loss 3.7083 (3.3885)	Top-5 Accuracy 66.216 (71.736)
Epoch: [31][15000/17702]	Batch Time 0.456 (0.472)	Data Load Time 0.001 (0.005)	Loss 3.2379 (3.3889)	Top-5 Accuracy 75.198 (71.732)
Epoch: [31][15100/17702]	Batch Time 0.390 (0.471)	Data Load Time 0.001 (0.005)	Loss

Epoch: [32][2200/17702]	Batch Time 0.409 (0.472)	Data Load Time 0.001 (0.005)	Loss 3.5132 (3.3477)	Top-5 Accuracy 70.506 (72.235)
Epoch: [32][2300/17702]	Batch Time 0.415 (0.472)	Data Load Time 0.001 (0.005)	Loss 3.2976 (3.3476)	Top-5 Accuracy 73.121 (72.244)
Epoch: [32][2400/17702]	Batch Time 0.438 (0.472)	Data Load Time 0.001 (0.005)	Loss 3.5217 (3.3482)	Top-5 Accuracy 70.145 (72.233)
Epoch: [32][2500/17702]	Batch Time 0.406 (0.472)	Data Load Time 0.001 (0.005)	Loss 3.2402 (3.3487)	Top-5 Accuracy 72.853 (72.232)
Epoch: [32][2600/17702]	Batch Time 0.445 (0.472)	Data Load Time 0.001 (0.005)	Loss 3.2468 (3.3485)	Top-5 Accuracy 74.425 (72.234)
Epoch: [32][2700/17702]	Batch Time 0.399 (0.472)	Data Load Time 0.001 (0.005)	Loss 3.2874 (3.3491)	Top-5 Accuracy 72.829 (72.229)
Epoch: [32][2800/17702]	Batch Time 0.591 (0.471)	Data Load Time 0.014 (0.005)	Loss 3.3294 (3.3498)	Top-5 Accuracy 73.421 (72.219)
Epoch: [32][2900/17702]	Batch Time 0.531 (0.472)	Data Load Time 0.014 (0.005)	Loss 3.1424 

Epoch: [32][8600/17702]	Batch Time 0.471 (0.471)	Data Load Time 0.015 (0.005)	Loss 3.0938 (3.3551)	Top-5 Accuracy 73.867 (72.124)
Epoch: [32][8700/17702]	Batch Time 0.423 (0.471)	Data Load Time 0.001 (0.005)	Loss 3.5373 (3.3552)	Top-5 Accuracy 72.823 (72.123)
Epoch: [32][8800/17702]	Batch Time 0.443 (0.471)	Data Load Time 0.001 (0.005)	Loss 3.3232 (3.3554)	Top-5 Accuracy 73.067 (72.121)
Epoch: [32][8900/17702]	Batch Time 0.461 (0.471)	Data Load Time 0.001 (0.005)	Loss 3.5044 (3.3556)	Top-5 Accuracy 69.253 (72.119)
Epoch: [32][9000/17702]	Batch Time 0.414 (0.471)	Data Load Time 0.001 (0.005)	Loss 3.4195 (3.3558)	Top-5 Accuracy 73.729 (72.115)
Epoch: [32][9100/17702]	Batch Time 0.436 (0.471)	Data Load Time 0.001 (0.005)	Loss 3.5268 (3.3559)	Top-5 Accuracy 70.055 (72.113)
Epoch: [32][9200/17702]	Batch Time 0.442 (0.471)	Data Load Time 0.013 (0.005)	Loss 3.1236 (3.3561)	Top-5 Accuracy 74.309 (72.110)
Epoch: [32][9300/17702]	Batch Time 0.446 (0.471)	Data Load Time 0.001 (0.005)	Loss 3.3540 