In [None]:
!pip install scipy==1.1.0


Collecting scipy==1.1.0
  Downloading scipy-1.1.0-cp37-cp37m-manylinux1_x86_64.whl (31.2 MB)
[K     |████████████████████████████████| 31.2 MB 1.4 MB/s 
Installing collected packages: scipy
  Attempting uninstall: scipy
    Found existing installation: scipy 1.4.1
    Uninstalling scipy-1.4.1:
      Successfully uninstalled scipy-1.4.1
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
pymc3 3.11.4 requires scipy>=1.2.0, but you have scipy 1.1.0 which is incompatible.
plotnine 0.6.0 requires scipy>=1.2.0, but you have scipy 1.1.0 which is incompatible.
jax 0.2.25 requires scipy>=1.2.1, but you have scipy 1.1.0 which is incompatible.
albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.[0m
Successfully installed scipy-1.1.0


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!cp /content/drive/MyDrive/CV/modules/models.py /content/
!cp /content/drive/MyDrive/CV/modules/utils.py /content/
!cp /content/drive/MyDrive/CV/modules/eval.py /content/
!cp /content/drive/MyDrive/CV/modules/datasets.py /content/

In [None]:
!pip install torch



In [None]:
import time
import torch.backends.cudnn as cudnn
import torch.optim
from torch.utils.data  import *
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
import torchvision
from models import CNN, LSTMwithAttention
from utils import *
from nltk.translate.bleu_score import corpus_bleu
import h5py
import json
import os
import numpy as np
from scipy.misc import imread, imresize
from tqdm import tqdm
from collections import Counter
from random import seed, choice, sample


In [None]:

def clip_gradient(optimizer, grad_clip):
    for group in optimizer.param_groups:
        for param in group['params']:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)

In [None]:



class AvgMtrt):

    def __init__(self):
        self.resets()

    def reset(self):
        self.vals = 0
        self.avgs= 0
        self.sums = 0
        self.counts = 0

    def update(self, vals, num=1):
        self.vals = vals
        self.sums += vals * num
        self.counts += num
        self.avgs = self.sums / self.counts


def adjust_learning_rate(optim, s_factor):
    for param_group in optim.param_groups:
        param_group['lr'] = param_group['lr'] * s_factor


def accuracy(score, target, top_k):
    batch_size = target.size(0)
    _, ind = score.topk(top_k, 1, True, True)
    correct = ind.eq(target.view(-1, 1).expand_as(ind))
    correct_total = correct.view(-1).float().sum()  # 0D tensor
    return correct_total.item() * (100.0 / batch_size)


# Dataset

