In [None]:
# this code is based on [ref], which is released under the MIT licesne
# make sure you reference any code you have studied as above here
# https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
# https://github.com/tneumann/minimal_glo/blob/master/glo.py
# https://github.com/yedidh/glann/blob/master/glo.py
# http://www.cs.cmu.edu/~16385/s17/Slides/3.1_Image_Pyramid.pdf
# https://angms.science/doc/CVX/Proj_l2.pdf

# imports
import numpy as np
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Subset
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA


# hyperparameters
batch_size  = 64
n_channels  = 3
latent_size = 256
dataset = 'stl10'
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


In [None]:
# optional Google drive integration - this will allow you to save and resume training, and may speed up redownloading the dataset
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
class Subset_with_Index(torch.utils.data.Dataset):
    def __init__(self,dataset: torch.utils.data.Dataset, indices: [int]) -> None:
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, idx):
        data = self.dataset[self.indices[idx]]
#         out = {"data":data,"label":torch.Tensor([target]).type(),"index":torch.Tensor([idx])}
#         print(out)
        if len(data)==3:
            return data
        else:
            return data[0],data[1],(idx,self.indices[idx])

    def __len__(self):
        return len(self.indices)

In [None]:
# helper function to make getting another batch of data easier
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

# you may use cifar10 or stl10 datasets
sklearnDIR = "drive/MyDrive/training/" # on local ""


# cifar10 is a collection of 32x32 images by default
if dataset == 'cifar10':
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.CIFAR10("./"+sklearnDIR+'cifar10', train=True, download=True, transform=torchvision.transforms
                                     .Compose([
            torchvision.transforms.ToTensor(),
        ])),
        shuffle=True, batch_size=batch_size, drop_last=True
    )
    class_names = ['airplane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# 32x32 MNIST for simpler testing
if dataset == 'mnist':
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST("./"+sklearnDIR+'mnist', train=True, download=True, transform=torchvision.transforms
                                     .Compose([
            torchvision.transforms.Resize(32),
            torchvision.transforms.ToTensor(),
        ])),
        shuffle=True, batch_size=batch_size, drop_last=True
    )
    class_names = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine']
    n_channels = 1

# stl10 has larger images which are much slower to train on. You should develop your method with CIFAR-10 before experimenting with STL-10
pca_channels = 1
# stl10 is a collection of 96x96 images by default
if dataset == 'stl10':


    stl_dataset = torchvision.datasets.STL10("./"+sklearnDIR+'stl10', split='train+unlabeled', download=True, transform=torchvision
                                   .transforms.Compose([
            torchvision.transforms.ToTensor(),
        ]))
    if pca_channels == 3:
        stl_dataset.transforms = stl_dataset.transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
    elif pca_channels == 1:
        stl_dataset.transforms = stl_dataset.transform = torchvision.transforms.Compose([
            torchvision.transforms.Grayscale(),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.5), (0.5)),
        ])

    class_names = ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck'] # these are slightly different to CIFAR-10
    data_loader = torch.utils.data.DataLoader(stl_dataset, shuffle=True, batch_size=batch_size, drop_last=True)

data_iterator = iter(cycle(data_loader))

**Determine Training set using clustering**

In [None]:
horse_indices=[]
bird_indices = []
for i in tqdm(range(len(stl_dataset))):
    # bird and plane
    if stl_dataset[i][1] in [0,1]:
        bird_indices.append(i)
    # horse and deer
    elif stl_dataset[i][1] in [4,6]:
        horse_indices.append(i)

horseSubset = Subset_with_Index(stl_dataset,horse_indices)

horse_loader = torch.utils.data.DataLoader(horseSubset, shuffle=True, batch_size=batch_size, drop_last=True)

horse_iterator = iter(cycle(horse_loader))

x,t,i = next(horse_iterator)

