In [None]:
import torch
class Linear(torch.nn.Module):
    def __init__(self, in_features=None, output_features=None):
        super(Linear, self).__init__()
        self.output_features = output_features
        self.in_features = in_features
        self.W = torch.nn.Parameter(data=torch.rand(size=(self.in_features, output_features)).cuda())
        torch.nn.init.xavier_normal(self.W)
        # if in_features is None or output_features is None:
        #     self.W = None
        # else:
        #     self.W = torch.nn.Parameter(data=torch.rand(size=(self.in_features, output_features)))

    def forward(self, x):
        return torch.matmul(x.reshape(x.shape[0], x.shape[1],-1), self.W)  # + self.b

    def __repr__(self):
        return "Linear(in_features = " + str(self.in_features) + " out_features = " + str(self.output_features) + ")"


class Relu(torch.nn.Module):
    def __init__(self):
        # Chiamata alla superclasse
        super(Relu, self).__init__()

    def forward(self, x):
        return torch.nn.functional.relu(x)

class BatchNorm1D(torch.nn.BatchNorm1d):
  def __init__(self, num_features, eps=1e-5, momentum=0.1,
                affine=True, track_running_stats=True):
      super(BatchNorm1D, self).__init__(
          num_features, eps, momentum, affine, track_running_stats)
      self.stability = 1e-05
      # Scale and shifting parameters
      self.gamma = torch.ones(num_features, requires_grad=True)
      self.beta = torch.zeros(num_features, requires_grad=True)
      self.eps = eps
      self.momentum = momentum

  def forward(self, x):
      # calculate running estimates
      if self.training:
          mean = x.mean([0, 2, 3])
          # use biased var in train
          var = x.var([0, 2, 3], unbiased=False)
          n = x.numel() / x.size(1)
          with torch.no_grad():
              self.running_mean = self.momentum * mean + (1 - self.momentum) * self.running_mean
              # update running_var with unbiased var
              self.running_var = self.momentum * var * n / (n - 1)\
                  + (1 - self.momentum) * self.running_var
      else:
          mean = self.running_mean
          var = self.running_var

      # Scale and shift
      x = (x - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
      return x * self.gamma[None,:,None,None] + self.beta[None, :, None, None]

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch

class VAE(torch.nn.Module):
    def __init__(self,input_shape,channels,h_dim,z_dim):
        super(VAE,self).__init__()
        #Encoder, image to hidden space
        self.encode_input_batch = BatchNorm1D(channels,eps=1e-3).cuda()
        self.img_2hid = Linear(input_shape,h_dim).cuda()
        self.fc_mu,self.fc_std = Linear(h_dim,z_dim).cuda(),Linear(h_dim,z_dim).cuda()
        self.norm_mu,self.norm_std = BatchNorm1D(channels,eps=1e-3).cuda(), BatchNorm1D(channels,eps=1e-3).cuda()
        # The decoder from the z space goes back to the hidden
        self.z_2hidden = Linear(z_dim,h_dim).cuda()
        self.hidden_batch_norm = BatchNorm1D(channels,eps=1e-3).cuda()
        # Finally from the hidden we go back to the input space
        self.hidden_2img = Linear(h_dim,input_shape).cuda()
        self.img_batch = BatchNorm1D(channels,eps=1e-3).cuda()
        self.relu = Relu().cuda()
        self.rloss = torch.nn.MSELoss(reduction="sum").cuda()



    def encode(self,x):
        x = self.img_2hid(x) #+ 1e-6)
        #print("To hidden: ",x)
        #Normalize
        x = self.encode_input_batch(x)
        x = self.relu(x)
        #print("Relu of hidden: ",x)
        mu,sigma = self.fc_mu(x),self.fc_std(x)
        mu,sigma = self.norm_mu(mu),self.norm_std(sigma)
        return self.relu(mu),self.relu(sigma)

    def decode(self,x):
        first_hidden = self.z_2hidden(x)
        hidden_norm = self.hidden_batch_norm(first_hidden)
        leaky_relu = torch.nn.functional.leaky_relu(hidden_norm)
        hidden_2img_space = self.hidden_2img(leaky_relu)
        img_batch = self.img_batch(hidden_2img_space)
        last_relu = self.relu(img_batch)
        if torch.all(torch.isnan(last_relu)):
            torch.save({
                'x_shape': x.shape,
                'first_hidden':first_hidden,
                'hidden_norm':hidden_norm,
                'leaky_relu':leaky_relu,
                'hidden_2img_space':hidden_2img_space,
                'img_batch':img_batch,
                'last_relu':last_relu,
            },'drive/MyDrive/ColabVAE/errorLog/decode_error.pt')
            raise OverflowError
        return last_relu

    def normalize(self,x):
        original = x
        v_min, v_max = x.min(), x.max()
        new_min, new_max = torch.Tensor([0]).float().to('cuda:0'), torch.Tensor([1]).float().to('cuda:0')
        x = torch.add(torch.mul(torch.div(torch.sub(x, v_min), torch.sub(v_max, v_min)), torch.sub(new_max, new_min)),new_min).to('cuda:0')
        if torch.all(torch.isnan(x)):
                torch.save({
                    'original': original,
                    'min': v_min,
                    'max': v_max,
                    'x': x,
                }, 'drive/MyDrive/ColabVAE/errorLog/normalization_error.pt')
                raise OverflowError
        return x


    def forward(self,x):
        original = x.shape
        mu, sigma = self.encode(x)
        normal = torch.distributions.Normal(0,1)
        # Sample from latent distribution from encoder std = torch.exp(torch.log(sigma) / 2 )
        epsilon = torch.randn_like(sigma)
        z = mu + sigma.exp() * epsilon
        if torch.all(torch.isnan(z)):
                torch.save({
                    'mu': mu,
                    'sigma': sigma,
                    'epsilon': epsilon,
                    'z': z,
                }, 'drive/MyDrive/ColabVAE/errorLog/forward_error.pt')
                raise OverflowError
        x_hat = self.decode(z)
        x_hat = self.normalize(x_hat).reshape(original)
        log_q = torch.sum(normal.log_prob(epsilon) - sigma)
        recon_error = self.rloss(x_hat,x)
        elbo = recon_error - log_q
        return elbo,recon_error,x_hat

In [None]:
import itertools
def build_hyper_params():
    epochs = [15]
    batch = [1000]
    hidden_features = [400,600,800]
    z_features = [100,200,400]
    lr = [1,1e-5,3e-2]
    grad_clip = [1.0,2e-4]
    hyper_list = [epochs,batch,hidden_features,z_features,lr,grad_clip]
    return list(itertools.product(*hyper_list))



In [None]:
import torchvision
from torch.utils.data import DataLoader, random_split
from torchvision import transforms

def load_data(batch,data_dir="./data"):

    trainset = torchvision.datasets.MNIST(
        root=data_dir, train=True, download=True,transform=transforms.ToTensor())

    testset = torchvision.datasets.MNIST(
        root=data_dir, train=False, download=True,transform=transforms.ToTensor())


    # Partitioning the dataset in 80% train & 20% validation
    test_abs = int(len(trainset) * 0.8)
    train_subset, val_subset = random_split(trainset, [test_abs, len(trainset) - test_abs])
    trainloader = torch.utils.data.DataLoader(
        train_subset, batch_size=batch, shuffle=True, num_workers=1
    )

    valloader = torch.utils.data.DataLoader(
        val_subset, batch_size=batch , shuffle=True, num_workers=1
    )

    testloader = torch.utils.data.DataLoader(
        testset, batch_size=batch, shuffle=False, num_workers=1
    )

    return trainloader,valloader,testloader

In [None]:
from tqdm import tqdm
########## MAIN LOOP #############################

configurations = build_hyper_params()
configurations = configurations[configurations.index((15, 1000, 400, 400, 0.03, 1.0)) + 1:configurations.index((15, 1000, 600, 400, 1, 1.0))]
loop = tqdm(configurations)
print("Model selection on ",len(configurations)," models")
for i,config in enumerate(loop):
    tr_loss_list,val_loss_list = [],[]
    tr_rec_list, val_rec_list = [], []
    epochs, batch, hidden_features, z_features, lr, grad_clip = config
    net = VAE(28*28,1,hidden_features,z_features).cuda()
    opt = torch.optim.Adam(net.parameters(),lr)
    #Load data
    trainset,valset,_ = load_data(batch)
    tr_loss,val_loss = 0,0

    for epoch in range(epochs):
        # Initial checkpoint
        torch.save({
            'epoch': epoch,
            'config': config,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': opt.state_dict(),
            'loss_list': tr_loss_list,
            'val_loss_list':val_loss_list,
            'tr_rec_list': tr_rec_list,
            'val_rec_list': val_rec_list
        },'drive/MyDrive/ColabVAE/checkpoints/before_training.pt')

        # Train loop
        net.train()
        loop.set_postfix(Trying=config, Status='Training',Epoch=epoch)
        for i,data in enumerate(trainset):
                opt.zero_grad()
                x,labels = data
                x = x.to('cuda:0')
                #t0 = time.time()
                tr_loss,tr_rec,_ = net(x)
                #print("Forward took: ",str(time.time() - t0))
                tr_loss.backward()
                torch.nn.utils.clip_grad_norm_(net.parameters(),lr)
                opt.step()

        #Validate
        net.eval()
        loop.set_postfix(Trying=config, Status='Validating',Epoch=epoch)
        for i,data in enumerate(valset):
            with torch.no_grad():
                x, labels = data
                x = x.to('cuda:0')
                val_loss,val_rec,_ = net(x)

        tr_loss_list.append(tr_loss),val_loss_list.append(val_loss)
        tr_rec_list.append(tr_rec), val_rec_list.append(val_rec)
        torch.save({
            'epoch': epoch,
            'config': config,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': opt.state_dict(),
            'loss_list': tr_loss_list,
            'val_loss_list':val_loss_list,
            'tr_rec_list': tr_rec_list,
            'val_rec_list': val_rec_list
        },'drive/MyDrive/ColabVAE/checkpoints/after_val.pt')

    # Save model & performance
    torch.save(net,'drive/MyDrive/ColabVAE/models/model' + str(config) + '.pt')
    torch.save({
        'epoch': epoch,
        'config': config,
        'model_state_dict': net.state_dict(),
        'optimizer_state_dict': opt.state_dict(),
        'loss_list': tr_loss_list,
        'val_loss_list': val_loss_list,
        'tr_rec_list': tr_rec_list,
        'val_rec_list': val_rec_list
    },'drive/MyDrive/ColabVAE/performance/model' + str(config) + 'perf.pt')



IndentationError: ignored

Model  0  Loss:  195899.125  Validation:  215140.984375  ReconLoss:  54371.19921875  ReconVal:  72968.796875
Model  1  Loss:  220538.359375  Validation:  215072.71875  ReconLoss:  78426.640625  ReconVal:  73285.09375
Model  2  Loss:  263671.875  Validation:  265030.53125  ReconLoss:  83039.203125  ReconVal:  85131.7890625
Model  3  Loss:  264015.8125  Validation:  262635.25  ReconLoss:  83062.796875  ReconVal:  82671.796875
Model  4  Loss:  190098.09375  Validation:  189923.53125  ReconLoss:  47923.69921875  ReconVal:  48096.3984375
Model  5  Loss:  190142.15625  Validation:  190219.8125  ReconLoss:  48408.3515625  ReconVal:  48268.77734375
Model  6  Loss:  330854.5  Validation:  332560.3125  ReconLoss:  47244.5625  ReconVal:  48438.671875
Model  7  Loss:  336466.875  Validation:  337272.59375  ReconLoss:  52819.56640625  ReconVal:  54076.36328125
Model  8  Loss:  445759.125  Validation:  444013.96875  ReconLoss:  85258.34375  ReconVal:  84896.9453125
Model  9  Loss:  445676.1875  Vali