In [1]:
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, models
import PIL.Image as Image
import torch.nn as nn
import torchvision
import torch.optim as optim
import sys
from captum.attr import GuidedGradCam
import cv2
from PIL import Image
from matplotlib.colors import LinearSegmentedColormap
from captum.attr import visualization as viz

sys.path.insert(0, '../src')
from bird_dataset import *
from XAI_birds_dataloader import *
from tqdm import tqdm
from models.multi_task_model import *
from XAI_birds_dataloader import *
from XAI_BirdAttribute_dataloader import *

from download import *

import shutil

In [2]:
# download = True
# id = '156fCp5_VvRfnyHSCBCXhInSE9TxEqDaY'
# destination = '../cub.zip'
# if download:
#     download_file_from_google_drive(id, destination)
# # shutil.unpack_archive('../cub.zip', '../CUB_200_2011')

In [3]:
# #hide
# from fastai.vision.all import *
# from fastai.text.all import *
# from fastai.collab import *
# from fastai.tabular.all import *

In [4]:
bd = BirdDataset(preload=True, attr_file='attributes')

In [5]:
## UNCOMMENT THIS LATER
vgg16 = models.vgg16_bn(pretrained=True)


In [6]:
# vgg16

In [7]:
# bd.images

In [8]:
# [bd.images[i]['attributes'] for i in list(bd.images.keys()) if 'has_bill_shape' in bd.images[i]['attributes']]

In [9]:
# bd.images

In [10]:
trans = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
train_bird_dataset = Bird_Attribute_Loader(bd, attrs=['has_bill_shape', 'has_breast_pattern'], species=True, transform=trans, train=True)
val_bird_dataset = Bird_Attribute_Loader(bd, attrs=['has_breast_pattern','has_bill_shape'], species=True, transform=trans, train=False, val=True)

'has_bill_shape__has_breast_pattern__species'

In [12]:
len(val_bird_dataset)

306

In [11]:
# bd.images

In [12]:
# train_bird_dataset.class_dict

In [13]:
train_bird_dataset[1]['image'].size()

torch.Size([3, 224, 224])

In [14]:
# train_bird_dataset[0]

In [15]:
model = MultiTaskModel(vgg16, train_bird_dataset)

In [16]:
# If there are GPUs, choose the first one for computing. Otherwise use CPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
# If 'cuda:0' is printed, it means GPU is available.

if torch.cuda.is_available():
    model.cuda()

batch_size = 1
trainloader = DataLoader(train_bird_dataset, batch_size=batch_size, shuffle=True)
valloader = DataLoader(val_bird_dataset, batch_size=batch_size, shuffle=True)

loss_func = MultiTaskLossWrapper().to(device)
opt = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

cuda:0


In [17]:
data = train_bird_dataset[0]
inputs, labels = data['image'].cuda(), torch.LongTensor(data['labels']).cuda()

In [18]:
# inputs.resize(1, -1)

In [19]:
type(inputs)

torch.Tensor

In [20]:
outputs =  model(inputs.reshape((1, 3, 224, 224)))

In [21]:
# inputs.size()

In [22]:
labels

tensor([ 6,  0, 16], device='cuda:0')

In [23]:
loss_func(outputs, labels)

tensor(10.8130, device='cuda:0', grad_fn=<AddBackward0>)

In [None]:
%%time
avg_losses = []
avg_val_losses = []
epochs = 50
print_freq = 100
val_acc = []
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # Get the inputs.
#         print("LABELS:",data['labels'])
        inputs, labels = data['image'], torch.LongTensor(data['labels'])

        # Move the inputs to the specified device.
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Zero the parameter gradients.
        opt.zero_grad()

        # Forward step.
        outputs = model(inputs)
        # print('outputs came')
        # print("OUTPUTS:",outputs) 
        # break
        loss = loss_func(outputs, labels)
        # print(loss)
        # Backward step.
        loss.backward()
        
        # Optimization step (update the parameters).
        opt.step()

        # Print statistics.
        running_loss += loss.item()
#         print('running loss', running_loss)
#         print(outputs)
        if i % print_freq == print_freq - 1: # Print every several mini-batches.
            avg_loss = running_loss / print_freq
            print('[epoch: {}, i: {:5d}] avg mini-batch loss: {:.3f}'.format(
                epoch, i, avg_loss))
            running_loss = 0.0
            
    avg_losses.append(avg_loss)
    
    model.eval()
    with torch.no_grad():
        num_correct=0
        val_losses = []
        data_iter = iter(valloader)
        for val_data in data_iter:
#             print(val_data)
            val_inputs, val_labels = val_data['image'].cuda(), torch.LongTensor(val_data['labels']).cuda()
    #                     print('val inputs: ',val_inputs.size())
    #                     print('val labels: ',val_labels)
    #                     print('inputs: ',inputs.size())
    #             print(val_labels)
            val_outputs = model(val_inputs)
    #                     print('val outputs: ',val_outputs)
    #                     print('val loss: ',nn.CrossEntropyLoss()(val_outputs, val_labels))
            opt.zero_grad() #zero the parameter gradients
            val_predicted = [torch.max(i, 1)[1] for i in val_outputs]
#             print(val_predicted)
            # print(labels)
            # print(predicted)
    #             num_correct_k += topk_accuracy(k, labels, outputs)
            num_correct += sum(np.array(val_labels.cpu())==np.array(val_predicted))
    #             print(predicted)
    #             print(val_labels)
            val_losses.append(loss_func(val_outputs, val_labels).item())
        acc = num_correct/(len(data_iter)*batch_size*len(val_labels))
        val_acc.append(acc)
        print('Validation accuracy:',acc)
        print('Average validation loss:',np.mean(val_losses))
        avg_val_losses.append(np.mean(val_losses))
# model.train()
print('Finished Training.')