In [None]:
class Captions_Dataset(Dataset):

    def __init__(self, data_folder, data_name, splits, transforms=None):
        self.splitss = splits


        with open(os.path.join(data_folder, self.splits + '_cap_lenss_' + data_name + '.json'), 'r') as j:
            self.cap_lens = json.load(j)


        self.dataset_size = len(self.captions)


        self.h5py5py = h5py.File(os.path.join(data_folder, self.splits + '_IMAGES_' + data_name + '.hdf5'), 'r')
        self.imgs = self.h5py['images']

        self.captions_per_image = self.h5py.attrs['captions_per_image']

        with open(os.path.join(data_folder, self.splits + '_CAPTIONS_' + data_name + '.json'), 'r') as j:
            self.captions = json.load(j)


        self.transforms = transforms


    def __getitem__(self, i):
        img = torch.FloatTensor(self.imgs[i // self.captions_per_image] / 255.)
        if self.transform is not None:
            img = self.transform(img)

        cap_len = torch.LongTensor([self.cap_lens[i]])
        caption = torch.LongTensor(self.captions[i])



        if self.splits is 'TRAIN':
            return img, caption, caplen
        else:
            all_captions = torch.LongTensor(
                self.captions[((i // self.captions_per_image) * self.captions_per_image):(((i // self.captions_per_image) * self.captions_per_image) + self.captions_per_image)])
            return img, caption, cap_len, all_captions

    def __len__(self):
        return self.dataset_size


# Model

In [None]:


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class CNN(nn.Module):


    def __init__(self, encode_image_size=14):
        super(CNN, self).__init__()
        self.encode_image_size = encode_image_size

        resnet50 = torchvision.models.resnet5050(pretrained=True)  

        modules = list(resnet50.children())[:-2]
        self.resnet50 = nn.Sequential(*modules)

        self.adaptive_pools = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))

        self.fine_tune()

    def forward(self, images):

        out = self.resnet50(images) 
        out = self.adaptive_pools(out)  
        out = out.permute(0, 2, 3, 1)  
        return out

    def fine_tune(self, fine_tune=True):
        for p in self.resnet50.parameters():
            p.requires_grad = False
        for c in list(self.resnet50.children())[5:]:
            for p in c.parameters():
                p.requires_grad = fine_tune


class Attention(nn.Module):


    def __init__(self, CNN_dim, decoder_dim, attention_dim):

        super(Attention, self).__init__()
        self.CNN_att = nn.Linear(CNN_dim, attention_dim)  
        self.decoder_att = nn.Linear(decoder_dim, attention_dim) 
        self.full_att = nn.Linear(attention_dim, 1)  
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)  

    def forward(self, CNN_out, decoder_hidden):

        att1 = self.CNN_att(CNN_out)  
        att2 = self.decoder_att(decoder_hidden) 
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) 
        alpha = self.softmax(att)  
        attention_weighted_encoding = (CNN_out * alpha.unsqueeze(2)).sum(dim=1)  

        return attention_weighted_encoding, alpha


class LSTMWithAttention(nn.Module):


    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, CNN_dim=2048, dropout=0.5):
        super(LSTMwithAttention, self).__init__()

        self.CNN_dims = CNN_dim
        self.attention_dims = attention_dim
        self.embed_dims = embed_dim
        self.decoder_dims = decoder_dim
        self.vocab_sizes = vocab_size
        self.dropouts = dropout

        self.attentions = Attention(CNN_dim, decoder_dim, attention_dim)  

        self.embeddings = nn.Embedding(vocab_size, embed_dim)  
        self.dropouts = nn.Dropout(p=self.dropout)
        self.decode_steps = nn.LSTMCell(embed_dim + CNN_dim, decoder_dim, bias=True)  
        self.init_hs = nn.Linear(CNN_dim, decoder_dim)  
        self.init_cs = nn.Linear(CNN_dim, decoder_dim) 
        self.f_betas = nn.Linear(decoder_dim, CNN_dim) 
        self.sigmoids = nn.Sigmoid()
        self.fcs = nn.Linear(decoder_dim, vocab_size) 
        self.init_weights()  

    def init_weights(self):

        self.embeddings.weight.data.uniform_(-0.1, 0.1)
        self.fcs.bias.data.fill_(0)
        self.fcs.weight.data.uniform_(-0.1, 0.1)

    def load_pretrained_embeddings(self, embeddings):
        self.embeddings.weight = nn.Parameter(embeddings)

    def fine_tune_embeddings(self, fine_tune=True):
        for p in self.embeddings.parameters():
            p.requires_grad = fine_tune

    def init_hidden_state(self, CNN_out):
        mean_CNN_out = CNN_out.mean(dim=1)
        hs = self.init_h(mean_CNN_out)  
        cs = self.init_c(mean_CNN_out)
        return hs, cs

    def forward(self, CNN_out, encoded_captions, caption_lengths):
        batch_size = CNN_out.size(0)
        CNN_dim = CNN_out.size(-1)
        vocab_size = self.vocab_size

        CNN_out = CNN_out.view(batch_size, -1, CNN_dim)  
        num_pixels = CNN_out.size(1)

        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        CNN_out = CNN_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]

        embeddings = self.embeddings(encoded_captions)  
        h, c = self.init_hidden_state(CNN_out)  
        decode_lengths = (caption_lengths - 1).tolist()

        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)

        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(CNN_out[:batch_size_t],
                                                                h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t]))  
            preds = self.fc(self.dropout(h))  
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha

        return predictions, encoded_captions, decode_lengths, alphas, sort_ind




