In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import seaborn as sns
import os
import spacy
import pandas as pd
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
import bcolz 
import pickle
spacy_eng = spacy.load("en")
from tqdm import tqdm

In [2]:
def extract_data():
    dictionary = {}
    filenames = open('Dataset1/image_names.txt').read().strip().split('\n')
    for f in filenames:
        dictionary[f] = []
    captions = open('Dataset1/captions.txt').read().strip().split('\n')
    capt_map = [cap.strip().split('\t') for cap in captions]
    for i in range(len(capt_map)):
        capt_map[i][0] = capt_map[i][0][:-2]
        if capt_map[i][0] in dictionary:
            dictionary[capt_map[i][0]].append(capt_map[i][1])
    df = pd.DataFrame(dictionary.items(), columns = ['IMAGE', 'CAPTIONS'])
    df2 = pd.DataFrame(capt_map, columns =['IMAGE', 'CAPTION']) 
    df.to_csv('dataset_neat.csv',index = False)
    df2.to_csv('dataset_raw.csv',index = False)

In [3]:
extract_data()
df = pd.read_csv('dataset_neat.csv')
df2 = pd.read_csv('dataset_raw.csv')

In [4]:
class Specialised_Dataset(Dataset):
    def __init__(self,df,min_freq=2):
        
        self.min_freq = min_freq
        self.word2ind = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.ind2word = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.word2count = {"<PAD>": 0, "<SOS>": 0, "<EOS>": 0, "<UNK>": 0}  
        self.num_words = 4
        
        self.image_names = df['IMAGE']
        self.image_captions = df['CAPTION']
        self.caption_list = self.image_captions.to_list()
        self.build_vocabulary(self.caption_list)

        self.glove = self.create_embedding_map();
        
        
    
    def create_embedding_map(self):
        vectors = bcolz.open('6B.200.dat')[:]
        words = pickle.load(open('6B.200_words.pkl', 'rb'))
        word2idx = pickle.load(open('6B.200_idx.pkl', 'rb'))
        return {w: vectors[word2idx[w]] for w in words}
    
    
    def build_vocabulary(self,text_list):
        next_index = 4
        for sentence in text_list:
            for word in [tok.text.lower() for tok in spacy_eng.tokenizer(sentence)]:
                if word not in self.word2count:
                    self.word2count[word]=1
                else:
                    self.word2count[word]+=1
                if(self.word2count[word] == self.min_freq):
                    self.word2ind[word] = next_index
                    self.ind2word[next_index] = word
                    next_index+=1
        self.num_words = next_index
    
    def get_gLoVe_form(self,word): 
        if word in self.glove:
            return torch.tensor(self.glove[word])
        else:
            return torch.tensor(np.random.normal(scale=0.6, size=(200, )))

    def generate_gLoVe_matrix(self):
        gloVe_list = []
        for i in range(self.num_words):
            word = self.ind2word[i]
            gloVe_list.append(self.get_gLoVe_form(word))
        return torch.stack(gloVe_list)
    
        
    def numericalize(self,text):
        tokens = [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
        return [
            self.word2ind[token] if token in self.word2ind else self.word2ind["<UNK>"]
            for token in tokens
        ]
    
    def __len__(self):
        return self.num_words
    
    def __getitem__(self,index):
        caption = self.caption_list[index]
        image_file = self.image_names[index]
        #print(Image)
        image = Image.open('Images/'+image_file).convert("RGB")
        #print(image)
        transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(),])
        image = transform(image)
        
        numericalized_caption = [self.word2ind["<SOS>"]]
        numericalized_caption += self.numericalize(caption)
        numericalized_caption.append(self.word2ind["<EOS>"])
        
        return image,torch.tensor(numericalized_caption)

In [5]:
class CustomPadderFunction:
    def __init__(self, padding):
        self.padding = padding

    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value=self.padding)
        return imgs, targets

In [6]:
Data = Specialised_Dataset(df2)
padding_index = Data.word2ind["<PAD>"]
#Hyperparams
batch_size = 10
shuffle = True
loader = DataLoader(dataset=Data,batch_size=batch_size,shuffle=shuffle,collate_fn=CustomPadderFunction(padding_index),)