[epoch: 0, i:    99] avg mini-batch loss: 7.919
[epoch: 0, i:   199] avg mini-batch loss: 7.138
[epoch: 0, i:   299] avg mini-batch loss: 6.625
[epoch: 0, i:   399] avg mini-batch loss: 6.401
[epoch: 0, i:   499] avg mini-batch loss: 6.043
[epoch: 0, i:   599] avg mini-batch loss: 6.000
[epoch: 0, i:   699] avg mini-batch loss: 6.248
[epoch: 0, i:   799] avg mini-batch loss: 5.790
[epoch: 0, i:   899] avg mini-batch loss: 5.989
[epoch: 0, i:   999] avg mini-batch loss: 5.731
[epoch: 0, i:  1099] avg mini-batch loss: 5.927
[epoch: 0, i:  1199] avg mini-batch loss: 6.136
Validation accuracy: 0.39651416122004357
Average validation loss: 6.933001776146733
[epoch: 1, i:    99] avg mini-batch loss: 6.765
[epoch: 1, i:   199] avg mini-batch loss: 6.002
[epoch: 1, i:   299] avg mini-batch loss: 5.842
[epoch: 1, i:   399] avg mini-batch loss: 6.093
[epoch: 1, i:   499] avg mini-batch loss: 5.997
[epoch: 1, i:   599] avg mini-batch loss: 5.952
[epoch: 1, i:   699] avg mini-batch loss: 5.882
[epo

[epoch: 12, i:   699] avg mini-batch loss: 5.773
[epoch: 12, i:   799] avg mini-batch loss: 5.679
[epoch: 12, i:   899] avg mini-batch loss: 5.850
[epoch: 12, i:   999] avg mini-batch loss: 5.740
[epoch: 12, i:  1099] avg mini-batch loss: 5.894
[epoch: 12, i:  1199] avg mini-batch loss: 5.742
Validation accuracy: 0.39433551198257083
Average validation loss: 5.623630969352972
[epoch: 13, i:    99] avg mini-batch loss: 5.683
[epoch: 13, i:   199] avg mini-batch loss: 5.634
[epoch: 13, i:   299] avg mini-batch loss: 5.657
[epoch: 13, i:   399] avg mini-batch loss: 5.869
[epoch: 13, i:   499] avg mini-batch loss: 5.535
[epoch: 13, i:   599] avg mini-batch loss: 5.955
[epoch: 13, i:   699] avg mini-batch loss: 5.710
[epoch: 13, i:   799] avg mini-batch loss: 5.666
[epoch: 13, i:   899] avg mini-batch loss: 5.534
[epoch: 13, i:   999] avg mini-batch loss: 5.582
[epoch: 13, i:  1099] avg mini-batch loss: 5.941
[epoch: 13, i:  1199] avg mini-batch loss: 5.881
Validation accuracy: 0.38562091503

[epoch: 24, i:   999] avg mini-batch loss: 5.830
[epoch: 24, i:  1099] avg mini-batch loss: 5.602
[epoch: 24, i:  1199] avg mini-batch loss: 5.827
Validation accuracy: 0.38562091503267976
Average validation loss: 5.6101285663305545
[epoch: 25, i:    99] avg mini-batch loss: 5.688
[epoch: 25, i:   199] avg mini-batch loss: 5.871
[epoch: 25, i:   299] avg mini-batch loss: 5.697
[epoch: 25, i:   399] avg mini-batch loss: 5.841
[epoch: 25, i:   499] avg mini-batch loss: 5.620
[epoch: 25, i:   599] avg mini-batch loss: 5.514
[epoch: 25, i:   699] avg mini-batch loss: 5.557
[epoch: 25, i:   799] avg mini-batch loss: 5.628
[epoch: 25, i:   899] avg mini-batch loss: 5.571
[epoch: 25, i:   999] avg mini-batch loss: 5.912
[epoch: 25, i:  1099] avg mini-batch loss: 5.781
[epoch: 25, i:  1199] avg mini-batch loss: 5.761
Validation accuracy: 0.3888888888888889
Average validation loss: 5.677523842044905
[epoch: 26, i:    99] avg mini-batch loss: 5.654
[epoch: 26, i:   199] avg mini-batch loss: 5.645

In [29]:
len(val_labels)

3

In [26]:
model.eval()
with torch.no_grad():
    num_correct=0
    val_losses = []
    data_iter = iter(valloader)
    for val_data in data_iter:
        print(val_data)
        val_inputs, val_labels = val_data['image'].cuda(), torch.LongTensor(val_data['labels']).cuda()
#                     print('val inputs: ',val_inputs.size())
#                     print('val labels: ',val_labels)
#                     print('inputs: ',inputs.size())
#             print(val_labels)
        val_outputs = model(val_inputs)
#                     print('val outputs: ',val_outputs)
#                     print('val loss: ',nn.CrossEntropyLoss()(val_outputs, val_labels))
        opt.zero_grad() #zero the parameter gradients
        val_predicted = [torch.max(i, 1)[1] for i in val_outputs]
        print(val_predicted)
        # print(labels)
        # print(predicted)
#             num_correct_k += topk_accuracy(k, labels, outputs)
        num_correct += sum(np.array(val_labels.cpu())==np.array(val_predicted))
#             print(predicted)
#             print(val_labels)
        val_losses.append(loss_func(val_outputs, val_labels).item())
    acc = num_correct/(len(data_iter)*batch_size*len(val_labels))
    val_acc.append(acc)
    print('Validation accuracy:',acc)
    print('Average validation loss:',np.mean(val_losses))
    avg_val_losses.append(np.mean(val_losses))
# model.train()
print('Finished Training.')

{'image': tensor([[[[0.6667, 0.6667, 0.6588,  ..., 0.6549, 0.6471, 0.6431],
          [0.6667, 0.6667, 0.6627,  ..., 0.6588, 0.6510, 0.6471],
          [0.6627, 0.6627, 0.6667,  ..., 0.6549, 0.6549, 0.6510],
          ...,
          [0.6275, 0.7176, 0.5882,  ..., 0.6706, 0.6667, 0.6667],
          [0.7922, 0.7608, 0.6588,  ..., 0.6667, 0.6706, 0.6667],
          [0.9059, 0.8784, 0.7059,  ..., 0.6706, 0.6706, 0.6667]],

         [[0.6941, 0.7020, 0.7020,  ..., 0.7020, 0.7098, 0.7098],
          [0.7020, 0.7020, 0.6980,  ..., 0.7059, 0.7098, 0.7020],
          [0.7020, 0.6980, 0.6980,  ..., 0.6980, 0.7059, 0.7020],
          ...,
          [0.3961, 0.4784, 0.3843,  ..., 0.7098, 0.7059, 0.7059],
          [0.5608, 0.5333, 0.4392,  ..., 0.7059, 0.7098, 0.7059],
          [0.6941, 0.6824, 0.4667,  ..., 0.7059, 0.7098, 0.7059]],

         [[0.7412, 0.7451, 0.7373,  ..., 0.7490, 0.7529, 0.7490],
          [0.7451, 0.7451, 0.7412,  ..., 0.7529, 0.7529, 0.7451],
          [0.7490, 0.7451, 0.749

{'image': tensor([[[[0.1686, 0.1137, 0.0745,  ..., 0.6039, 0.6000, 0.5922],
          [0.4627, 0.4431, 0.3490,  ..., 0.6118, 0.6000, 0.6000],
          [0.3529, 0.5686, 0.5922,  ..., 0.6157, 0.6039, 0.6039],
          ...,
          [0.3255, 0.3176, 0.3216,  ..., 0.3490, 0.3373, 0.3333],
          [0.3216, 0.3176, 0.3098,  ..., 0.3451, 0.3333, 0.3294],
          [0.3255, 0.3216, 0.3137,  ..., 0.3412, 0.3294, 0.3216]],

         [[0.1922, 0.1490, 0.1529,  ..., 0.6118, 0.5961, 0.5922],
          [0.5059, 0.4824, 0.3922,  ..., 0.6196, 0.6078, 0.6000],
          [0.3882, 0.6078, 0.6235,  ..., 0.6275, 0.6118, 0.6118],
          ...,
          [0.2824, 0.2784, 0.2745,  ..., 0.3255, 0.2941, 0.2863],
          [0.2784, 0.2784, 0.2706,  ..., 0.3216, 0.2941, 0.2863],
          [0.2863, 0.2784, 0.2745,  ..., 0.3255, 0.2941, 0.2784]],

         [[0.1569, 0.1176, 0.1098,  ..., 0.6667, 0.6549, 0.6431],
          [0.4471, 0.4471, 0.3373,  ..., 0.6784, 0.6667, 0.6549],
          [0.3529, 0.6118, 0.643

{'image': tensor([[[[0.4824, 0.4706, 0.5020,  ..., 0.6118, 0.6588, 0.6941],
          [0.4549, 0.4941, 0.5098,  ..., 0.6471, 0.6510, 0.6706],
          [0.3961, 0.4510, 0.5216,  ..., 0.6510, 0.6392, 0.6706],
          ...,
          [0.5216, 0.6039, 0.6196,  ..., 0.7059, 0.5922, 0.5098],
          [0.4314, 0.4902, 0.5451,  ..., 0.8078, 0.6510, 0.5529],
          [0.5490, 0.5804, 0.6471,  ..., 0.6314, 0.5882, 0.5882]],

         [[0.5725, 0.5412, 0.5373,  ..., 0.6118, 0.6588, 0.6941],
          [0.5490, 0.5765, 0.5608,  ..., 0.6510, 0.6471, 0.6667],
          [0.4980, 0.5451, 0.5961,  ..., 0.6549, 0.6392, 0.6667],
          ...,
          [0.4588, 0.5216, 0.5216,  ..., 0.5216, 0.4980, 0.4588],
          [0.4000, 0.4392, 0.4784,  ..., 0.6392, 0.5647, 0.5098],
          [0.5451, 0.5569, 0.6118,  ..., 0.4627, 0.5137, 0.5569]],

         [[0.5686, 0.5373, 0.5373,  ..., 0.6039, 0.6353, 0.6667],
          [0.5451, 0.5686, 0.5569,  ..., 0.6392, 0.6431, 0.6627],
          [0.4902, 0.5333, 0.584

[tensor([7], device='cuda:0'), tensor([0], device='cuda:0'), tensor([54], device='cuda:0')]
{'image': tensor([[[[0.5922, 0.5843, 0.5804,  ..., 0.2627, 0.2431, 0.2353],
          [0.5922, 0.5843, 0.5765,  ..., 0.2627, 0.2510, 0.2510],
          [0.5922, 0.5804, 0.5765,  ..., 0.2627, 0.2510, 0.2510],
          ...,
          [0.4235, 0.4314, 0.4275,  ..., 0.3569, 0.2314, 0.2510],
          [0.4431, 0.4471, 0.4471,  ..., 0.4706, 0.2627, 0.2353],
          [0.4667, 0.4706, 0.4667,  ..., 0.4980, 0.2824, 0.2118]],

         [[0.5922, 0.5843, 0.5804,  ..., 0.2431, 0.2157, 0.1961],
          [0.5922, 0.5843, 0.5804,  ..., 0.2431, 0.2235, 0.2118],
          [0.5922, 0.5804, 0.5765,  ..., 0.2431, 0.2235, 0.2157],
          ...,
          [0.4235, 0.4392, 0.4471,  ..., 0.3569, 0.2353, 0.2431],
          [0.4471, 0.4588, 0.4667,  ..., 0.4745, 0.2667, 0.2275],
          [0.4706, 0.4824, 0.4863,  ..., 0.5137, 0.2980, 0.2078]],

         [[0.5608, 0.5451, 0.5333,  ..., 0.2275, 0.1922, 0.1569],
      

{'image': tensor([[[[0.5020, 0.5059, 0.5059,  ..., 0.5882, 0.5882, 0.5804],
          [0.5059, 0.5059, 0.5098,  ..., 0.5882, 0.5843, 0.5804],
          [0.5020, 0.5020, 0.5020,  ..., 0.5882, 0.5843, 0.5804],
          ...,
          [0.7725, 0.6784, 0.5686,  ..., 0.4627, 0.5020, 0.5098],
          [0.7843, 0.6667, 0.5569,  ..., 0.4745, 0.5176, 0.4745],
          [0.7373, 0.6157, 0.5686,  ..., 0.4902, 0.5059, 0.5294]],

         [[0.5412, 0.5451, 0.5412,  ..., 0.6392, 0.6314, 0.6196],
          [0.5451, 0.5412, 0.5451,  ..., 0.6392, 0.6353, 0.6353],
          [0.5373, 0.5373, 0.5373,  ..., 0.6353, 0.6353, 0.6353],
          ...,
          [0.6588, 0.5961, 0.5098,  ..., 0.3843, 0.4078, 0.4392],
          [0.6510, 0.5961, 0.5137,  ..., 0.3804, 0.4275, 0.3961],
          [0.6118, 0.5569, 0.5098,  ..., 0.3922, 0.4196, 0.4471]],

         [[0.4314, 0.4275, 0.4275,  ..., 0.4353, 0.4431, 0.4353],
          [0.4314, 0.4275, 0.4314,  ..., 0.4353, 0.4392, 0.4392],
          [0.4275, 0.4314, 0.427

{'image': tensor([[[[0.4706, 0.4667, 0.4275,  ..., 0.6941, 0.6902, 0.6980],
          [0.4431, 0.4392, 0.4471,  ..., 0.7059, 0.7020, 0.7059],
          [0.4353, 0.4627, 0.4510,  ..., 0.7059, 0.7020, 0.6980],
          ...,
          [0.2196, 0.2314, 0.2314,  ..., 0.6863, 0.6941, 0.7098],
          [0.2235, 0.2157, 0.2157,  ..., 0.6863, 0.6941, 0.7098],
          [0.2235, 0.2000, 0.2039,  ..., 0.6824, 0.6941, 0.7059]],

         [[0.4941, 0.4824, 0.4431,  ..., 0.6353, 0.6353, 0.6314],
          [0.4706, 0.4549, 0.4549,  ..., 0.6314, 0.6314, 0.6314],
          [0.4627, 0.4784, 0.4627,  ..., 0.6392, 0.6275, 0.6235],
          ...,
          [0.2118, 0.2314, 0.2235,  ..., 0.5882, 0.5922, 0.6000],
          [0.2118, 0.2157, 0.2196,  ..., 0.5843, 0.5882, 0.5922],
          [0.2118, 0.2078, 0.2118,  ..., 0.5922, 0.5961, 0.6000]],

         [[0.5020, 0.5137, 0.4784,  ..., 0.5725, 0.5922, 0.5804],
          [0.5059, 0.4941, 0.4902,  ..., 0.5765, 0.5765, 0.5725],
          [0.4980, 0.5098, 0.486

[tensor([7], device='cuda:0'), tensor([0], device='cuda:0'), tensor([54], device='cuda:0')]
{'image': tensor([[[[0.1451, 0.1647, 0.1843,  ..., 0.1569, 0.1569, 0.1569],
          [0.1529, 0.1686, 0.1843,  ..., 0.1608, 0.1608, 0.1608],
          [0.1608, 0.1647, 0.1765,  ..., 0.1725, 0.1686, 0.1647],
          ...,
          [0.1294, 0.1412, 0.1529,  ..., 0.2078, 0.2157, 0.2314],
          [0.1294, 0.1294, 0.1412,  ..., 0.2000, 0.2039, 0.2196],
          [0.1294, 0.1216, 0.1294,  ..., 0.1961, 0.1961, 0.2078]],

         [[0.1686, 0.1804, 0.1961,  ..., 0.1412, 0.1412, 0.1412],
          [0.1686, 0.1843, 0.2000,  ..., 0.1490, 0.1451, 0.1412],
          [0.1725, 0.1843, 0.2000,  ..., 0.1569, 0.1569, 0.1529],
          ...,
          [0.1216, 0.1294, 0.1373,  ..., 0.1882, 0.1961, 0.2118],
          [0.1176, 0.1176, 0.1255,  ..., 0.1804, 0.1882, 0.2078],
          [0.1137, 0.1098, 0.1176,  ..., 0.1765, 0.1804, 0.1961]],

         [[0.1216, 0.1373, 0.1529,  ..., 0.1216, 0.1216, 0.1216],
      

[tensor([7], device='cuda:0'), tensor([0], device='cuda:0'), tensor([96], device='cuda:0')]
{'image': tensor([[[[0.3176, 0.3059, 0.3059,  ..., 0.3098, 0.3098, 0.3216],
          [0.3059, 0.2941, 0.2941,  ..., 0.2980, 0.2980, 0.3059],
          [0.3020, 0.2941, 0.2941,  ..., 0.3020, 0.2980, 0.3059],
          ...,
          [0.3412, 0.3294, 0.3333,  ..., 0.0627, 0.0588, 0.0667],
          [0.3451, 0.3333, 0.3333,  ..., 0.0588, 0.0627, 0.0706],
          [0.3569, 0.3412, 0.3412,  ..., 0.0745, 0.0706, 0.0902]],

         [[0.5843, 0.5804, 0.5804,  ..., 0.5804, 0.5804, 0.5843],
          [0.5765, 0.5725, 0.5725,  ..., 0.5804, 0.5765, 0.5843],
          [0.5804, 0.5765, 0.5765,  ..., 0.5804, 0.5765, 0.5843],
          ...,
          [0.6118, 0.6039, 0.6078,  ..., 0.0471, 0.0431, 0.0510],
          [0.6118, 0.6039, 0.6078,  ..., 0.0431, 0.0471, 0.0549],
          [0.6235, 0.6118, 0.6157,  ..., 0.0627, 0.0627, 0.0706]],

         [[0.7725, 0.7765, 0.7765,  ..., 0.7725, 0.7804, 0.7765],
      

{'image': tensor([[[[0.4863, 0.4627, 0.4980,  ..., 0.4902, 0.4902, 0.4863],
          [0.4588, 0.4863, 0.4902,  ..., 0.4784, 0.4863, 0.4824],
          [0.4588, 0.4980, 0.4980,  ..., 0.4745, 0.4824, 0.4784],
          ...,
          [0.4745, 0.4510, 0.4549,  ..., 0.4078, 0.4039, 0.4039],
          [0.4784, 0.4745, 0.4667,  ..., 0.4078, 0.4000, 0.3922],
          [0.4824, 0.4863, 0.4784,  ..., 0.4078, 0.4039, 0.3961]],

         [[0.5059, 0.4824, 0.5176,  ..., 0.5412, 0.5412, 0.5373],
          [0.4784, 0.5020, 0.5098,  ..., 0.5333, 0.5373, 0.5333],
          [0.4784, 0.5137, 0.5176,  ..., 0.5333, 0.5333, 0.5294],
          ...,
          [0.5333, 0.5137, 0.5137,  ..., 0.5059, 0.4941, 0.4941],
          [0.5333, 0.5294, 0.5176,  ..., 0.5020, 0.4902, 0.4863],
          [0.5333, 0.5373, 0.5255,  ..., 0.4980, 0.4980, 0.4902]],

         [[0.4824, 0.4588, 0.4902,  ..., 0.4314, 0.4275, 0.4314],
          [0.4549, 0.4784, 0.4863,  ..., 0.4157, 0.4196, 0.4235],
          [0.4549, 0.4902, 0.494

[tensor([7], device='cuda:0'), tensor([0], device='cuda:0'), tensor([139], device='cuda:0')]
{'image': tensor([[[[0.1961, 0.2000, 0.2471,  ..., 0.4078, 0.4118, 0.4196],
          [0.2000, 0.2078, 0.2431,  ..., 0.4157, 0.4118, 0.3922],
          [0.1922, 0.2118, 0.2627,  ..., 0.4392, 0.4078, 0.3608],
          ...,
          [0.2235, 0.1882, 0.1647,  ..., 0.2941, 0.2863, 0.2667],
          [0.2235, 0.1922, 0.1725,  ..., 0.2980, 0.2824, 0.2510],
          [0.2157, 0.2039, 0.1961,  ..., 0.2941, 0.2745, 0.2353]],

         [[0.2941, 0.2980, 0.3529,  ..., 0.3804, 0.3569, 0.3608],
          [0.2980, 0.3098, 0.3569,  ..., 0.4275, 0.4039, 0.3804],
          [0.2941, 0.3098, 0.3529,  ..., 0.5098, 0.4902, 0.4471],
          ...,
          [0.3765, 0.3608, 0.3490,  ..., 0.3961, 0.3843, 0.3647],
          [0.3882, 0.3725, 0.3686,  ..., 0.3961, 0.3843, 0.3529],
          [0.3843, 0.3804, 0.3725,  ..., 0.3922, 0.3725, 0.3373]],

         [[0.1804, 0.1725, 0.2078,  ..., 0.3804, 0.3647, 0.3725],
     

[tensor([6], device='cuda:0'), tensor([0], device='cuda:0'), tensor([192], device='cuda:0')]
{'image': tensor([[[[1.0000, 1.0000, 1.0000,  ..., 0.8784, 0.9098, 0.9294],
          [1.0000, 1.0000, 1.0000,  ..., 0.8824, 0.9137, 0.9333],
          [1.0000, 1.0000, 1.0000,  ..., 0.8863, 0.9176, 0.9373],
          ...,
          [0.9922, 0.9922, 0.9922,  ..., 0.8745, 0.9176, 0.9294],
          [0.9922, 0.9922, 0.9922,  ..., 0.8667, 0.9059, 0.9255],
          [0.9922, 0.9922, 0.9882,  ..., 0.8510, 0.9059, 0.9294]],

         [[1.0000, 1.0000, 1.0000,  ..., 0.8706, 0.9059, 0.9255],
          [1.0000, 1.0000, 1.0000,  ..., 0.8745, 0.9098, 0.9294],
          [1.0000, 1.0000, 1.0000,  ..., 0.8784, 0.9098, 0.9333],
          ...,
          [0.9922, 0.9922, 0.9922,  ..., 0.8627, 0.9137, 0.9294],
          [0.9922, 0.9922, 0.9922,  ..., 0.8588, 0.9059, 0.9255],
          [0.9961, 0.9922, 0.9882,  ..., 0.8471, 0.8980, 0.9216]],

         [[0.9922, 0.9922, 0.9922,  ..., 0.8745, 0.8941, 0.9176],
     

{'image': tensor([[[[0.5294, 0.5294, 0.5294,  ..., 0.5373, 0.5412, 0.5412],
          [0.5294, 0.5294, 0.5294,  ..., 0.5373, 0.5373, 0.5373],
          [0.5294, 0.5294, 0.5294,  ..., 0.5333, 0.5294, 0.5294],
          ...,
          [0.4980, 0.4980, 0.4980,  ..., 0.4824, 0.4784, 0.4784],
          [0.4980, 0.4980, 0.4980,  ..., 0.4745, 0.4745, 0.4745],
          [0.5020, 0.5020, 0.4941,  ..., 0.4745, 0.4706, 0.4706]],

         [[0.4863, 0.4863, 0.4863,  ..., 0.4941, 0.4980, 0.4980],
          [0.4863, 0.4863, 0.4863,  ..., 0.4941, 0.4941, 0.4941],
          [0.4863, 0.4863, 0.4863,  ..., 0.4902, 0.4863, 0.4863],
          ...,
          [0.4549, 0.4549, 0.4549,  ..., 0.4941, 0.4902, 0.4902],
          [0.4588, 0.4588, 0.4588,  ..., 0.4863, 0.4863, 0.4863],
          [0.4627, 0.4627, 0.4549,  ..., 0.4863, 0.4824, 0.4824]],

         [[0.4706, 0.4706, 0.4706,  ..., 0.4784, 0.4824, 0.4824],
          [0.4706, 0.4706, 0.4706,  ..., 0.4784, 0.4784, 0.4784],
          [0.4706, 0.4706, 0.470

[tensor([7], device='cuda:0'), tensor([0], device='cuda:0'), tensor([96], device='cuda:0')]
{'image': tensor([[[[0.3843, 0.3843, 0.3843,  ..., 0.5333, 0.5412, 0.5765],
          [0.3843, 0.3765, 0.3765,  ..., 0.5294, 0.5451, 0.5608],
          [0.3725, 0.3725, 0.3725,  ..., 0.4941, 0.5176, 0.5412],
          ...,
          [0.6784, 0.6667, 0.6627,  ..., 0.7608, 0.7255, 0.6980],
          [0.7098, 0.6980, 0.7020,  ..., 0.7529, 0.7176, 0.6863],
          [0.7294, 0.7216, 0.7294,  ..., 0.7333, 0.7137, 0.7020]],

         [[0.4824, 0.4824, 0.4824,  ..., 0.5451, 0.5490, 0.5843],
          [0.4824, 0.4745, 0.4745,  ..., 0.5412, 0.5529, 0.5686],
          [0.4706, 0.4706, 0.4706,  ..., 0.5059, 0.5255, 0.5529],
          ...,
          [0.7216, 0.7098, 0.7098,  ..., 0.7686, 0.7451, 0.7216],
          [0.7569, 0.7451, 0.7490,  ..., 0.7647, 0.7412, 0.7137],
          [0.7765, 0.7725, 0.7804,  ..., 0.7529, 0.7412, 0.7333]],

         [[0.1255, 0.1216, 0.1216,  ..., 0.5373, 0.5451, 0.5804],
      

{'image': tensor([[[[0.3490, 0.2980, 0.2549,  ..., 0.2941, 0.3059, 0.3059],
          [0.3529, 0.3020, 0.2588,  ..., 0.3098, 0.3137, 0.3137],
          [0.3529, 0.3176, 0.2706,  ..., 0.3098, 0.3098, 0.3137],
          ...,
          [0.3725, 0.3765, 0.4196,  ..., 0.4157, 0.3922, 0.3882],
          [0.4549, 0.4627, 0.4784,  ..., 0.4549, 0.4314, 0.4235],
          [0.4980, 0.5020, 0.5176,  ..., 0.5333, 0.5373, 0.5412]],

         [[0.3294, 0.3020, 0.2549,  ..., 0.2902, 0.2941, 0.2902],
          [0.3333, 0.3137, 0.2588,  ..., 0.2902, 0.2902, 0.2902],
          [0.3451, 0.3216, 0.2667,  ..., 0.2863, 0.2824, 0.2784],
          ...,
          [0.2784, 0.2863, 0.3176,  ..., 0.3255, 0.3020, 0.2941],
          [0.3490, 0.3529, 0.3686,  ..., 0.3333, 0.3098, 0.3020],
          [0.4118, 0.4157, 0.4314,  ..., 0.4000, 0.3961, 0.3961]],

         [[0.1882, 0.1608, 0.1294,  ..., 0.0784, 0.0863, 0.0824],
          [0.2000, 0.1529, 0.1333,  ..., 0.0784, 0.0784, 0.0784],
          [0.2039, 0.1686, 0.137

{'image': tensor([[[[0.0314, 0.0314, 0.0275,  ..., 0.4941, 0.4941, 0.4980],
          [0.0235, 0.0353, 0.0314,  ..., 0.4824, 0.4941, 0.4941],
          [0.0314, 0.0353, 0.0353,  ..., 0.4824, 0.4824, 0.4863],
          ...,
          [0.4078, 0.4157, 0.4157,  ..., 0.9922, 0.9961, 0.9961],
          [0.4118, 0.4157, 0.4157,  ..., 0.9922, 0.9961, 1.0000],
          [0.4157, 0.4118, 0.4118,  ..., 0.9922, 0.9961, 1.0000]],

         [[0.0353, 0.0353, 0.0314,  ..., 0.3569, 0.3569, 0.3608],
          [0.0314, 0.0392, 0.0392,  ..., 0.3451, 0.3569, 0.3569],
          [0.0392, 0.0431, 0.0431,  ..., 0.3451, 0.3451, 0.3490],
          ...,
          [0.3804, 0.3882, 0.3882,  ..., 0.9569, 0.9608, 0.9608],
          [0.3843, 0.3882, 0.3882,  ..., 0.9569, 0.9608, 0.9647],
          [0.3882, 0.3843, 0.3843,  ..., 0.9569, 0.9608, 0.9647]],

         [[0.0078, 0.0078, 0.0039,  ..., 0.2314, 0.2314, 0.2353],
          [0.0000, 0.0078, 0.0000,  ..., 0.2196, 0.2314, 0.2314],
          [0.0000, 0.0078, 0.003

[tensor([7], device='cuda:0'), tensor([0], device='cuda:0'), tensor([96], device='cuda:0')]
{'image': tensor([[[[0.0275, 0.0235, 0.0275,  ..., 0.0510, 0.0471, 0.0471],
          [0.0196, 0.0196, 0.0235,  ..., 0.0549, 0.0471, 0.0510],
          [0.0196, 0.0196, 0.0235,  ..., 0.0510, 0.0471, 0.0510],
          ...,
          [0.3098, 0.3059, 0.2510,  ..., 0.4118, 0.4314, 0.4353],
          [0.2588, 0.2667, 0.2549,  ..., 0.3843, 0.4000, 0.3922],
          [0.2431, 0.2471, 0.2510,  ..., 0.3373, 0.3333, 0.3176]],

         [[0.0196, 0.0157, 0.0196,  ..., 0.1059, 0.0980, 0.0980],
          [0.0118, 0.0118, 0.0157,  ..., 0.1098, 0.0980, 0.1020],
          [0.0118, 0.0118, 0.0157,  ..., 0.1059, 0.0980, 0.1020],
          ...,
          [0.3098, 0.3020, 0.2471,  ..., 0.4157, 0.4314, 0.4314],
          [0.2549, 0.2627, 0.2510,  ..., 0.3882, 0.3961, 0.3882],
          [0.2392, 0.2431, 0.2471,  ..., 0.3412, 0.3333, 0.3176]],

         [[0.0235, 0.0196, 0.0196,  ..., 0.0549, 0.0588, 0.0588],
      

[tensor([7], device='cuda:0'), tensor([0], device='cuda:0'), tensor([96], device='cuda:0')]
{'image': tensor([[[[0.3686, 0.3647, 0.3608,  ..., 0.1373, 0.1412, 0.1529],
          [0.3647, 0.3569, 0.3529,  ..., 0.1333, 0.1373, 0.1451],
          [0.3725, 0.3647, 0.3608,  ..., 0.1216, 0.1137, 0.1176],
          ...,
          [0.0824, 0.0863, 0.0784,  ..., 0.3529, 0.3490, 0.3451],
          [0.0863, 0.0941, 0.0863,  ..., 0.3569, 0.3490, 0.3451],
          [0.0784, 0.0941, 0.0980,  ..., 0.3569, 0.3490, 0.3451]],

         [[0.5647, 0.5608, 0.5608,  ..., 0.1098, 0.1137, 0.1176],
          [0.5608, 0.5608, 0.5608,  ..., 0.0902, 0.0980, 0.1059],
          [0.5608, 0.5608, 0.5647,  ..., 0.0706, 0.0745, 0.0863],
          ...,
          [0.0784, 0.0784, 0.0745,  ..., 0.5765, 0.5725, 0.5725],
          [0.0745, 0.0863, 0.0784,  ..., 0.5765, 0.5765, 0.5725],
          [0.0745, 0.0824, 0.0824,  ..., 0.5765, 0.5765, 0.5765]],

         [[0.7020, 0.6980, 0.6941,  ..., 0.1098, 0.1176, 0.1451],
      

{'image': tensor([[[[0.1216, 0.1216, 0.1137,  ..., 0.5804, 0.5608, 0.5529],
          [0.1176, 0.1216, 0.1176,  ..., 0.5059, 0.5059, 0.4667],
          [0.1176, 0.1255, 0.1176,  ..., 0.4941, 0.4627, 0.4314],
          ...,
          [0.5333, 0.5569, 0.5804,  ..., 0.2471, 0.2118, 0.2118],
          [0.5373, 0.5647, 0.5804,  ..., 0.3373, 0.2510, 0.2353],
          [0.5647, 0.5922, 0.5843,  ..., 0.4353, 0.3098, 0.2471]],

         [[0.1216, 0.1216, 0.1137,  ..., 0.5098, 0.5020, 0.5020],
          [0.1176, 0.1216, 0.1176,  ..., 0.4549, 0.4588, 0.4235],
          [0.1176, 0.1255, 0.1176,  ..., 0.4549, 0.4314, 0.4000],
          ...,
          [0.5451, 0.5569, 0.5765,  ..., 0.2549, 0.2196, 0.2235],
          [0.5137, 0.5333, 0.5608,  ..., 0.3647, 0.2627, 0.2392],
          [0.5412, 0.5608, 0.5608,  ..., 0.4745, 0.3294, 0.2510]],

         [[0.1216, 0.1216, 0.1137,  ..., 0.4471, 0.4431, 0.4510],
          [0.1176, 0.1216, 0.1176,  ..., 0.4000, 0.4039, 0.3765],
          [0.1176, 0.1255, 0.117

[tensor([7], device='cuda:0'), tensor([0], device='cuda:0'), tensor([96], device='cuda:0')]
{'image': tensor([[[[0.4549, 0.4627, 0.4706,  ..., 0.3059, 0.3098, 0.3098],
          [0.4627, 0.4706, 0.4824,  ..., 0.2980, 0.3020, 0.3020],
          [0.4667, 0.4784, 0.4941,  ..., 0.2941, 0.2941, 0.2941],
          ...,
          [0.7686, 0.7686, 0.7686,  ..., 0.4118, 0.3686, 0.6039],
          [0.7725, 0.7725, 0.7725,  ..., 0.4745, 0.6510, 0.8902],
          [0.7804, 0.7804, 0.7804,  ..., 0.4667, 0.7137, 0.6627]],

         [[0.3765, 0.3922, 0.4078,  ..., 0.2980, 0.2980, 0.2980],
          [0.3882, 0.4039, 0.4235,  ..., 0.2902, 0.2902, 0.2902],
          [0.3961, 0.4157, 0.4392,  ..., 0.2824, 0.2824, 0.2824],
          ...,
          [0.6745, 0.6745, 0.6745,  ..., 0.3647, 0.2941, 0.5020],
          [0.6784, 0.6784, 0.6784,  ..., 0.4196, 0.5569, 0.7922],
          [0.6863, 0.6863, 0.6863,  ..., 0.4078, 0.6431, 0.6196]],

         [[0.3804, 0.3922, 0.4078,  ..., 0.2235, 0.2235, 0.2235],
      

{'image': tensor([[[[0.1843, 0.1961, 0.1804,  ..., 0.2549, 0.2745, 0.2706],
          [0.1961, 0.2000, 0.1922,  ..., 0.2824, 0.2980, 0.2863],
          [0.1922, 0.1882, 0.1922,  ..., 0.2941, 0.3059, 0.2980],
          ...,
          [0.4235, 0.4235, 0.4392,  ..., 0.9804, 0.9647, 0.9569],
          [0.3843, 0.3922, 0.4157,  ..., 0.9882, 0.9647, 0.9569],
          [0.3608, 0.3804, 0.3922,  ..., 0.9922, 0.9725, 0.9412]],

         [[0.2353, 0.2431, 0.2275,  ..., 0.2588, 0.2667, 0.2745],
          [0.2471, 0.2431, 0.2235,  ..., 0.2784, 0.2941, 0.2941],
          [0.2392, 0.2235, 0.2078,  ..., 0.2902, 0.3098, 0.3216],
          ...,
          [0.5020, 0.4902, 0.4824,  ..., 0.9765, 0.9608, 0.9490],
          [0.4471, 0.4549, 0.4510,  ..., 0.9882, 0.9647, 0.9569],
          [0.4078, 0.4392, 0.4392,  ..., 0.9922, 0.9765, 0.9490]],

         [[0.1765, 0.1804, 0.1686,  ..., 0.2353, 0.2510, 0.2627],
          [0.1804, 0.1882, 0.1765,  ..., 0.2353, 0.2667, 0.2706],
          [0.1686, 0.1725, 0.176

[tensor([6], device='cuda:0'), tensor([0], device='cuda:0'), tensor([192], device='cuda:0')]
{'image': tensor([[[[0.4000, 0.4235, 0.4235,  ..., 0.3294, 0.3333, 0.3529],
          [0.4000, 0.4196, 0.4196,  ..., 0.3490, 0.3569, 0.3608],
          [0.4157, 0.4157, 0.4157,  ..., 0.3725, 0.3647, 0.3608],
          ...,
          [0.3451, 0.3333, 0.3373,  ..., 0.9020, 0.8353, 0.8353],
          [0.3373, 0.3294, 0.3373,  ..., 0.9059, 0.8824, 0.8902],
          [0.3451, 0.3294, 0.3294,  ..., 0.9255, 0.8941, 0.8941]],

         [[0.5529, 0.5490, 0.5373,  ..., 0.4471, 0.4588, 0.4745],
          [0.5608, 0.5529, 0.5451,  ..., 0.4706, 0.4706, 0.4745],
          [0.5686, 0.5569, 0.5412,  ..., 0.4902, 0.4824, 0.4863],
          ...,
          [0.5373, 0.5451, 0.5451,  ..., 0.8824, 0.8235, 0.8078],
          [0.5333, 0.5412, 0.5490,  ..., 0.8824, 0.8784, 0.8706],
          [0.5373, 0.5373, 0.5451,  ..., 0.9059, 0.8941, 0.8745]],

         [[0.2627, 0.2667, 0.2706,  ..., 0.2078, 0.1922, 0.1922],
     

{'image': tensor([[[[0.2235, 0.2235, 0.2235,  ..., 0.3569, 0.3333, 0.3333],
          [0.2275, 0.2275, 0.2275,  ..., 0.3608, 0.3922, 0.4314],
          [0.2314, 0.2314, 0.2314,  ..., 0.3333, 0.3333, 0.3529],
          ...,
          [0.4549, 0.4627, 0.4745,  ..., 0.2157, 0.2235, 0.2275],
          [0.4549, 0.4627, 0.4745,  ..., 0.2275, 0.2275, 0.2275],
          [0.4549, 0.4627, 0.4745,  ..., 0.2078, 0.1922, 0.1804]],

         [[0.2078, 0.2078, 0.2078,  ..., 0.3176, 0.2941, 0.2941],
          [0.2118, 0.2118, 0.2118,  ..., 0.3216, 0.3529, 0.3922],
          [0.2157, 0.2157, 0.2157,  ..., 0.2941, 0.2941, 0.3137],
          ...,
          [0.4588, 0.4667, 0.4784,  ..., 0.2157, 0.2196, 0.2235],
          [0.4549, 0.4627, 0.4745,  ..., 0.2314, 0.2314, 0.2314],
          [0.4549, 0.4627, 0.4745,  ..., 0.2235, 0.2078, 0.1961]],

         [[0.2118, 0.2118, 0.2118,  ..., 0.3098, 0.2863, 0.2863],
          [0.2157, 0.2157, 0.2157,  ..., 0.3137, 0.3451, 0.3843],
          [0.2196, 0.2196, 0.219

{'image': tensor([[[[0.8118, 0.7843, 0.7647,  ..., 0.2902, 0.2745, 0.2824],
          [0.8275, 0.8000, 0.7608,  ..., 0.2784, 0.3255, 0.3294],
          [0.8510, 0.8314, 0.7961,  ..., 0.3765, 0.4000, 0.3804],
          ...,
          [0.4863, 0.4667, 0.4745,  ..., 0.1569, 0.1529, 0.1529],
          [0.4745, 0.4627, 0.4588,  ..., 0.1843, 0.1725, 0.1765],
          [0.3569, 0.3686, 0.3922,  ..., 0.2784, 0.2706, 0.2941]],

         [[0.6510, 0.6431, 0.6392,  ..., 0.3725, 0.3608, 0.3647],
          [0.6549, 0.6471, 0.6314,  ..., 0.3804, 0.4275, 0.4588],
          [0.6627, 0.6510, 0.6392,  ..., 0.4941, 0.5255, 0.5137],
          ...,
          [0.5020, 0.5020, 0.4980,  ..., 0.2196, 0.2118, 0.2157],
          [0.4980, 0.4980, 0.4902,  ..., 0.2667, 0.2627, 0.2627],
          [0.3843, 0.4235, 0.4353,  ..., 0.3882, 0.3804, 0.3922]],

         [[0.4902, 0.4706, 0.4549,  ..., 0.2039, 0.1961, 0.2118],
          [0.5216, 0.4902, 0.4667,  ..., 0.2157, 0.2314, 0.2196],
          [0.5451, 0.5176, 0.490

{'image': tensor([[[[1.0000, 1.0000, 1.0000,  ..., 0.3765, 0.3529, 0.3843],
          [1.0000, 1.0000, 1.0000,  ..., 0.3608, 0.4118, 0.4980],
          [1.0000, 1.0000, 1.0000,  ..., 0.5686, 0.6471, 0.6706],
          ...,
          [0.9961, 0.9804, 0.9608,  ..., 0.4745, 0.5098, 0.5294],
          [1.0000, 0.9843, 0.9647,  ..., 0.4902, 0.5255, 0.5529],
          [0.9961, 0.9882, 0.9608,  ..., 0.5020, 0.5373, 0.5608]],

         [[1.0000, 1.0000, 1.0000,  ..., 0.3176, 0.2980, 0.3765],
          [1.0000, 1.0000, 1.0000,  ..., 0.3020, 0.3686, 0.5020],
          [1.0000, 1.0000, 1.0000,  ..., 0.5255, 0.6235, 0.6745],
          ...,
          [0.9961, 0.9843, 0.9647,  ..., 0.5294, 0.5451, 0.5608],
          [1.0000, 0.9843, 0.9647,  ..., 0.5373, 0.5647, 0.5804],
          [0.9961, 0.9882, 0.9647,  ..., 0.5490, 0.5765, 0.6000]],

         [[1.0000, 1.0000, 1.0000,  ..., 0.2588, 0.2157, 0.2667],
          [1.0000, 1.0000, 1.0000,  ..., 0.2235, 0.2275, 0.3922],
          [1.0000, 1.0000, 1.000

{'image': tensor([[[[0.5647, 0.5765, 0.5843,  ..., 0.4824, 0.4627, 0.4706],
          [0.5529, 0.5686, 0.5882,  ..., 0.4824, 0.4706, 0.4588],
          [0.5608, 0.5647, 0.5922,  ..., 0.4824, 0.4706, 0.4627],
          ...,
          [0.6745, 0.5961, 0.5686,  ..., 0.4863, 0.4902, 0.4941],
          [0.7373, 0.6471, 0.5882,  ..., 0.5255, 0.5255, 0.5294],
          [0.8078, 0.7176, 0.6196,  ..., 0.5529, 0.5529, 0.5569]],

         [[0.7020, 0.7137, 0.7137,  ..., 0.5882, 0.5804, 0.5686],
          [0.6980, 0.7098, 0.7255,  ..., 0.5922, 0.5804, 0.5686],
          [0.7020, 0.7059, 0.7255,  ..., 0.6078, 0.5804, 0.5686],
          ...,
          [0.7216, 0.7020, 0.7059,  ..., 0.5333, 0.5412, 0.5451],
          [0.7608, 0.7294, 0.7176,  ..., 0.5647, 0.5725, 0.5765],
          [0.8118, 0.7686, 0.7373,  ..., 0.5922, 0.5922, 0.6039]],

         [[0.3490, 0.3647, 0.3451,  ..., 0.1333, 0.1216, 0.1412],
          [0.3529, 0.3804, 0.3804,  ..., 0.1216, 0.1137, 0.1294],
          [0.3725, 0.3922, 0.400

{'image': tensor([[[[0.0078, 0.0196, 0.0196,  ..., 0.2627, 0.2667, 0.2745],
          [0.0078, 0.0196, 0.0235,  ..., 0.2627, 0.2588, 0.2627],
          [0.0078, 0.0235, 0.0196,  ..., 0.2627, 0.2667, 0.2667],
          ...,
          [0.0039, 0.0078, 0.0118,  ..., 0.0118, 0.0157, 0.0118],
          [0.0078, 0.0039, 0.0039,  ..., 0.0039, 0.0118, 0.0157],
          [0.0039, 0.0078, 0.0039,  ..., 0.0118, 0.0118, 0.0157]],

         [[0.0627, 0.0588, 0.0627,  ..., 0.3098, 0.3098, 0.3137],
          [0.0588, 0.0627, 0.0706,  ..., 0.3059, 0.3059, 0.3059],
          [0.0627, 0.0667, 0.0667,  ..., 0.3059, 0.3098, 0.3098],
          ...,
          [0.0235, 0.0314, 0.0314,  ..., 0.0706, 0.0784, 0.0784],
          [0.0314, 0.0275, 0.0275,  ..., 0.0706, 0.0784, 0.0784],
          [0.0275, 0.0353, 0.0314,  ..., 0.0745, 0.0824, 0.0824]],

         [[0.0667, 0.0667, 0.0667,  ..., 0.2431, 0.2431, 0.2471],
          [0.0627, 0.0706, 0.0706,  ..., 0.2392, 0.2353, 0.2392],
          [0.0667, 0.0745, 0.066

{'image': tensor([[[[0.0902, 0.0784, 0.0667,  ..., 0.2314, 0.2353, 0.2588],
          [0.0824, 0.0824, 0.0745,  ..., 0.1804, 0.1961, 0.2157],
          [0.0863, 0.0706, 0.0745,  ..., 0.1490, 0.1725, 0.1882],
          ...,
          [0.0039, 0.0039, 0.0039,  ..., 0.1137, 0.1098, 0.1020],
          [0.0118, 0.0196, 0.0078,  ..., 0.1098, 0.1098, 0.1059],
          [0.0118, 0.0157, 0.0118,  ..., 0.1059, 0.1020, 0.1333]],

         [[0.0941, 0.0824, 0.0706,  ..., 0.2627, 0.2667, 0.2902],
          [0.0863, 0.0863, 0.0784,  ..., 0.2314, 0.2431, 0.2588],
          [0.0902, 0.0745, 0.0784,  ..., 0.2039, 0.2314, 0.2471],
          ...,
          [0.0039, 0.0039, 0.0039,  ..., 0.1333, 0.1176, 0.0980],
          [0.0118, 0.0196, 0.0078,  ..., 0.1373, 0.1255, 0.1059],
          [0.0118, 0.0157, 0.0118,  ..., 0.1255, 0.1137, 0.1216]],

         [[0.0627, 0.0510, 0.0392,  ..., 0.1569, 0.1490, 0.1686],
          [0.0549, 0.0549, 0.0471,  ..., 0.1216, 0.1255, 0.1373],
          [0.0588, 0.0431, 0.047

{'image': tensor([[[[0.5216, 0.5216, 0.5255,  ..., 0.8275, 0.8275, 0.8314],
          [0.5098, 0.5137, 0.5137,  ..., 0.8353, 0.8353, 0.8353],
          [0.4902, 0.5020, 0.5020,  ..., 0.8431, 0.8431, 0.8392],
          ...,
          [0.4667, 0.4627, 0.4549,  ..., 0.6078, 0.6118, 0.6078],
          [0.4588, 0.4627, 0.4667,  ..., 0.6078, 0.6118, 0.6157],
          [0.4588, 0.4706, 0.4745,  ..., 0.6118, 0.6157, 0.6157]],

         [[0.6353, 0.6431, 0.6471,  ..., 0.8784, 0.8784, 0.8824],
          [0.6275, 0.6353, 0.6392,  ..., 0.8863, 0.8863, 0.8863],
          [0.6118, 0.6235, 0.6275,  ..., 0.8941, 0.8941, 0.8902],
          ...,
          [0.5922, 0.5882, 0.5804,  ..., 0.6784, 0.6941, 0.6980],
          [0.5843, 0.5882, 0.5882,  ..., 0.6980, 0.7059, 0.7098],
          [0.5843, 0.5961, 0.6000,  ..., 0.7059, 0.7098, 0.7098]],

         [[0.2706, 0.2745, 0.2784,  ..., 0.6706, 0.6706, 0.6745],
          [0.2588, 0.2667, 0.2706,  ..., 0.6784, 0.6784, 0.6784],
          [0.2431, 0.2549, 0.258

[tensor([7], device='cuda:0'), tensor([0], device='cuda:0'), tensor([96], device='cuda:0')]
{'image': tensor([[[[0.7216, 0.7255, 0.7216,  ..., 0.7059, 0.7059, 0.7020],
          [0.7255, 0.7216, 0.7255,  ..., 0.7059, 0.7059, 0.7020],
          [0.7255, 0.7216, 0.7255,  ..., 0.7059, 0.7059, 0.7020],
          ...,
          [0.6706, 0.6667, 0.6667,  ..., 0.6157, 0.6196, 0.6196],
          [0.6667, 0.6627, 0.6706,  ..., 0.6157, 0.6157, 0.6157],
          [0.6667, 0.6627, 0.6706,  ..., 0.6196, 0.6157, 0.6118]],

         [[0.7804, 0.7843, 0.7804,  ..., 0.7765, 0.7725, 0.7686],
          [0.7843, 0.7804, 0.7843,  ..., 0.7725, 0.7725, 0.7686],
          [0.7843, 0.7804, 0.7843,  ..., 0.7725, 0.7725, 0.7686],
          ...,
          [0.7333, 0.7255, 0.7294,  ..., 0.6863, 0.6902, 0.6902],
          [0.7333, 0.7373, 0.7333,  ..., 0.6863, 0.6863, 0.6863],
          [0.7333, 0.7333, 0.7333,  ..., 0.6902, 0.6902, 0.6902]],

         [[0.9412, 0.9451, 0.9412,  ..., 0.9412, 0.9412, 0.9373],
      

In [25]:
# torch.save(model.state_dict(), '../models/transfer_multi-task_2_50_epoch_state_dict.pth')

In [40]:
labels[0]

tensor(7, device='cuda:0')

In [44]:
model(inputs)

(tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan]], device='cuda:0',
        grad_fn=<AddmmBackward>),
 tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]],
        device='cuda:0', grad_fn=<AddmmBackward>))

In [45]:
model

MultiTaskModel(
  (encoder): VGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (9): ReLU(inplace=True)
      (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (12): ReLU(inplace=True)
      (13): MaxPool2d(kernel_size=2, stride=2, paddin

In [42]:
outputs

(tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan]], device='cuda:0',
        grad_fn=<AddmmBackward>),
 tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]],
        device='cuda:0', grad_fn=<AddmmBackward>))

In [41]:
outputs[0].reshape(1, -1)

tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan]], device='cuda:0',
       grad_fn=<AsStridedBackward>)

In [125]:
labels[0]

tensor(6., device='cuda:0')

In [128]:
outputs[0].size()

torch.Size([1, 9])

In [131]:
labels[0]

tensor(11., device='cuda:0')

In [129]:
outputs[1].size()

torch.Size([1, 15])

In [126]:
nn.CrossEntropyLoss()(outputs, labels)

AttributeError: 'tuple' object has no attribute 'log_softmax'