In [1]:
# Imports 
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
import torch.nn as nn
from tqdm.autonotebook import tqdm, trange

import sys
import matplotlib.pyplot as plt
from torchvision import models

import imagenet_autoencoder_utils as utils
plt.rcParams['figure.figsize'] = [15, 10]

from IPython import display
display.set_matplotlib_formats('svg')

# use GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print (device)

In [None]:
def loss_function(W, x, recons_x, h, lam):
    """Compute the Contractive AutoEncoder Loss
    Evalutes the CAE loss, which is composed as the summation of a Mean
    Squared Error and the weighted l2-norm of the Jacobian of the hidden
    units with respect to the inputs.
    See reference below for an in-depth discussion:
      #1: http://wiseodd.github.io/techblog/2016/12/05/contractive-autoencoder
    Args:
        `W` (FloatTensor): (N_hidden x N), where N_hidden and N are the
          dimensions of the hidden units and input respectively.
        `x` (Variable): the input to the network, with dims (N_batch x N)
        recons_x (Variable): the reconstruction of the input, with dims
          N_batch x N.
        `h` (Variable): the hidden units of the network, with dims
          batch_size x N_hidden
        `lam` (float): the weight given to the jacobian regulariser term
    Returns:
        Variable: the (scalar) CAE loss
    """
    mse = mse_loss(recons_x, x)
    # Since: W is shape of N_hidden x N. So, we do not need to transpose it as
    # opposed to #1
    dh = h * (1 - h) # Hadamard product produces size N_batch x N_hidden
    # Sum through the input dimension to improve efficiency, as suggested in #1
    w_sum = torch.sum(Variable(W)**2, dim=1)
    # unsqueeze to avoid issues with torch.mv
    w_sum = w_sum.unsqueeze(1) # shape N_hidden x 1
    contractive_loss = torch.sum(torch.mm(dh**2, w_sum), 0)
    return mse + contractive_loss.mul_(lam)

    

In [2]:
# Get data
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, Subset
from PIL import Image
import os

DATA_PATH_TRAIN = '../input/imagenetmini-1000/imagenet-mini/train'
DATA_PATH_VAL = '../input/imagenetmini-1000/imagenet-mini/val'

## devide the image by 255
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

batch_size = 32

train_dataset = datasets.ImageFolder(DATA_PATH_TRAIN, transform=transform)
val_dataset = datasets.ImageFolder(DATA_PATH_VAL, transform=transform)

train_set_size = len(train_dataset) - len(train_dataset)%batch_size
val_set_size = len(val_dataset) - len(val_dataset)%batch_size

print(len(train_dataset))
print(len(val_dataset))

# Shorten the dataset
train_indecies = torch.arange(train_set_size)
val_indecies = torch.arange(val_set_size)

train_set = Subset(train_dataset, train_indecies)
val_set = Subset(val_dataset, val_indecies)

train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True, drop_last=False)
val_dataloader = DataLoader(val_set, batch_size=batch_size, shuffle=True, pin_memory=True, drop_last=False)


In [3]:
# loading the classes
if not os.path.exists("/kaggle/working/imagenet_classes.txt"):
    !wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
if not os.path.exists("/kaggle/working/map_clsloc.txt"):
    !wget https://gist.githubusercontent.com/aaronpolhamus/964a4411c0906315deb9f4a3723aac57/raw/aa66dd9dbf6b56649fa3fab83659b2acbf3cbfd1/map_clsloc.txt    
    
with open("/kaggle/working/imagenet_classes.txt", "r") as f:
    classes = [s.strip() for s in f.readlines()]

In [4]:
class VggSplit(nn.Module):
    """This is the model definition of vgg-19 but split just before the avg. pool layer"""
    def __init__(self, vgg):
        super(VggSplit, self).__init__()
        # with batch normalization
        self.encoder = []
        self.encoder.append(nn.Sequential(*list(vgg.features.children())[:6]))
        self.encoder.append(nn.Sequential(*list(vgg.features.children())[6:13]))
        self.encoder.append(nn.Sequential(*list(vgg.features.children())[13:26]))
        self.encoder.append(nn.Sequential(*list(vgg.features.children())[26:39]))
        self.encoder.append(nn.Sequential(*list(vgg.features.children())[39:-1]))
        self.lastMax = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

        self.avg_pool = vgg.avgpool
        self.classifier = nn.Sequential(*list(vgg.classifier.children()))

    def forward(self, x):
        inter_outputs = []
        for encoder_layer in self.encoder:
            x = encoder_layer(x)
            inter_outputs.append(x)
            
        encoder_out = self.lastMax(x)
        avg_pool_out = self.avg_pool(encoder_out)
        flat = torch.flatten(avg_pool_out, 1)
        classifier_out = self.classifier(flat)
        return encoder_out, classifier_out, inter_outputs