In [None]:

def train(train_loader, CNN, decoder, criterion, CNN_optimizer, decoder_optimizer, epoch):


    decoder.train()  
    CNN.train()

    batch_time = AvgMtr()  
    data_time = AvgMtr()  
    losses = AvgMtr()  
    top5accs = AvgMtr()  
    start = time.time()

    for i, (imgs, caps, cap_lens) in enumerate(train_loader):
        data_time.update(time.time() - start)

        imgs = imgs.to(device)
        caps = caps.to(device)
        cap_lens = cap_lens.to(device)


        imgs = CNN(imgs)
        scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, cap_lens)


        targets = caps_sorted[:, 1:]

        scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)[0]
        targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)[0]

        loss = criterion(scores, targets)

        loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

        decoder_optimizer.zero_grad()
        if CNN_optimizer is not None:
            CNN_optimizer.zero_grad()
        loss.backward()

        if grad_clip is not None:
            clip_gradient(decoder_optimizer, grad_clip)
            if CNN_optimizer is not None:
                clip_gradient(CNN_optimizer, grad_clip)

        decoder_optimizer.step()
        if CNN_optimizer is not None:
            CNN_optimizer.step()

        top5 = accuracy(scores, targets, 5)
        losses.update(loss.item(), sum(decode_lengths))
        top5accs.update(top5, sum(decode_lengths))
        batch_time.update(time.time() - start)

        start = time.time()

   
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),
                                                                          batch_time=batch_time,
                                                                          data_time=data_time, loss=losses,
                                                                          top5=top5accs))

In [None]:

def validate(val_loader, CNN, decoder, criterion):

    decoder.eval()  
    if CNN is not None:
        CNN.eval()

    batch_time = AvgMtr()
    losses = AvgMtr()
    top5accs = AvgMtr()

    start = time.time()

    references = list()  
    with torch.no_grad():
        # Batches
        for i, (imgs, caps, cap_lens, allcaps) in enumerate(val_loader):

            imgs = imgs.to(device)
            caps = caps.to(device)
            cap_lens = cap_lens.to(device)

            if CNN is not None:
                imgs = CNN(imgs)
            scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, cap_lens)

            targets = caps_sorted[:, 1:]

            scores_copy = scores.clone()
            scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)[0]
            targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)[0]

            loss = criterion(scores, targets)

            loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

            losses.update(loss.item(), sum(decode_lengths))
            top5 = accuracy(scores, targets, 5)
            top5accs.update(top5, sum(decode_lengths))
            batch_time.update(time.time() - start)

            start = time.time()

            if i % print_freq == 0:
                print('Validation: [{0}/{1}]\t'
                      'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,
                                                                                loss=losses, top5=top5accs))

            allcaps = allcaps[sort_ind]  
            for j in range(allcaps.shape[0]):
                img_caps = allcaps[j].tolist()
                img_captions = list(
                    map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}],
                        img_caps))  
                references.append(img_captions)


            _, preds = torch.max(scores_copy, dim=2)
            preds = preds.tolist()
            temp_preds = list()
            for j, p in enumerate(preds):
                temp_preds.append(preds[j][:decode_lengths[j]])  # remove pads
            preds = temp_preds
            hypotheses.extend(preds)

            assert len(references) == len(hypotheses)

        bleu4 = corpus_bleu(references, hypotheses)

        print(
            '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(
                loss=losses,
                top5=top5accs,
                bleu=bleu4))

    return bleu4

In [None]:


# Data parameters
data_folder = '/content/drive/MyDrive/CV/out/' 
data_name = 'flickr8k_5_cap_per_img_5_min_word_freq'  

# Model parameters
emb_dim = 512  
attention_dim = 512  
decoder_dim = 512  
dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
cudnn.benchmark = True 

# Training parameters
start_epoch = 0
epochs = 10 
epochs_since_improvement = 0  
batch_size = 32
workers = 1  
CNN_lr = 1e-4  
decoder_lr = 4e-4  
grad_clip = 5. 
alpha_c = 1.  
best_bleu4 = 0.  
print_freq = 100
fine_tune_CNN = False  
checkpoint = None 



In [None]:
def main():


    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_CNN, data_name, word_map

    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)


    if checkpoint is None:
        decoder = LSTMwithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        CNN = CNN()
        CNN.fine_tune(fine_tune_CNN)
        CNN_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, CNN.parameters()),
                                             lr=CNN_lr) if fine_tune_CNN else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        CNN = checkpoint['CNN']
        CNN_optimizer = checkpoint['CNN_optimizer']
        if fine_tune_CNN is True and CNN_optimizer is None:
            CNN.fine_tune(fine_tune_CNN)
            CNN_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, CNN.parameters()),
                                                 lr=CNN_lr)

    decoder = decoder.to(device)
    CNN = CNN.to(device)

    criterion = nn.CrossEntropyLoss().to(device)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(
        Captions_Dataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        Captions_Dataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

    for epoch in tqdm(range(start_epoch, epochs), desc = "Epochs"):

        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_CNN:
                adjust_learning_rate(CNN_optimizer, 0.8)

        train(train_loader=train_loader,
              CNN=CNN,
              decoder=decoder,
              criterion=criterion,
              CNN_optimizer=CNN_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        recent_bleu4 = validate(val_loader=val_loader,
                                CNN=CNN,
                                decoder=decoder,
                                criterion=criterion)

        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0



if __name__ == '__main__':
    main()


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch: [0][0/938]	Batch Time 1.099 (1.099)	Data Load Time 0.212 (0.212)	Loss 8.8387 (8.8387)	Top-5 Accuracy 0.000 (0.000)
Epoch: [0][100/938]	Batch Time 0.735 (0.774)	Data Load Time 0.000 (0.002)	Loss 5.7041 (6.1299)	Top-5 Accuracy 40.230 (35.036)
Epoch: [0][200/938]	Batch Time 0.755 (0.774)	Data Load Time 0.000 (0.001)	Loss 4.9754 (5.7669)	Top-5 Accuracy 50.000 (39.685)
Epoch: [0][300/938]	Batch Time 0.751 (0.772)	Data Load Time 0.000 (0.001)	Loss 4.7979 (5.5155)	Top-5 Accuracy 53.911 (43.424)
Epoch: [0][400/938]	Batch Time 0.721 (0.772)	Data Load Time 0.000 (0.001)	Loss 4.7721 (5.3472)	Top-5 Accuracy 52.959 (45.823)
Epoch: [0][500/938]	Batch Time 0.787 (0.773)	Data Load Time 0.001 (0.001)	Loss 4.5057 (5.2194)	Top-5 Accuracy 57.368 (47.621)
Epoch: [0][600/938]	Batch Time 0.814 (0.773)	Data Load Time 0.003 (0.001)	Loss 4.5344 (5.1082)	Top-5 Accuracy 58.929 (49.224)
Epoch: [0][700/938]	Batch Time 0.764 (0.774)	Data Load Time 0.000 (0.001)	Loss 4.2406 (5.0170)	Top-5 Accuracy 64.000 (50.4

Epochs:  10%|█         | 1/10 [13:42<2:03:25, 822.86s/it]

Epoch: [1][0/938]	Batch Time 1.060 (1.060)	Data Load Time 0.247 (0.247)	Loss 4.3736 (4.3736)	Top-5 Accuracy 57.027 (57.027)
Epoch: [1][100/938]	Batch Time 0.759 (0.774)	Data Load Time 0.000 (0.003)	Loss 4.2424 (4.2176)	Top-5 Accuracy 61.478 (60.864)
Epoch: [1][200/938]	Batch Time 0.793 (0.775)	Data Load Time 0.000 (0.002)	Loss 3.8129 (4.1845)	Top-5 Accuracy 67.662 (61.369)
Epoch: [1][300/938]	Batch Time 0.772 (0.774)	Data Load Time 0.000 (0.001)	Loss 4.1959 (4.1760)	Top-5 Accuracy 61.602 (61.435)
Epoch: [1][400/938]	Batch Time 0.778 (0.775)	Data Load Time 0.000 (0.001)	Loss 4.0903 (4.1616)	Top-5 Accuracy 64.925 (61.591)
Epoch: [1][500/938]	Batch Time 0.730 (0.774)	Data Load Time 0.000 (0.001)	Loss 4.0507 (4.1454)	Top-5 Accuracy 62.286 (61.762)
Epoch: [1][600/938]	Batch Time 0.771 (0.774)	Data Load Time 0.000 (0.001)	Loss 3.8274 (4.1342)	Top-5 Accuracy 65.957 (61.975)
Epoch: [1][700/938]	Batch Time 0.791 (0.774)	Data Load Time 0.000 (0.001)	Loss 3.8930 (4.1210)	Top-5 Accuracy 65.995 (62

Epochs:  20%|██        | 2/10 [27:23<1:49:33, 821.69s/it]

Epoch: [2][0/938]	Batch Time 1.066 (1.066)	Data Load Time 0.187 (0.187)	Loss 3.8447 (3.8447)	Top-5 Accuracy 66.253 (66.253)
Epoch: [2][100/938]	Batch Time 0.766 (0.779)	Data Load Time 0.000 (0.002)	Loss 3.6172 (3.8684)	Top-5 Accuracy 69.974 (65.227)
Epoch: [2][200/938]	Batch Time 0.774 (0.776)	Data Load Time 0.000 (0.001)	Loss 4.3711 (3.8596)	Top-5 Accuracy 58.649 (65.516)
Epoch: [2][300/938]	Batch Time 0.784 (0.775)	Data Load Time 0.000 (0.001)	Loss 3.8770 (3.8549)	Top-5 Accuracy 60.741 (65.558)
Epoch: [2][400/938]	Batch Time 0.775 (0.775)	Data Load Time 0.000 (0.001)	Loss 3.7714 (3.8510)	Top-5 Accuracy 68.490 (65.659)
Epoch: [2][500/938]	Batch Time 0.773 (0.774)	Data Load Time 0.000 (0.001)	Loss 3.8423 (3.8468)	Top-5 Accuracy 66.220 (65.759)
Epoch: [2][600/938]	Batch Time 0.823 (0.774)	Data Load Time 0.000 (0.001)	Loss 3.7524 (3.8394)	Top-5 Accuracy 68.828 (65.898)
Epoch: [2][700/938]	Batch Time 0.756 (0.774)	Data Load Time 0.000 (0.001)	Loss 3.8953 (3.8354)	Top-5 Accuracy 66.576 (65

Epochs:  30%|███       | 3/10 [41:04<1:35:48, 821.25s/it]

Epoch: [3][0/938]	Batch Time 1.057 (1.057)	Data Load Time 0.186 (0.186)	Loss 3.6786 (3.6786)	Top-5 Accuracy 66.223 (66.223)
Epoch: [3][100/938]	Batch Time 0.791 (0.777)	Data Load Time 0.000 (0.002)	Loss 3.5566 (3.6460)	Top-5 Accuracy 69.877 (68.512)
Epoch: [3][200/938]	Batch Time 0.802 (0.775)	Data Load Time 0.001 (0.001)	Loss 3.7763 (3.6570)	Top-5 Accuracy 68.408 (68.259)
Epoch: [3][300/938]	Batch Time 0.767 (0.774)	Data Load Time 0.000 (0.001)	Loss 3.3867 (3.6623)	Top-5 Accuracy 71.802 (68.088)
Epoch: [3][400/938]	Batch Time 0.774 (0.773)	Data Load Time 0.002 (0.001)	Loss 3.5198 (3.6578)	Top-5 Accuracy 68.449 (68.186)
Epoch: [3][500/938]	Batch Time 0.752 (0.773)	Data Load Time 0.000 (0.001)	Loss 3.4854 (3.6603)	Top-5 Accuracy 68.286 (68.159)
Epoch: [3][600/938]	Batch Time 0.730 (0.773)	Data Load Time 0.000 (0.001)	Loss 3.6074 (3.6634)	Top-5 Accuracy 68.935 (68.148)
Epoch: [3][700/938]	Batch Time 0.767 (0.773)	Data Load Time 0.000 (0.001)	Loss 3.5650 (3.6630)	Top-5 Accuracy 70.876 (68

Epochs:  40%|████      | 4/10 [54:45<1:22:06, 821.01s/it]

Epoch: [4][0/938]	Batch Time 1.053 (1.053)	Data Load Time 0.256 (0.256)	Loss 3.4832 (3.4832)	Top-5 Accuracy 72.434 (72.434)
Epoch: [4][100/938]	Batch Time 0.798 (0.780)	Data Load Time 0.000 (0.003)	Loss 3.2866 (3.5110)	Top-5 Accuracy 75.062 (70.396)
Epoch: [4][200/938]	Batch Time 0.777 (0.779)	Data Load Time 0.000 (0.002)	Loss 3.4480 (3.5103)	Top-5 Accuracy 69.171 (70.397)
Epoch: [4][300/938]	Batch Time 0.767 (0.779)	Data Load Time 0.000 (0.001)	Loss 3.3328 (3.5110)	Top-5 Accuracy 72.576 (70.392)
Epoch: [4][400/938]	Batch Time 0.822 (0.778)	Data Load Time 0.000 (0.001)	Loss 3.7398 (3.5203)	Top-5 Accuracy 66.826 (70.238)
Epoch: [4][500/938]	Batch Time 0.765 (0.778)	Data Load Time 0.000 (0.001)	Loss 3.4282 (3.5203)	Top-5 Accuracy 70.000 (70.264)
Epoch: [4][600/938]	Batch Time 0.791 (0.777)	Data Load Time 0.000 (0.001)	Loss 3.6524 (3.5167)	Top-5 Accuracy 67.308 (70.361)
Epoch: [4][700/938]	Batch Time 0.783 (0.777)	Data Load Time 0.000 (0.001)	Loss 3.5951 (3.5179)	Top-5 Accuracy 68.653 (70

Epochs:  50%|█████     | 5/10 [1:08:29<1:08:30, 822.09s/it]

Epoch: [5][0/938]	Batch Time 1.165 (1.165)	Data Load Time 0.253 (0.253)	Loss 3.2936 (3.2936)	Top-5 Accuracy 74.046 (74.046)
Epoch: [5][100/938]	Batch Time 0.738 (0.779)	Data Load Time 0.001 (0.003)	Loss 3.3916 (3.3968)	Top-5 Accuracy 69.565 (72.163)
Epoch: [5][200/938]	Batch Time 0.791 (0.776)	Data Load Time 0.000 (0.002)	Loss 3.4732 (3.3998)	Top-5 Accuracy 69.975 (72.117)
Epoch: [5][300/938]	Batch Time 0.791 (0.775)	Data Load Time 0.000 (0.001)	Loss 3.3504 (3.4066)	Top-5 Accuracy 71.638 (71.998)
Epoch: [5][400/938]	Batch Time 0.791 (0.775)	Data Load Time 0.000 (0.001)	Loss 3.5559 (3.4037)	Top-5 Accuracy 69.797 (72.048)
Epoch: [5][500/938]	Batch Time 0.767 (0.774)	Data Load Time 0.000 (0.001)	Loss 3.4585 (3.4025)	Top-5 Accuracy 71.391 (72.089)
Epoch: [5][600/938]	Batch Time 0.719 (0.774)	Data Load Time 0.000 (0.001)	Loss 3.3505 (3.4060)	Top-5 Accuracy 72.393 (72.017)
Epoch: [5][700/938]	Batch Time 0.761 (0.774)	Data Load Time 0.000 (0.001)	Loss 3.3175 (3.4040)	Top-5 Accuracy 73.829 (72

Epochs:  60%|██████    | 6/10 [1:22:09<54:46, 821.62s/it]  

Epoch: [6][0/938]	Batch Time 1.092 (1.092)	Data Load Time 0.189 (0.189)	Loss 3.1084 (3.1084)	Top-5 Accuracy 75.556 (75.556)
Epoch: [6][100/938]	Batch Time 0.767 (0.775)	Data Load Time 0.000 (0.002)	Loss 3.4841 (3.2899)	Top-5 Accuracy 69.048 (73.795)
Epoch: [6][200/938]	Batch Time 0.793 (0.776)	Data Load Time 0.002 (0.001)	Loss 3.5419 (3.2849)	Top-5 Accuracy 69.330 (73.840)
Epoch: [6][300/938]	Batch Time 0.774 (0.776)	Data Load Time 0.000 (0.001)	Loss 3.2572 (3.2915)	Top-5 Accuracy 75.067 (73.667)
Epoch: [6][400/938]	Batch Time 0.756 (0.775)	Data Load Time 0.000 (0.001)	Loss 3.1360 (3.2907)	Top-5 Accuracy 72.652 (73.692)
Epoch: [6][500/938]	Batch Time 0.739 (0.775)	Data Load Time 0.000 (0.001)	Loss 3.4010 (3.3011)	Top-5 Accuracy 71.225 (73.534)
Epoch: [6][600/938]	Batch Time 0.779 (0.774)	Data Load Time 0.000 (0.001)	Loss 3.5580 (3.2998)	Top-5 Accuracy 68.146 (73.585)
Epoch: [6][700/938]	Batch Time 0.768 (0.774)	Data Load Time 0.000 (0.001)	Loss 3.1819 (3.3018)	Top-5 Accuracy 74.271 (73

Epochs:  70%|███████   | 7/10 [1:35:50<41:03, 821.25s/it]

Epoch: [7][0/938]	Batch Time 1.142 (1.142)	Data Load Time 0.232 (0.232)	Loss 3.0755 (3.0755)	Top-5 Accuracy 75.641 (75.641)
Epoch: [7][100/938]	Batch Time 0.793 (0.783)	Data Load Time 0.000 (0.003)	Loss 3.2048 (3.1910)	Top-5 Accuracy 74.169 (75.065)
Epoch: [7][200/938]	Batch Time 0.770 (0.777)	Data Load Time 0.000 (0.002)	Loss 3.4207 (3.1917)	Top-5 Accuracy 71.129 (75.056)
Epoch: [7][300/938]	Batch Time 0.773 (0.776)	Data Load Time 0.000 (0.001)	Loss 3.3171 (3.2002)	Top-5 Accuracy 72.533 (74.996)
Epoch: [7][400/938]	Batch Time 0.745 (0.776)	Data Load Time 0.000 (0.001)	Loss 3.2309 (3.2042)	Top-5 Accuracy 76.438 (75.001)
Epoch: [7][500/938]	Batch Time 0.756 (0.776)	Data Load Time 0.000 (0.001)	Loss 3.2991 (3.2069)	Top-5 Accuracy 73.529 (74.984)
Epoch: [7][600/938]	Batch Time 0.738 (0.775)	Data Load Time 0.000 (0.001)	Loss 3.0519 (3.2113)	Top-5 Accuracy 76.705 (74.878)
Epoch: [7][700/938]	Batch Time 0.748 (0.774)	Data Load Time 0.000 (0.001)	Loss 3.2480 (3.2126)	Top-5 Accuracy 76.177 (74

Epochs:  80%|████████  | 8/10 [1:49:31<27:22, 821.39s/it]

Epoch: [8][0/938]	Batch Time 1.103 (1.103)	Data Load Time 0.190 (0.190)	Loss 3.1718 (3.1718)	Top-5 Accuracy 73.629 (73.629)
Epoch: [8][100/938]	Batch Time 0.754 (0.775)	Data Load Time 0.000 (0.002)	Loss 3.2382 (3.0813)	Top-5 Accuracy 74.731 (76.981)
Epoch: [8][200/938]	Batch Time 0.755 (0.772)	Data Load Time 0.000 (0.001)	Loss 3.1947 (3.1080)	Top-5 Accuracy 76.045 (76.567)
Epoch: [8][300/938]	Batch Time 0.751 (0.773)	Data Load Time 0.000 (0.001)	Loss 3.2882 (3.1152)	Top-5 Accuracy 73.596 (76.442)
Epoch: [8][400/938]	Batch Time 0.820 (0.773)	Data Load Time 0.000 (0.001)	Loss 3.0245 (3.1214)	Top-5 Accuracy 76.812 (76.300)
Epoch: [8][500/938]	Batch Time 0.766 (0.773)	Data Load Time 0.000 (0.001)	Loss 3.1636 (3.1247)	Top-5 Accuracy 74.400 (76.231)
Epoch: [8][600/938]	Batch Time 0.771 (0.773)	Data Load Time 0.000 (0.001)	Loss 3.0861 (3.1312)	Top-5 Accuracy 80.851 (76.145)
Epoch: [8][700/938]	Batch Time 0.754 (0.773)	Data Load Time 0.000 (0.001)	Loss 3.2985 (3.1363)	Top-5 Accuracy 75.698 (76

Epochs:  90%|█████████ | 9/10 [2:03:12<13:41, 821.18s/it]

Epoch: [9][0/938]	Batch Time 1.125 (1.125)	Data Load Time 0.234 (0.234)	Loss 3.0588 (3.0588)	Top-5 Accuracy 76.338 (76.338)
Epoch: [9][100/938]	Batch Time 0.779 (0.775)	Data Load Time 0.000 (0.003)	Loss 3.0365 (3.0246)	Top-5 Accuracy 77.719 (77.860)
Epoch: [9][200/938]	Batch Time 0.786 (0.772)	Data Load Time 0.000 (0.002)	Loss 3.0576 (3.0254)	Top-5 Accuracy 78.646 (77.805)
Epoch: [9][300/938]	Batch Time 0.757 (0.771)	Data Load Time 0.000 (0.001)	Loss 3.0208 (3.0387)	Top-5 Accuracy 76.554 (77.609)
Epoch: [9][400/938]	Batch Time 0.779 (0.772)	Data Load Time 0.000 (0.001)	Loss 3.1705 (3.0414)	Top-5 Accuracy 76.517 (77.517)
Epoch: [9][500/938]	Batch Time 0.777 (0.773)	Data Load Time 0.000 (0.001)	Loss 3.0602 (3.0491)	Top-5 Accuracy 77.979 (77.408)
Epoch: [9][600/938]	Batch Time 0.767 (0.773)	Data Load Time 0.000 (0.001)	Loss 3.0656 (3.0535)	Top-5 Accuracy 77.454 (77.337)
Epoch: [9][700/938]	Batch Time 0.779 (0.774)	Data Load Time 0.000 (0.001)	Loss 3.1444 (3.0585)	Top-5 Accuracy 77.151 (77

Epochs: 100%|██████████| 10/10 [2:16:53<00:00, 821.35s/it]