x,t = x.to(device), t.to(device)
plt.grid(False)
plt.imshow(torchvision.utils.make_grid(x/2. + 0.5 ).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
plt.show()

birdSubset = Subset_with_Index(stl_dataset,bird_indices)

bird_loader = torch.utils.data.DataLoader(birdSubset, shuffle=True, batch_size=batch_size, drop_last=True)

bird_iterator = iter(cycle(bird_loader))

x,t,i = next(bird_iterator)

x,t = x.to(device), t.to(device)
plt.grid(False)
plt.imshow(torchvision.utils.make_grid(x/2. + 0.5 ).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
plt.show()


In [None]:
print(len(horse_indices),len(bird_indices))
genHorseKmeans =False
loops=150 # -> 10 loops through the labelled data

if genHorseKmeans:

    horse_PCA = PCA(n_components=latent_size)

    x = []
    for i in tqdm(range(loops)):
      data,t,idx = next(horse_iterator)
      x.append(data.numpy().reshape(len(data),-1))

    reduced_x = horse_PCA.fit_transform(np.array(x).reshape(batch_size*loops,pca_channels*96*96))

    horse_k_means = KMeans(init="k-means++",n_clusters=6, n_init=20)
    horse_k_means.fit(reduced_x)

    pickle.dump(horse_k_means, open("./"+sklearnDIR+"horse-k-means.pkl", "wb"))
    pickle.dump(horse_PCA, open("./"+sklearnDIR+"horse-PCA.pkl", "wb"))
else:
    horse_k_means = pickle.load(open("./"+sklearnDIR+"horse-k-means.pkl","rb"))
    horse_PCA = pickle.load(open("./"+sklearnDIR+"horse-PCA.pkl","rb"))


genBirdKmeans = False

if genBirdKmeans:

    bird_PCA = PCA(n_components=latent_size)

    x = []
    for i in tqdm(range(loops)):
      data,t,idx = next(bird_iterator)
      x.append(data.numpy().reshape(len(data),-1))

    reduced_x = bird_PCA.fit_transform(np.array(x).reshape(batch_size*loops,pca_channels*96*96))

    bird_k_means = KMeans(init="k-means++",n_clusters=6, n_init=20)
    bird_k_means.fit(reduced_x)


    pickle.dump(bird_k_means, open("./"+sklearnDIR+"bird-k-means.pkl", "wb"))
    pickle.dump(bird_PCA, open("./"+sklearnDIR+"bird-PCA.pkl", "wb"))
else:
    bird_k_means = pickle.load(open("./"+sklearnDIR+"bird-k-means.pkl","rb"))
    bird_PCA = pickle.load(open("./"+sklearnDIR+"bird-PCA.pkl","rb"))

In [None]:
horse_cluster_indices = [[],[],[],[],[],[]]
datasetBatches = len(horseSubset)
for i in range (datasetBatches):

    x = horseSubset[i][0]

    reduced_x = horse_PCA.transform(x.reshape(1,pca_channels*96*96))
    cluster = horse_k_means.predict(reduced_x)

    horse_cluster_indices[cluster[0]].append(i)

bird_cluster_indices = [[],[],[],[],[],[],[],[]]
datasetBatches = len(birdSubset)
for i in range (datasetBatches):

    x = birdSubset[i][0]
    reduced_x = bird_PCA.transform(x.reshape(1,pca_channels*96*96))
    cluster = bird_k_means.predict(reduced_x)
    bird_cluster_indices[cluster[0]].append(i)


In [None]:

if pca_channels == 1:
    stl_dataset.transforms = stl_dataset.transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
for k in [(horseSubset,horse_cluster_indices),(birdSubset,bird_cluster_indices)]:

    for j,i in enumerate(k[1]):
        print("cluster: ",j)
        j+=1
        if len(i)>0:
            tempsubset = Subset(k[0],i)
            # print(tempsubset.indices)
            temp_loader = torch.utils.data.DataLoader(tempsubset, shuffle=True, batch_size=min(batch_size,len(i)),
                                                     drop_last=True)

            temp_iterator = iter(cycle(temp_loader))

            x,t,i = next(temp_iterator)

            x,t = x[:64].to(device), t.to(device)
            plt.grid(False)
            plt.imshow(torchvision.utils.make_grid(x/2. + 0.5 ).cpu().data.permute(0,2,1).contiguous().permute(2,1,0),
                       cmap=plt.cm.binary)
            plt.show()
            temp_iterator = iter(cycle(temp_loader))

            x,t,i = next(temp_iterator)

            x,t = x.to(device), t.to(device)
            plt.grid(False)
            plt.imshow(torchvision.utils.make_grid(x/2. + 0.5 ).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
            plt.show()



In [None]:


newHorses = Subset_with_Index(horseSubset,horse_cluster_indices[2]+horse_cluster_indices[0]+horse_cluster_indices[3]+horse_cluster_indices[4]+horse_cluster_indices[5])

horse_loader = torch.utils.data.DataLoader(newHorses, shuffle=True, batch_size=batch_size, drop_last=True)

horse_iterator = iter(cycle(horse_loader))

x,t,i = next(horse_iterator)

x,t = x[:64].to(device), t.to(device)
plt.grid(False)
plt.imshow(torchvision.utils.make_grid(x/2. + 0.5 ).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
plt.show()

x,t,i = next(horse_iterator)

x,t = x.to(device), t.to(device)
plt.grid(False)
plt.imshow(torchvision.utils.make_grid(x/2. + 0.5 ).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
plt.show()

In [None]:


newBirds = Subset(birdSubset,bird_cluster_indices[1]+bird_cluster_indices[2])

stl_loader = torch.utils.data.DataLoader(newBirds+newHorses, shuffle=True, batch_size=batch_size, drop_last=True)

stl_iterator = iter(cycle(stl_loader))

x,t,i = next(stl_iterator)

x,t = x.to(device), t.to(device)
plt.grid(False)
plt.imshow(torchvision.utils.make_grid(x/2. + 0.5 ).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
plt.show()


In [None]:
# delete unused variables

del horse_indices
del bird_indices
del bird_loader
del bird_iterator
del bird_cluster_indices
del horse_cluster_indices
del datasetBatches

**Define a simple convolutional autoencoder**

In [None]:
# Number of Linear input connections depends on output of conv2d layers
# and therefore the input image size, so compute it.
# H_out = (H_in−1)×stride[0]−2×padding[0]+dilation[0]×(kernel_size[0]−1)+output_padding[0]+1

def conv2d_size_out(size, kernel_size = 4, stride = 2,padding=0):
    return ((size - kernel_size + 2*padding) // stride)  + 1
def conv2dTrans_size_out(size, kernel_size = 4, stride = 2,padding=0):
    return stride*size + 1 + kernel_size - 2* padding

# simple block of convolution, batchnorm, and leakyrelu
class ConvBlock(nn.Module):
    def __init__(self, in_f, out_f,stride=1):
        super(ConvBlock, self).__init__()
        self.f = nn.Sequential(
            nn.Conv2d(in_f, out_f, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(out_f),
            nn.LeakyReLU(inplace=True)
        )
    def forward(self,x):
        return self.f(x)

class ConvTransBlock(nn.Module):
    def __init__(self, in_f, out_f,kernel=4,stride=2,padding=1):
        super(ConvTransBlock, self).__init__()
        self.f = nn.Sequential(
            nn.ConvTranspose2d(in_f, out_f, kernel_size=kernel, stride=stride, padding=padding),
            nn.BatchNorm2d(out_f),
            nn.LeakyReLU(inplace=True)
        )
    def forward(self,x):
        return self.f(x)

def gauss_kernel():
    # standard definition for sigma=1 - https://homepages.inf.ed.ac.uk/rbf/HIPR2/gsmooth.htm
    kernel = torch.tensor(  [[1., 4., 7., 4., 1],
                            [4., 16., 26., 16., 4.],
                            [7., 26., 41., 26., 7.],
                            [4., 16., 26., 16., 4.],
                            [1., 4., 7., 4., 1.]])
    kernel /= kernel.sum()
    kernel = kernel.repeat(n_channels,1, 1, 1)
    kernel = kernel.to(device)
    return kernel

# convolve gauss kernel over image causing a blur
def conv_gauss(img, kernel):
    # pad to make convolution same input as output
    img = F.pad(img,(2,2,2,2),mode='reflect')
    return F.conv2d(img,kernel,groups=kernel.shape[0])

def lap_pyramid(img,kernel,max_depth=5):
    temp = img
    pyramid = []
    for i in range (max_depth):
        blurred = conv_gauss(temp,kernel)
        downsample = F.avg_pool2d(blurred, 2)

        pyramid.append(temp-blurred)

        temp = downsample
    pyramid.append(temp)
    return pyramid

x,t,i = next(stl_iterator)

kernel = gauss_kernel()
x = x.to(device)
pyr = lap_pyramid(x,kernel,5)

for i in pyr:
  plt.figure(1,(10,10))
  x,t = x.to(device), t.to(device)
  plt.grid(False)
  plt.imshow(torchvision.utils.make_grid(i/2. +0.5).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
  plt.show()
plt.figure(1,(10,10))
x,t = x.to(device), t.to(device)
plt.grid(False)
plt.imshow(torchvision.utils.make_grid(x/2. +0.5).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
plt.show()

def lapLoss(x,x_hat,kernel):
    x_hat_pyramid  = lap_pyramid( x_hat,kernel , 5)
    x_pyramid = lap_pyramid(x,kernel , 5)
    return sum(F.l1_loss(a, b) for a, b in zip(x_hat_pyramid, x_pyramid))

In [None]:
# define the model
# https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
# https://github.com/tneumann/minimal_glo/blob/master/glo.py
# https://github.com/yedidh/glann/blob/master/glo.py
class GLO(nn.Module):
    def __init__(self, f=16):
        super().__init__()
        self.gauss_kernel = gauss_kernel()

        # [(W−K+2P)/S]+1 == conv_output ((width - kernel size + 2*padding)/stride) + 1

        self.generator = nn.Sequential(

            nn.ConvTranspose2d(latent_size, f * 16, 4, 2, 1, bias=False), #2x2
            nn.BatchNorm2d(f * 16), nn.ReLU(True),

            # nn.ConvTranspose2d(latent_size, f * 8, 4, 1, 0, bias=False), # or start at 4x4


            nn.ConvTranspose2d(f*16, f * 8, 4, 2, 1, bias=False), # 4x4
            nn.BatchNorm2d(f * 8), nn.ReLU(True),
            nn.ConvTranspose2d(f * 8, f * 4, 4, 2, 1, bias=False),  # 8x8
            nn.BatchNorm2d(f * 4), nn.ReLU(True),
            nn.ConvTranspose2d(f * 4, f * 2, 4, 2, 1, bias=False),  # 16x16
            nn.BatchNorm2d(f * 2), nn.ReLU(True),
            nn.ConvTranspose2d(f * 2, f, 4, 2, 1, bias=False),  # 32x32
            nn.BatchNorm2d(f), nn.ReLU(True),
            nn.ConvTranspose2d(f, n_channels, 5, 3, 1, bias=False),  # 96x96
            nn.Tanh(),
        )

    def forward(self, latent_code):

        return self.generator(latent_code.view(latent_code.size(0), latent_size, 1, 1))


    def projection_l2_sphere(self,z:torch.Tensor): # from GLO paper -> project latent space onto sphere
        # x = z / max(1, ||z||_2^2)
        # ||z||_2^2 -> sum (z_i)^2

        l2Norm = torch.linalg.norm(z,dim=1)
        l2Norm = torch.maximum(l2Norm,torch.Tensor([1]))
        return z/l2Norm.reshape(z.shape[0],1)

    def calc_loss(self,recon,actual):
        # laplacian loss + the l1 loss between initial images
        return lapLoss(recon,actual,self.gauss_kernel)+ nn.functional.l1_loss(actual,recon)*0.01



class embeddingNetwork(nn.Module):
    def __init__(self, totalImages):
        super().__init__()
        self.n = totalImages
        self.embedding = nn.Embedding(self.n, latent_size)

    def norm_embedding(self): # normalise the embedding weights, using l2 norm
        wn = self.embedding.weight.norm(2, 1).data.unsqueeze(1)
        self.embedding.weight.data = \
            self.embedding.weight.data.div(wn.expand_as(self.embedding.weight.data))

    def forward(self, idx):
        z = self.embedding(idx).squeeze()
        return z



glo = GLO(96).to(device)

print(f'> Number of GLO parameters {len(torch.nn.utils.parameters_to_vector(glo.parameters()))}')
glo_optimiser = torch.optim.Adam(glo.parameters(), lr=0.00001, weight_decay=0.0005)
Z = embeddingNetwork(len(stl_loader.dataset)).to(device)
Z_optimiser = torch.optim.Adam(Z.parameters(), lr=0.00001, weight_decay=0.0005)

In [None]:
# generate initial latent Z's


# Z = torch.zeros((len(stl_loader.dataset)-(len(stl_loader.dataset)%64), latent_size))
# PCA_bool = True
#
# pca_inst = PCA(n_components=latent_size)
#
# x = []
# for data,t,idx in tqdm(stl_loader):
#     x.append(data.numpy().reshape(len(data),-1))
#
# pca_inst.fit(np.array(x).reshape((len(stl_loader.dataset)-(len(stl_loader.dataset)%64),27648)))
# # misses len(stl_loader) - 1408 images -> ie difference in batch size
# for data,target,i in tqdm(stl_loader):
#
#     if PCA_bool:
#       Z[i[0]] = torch.Tensor(pca_inst.transform(data.reshape(batch_size,-1)))
#     else:
#       temp = torch.randn((batch_size,latent_size))
#
#       Z[i[0]]=temp
#       x,t = x.to(device), t.to(device)
#       plt.grid(False)
#       plt.imshow(torchvision.utils.make_grid(data/ 2. + 0.5).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
#       plt.show()
# Z = glo.projection_l2_sphere(Z)

In [None]:
# initial sample
x, t, idx = next(stl_iterator)
latent_code = Z(idx[0].to(device)).to(device)
g = glo.forward(latent_code)

plt.rcParams['figure.dpi'] = 100
plt.grid(False)
plt.imshow(torchvision.utils.make_grid(g / 2. + 0.5).cpu().data.permute(0, 2, 1).contiguous().permute(2, 1, 0),
           cmap=plt.cm
           .binary)
plt.show()
plt.pause(0.0001)

In [None]:

EPOCHS = 150
# training loop, you will want to train for more than 10 here!
for epoch in range(EPOCHS):
    # array(s) for the performance measures
    loss_arr = []

    pbar = tqdm(200, desc='epoch % 3d' % epoch)
    # iterate over some of the train dateset
    for j in range(200):
        # sample x from the dataset
        x, t, idx = next(stl_iterator)
        x= x.to(device)

        glo_optimiser.zero_grad()
        Z_optimiser.zero_grad()

        z = Z(idx[0].to(device))
        z = glo.forward(z)
        loss = glo.calc_loss(z, x)

        # backpropagate to compute the gradient of the loss w.r.t the parameters and optimise
        loss.backward()
        glo_optimiser.step()
        Z_optimiser.step()
        # collect stats
        loss_arr.append(loss.item())
        pbar.set_postfix({'loss': np.mean(loss_arr[-100:])})
        pbar.update(1)
    Z.norm_embedding()
    pbar.close()
    print('loss ' + str(loss.mean()))


    # output reconstructions

    if epoch%5==0:
        x, t, idx = next(stl_iterator)
        latent_code = Z(idx[0].to(device)).to(device)
        g = glo.forward(latent_code)

        plt.rcParams['figure.dpi'] = 100
        plt.grid(False)
        plt.imshow(torchvision.utils.make_grid(g / 2. + 0.5).cpu().data.permute(0, 2, 1).contiguous().permute(2, 1, 0),
                   cmap=plt.cm
                   .binary)

        plt.show()
        plt.pause(0.0001)

        if epoch%25==0:
            torch.save({'GLO params': glo.state_dict(), 'glo optimiser': glo_optimiser.state_dict(), 'epoch': epoch,
            "Z Latent params": Z.state_dict(), "Z optimiser":Z_optimiser.state_dict()},
           "./" + sklearnDIR + f'/save_embedding_{epoch}.chkpt')


In [None]:
# now show your best batch of data for the submission, right click and save the image for your report
x, t, idx = next(stl_iterator)
latent_code = Z(idx[0].to(device)).to(device)
g = glo.forward(latent_code)
plt.rcParams['figure.dpi'] = 175
plt.grid(False)
plt.imshow(torchvision.utils.make_grid(g / 2. + 0.5).cpu().data.permute(0, 2, 1).contiguous().permute(2, 1, 0),
           cmap=plt.cm.binary)
plt.show()

In [None]:
#
# params = torch.load("./" + sklearnDIR + '/saveV15.chkpt', map_location=torch.device('cpu'))
# glo.load_state_dict(params['GLO params'])
# Z.load_state_dict(params["Z Latent params"])

# Loop through latent space to find learned features


for i in horse_loader:
    print(i)
    plt.rcParams['figure.dpi'] = 400
    plt.grid(False)
    plt.imshow(torchvision.utils.make_grid(glo(Z(torch.LongTensor(list(range(i,min(i+16,Z.n))))))/ 2. + 0.5).cpu().data.permute(0,2,1).contiguous().permute(2,1,0),
               cmap=plt.cm.binary)
    plt.title(f"Recon {i}")

    plt.show()

# White Horses - 101, 145, 161, 201, 325(gray), 345,605(best) 798,

#birds 659, 263, 264, 176, 287,819, 857,862


# 798


In [None]:
exploreLatent=True

if exploreLatent:

    birdTarget = torch.LongTensor([659, 263, 264, 176, 287,819, 857,862])
    horseLatent = Z(torch.LongTensor([605,101, 145, 161, 201, 325, 345,798]))

    plt.rcParams['figure.dpi'] = 1000
    plt.grid(False)
    plt.imshow(torchvision.utils.make_grid(glo(horseLatent.data.reshape(8,latent_size)) / 2. + 0.5).cpu().data.permute(0, 2,
                                                                                                          1).contiguous().permute(
        2, 1, 0),
               cmap=plt.cm.binary)
    plt.show()
    birdLatent = Z(birdTarget)
    plt.rcParams['figure.dpi'] = 1000
    plt.grid(False)
    plt.imshow(torchvision.utils.make_grid(glo(birdLatent.data.reshape(8,latent_size)) / 2. + 0.5).cpu().data.permute(0, 2,
                                                                                                          1).contiguous().permute(
        2, 1, 0),
               cmap=plt.cm.binary)
    plt.show()

    # allImages = Z(torch.LongTensor(list(range(Z.n))))
    # mean = torch.mean(allImages,dim=0)
    # plt.rcParams['figure.dpi'] = 1000
    # plt.grid(False)
    # plt.imshow(torchvision.utils.make_grid(glo(mean.data.reshape(1,latent_size)) / 2. + 0.5).cpu().data.permute(0, 2,
    #                                                                                                       1).contiguous().permute(
    #     2, 1, 0),
    #            cmap=plt.cm.binary)
    # plt.show()
    for j in range(8):
        for i in np.linspace(1,0,8):

            temp = horseLatent.detach().clone()

            temp.data = i * temp[:] + (1-i) * birdLatent[j].reshape(1,latent_size)

            plt.imshow(torchvision.utils.make_grid(glo(temp.reshape(8, -1)) / 2. + 0.5).cpu().data.permute(0, 2, 1)
                .contiguous().permute(2, 1, 0),cmap=plt.cm.binary)

            plt.title(f"Bird {j} scaled with {i:.2f}")
            plt.show()


In [None]:
# optional example code to save your training progress for resuming later if you authenticated Google Drive previously
torch.save({'GLO params': glo.state_dict(), 'glo optimiser': glo_optimiser.state_dict(), 'epoch': epoch,
            "Z Latent params": Z.state_dict(), "Z optimiser":Z_optimiser.state_dict()},
           "./" + sklearnDIR + '/save.chkpt')