In [16]:
class Convolutional_Encoder(nn.Module):
    def __init__(self,first_layer_depth,second_layer_depth,num_clusters,alpha=1.0,normalize_input=True):
        super().__init__()
        self.depth1 = first_layer_depth
        self.depth2 = second_layer_depth
        self.conv1 = nn.Conv2d(3,self.depth1,3,1)
        self.conv2 = nn.Conv2d(self.depth1,self.depth2,3,1)
        
        self.conv1.weight.data.fill_(0.1)
        self.conv2.weight.data.fill_(0.1)
        
        self.num_clusters = num_clusters
        self.alpha = alpha
        self.normalize_input = normalize_input
        
        self.netvlad_layer = nn.Conv2d(self.depth2, num_clusters, kernel_size=(1, 1), bias=True)
        self.centroids = nn.Parameter(torch.rand(self.num_clusters, self.depth2))
        
        self.netvlad_layer.weight = nn.Parameter(
            (2.0 * self.alpha * self.centroids).unsqueeze(-1).unsqueeze(-1)
        )
        self.netvlad_layer.bias = nn.Parameter(
            - self.alpha * self.centroids.norm(dim=1)
        )
        
    def forward(self,x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.max_pool2d(x,2,2)
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.max_pool2d(x,2,2)
        
        print('XOUT',x)
        
        N,C =  x.shape[:2]
        if self.normalize_input:
            x = nn.functional.normalize(x, p=2, dim=1)
            
        soft_assign = self.netvlad_layer(x).view(N, self.num_clusters, -1)
        soft_assign = nn.functional.softmax(soft_assign, dim=1)
        x_flatten = x.view(N, C, -1)
        residual = x_flatten.expand(self.num_clusters, -1, -1, -1).permute(1, 0, 2, 3) - \
            self.centroids.expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0)
        residual *= soft_assign.unsqueeze(2)
        vlad = residual.sum(dim=-1)

        vlad = nn.functional.normalize(vlad, p=2, dim=2)  # intra-normalization
        vlad = vlad.view(x.size(0), -1)  # flatten
        vlad = nn.functional.normalize(vlad, p=2, dim=1)  # L2 normalize
        
        return vlad

In [17]:
class RNN_Decoder(nn.Module):
    def __init__(self,embedding_size,hidden_size, input_size):
        super().__init__()
        self.embedding = nn.Embedding(input_size,embedding_size)
        self.lstm = nn.LSTM(embedding_size, hidden_size, 1)
        self.fc = nn.Linear(hidden_size, input_size)
        
        for par in self.embedding.parameters():
            par.requires_grad = False
        
        
    def forward(self,image_features, captions):
        gLoVe_emb = self.embedding(captions)
        embeddings = torch.cat((image_features.unsqueeze(0), gLoVe_emb), dim=0)     
        hiddens, _ = self.lstm(embeddings)
        outputs = self.fc(hiddens)
        return outputs

In [18]:
class CNN_RNN_Network(nn.Module):
    def __init__(self, embed_size, hidden_size, input_size):
        super().__init__()
        self.encoder = Convolutional_Encoder(30,10,20)
        self.decoder = RNN_Decoder(embed_size,hidden_size,input_size)
    
    def forward(self,images,captions):
        features = self.encoder(images)
        #print(features.shape)
        outputs = self.decoder(features, captions)
        return outputs
    
    def generate_caption(self, image, dictionary, max_length = 30):
        caption = []
        
        with torch.no_grad():
            x = self.encoder(image).unsqueeze(0)
            #print(x)
            #states = None
            #for _ in range(max_length):
            #    #print('states',states)
            #    #print('Input',x)
            #    hiddens, states = self.decoder.lstm(x, states)
            ##    output = self.decoder.fc(hiddens.squeeze(0))
            #    pred = output.argmax(1)
            #    caption.append(pred.item())
            #    x = self.decoder.embedding(pred).unsqueeze(0)
            #    if(dictionary[pred.item()] == "<EOS>"):
            #        break
        
        return x
        #return [dictionary[idx] for idx in caption]

In [19]:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
learning_rate = 0.1
num_epochs = 100

In [20]:
Net = CNN_RNN_Network(200,800,Data.num_words).to(device)
Net.decoder.embedding.weight.data.copy_(Data.generate_gLoVe_matrix().to(device))

tensor([[ 0.6850,  1.3936,  0.5608,  ...,  0.2518, -1.1372, -0.2139],
        [-0.1817,  0.0833, -0.1953,  ...,  0.8463,  0.6671, -0.4662],
        [ 0.4585,  0.7632,  0.4686,  ..., -0.2959, -0.3993, -0.6911],
        ...,
        [-0.2696,  0.2943, -0.5669,  ..., -0.7939, -0.4814, -0.1525],
        [-0.4282,  0.3101, -0.1896,  ...,  0.1711, -0.3179, -0.8590],
        [ 0.0837,  0.4184,  0.0520,  ..., -0.0124, -0.0672,  0.0180]],
       device='cuda:0')

In [21]:
files = ['test1.jpg','test2.jpg','test3.jpg','test4.jpg','test5.jpg']
def test():
    transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(),])
    for f in files:
        test_img = transform(Image.open(f).convert("RGB")).unsqueeze(0)
        Net.generate_caption(test_img.to(device), Data.ind2word)
        #print(test_img)
        print("Example 1 OUTPUT: "+ " ".join(Net.generate_caption(test_img.to(device), Data.ind2word)))

In [13]:

criterion = nn.CrossEntropyLoss(ignore_index=Data.word2ind["<PAD>"])
optimizer = torch.optim.Adam(Net.parameters(), lr=learning_rate)

for e in range(num_epochs):
    epoch_loss = 0
    for idx, (imgs, captions) in tqdm(enumerate(loader), total=len(loader), leave=False):
        imgs = imgs.to(device)
        captions = captions.to(device)
        outputs = Net(imgs, captions[:-1])
        loss = criterion(outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1))
        optimizer.zero_grad()
        epoch_loss += loss.item()
        loss.backward(loss)
        optimizer.step()
    test()
    print(epoch_loss)
    #print(Net.parameters())

  0%|          | 2/517 [00:00<01:27,  5.86it/s]