In [5]:
def plotValidation(images, model, classes):
    test_fig, test_axis = plt.subplots(2,5)
    
    images = images.to(device)
    out, class_out = model(images)
    
    out = utils.unNormalize(out, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    images = utils.unNormalize(images, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    probabilities = torch.nn.functional.softmax(class_out, dim=0)

    out = out.to('cpu')
    out = torch.permute(out, (0,2,3,1))
    images = images.to('cpu')
    images = torch.permute(images, (0,2,3,1))
    probabilities = probabilities.to('cpu')

    top_class_out = torch.topk(probabilities, 1)
    indexes = top_class_out.indices.detach().numpy()
    values = top_class_out.values.detach().numpy()

    for inx, image in enumerate(images):
        test_axis[0,inx].imshow(image.detach().squeeze().numpy())
        test_axis[0,inx].set_title(str(classes[indexes[inx,0]]))
        test_axis[1,inx].imshow(out[inx].detach().squeeze().numpy())
        test_axis[1,inx].set_title('{:.2f}'.format(criterion(image, out[inx])))

    test_fig.tight_layout()
    test_fig.show()

In [6]:
# download pretrained model
vgg = models.vgg19_bn(pretrained=True, progress=True)
vgg = vgg.to(device)

# val_images = torch.unsqueeze()
# val_images = torch.cat(val_images)
# print(tuple(val_dataloader)[:10])

In [8]:
model_test_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True, pin_memory=True)

def val_VGG(model, val_dataloader):
    acc = utils.AverageMeter()
    for X_val,Y_val in val_dataloader:
        X_val = X_val.to(device)
        Y_val = Y_val.to(device)

        with torch.no_grad():
            classification = model(X_val)

        top_class_out = torch.topk(classification, 5)
        top_class_out = torch.squeeze(top_class_out.indices)

        # convert to float
        Y_val = Y_val.type(torch.float)
        top_class_out = top_class_out.type(torch.float)
        curr_acc = int(Y_val in top_class_out)
        acc.update(curr_acc)
    return acc

res = val_VGG(vgg, model_test_dataloader)
print(res.avg)

In [17]:
vgg = VggSplit(vgg)
vgg = vgg.to(device)
vgg.eval()

In [19]:
class DecoderBN(nn.Module):
    """Decoder with batch normalization"""
    def __init__(self):
        super(DecoderBN, self).__init__()
        self.dec0 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
        )
        self.dec1 = nn.Sequential(
            nn.ReLU(),
            nn.BatchNorm2d(1024),
            nn.Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1),  padding_mode='reflect', bias=False),

            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1),  padding_mode='reflect', bias=False ),
        
            
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1),  padding_mode='reflect', bias=False ),


            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1),  padding_mode='reflect', bias=False ),


            nn.Upsample(scale_factor=2, mode='nearest'),
        )
        self.dec2 = nn.Sequential(
            nn.ReLU(),
            nn.BatchNorm2d(1024),
            nn.Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1),  padding_mode='reflect', bias=False ),

            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1),  padding_mode='reflect', bias=False ),

            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1),  padding_mode='reflect', bias=False ),
            # from here on is the old version
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect', bias=False ),

            nn.Upsample(scale_factor=2, mode='nearest'),
        )
        self.dec3 = nn.Sequential(

            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect', bias=False),

            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect', bias=False),
            
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect', bias=False),

            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect', bias=False),

            nn.Upsample(scale_factor=2, mode='nearest'),
        )
        self.dec4 = nn.Sequential(
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect', bias=False),
            
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect', bias=False),

            nn.Upsample(scale_factor=2, mode='nearest'),
        )
        self.dec5 = nn.Sequential(
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect', bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect', bias=False),
        )

    def forward(self, x):
        with torch.no_grad():
            x, classifier, encoder_layers = vgg(x)
        # flip the encoder_layers list
        encoder_layers = encoder_layers[::-1]
        x = self.dec0(x)
        x = torch.cat((x, encoder_layers[0]), 1)
        x = self.dec1(x)
        x = torch.cat((x, encoder_layers[1]), 1)
        x = self.dec2(x)
        x = torch.cat((x, encoder_layers[2]), 1)
        x = self.dec3(x)
        x = torch.cat((x, encoder_layers[3]), 1)
        x = self.dec4(x)
        x = torch.cat((x, encoder_layers[4]), 1)
        x = self.dec5(x)
        return x, classifier

In [20]:
# instantiate model
model = DecoderBN()
# transfer the model to the GPU
model = model.to(device)

In [21]:
import torch.optim as optim

def train(model, epochs = 10, lr = 0.01, criterion = nn.MSELoss(), optimizer='adam',name='', set_size=10000):

    min_loss = 99999
    
    if (optimizer=='adam'):
        optimizer = optim.Adam(model.parameters(), lr=lr)
    elif (optimizer=='sgd'):
        optimizer = optim.SGD(model.parameters(), lr=lr)
    else:
        raise Exception('Optimizer not recognized!')

    loss_history = np.zeros(epochs)
    val_history = np.zeros(epochs)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
#     scheduler = torch.optim.lr_scheduler.LinearLR(optimizer)
#     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)


    for epoch in tqdm(range(epochs)):
        model.train()
        epoch_loss = utils.AverageMeter()
        t = tqdm(total=set_size, desc='|_ Mini-batch '+ str(epoch), unit='eval')
        for X_train,_ in train_dataloader:
            # Move to device
            X_train = X_train.to(device)
            
            # Get prediction
            out, _ = model(X_train)

            # Get the loss
            curr_loss = criterion(X_train, out)
            epoch_loss.update(curr_loss.item(), batch_size)

            # Make the adjustments
            optimizer.zero_grad()
            curr_loss.backward()
            optimizer.step()
            t.set_postfix(lr=optimizer.param_groups[0]["lr"], epoch_loss=epoch_loss.avg)
            t.update(batch_size)
        t.close()
        # print/record loss
        loss_history[epoch] = epoch_loss.avg
        scheduler.step()
        
        # eval
        val_loss = utils.AverageMeter()
        model.eval()
        
        for X_val,_ in val_dataloader:
            X_val = X_val.to(device)
            
            with torch.no_grad():
                res, _ = model(X_val)
            
            curr_loss = criterion(X_val, res)
            val_loss.update(curr_loss.item(), batch_size)
        
#         scheduler.step(val_loss.avg)
        val_history[epoch] = val_loss.avg
        
        if val_loss.avg < min_loss:
            # Save model
            print('Model saved: '+ str(val_loss.avg))
            torch.save(model.state_dict(), '/kaggle/working/test'+'-c_'+name+'.pth')
            min_loss = val_loss.avg
            
        print(val_loss.avg)
        
    return loss_history, val_history

In [None]:
import torch.optim as optim

def contractive_train(model, epochs = 10, lr = 0.01, criterion = nn.MSELoss(), optimizer='sgd',name='', set_size=10000):

    min_loss = 99999
    
    if (optimizer=='adam'):
        optimizer = optim.Adam(model.parameters(), lr=lr)
    elif (optimizer=='sgd'):
        optimizer = optim.SGD(model.parameters(), lr=lr)
    else:
        raise Exception('Optimizer not recognized!')

    loss_history = np.zeros(epochs)
    val_history = np.zeros(epochs)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