tensor([[[[46.7839, 46.7333, 46.7663,  ..., 46.0651, 45.9816, 45.9227],
          [47.5992, 47.0710, 46.7533,  ..., 46.1345, 46.0721, 46.0639],
          [53.3086, 52.0286, 50.9886,  ..., 46.2545, 46.2321, 46.1345],
          ...,
          [31.6427, 31.6439, 31.5698,  ..., 31.1510, 30.8333, 30.3874],
          [30.7616, 30.6721, 30.3392,  ..., 29.2333, 29.0804, 29.3204],
          [30.0721, 30.0098, 29.6333,  ..., 29.3439, 29.8416, 30.5004]],

         [[46.7009, 46.6503, 46.6832,  ..., 45.9820, 45.8985, 45.8397],
          [47.5161, 46.9879, 46.6703,  ..., 46.0515, 45.9891, 45.9808],
          [53.2256, 51.9456, 50.9056,  ..., 46.1714, 46.1491, 46.0514],
          ...,
          [31.5597, 31.5609, 31.4867,  ..., 31.0679, 30.7503, 30.3044],
          [30.6785, 30.5891, 30.2562,  ..., 29.1503, 28.9973, 29.2373],
          [29.9891, 29.9267, 29.5503,  ..., 29.2609, 29.7585, 30.4173]],

         [[46.7438, 46.6932, 46.7261,  ..., 46.0250, 45.9414, 45.8826],
          [47.5591, 47.0308, 4

  1%|          | 4/517 [00:00<01:14,  6.87it/s]

tensor([[[[ 79.8479,  81.0595,  82.0020,  ...,  74.2340,  73.7117,  72.8570],
          [ 81.3372,  82.4960,  83.3104,  ...,  75.3832,  74.9951,  74.4942],
          [ 82.7828,  83.8281,  84.3964,  ...,  76.4662,  76.1341,  75.7593],
          ...,
          [ 50.0972,  48.1363,  49.8265,  ...,  73.6613,  76.0138,  78.0457],
          [ 37.9003,  43.0924,  43.0786,  ...,  74.9698,  70.8300,  67.6509],
          [ 42.4154,  45.5787,  44.5348,  ...,  67.9991,  62.3145,  58.4417]],

         [[326.5729, 331.5247, 335.3774,  ..., 303.6276, 301.4930, 298.0014],
          [332.6607, 337.3976, 340.7243,  ..., 308.3240, 306.7382, 304.6923],
          [338.5678, 342.8396, 345.1629,  ..., 312.7503, 311.3931, 309.8616],
          ...,
          [204.9155, 196.9099, 203.8044,  ..., 301.2950, 310.8863, 319.1758],
          [155.1642, 176.3722, 176.3124,  ..., 306.6331, 289.6865, 276.6510],
          [173.5264, 186.4922, 182.2363,  ..., 278.0844, 254.8792, 239.0671]],

         [[318.3563, 323.1825,

  1%|          | 6/517 [00:00<01:09,  7.38it/s]

tensor([[[[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[534.8586, 506.8801, 403.0543,  ..., 388.8328, 393.6372, 402.2244],
          [568.6816, 532.6161, 355.6373,  ..., 292.7814, 425.6642, 464.4602],
          [436.9697, 440.2132, 305.3352,  ..., 303.9993, 422.5151, 453.5696],
          ...,
          [152.9193, 195.1150, 215.2498,  ..., 458.3470, 387.4246, 227.0489],
          [273.9245, 360.6491, 352.1894,  ..., 434.1558, 361.3039, 258.7000],
          [360.2505, 461.9155, 441.6761,  ..., 395.1671, 348.4297, 281.3212]],

         [[223.4811, 211.7939,

  1%|▏         | 7/517 [00:00<01:07,  7.54it/s]

tensor([[[[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[484.5269, 478.9976, 449.2804,  ..., 407.6689, 406.5086, 405.3383],
          [576.2620, 521.3083, 453.8225,  ..., 416.8549, 413.1987, 410.9923],
          [600.7654, 548.3375, 475.6035,  ..., 442.9283, 427.1281, 420.1806],
          ...,
          [722.2838, 725.1492, 725.9605,  ..., 729.1105, 730.3602, 729.9655],
          [725.3828, 726.3331, 731.3735,  ..., 727.3298, 727.7751, 727.7480],
          [726.2138, 725.5165, 730.0369,  ..., 725.0592, 723.3698, 724.2482]],

         [[ 87.9257,  86.9191,

  2%|▏         | 10/517 [00:01<01:00,  8.34it/s]

tensor([[[[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[930.2390, 927.6406, 929.4360,  ..., 911.9596, 926.8910, 937.6786],
          [933.1984, 931.0643, 933.5697,  ..., 916.8058, 926.5967, 928.2999],
          [938.6121, 937.5914, 937.2936,  ..., 913.2525, 926.8073, 926.3621],
          ...,
          [732.8148, 733.6430, 720.6847,  ..., 597.3179, 467.8056, 344.8335],
          [738.2285, 740.3475, 748.4113,  ..., 597.0654, 593.3212, 449.9055],
          [705.4380, 695.7717, 689.3499,  ..., 570.8735, 578.9266, 459.0474]],

         [[ 36.3126,  36.2100,

  2%|▏         | 12/517 [00:01<00:58,  8.58it/s]

tensor([[[[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

         [[956.8022, 966.6373, 968.3981,  ..., 977.7172, 974.4448, 971.8544],
          [967.5309, 969.2415, 971.1315,  ..., 909.6108, 963.8970, 971.2451],
          [975.4237, 972.0055, 974.3854,  ..., 407.7240, 648.0964, 894.9727],
          ...,
          [218.4301, 267.7963, 314.6507,  ..., 597.0206, 564.9747, 592.9345],
          [342.7336, 349.3069, 447.4215,  ..., 600.3127, 581.0516, 589.0652],
          [405.8161, 425.2846, 513.8604,  ..., 587.0604, 565.6645, 584.6022]],

         [[  0.0000,   0.0000,

  3%|▎         | 15/517 [00:01<00:51,  9.67it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[1233.3088, 1225.9091, 1184.0018,  ..., 1256.1097, 1242.0638,
           1062.1842],
          [1010.2872, 1017.1023,  904.1022,  ..., 1107.8931, 1055.1971,
            792.9312],
          [ 942.4503,  725.7240,  496.9242,  ...,  777.8467,  664.0216,
            640.7065],
          ...,
          [ 325.9172,  326.5590,  325.2694,  ...,  409.9974,  402.7468,
            388.9811],
          [ 336.62

  3%|▎         | 17/517 [00:01<00:51,  9.72it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[ 344.2733,  345.0247,  347.6792,  ...,  469.9708,  471.6978,
            472.6022],
          [ 353.3908,  353.8494,  356.7508,  ...,  478.7924,  479.4730,
            481.0864],
          [ 361.8504,  361.7742,  365.6286,  ...,  485.3956,  486.9050,
            487.8346],
          ...,
          [ 699.4207,  717.0732,  740.0394,  ..., 1035.4080, 1035.4835,
           1031.8942],
          [ 729.63

  4%|▍         | 20/517 [00:02<00:49,  9.99it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[ 343.0610,  335.4399,  397.4086,  ...,  193.6290,  185.5818,
            178.7377],
          [ 259.3180,  312.5072,  398.6923,  ...,  189.1300,  190.8106,
            182.3400],
          [ 258.4603,  323.6714,  411.9320,  ...,  196.0940,  198.9613,
            197.0969],
          ...,
          [ 206.7467,  269.6044,  237.7965,  ..., 1433.9839, 1431.4076,
           1431.2736],
          [ 136.55

  4%|▍         | 22/517 [00:02<00:47, 10.36it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[1540.8553, 1443.3297, 1379.7015,  ..., 1785.5465, 1810.3162,
           1810.4875],
          [1549.7061, 1457.2196, 1385.2971,  ..., 1473.2404, 1739.1210,
           1808.9413],
          [1136.7070, 1117.0642, 1155.9287,  ..., 1085.5244, 1350.1235,
           1685.4611],
          ...,
          [ 481.3944,  427.6414,  383.2349,  ...,  283.3112,  244.6583,
            219.1523],
          [ 438.24

  5%|▌         | 26/517 [00:02<00:44, 11.01it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[1398.0636, 1392.4631, 1388.8287,  ...,  857.2170,  857.7365,
            870.6469],
          [1416.5470, 1409.7828, 1405.3344,  ...,  873.5681,  878.8365,
            891.0326],
          [1440.4465, 1432.6072, 1426.8734,  ...,  889.5920,  895.7975,
            898.1554],
          ...,
          [1654.6339, 1519.3051, 1582.7863,  ...,  309.9922,  357.9934,
            310.2112],
          [1613.97

  5%|▌         | 28/517 [00:02<00:43, 11.24it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[1928.4351, 1927.3759, 1926.6010,  ..., 1741.9360, 1737.1727,
           1731.5372],
          [1931.7203, 1929.4393, 1927.8279,  ..., 1744.7410, 1738.0065,
           1733.9482],
          [1930.7904, 1928.7639, 1926.2506,  ..., 1744.5142, 1738.5168,
           1736.7438],
          ...,
          [1304.3301, 1302.3633, 1298.2727,  ...,  960.4331,  959.2361,
            955.1515],
          [1241.91

  6%|▌         | 30/517 [00:03<00:43, 11.30it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[ 206.0940,  272.5124,  498.4057,  ...,  276.6037,  280.8809,
            273.5795],
          [ 208.0915,  278.2290,  518.3284,  ...,  276.0393,  281.3122,
            272.8057],
          [ 209.2982,  244.8399,  494.8116,  ...,  281.8203,  283.8585,
            275.8998],
          ...,
          [ 399.7170,  458.0662,  507.0143,  ...,  701.1400,  662.9056,
            667.0590],
          [ 335.10

  7%|▋         | 34/517 [00:03<00:45, 10.71it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[1074.0316, 1009.7103,  975.0651,  ...,  393.9131,  323.6444,
            243.3629],
          [1229.2223, 1155.2743, 1023.9580,  ...,  408.9867,  345.8232,
            261.1426],
          [1192.5210, 1195.8053, 1068.5736,  ...,  572.5364,  467.1155,
            436.5665],
          ...,
          [1183.7727, 1315.4298, 1378.3829,  ..., 1620.9104, 1569.5867,
           1370.8624],
          [ 617.11

  7%|▋         | 36/517 [00:03<00:46, 10.24it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[ 891.0751,  882.7134,  878.8534,  ..., 1005.0031,  858.6639,
            799.1205],
          [ 895.9348,  890.9213,  884.4321,  ...,  957.8354,  910.9810,
            831.3495],
          [ 906.8182,  900.9456,  892.6049,  ...,  903.6516,  855.0884,
            733.6610],
          ...,
          [2193.4297, 2189.2847, 2185.1140,  ..., 1354.3284,  962.2811,
            858.7110],
          [2194.06

  7%|▋         | 38/517 [00:03<00:46, 10.22it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[  67.7802,   60.2007,   59.4132,  ...,  449.5602,  202.6916,
            191.6806],
          [ 105.8137,   95.9498,   88.5305,  ...,  411.2064,  193.7702,
            185.5276],
          [ 175.6891,  177.5360,  176.4031,  ...,  378.3857,  187.8882,
            174.9231],
          ...,
          [ 317.3306,  315.3916,  323.1647,  ...,   62.1467,   60.9751,
             62.7690],
          [ 320.84

  8%|▊         | 42/517 [00:04<00:45, 10.55it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[ 384.0815,  377.9288,  378.6100,  ...,  893.0887,  639.7296,
            488.4915],
          [ 362.4959,  349.4034,  348.9276,  ...,  816.4844,  592.7648,
            472.3441],
          [ 332.4545,  337.8174,  340.3260,  ..., 1060.6410,  700.1093,
            492.6030],
          ...,
          [1228.5576, 1237.5825, 1251.8223,  ..., 1567.4059, 1554.4983,
           1541.5504],
          [1225.65

  9%|▊         | 44/517 [00:04<00:44, 10.52it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[1693.0192, 1777.6537, 1813.5288,  ..., 1746.4352, 1732.6702,
           1722.0677],
          [1589.2739, 1675.0767, 1751.0752,  ..., 1760.7823, 1751.0186,
           1751.0518],
          [1609.4414, 1592.8165, 1647.0564,  ..., 1786.8988, 1782.2263,
           1772.5670],
          ...,
          [1149.7977, 1156.6630, 1154.8218,  ..., 1697.2734, 1686.6238,
           1675.5391],
          [1153.53

  9%|▉         | 46/517 [00:04<00:44, 10.54it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[1493.8661, 1501.0533, 1518.4562,  ..., 1277.6670, 1272.1926,
           1263.0192],
          [1472.1801, 1488.3081, 1498.1395,  ..., 1280.3278, 1276.7725,
           1271.1702],
          [1471.9550, 1522.0476, 1511.4795,  ..., 1368.8369, 1375.1500,
           1380.9154],
          ...,
          [1115.5997, 1093.3568, 1042.5223,  ..., 1635.3368, 1637.1796,
           1684.0104],
          [1088.57

 10%|▉         | 50/517 [00:05<00:43, 10.66it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[ 567.0007,  751.5171,  534.2816,  ...,  415.4544,  561.0919,
            428.6677],
          [ 537.8743,  715.4865,  517.2421,  ...,  511.7489,  560.8432,
            369.1427],
          [ 551.6310,  722.4313,  511.6402,  ...,  577.5707,  548.0916,
            475.0262],
          ...,
          [1852.6357, 1868.4113, 1870.2532,  ..., 1847.8876, 1853.2393,
           1869.8029],
          [1911.20

 10%|█         | 52/517 [00:05<00:45, 10.31it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[ 354.0034,  297.5412,  340.3116,  ..., 1223.1113, 1184.6737,
           1100.2219],
          [ 286.0974,  275.1077,  285.9328,  ..., 1362.3953, 1405.7023,
           1240.4834],
          [ 226.4556,  258.2952,  280.5315,  ..., 1490.0852, 1349.8027,
           1225.9806],
          ...,
          [ 825.0243,  840.1982,  790.1258,  ..., 1645.0813, 1646.9458,
           1611.1736],
          [ 787.13

 11%|█         | 56/517 [00:05<00:44, 10.34it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[2040.7214, 2041.8586, 2036.0598,  ..., 1936.7871, 1849.4181,
           1863.1545],
          [2034.8422, 2042.4509, 2034.8472,  ..., 1923.1133, 1839.4019,
           1841.8391],
          [2035.9836, 2035.7859, 2022.7471,  ..., 1906.2046, 1825.3167,
           1833.0043],
          ...,
          [ 256.1844,  262.9476,  274.5112,  ...,  214.4607,  211.4156,
            204.8083],
          [ 252.77

 12%|█▏        | 60/517 [00:06<00:43, 10.41it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[ 997.0959, 1003.3498, 1030.6470,  ..., 1187.5677, 1196.3156,
           1208.2096],
          [1162.4459, 1161.9120, 1148.7142,  ..., 1210.0006, 1206.8605,
           1212.1428],
          [1172.8975, 1172.8015, 1194.5262,  ..., 1198.4276, 1190.3864,
           1217.4905],
          ...,
          [ 994.9715,  937.8797,  895.4340,  ...,  952.4808,  956.7768,
            936.6061],
          [ 756.04

 12%|█▏        | 62/517 [00:06<00:44, 10.19it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[1124.4541, 1300.0337, 1395.9972,  ..., 1859.4291, 1474.5205,
           1061.1068],
          [1242.6914, 1307.8778, 1370.4602,  ..., 1857.9109, 1547.2871,
           1283.9479],
          [1337.8831, 1184.5327, 1078.2130,  ..., 1637.4576, 1507.6056,
           1518.8301],
          ...,
          [1310.9269, 1231.0730,  998.1948,  ...,  725.2395,  810.5607,
            830.9194],
          [1650.96

 13%|█▎        | 66/517 [00:06<00:43, 10.30it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[ 394.5795,  464.2731,  486.5869,  ...,  583.0881,  588.7965,
            901.9837],
          [ 441.3425,  443.1500,  422.2701,  ..., 1100.4276, 1063.6833,
            950.2322],
          [ 522.3204,  497.7282,  476.6920,  ..., 1183.6605, 1155.5398,
           1067.2495],
          ...,
          [1932.3958, 1338.4612,  777.4719,  ..., 1610.6279, 1626.5129,
           1640.1567],
          [1993.53

 13%|█▎        | 68/517 [00:06<00:42, 10.53it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[1193.7874, 1271.9547, 1391.7145,  ..., 1739.8248, 1613.0964,
           1649.7400],
          [1192.3348, 1307.7821, 1337.4307,  ..., 1864.2435, 1799.4568,
           1598.4240],
          [1271.2227, 1315.3860, 1320.1349,  ..., 1843.6851, 1813.9843,
           1707.4069],
          ...,
          [1703.9822, 1846.1161, 1893.3168,  ..., 1213.5020, 1204.0338,
           1144.3683],
          [1803.32

 14%|█▍        | 72/517 [00:07<00:41, 10.83it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[ 118.2172,   58.9910,   58.9910,  ..., 1847.0236, 1842.4543,
           1833.3402],
          [ 434.5902,   95.3723,   60.2601,  ..., 1873.0131, 1865.4943,
           1855.7256],
          [1033.5891,  424.7715,   99.1433,  ..., 1877.1534, 1867.6787,
           1861.9822],
          ...,
          [ 535.3133,  608.7954,  732.7943,  ..., 1198.2816, 1536.4733,
           1603.4891],
          [ 499.68

 14%|█▍        | 74/517 [00:07<00:40, 10.92it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[ 460.0701,  454.3695,  511.2523,  ...,  341.8026,  440.7187,
            410.8807],
          [ 478.9928,  452.2034,  462.1933,  ...,  380.7556,  441.7427,
            391.3853],
          [ 470.2381,  445.6492,  468.7524,  ...,  382.1997,  404.8680,
            358.7169],
          ...,
          [1482.5585, 1497.3423, 1513.0841,  ..., 1465.5458, 1449.7362,
           1428.2845],
          [1466.75

 15%|█▌        | 78/517 [00:07<00:39, 11.22it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[ 628.3632,  642.7060,  652.5885,  ...,  824.6552,  909.8527,
            930.0585],
          [ 638.1961,  638.4747,  653.8693,  ...,  887.9289,  980.5404,
            986.4010],
          [ 654.0853,  642.4449,  648.1242,  ...,  933.6024, 1039.4790,
           1045.8748],
          ...,
          [ 565.9937,  585.5607,  599.6006,  ..., 2316.0181, 2316.0181,
           2316.0181],
          [ 549.14

 15%|█▌        | 80/517 [00:07<00:39, 10.98it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[2260.2573, 2280.8582, 2277.2349,  ..., 1546.2554, 1554.3579,
           1314.9808],
          [1934.5995, 1995.3291, 2003.2913,  ..., 1703.5437, 1703.5530,
           1577.6628],
          [1757.1855, 1740.9169, 1648.9622,  ..., 1587.5879, 1583.4377,
           1620.4269],
          ...,
          [1382.3434, 1518.3683, 1233.4459,  ...,  916.7114, 1074.6807,
           1144.8887],
          [1288.34

 16%|█▌        | 84/517 [00:08<00:38, 11.27it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[ 806.6628,  814.7521,  818.3989,  ...,  757.7548,  546.6041,
            549.3721],
          [ 816.7142,  825.5634,  829.4670,  ...,  728.4682,  595.6277,
            558.4760],
          [ 826.8591,  833.8618,  838.1945,  ...,  649.5930,  601.6722,
            603.9188],
          ...,
          [2038.8275, 2067.5715, 2082.5513,  ..., 2085.5090, 2063.9875,
           2024.4130],
          [2029.58

 17%|█▋        | 86/517 [00:08<00:38, 11.31it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[1998.2970, 1988.3809, 1896.9501,  ..., 2114.8823, 2074.0754,
           2032.7878],
          [1947.8441, 1878.2842, 1859.2458,  ..., 1934.6245, 1921.7780,
           1854.0311],
          [1850.9137, 1787.9033, 1814.6433,  ..., 1677.5293, 1701.0786,
           1699.1927],
          ...,
          [1634.6914, 1635.6205, 1661.9413,  ..., 1422.9222, 1322.5570,
           1311.1284],
          [1623.82

 17%|█▋        | 90/517 [00:08<00:38, 11.01it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[ 698.4769,  712.2150,  683.0466,  ...,  480.3061,  438.2519,
            319.0505],
          [ 716.6053,  751.6976,  749.7188,  ...,  485.4153,  437.9123,
            385.7342],
          [ 723.9771,  811.7003,  871.7224,  ...,  535.9206,  525.3539,
            500.9443],
          ...,
          [ 743.1971,  763.0276,  769.0887,  ..., 1350.2211, 1444.3247,
           1483.4929],
          [ 726.42

 18%|█▊        | 92/517 [00:08<00:39, 10.82it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[ 317.1753,  268.3942,  187.6546,  ...,  328.2368,  240.8516,
            202.5377],
          [ 366.1490,  306.2257,  219.6722,  ...,  262.7205,  294.1484,
            285.6987],
          [ 404.7349,  397.8792,  277.9308,  ...,  324.2107,  451.0444,
            425.0350],
          ...,
          [2084.0808, 2080.9624, 2084.9485,  ..., 1853.1715, 1846.2128,
           1845.1935],
          [2135.92

 19%|█▊        | 96/517 [00:09<00:38, 10.91it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[1146.9359, 1162.4512, 1161.3442,  ..., 1847.0535, 1807.7773,
           1536.9292],
          [1223.6284, 1209.5875, 1146.5338,  ..., 1831.5946, 1476.8839,
            904.8354],
          [1212.7203, 1223.0052, 1157.7678,  ..., 1656.6674,  945.9913,
            665.4825],
          ...,
          [ 392.9653,  409.7436,  419.2237,  ...,  737.8057,  891.3768,
           1042.2462],
          [ 385.87

 19%|█▉        | 98/517 [00:09<00:39, 10.74it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[1042.9335,  916.9550,  827.3541,  ...,  824.1709,  902.1760,
           1105.6736],
          [1059.6245, 1000.3244,  896.0578,  ...,  812.3113,  803.5283,
            993.2057],
          [1011.2839,  987.3519,  910.3784,  ...,  881.4209,  756.6957,
            778.5094],
          ...,
          [ 958.9827, 1098.9823, 1146.7086,  ..., 1091.6270, 1153.9766,
           1141.6300],
          [ 948.74

 20%|█▉        | 102/517 [00:09<00:37, 10.94it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[ 316.5153,  336.7933,  334.5897,  ...,  558.1561,  395.7130,
            187.0245],
          [ 328.4366,  335.9090,  334.3690,  ...,  435.1104,  214.3756,
            160.1301],
          [ 357.0469,  335.8099,  332.1218,  ...,  310.1415,  195.7874,
            167.5865],
          ...,
          [1544.6311, 1683.9696, 1703.0209,  ..., 1730.3972, 1726.6910,
           1701.5300],
          [1600.09

 20%|██        | 104/517 [00:10<00:38, 10.61it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[ 192.4408,  143.7543,  162.5214,  ..., 1035.0591, 1029.0444,
           1064.2198],
          [ 264.0995,  253.2417,  239.0890,  ..., 1138.4625,  974.1572,
            945.0747],
          [ 521.1202,  515.2576,  496.0368,  ..., 1092.6030,  978.7849,
            911.3881],
          ...,
          [ 382.4500,  415.6150,  429.1219,  ...,  617.7682,  601.8311,
            705.9087],
          [ 373.55

 21%|██        | 108/517 [00:10<00:37, 10.94it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[2271.1355, 2271.9204, 2271.8496,  ..., 1969.6451, 1987.7449,
           2054.8398],
          [2275.1958, 2275.3743, 2275.3743,  ..., 2098.7305, 2196.1018,
           2236.0276],
          [2279.4507, 2278.5613, 2277.6621,  ..., 2199.9495, 2243.0911,
           2251.7896],
          ...,
          [ 783.5945,  809.4882,  803.6371,  ...,  862.9297,  782.0861,
            757.9264],
          [ 779.70

 21%|██▏       | 110/517 [00:10<00:36, 11.13it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[ 946.3475,  983.0334, 1214.4470,  ...,  182.1125,  171.0453,
            184.0864],
          [ 955.4636,  965.0853, 1211.9675,  ...,  166.5953,  191.0484,
            214.1080],
          [ 940.3630,  955.3439, 1215.1398,  ...,  169.1243,  227.0007,
            237.4424],
          ...,
          [1009.7834, 1036.6095, 1088.5540,  ..., 1210.2185, 1211.3932,
           1112.8046],
          [1046.51

 22%|██▏       | 114/517 [00:10<00:37, 10.87it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[1372.5164, 1373.5521, 1379.3492,  ...,  684.8991,  674.2502,
            662.9584],
          [1381.3557, 1404.5674, 1426.9609,  ...,  718.9985,  722.0116,
            695.3818],
          [1374.5276, 1405.8878, 1427.6945,  ...,  765.4999,  731.7361,
            702.1078],
          ...,
          [1452.4940, 1350.3342, 1696.4641,  ...,  389.6334,  457.7556,
            532.3132],
          [1436.19

 22%|██▏       | 116/517 [00:11<00:36, 10.99it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[ 464.6991,  547.7827,  481.1628,  ...,  948.8883, 1010.9937,
           1034.1923],
          [ 504.3790,  549.6101,  488.3733,  ..., 1068.3918, 1082.2806,
           1091.4425],
          [ 548.9428,  613.9152,  494.7053,  ..., 1073.2909, 1070.2461,
           1054.8844],
          ...,
          [ 635.8395,  655.4061,  656.8107,  ...,  400.2500,  388.4428,
            368.2956],
          [ 599.00

 23%|██▎       | 120/517 [00:11<00:35, 11.24it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[ 969.0964, 1057.5695,  998.1807,  ...,  564.2018,  517.9747,
            463.6769],
          [1139.8329, 1160.5111, 1174.7510,  ...,  544.8013,  518.3967,
            466.5053],
          [ 996.1653, 1112.7216, 1151.8230,  ...,  484.7108,  468.5808,
            414.7408],
          ...,
          [ 930.5153,  971.6259,  971.3088,  ..., 1199.2611, 1233.4087,
           1222.3608],
          [ 925.00

 24%|██▎       | 122/517 [00:11<00:34, 11.31it/s]

tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],

         [[1157.0194,  955.6046,  549.5120,  ...,  945.5104,  933.7246,
            861.3597],
          [ 561.2236,  406.8064,  507.5593,  ...,  836.4211,  852.9457,
            846.5664],
          [ 457.9375,  337.3545,  534.1287,  ...,  801.3091,  802.2013,
            805.3586],
          ...,
          [ 177.6912,  249.0210,  320.1911,  ...,  580.9341,  390.3537,
            306.3602],
          [ 224.17

                                                 

KeyboardInterrupt: 

In [22]:
#0 A boy wearing a red shirt and jeans is doing a flip on his bike.
#1 A person flipping a bicycle upside down.
#2 a person flips on a bike.
#3 A person in a red shirt doing tricks on a bicycle .
#4 A person is show upside down on his bicycle over a large field .

#0	A group of people are walking while holding numerous baskets and rugs .
#1	A man carrying many woven baskets next to a man carrying a stack of rugs on his back .
#2	Two individuals carrying numerous products to sell are walking on a beach .
#3	Two people carrying colorful baskets and blankets walk near a building .
#4	Two people walking laden down with baskets and blankets .

#0	A man in black is sitting next to a modern art structure in front of a glass building .
#1	A man sits and reads a newspaper by a sculpture outside of an office building .
#2	a man sits near a large statue .
#3	A man sitting in front of a metal sculpture in front of a building .
#4	The man with the backpack is sitting in a buildings courtyard in front of an art sculpture reading.

#0	A crowd watching air balloons at night .
#1	A group of hot air balloons lit up at night .
#2	People are watching hot air balloons in the park .
#3	People watching hot air balloons .
#4	Seven large balloons are lined up at nighttime near a crowd .

#0	A cyclist is riding a bicycle on a curved road up a hill .
#1	A man in aerodynamic gear riding a bicycle down a road around a sharp curve .
#2	A man on a mountain bike is pedaling up a hill .
#3	Man bicycle up a road , while cows graze on a hill nearby .
#4	The biker is riding around a curve in the road .

In [23]:
cnet = Convolutional_Encoder(30,10,20)
for img,cap in loader:
    print(cnet(img))
    break

XOUT tensor([[[[31.6706, 28.4788, 16.9788,  ..., 24.4377, 35.2012, 35.8153],
          [29.3412, 27.0988, 15.4224,  ..., 22.9988, 27.4235, 27.6530],
          [26.5824, 24.0600, 13.9447,  ..., 23.7471, 27.1730, 26.8083],
          ...,
          [40.3059, 22.5647, 14.9871,  ...,  9.1330,  9.4353,  9.6612],
          [48.2036, 42.0836, 35.5612,  ..., 16.2741, 14.5271, 14.1271],
          [52.5318, 52.9283, 49.4012,  ..., 18.6930, 18.1377, 15.2400]],

         [[31.6676, 28.4758, 16.9759,  ..., 24.4347, 35.1982, 35.8123],
          [29.3382, 27.0958, 15.4194,  ..., 22.9959, 27.4206, 27.6500],
          [26.5794, 24.0570, 13.9417,  ..., 23.7441, 27.1700, 26.8053],
          ...,
          [40.3029, 22.5617, 14.9841,  ...,  9.1300,  9.4323,  9.6582],
          [48.2006, 42.0806, 35.5582,  ..., 16.2712, 14.5241, 14.1241],
          [52.5288, 52.9253, 49.3982,  ..., 18.6900, 18.1347, 15.2370]],

         [[31.6626, 28.4708, 16.9708,  ..., 24.4297, 35.1932, 35.8073],
          [29.3332, 27.09