#     scheduler = torch.optim.lr_scheduler.LinearLR(optimizer)
#     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)


    for epoch in tqdm(range(epochs)):
        model.train()
        epoch_loss = utils.AverageMeter()
        t = tqdm(total=set_size, desc='|_ Mini-batch '+ str(epoch), unit='eval')
        for X_train,_ in train_dataloader:
            # Move to device
            X_train = X_train.to(device)
            
            # Get prediction
            out, _ = model(X_train)

            # Get the loss
            W = model.state_dict()['.weight']
            curr_loss = loss_function(out, X_train,)
            epoch_loss.update(curr_loss.item(), batch_size)

            # Make the adjustments
            optimizer.zero_grad()
            curr_loss.backward()
            optimizer.step()
            t.set_postfix(lr=optimizer.param_groups[0]["lr"], epoch_loss=epoch_loss.avg)
            t.update(batch_size)
        t.close()
        # print/record loss
        loss_history[epoch] = epoch_loss.avg
        scheduler.step()
        
        # eval
        val_loss = utils.AverageMeter()
        model.eval()
        
        for X_val,_ in val_dataloader:
            X_val = X_val.to(device)
            
            with torch.no_grad():
                res, _ = model(X_val)
            
            curr_loss = criterion(X_val, res)
            val_loss.update(curr_loss.item(), batch_size)
        
#         scheduler.step(val_loss.avg)
        val_history[epoch] = val_loss.avg
        
        if val_loss.avg < min_loss:
            # Save model
            print('Model saved: '+ str(val_loss.avg))
            torch.save(model.state_dict(), '/kaggle/working/test'+'-c_'+name+'.pth')
            min_loss = val_loss.avg
            
        print(val_loss.avg)
        
    return loss_history, val_history

In [22]:
# loss_history = train(criterion = nn.CrossEntropyLoss())
criterion = loss_function() # nn.MSELoss()
loss_history, val_history = train(model, epochs=12, criterion = criterion, optimizer='sgd', name='MSE', set_size=train_set_size) 

In [None]:

def plotLosses(losses, val_history):

    ax1 = plt.subplot(211, yscale='log')
    ax1.plot(losses)
    ax1.set_title('Training error')

    ax2 = plt.subplot(212, yscale='log')
    ax2.plot(val_history)
    ax2.set_title('Validation error')
    plt.tight_layout()
    plt.show()

In [None]:
plotLosses(loss_history, val_history)

In [23]:
torch.save(model.state_dict(), '/kaggle/working/test'+'-c_final.pth')
torch.save(loss_history, '/kaggle/working/test_loss_his.pth')
torch.save(val_history, '/kaggle/working/test_val_his.pth')
torch.cuda.empty_cache()

In [24]:
def plotValidation(images, img_categs, model, classes):
    test_fig, test_axis = plt.subplots(2,5)
    
    images = images.to(device)
    with torch.no_grad():
        out, class_out = model(images)
    
    out = utils.unNormalize(out, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    images = utils.unNormalize(images, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    probabilities = torch.nn.functional.softmax(class_out, dim=0)

    out = out.to('cpu')
    out = torch.permute(out, (0,2,3,1))
    images = images.to('cpu')
    images = torch.permute(images, (0,2,3,1))
    probabilities = probabilities.to('cpu')

    top_class_out = torch.topk(probabilities, 1)
    indexes = top_class_out.indices.detach().numpy()
    values = top_class_out.values.detach().numpy()

    for inx, image in enumerate(images):
        test_axis[0,inx].imshow(image.detach().squeeze().numpy())
        test_axis[0,inx].set_title(str(classes[img_categs[inx]]))
        test_axis[1,inx].imshow(out[inx].detach().squeeze().numpy())
        test_axis[1,inx].set_title(str(classes[indexes[inx,0]])+" ["+str('{:.2f}'.format(criterion(image, out[inx]))) +"]")

    test_fig.tight_layout()
    test_fig.show()

In [25]:
model = DecoderBN()
model = model.to(device)
model.load_state_dict(torch.load('/kaggle/working/test-c_MSE.pth'))
model.eval()

images,img_cat = next(iter(val_dataloader))
test_images = images[:5]

plotValidation(test_images,img_cat, model, classes)

In [26]:
class BoolMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.num_succ = 0
        self.num_fail = 0
        self.count = 0
        self.values = []

    def update(self, value):
        if value:
            self.num_succ += 1
        else:
            self.num_fail += 1
        self.count += 1
        self.values.append(value) 


def validate_dataset(model, val_dataloader, rec_criterion, set_size):
    cat_acc = BoolMeter()
    rec_err = utils.AverageMeter()
    t = tqdm(total=set_size, desc='VAL ', unit='eval')
    for X_val,Y_val in val_dataloader:
        X_val = X_val.to(device)
        Y_val = Y_val.to(device)

        with torch.no_grad():
            rec_out, classification = model(X_val)

        rec_err.update(rec_criterion(rec_out, X_val).detach().item(), X_val.size()[0])
        
        top_class_out = torch.topk(classification, 5)
        top_class_out = torch.squeeze(top_class_out.indices)

        # convert to float
        Y_val = Y_val.type(torch.float)
        top_class_out = top_class_out.type(torch.float)
        
        if X_val.size()[0] == 1:
            curr_acc = (Y_val in top_class_out)
        cat_acc.update(curr_acc)
        t.update(X_val.size()[0])
    t.close()
    return rec_err, cat_acc

In [166]:
rec_err, cat_acc = validate_dataset(model, model_test_dataloader, criterion, 3923)

In [167]:
print(cat_acc.num_succ, cat_acc.num_fail)
plt.scatter(rec_err.values, cat_acc.values)
plt.xlabel('Reconstruction error (MSELoss)')
plt.ylabel('Categorisation acc (%)')
plt.title('Error correlation')
plt.show()

In [174]:
def plot_barchart(buckets, succ_buckets, fail_buckets, bar_labels):
    width = 0.001       # the width of the bars: can also be len(x) sequence

    fig, ax = plt.subplots()

    ax.bar(buckets, succ_buckets, width, label='Success')
    fails = ax.bar(buckets, fail_buckets, width, bottom=succ_buckets,
           label='Fail')

    ax.bar_label(fails, bar_labels)
    
    ax.set_ylabel('Categorisation')
    ax.set_xlabel('Reconstruction error (MSE)')
    ax.set_xlim([-0.001, 0.04])
    ax.set_title('Categorisation success relative to reconstruction error')
    ax.legend()

    plt.show()
    
    plt.bar(buckets, bar_labels, width, label='Succ ratio')
    plt.title('Categorisation accuracy in relation to reconstruction acc.')
    plt.ylabel('Accuracy (%)')
    plt.xlabel('Reconstruction error (MSE)')
    
    plt.show()


In [173]:
rec_err_vals_rounded = torch.round(torch.tensor(rec_err.values), decimals=3).numpy()
cat_acc_vals = cat_acc.values
unique = np.unique(rec_err_vals_rounded)
succs, fails, ratios, ratios_scaled = [], [], [], []
for val in unique:
    occ = np.extract(rec_err_vals_rounded==val, cat_acc_vals)
    succs.append(np.count_nonzero(occ==True))
    fails.append(np.count_nonzero(occ==False))
    ratios.append(round(np.count_nonzero(occ==True)/len(occ), 2))
    ratios_scaled.append(round(np.count_nonzero(occ==True)*(len(occ)/len(rec_err_vals_rounded))/len(occ), 2))
    

plot_barchart(unique, succs, fails, ratios)
plot_barchart(unique, succs, fails, ratios_scaled)

In [164]:
recriation_imgs = test_images
print(recriation_imgs.size())
errs = []

err = 9999
counter = 0
recriation_imgs = recriation_imgs.to(device)
while err > 0.001:
    with torch.no_grad():
        rec_out, _ = model(recriation_imgs)

    err = criterion(rec_out, recriation_imgs).detach().item()
    errs.append(err)
    print(err, end="\r")
    recriation_imgs = rec_out
    counter += 1
    if counter >= 1000:
        break

In [165]:
test_fig, test_axis = plt.subplots(2,5)
    
recriation_imgs = utils.unNormalize(recriation_imgs, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
test_images = utils.unNormalize(test_images, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

recriation_imgs = recriation_imgs.to('cpu')
recriation_imgs = torch.permute(recriation_imgs, (0,2,3,1))
test_images = test_images.to('cpu')
test_images = torch.permute(test_images, (0,2,3,1))    
    
for inx, image in enumerate(test_images):
    test_axis[0,inx].imshow(image.detach().squeeze().numpy())
#     test_axis[0,inx].set_title(str(classes[img_categs[inx]]))
    test_axis[1,inx].imshow(recriation_imgs[inx].detach().squeeze().numpy())
#     test_axis[1,inx].set_title(str(classes[indexes[inx,0]])+" ["+str('{:.2f}'.format(criterion(image, out[inx]))) +"]")

test_fig.tight_layout()
test_fig.show()