#### Optimizing lambda parameter

In [None]:
def find_best_lambda(lambdas):
    data_root_dir = 'Dataset'
    train_dataset, test_dataset = load_datasets(data_root_dir)
    
    train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=512, shuffle=True)
    
    
    noise_try = [0, 0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.3, 0.4,0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.5, 2]
    scale_try = [0, 0.01, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
    
    num_epochs = 25
    load_weights = False
    load_best = True
    
    log_loss = []
    log_noise = []
    log_crop = []
    for lambda_ in lambdas:
        encoded_space_dim = 5
        net = VAE(encoded_space_dim=encoded_space_dim, lambda_=lambda_)
        
        loss_fn = torch.nn.MSELoss()
        optim = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=1e-5)
        
        device = torch.device("cuda")
        net.to(device)
        
        val_loss_log = []
        for epoch in range(num_epochs):
            if load_weights or (load_best and epoch>0.5):
                print('Loaded!')
                net.load_state_dict(torch.load('ckpt/net_params_%s.pth' %(str(lambda_*100))))
            print('EPOCH %d/%d' % (epoch + 1, num_epochs))
            train_epoch(net, dataloader=train_dataloader, loss_fn=loss_fn, optimizer=optim, show_steps=20, 
                        use_noise=True)
            val_loss = test_epoch(net, dataloader=test_dataloader, loss_fn=loss_fn)
            val_loss_log.append(val_loss.item())
            print('\n\n\t VALIDATION - EPOCH %d/%d - loss: %f\n\n' % (epoch + 1, num_epochs, val_loss))
            
            #clear_output(wait = True)
            #show_progress(0, test_dataset, net, epoch)
            #plt.plot(range(1,epoch+2), val_loss_log)
            
            if (epoch<0.5 or val_loss.item()<min(val_loss_log[:-1])):
                print('Saved!')
                torch.save(net.state_dict(), 'ckpt/net_params_%s.pth' %(str(lambda_*100)))
        log_loss.append(val_loss_log)
        
        net.load_state_dict(torch.load('ckpt/net_params_%s.pth' %(str(lambda_*100)), map_location='cpu'))
        net.to('cpu')
        
        #show_result(net, test_dataset)
        
        loss_noise, loss_crop = test_noise_crop_opt(net, test_dataloader, loss_fn, noise_try, scale_try)
        log_noise.append(loss_noise)
        log_crop.append(loss_crop)
        
    #for i, log in enumerate(log_loss):
    #    plt.plot(range(1, num_epochs+1), log, label=str(lambdas[i]))
    #plt.legend()
    #plt.show()
    #
    #for i, log in enumerate(log_noise):
    #    plt.plot(noise_try, log, label=str(lambdas[i]))
    #plt.legend()
    #plt.show()
    #
    #for i, log in enumerate(log_crop):
    #    plt.plot(scale_try, log, label=str(lambdas[i]))
    #plt.legend()
    #plt.show()
    
    return log_loss, loss_noise, loss_crop

#### Find the best dimension of the latent space

In [None]:
def find_best_hidden(hidden):
    data_root_dir = 'Dataset'
    train_dataset, test_dataset = load_datasets(data_root_dir)
    
    train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=512, shuffle=True)
    
    num_epochs = 25
    load_weights = False
    load_best = True
    
    log_loss = []
    log_noise = []
    log_crop = []
    for hid in hidden:
        encoded_space_dim = hid
        lambda_ = 0.75
        net = VAE(encoded_space_dim=encoded_space_dim, lambda_=lambda_)
        
        loss_fn = torch.nn.MSELoss()
        optim = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=1e-5)
        
        device = torch.device("cuda")
        net.to(device)
        
        val_loss_log = []
        for epoch in range(num_epochs):
            if load_weights or (load_best and epoch>0.5):
                print('Loaded!')
                net.load_state_dict(torch.load('ckpt/net_params_%s.pth' %(str(hid))))
            print('EPOCH %d/%d' % (epoch + 1, num_epochs))
            train_epoch(net, dataloader=train_dataloader, loss_fn=loss_fn, optimizer=optim, show_steps=20,
                        use_noise=True)
            val_loss = test_epoch(net, dataloader=test_dataloader, loss_fn=loss_fn)
            val_loss_log.append(val_loss.item())
            print('\n\n\t VALIDATION - EPOCH %d/%d - loss: %f\n\n' % (epoch + 1, num_epochs, val_loss))
            
            #clear_output(wait = True)
            #show_progress(0, test_dataset, net, epoch)
            #plt.plot(range(1,epoch+2), val_loss_log)
            
            if (epoch<0.5 or val_loss.item()<min(val_loss_log[:-1])):
                print('Saved!')
                torch.save(net.state_dict(), 'ckpt/net_params_%s.pth' %(str(hid)))
        log_loss.append(val_loss_log)
        
        net.load_state_dict(torch.load('ckpt/net_params_%s.pth' %(str(hid)), map_location='cpu'))
        net.to('cpu')
        
        #show_result(net, test_dataset)
        
        noise_try = [0, 0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.3, 0.4,0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.5, 2]
        scale_try = [0, 0.01, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
        loss_noise, loss_crop = test_noise_crop_opt(net, test_dataloader, loss_fn, noise_try, scale_try)
        log_noise.append(loss_noise)
        log_crop.append(loss_crop)
        
    return log_loss, log_noise, log_crop
    #fig = plt.figure(figsize=(14,10))
    #for i, log in enumerate(log_loss):
    #    plt.plot(range(1, num_epochs+1), log, label=str(hidden[i]))
    #plt.legend()
    #plt.show()
    #
    #return    

#### Testing the performance for various levels of noise and crops

In [None]:
def test_noise_crop_opt(net, dataloader, loss_fn, noise_try, scale_try):
    net.eval()
    device = torch.device('cpu')
    net.to(device)
    
    #Noise
    #limits_loss_noise = compute_limits_loss_noise(net, dataloader, loss_fn)
    print('\n@@@@@@@@@@@@@@@@@@@@@@@   Noise testing:\n')
    loss_log_noise = []
    for noise_level in noise_try:
        transform_noise = transforms.Lambda(lambda x: x + noise_level*torch.randn(x.shape))
        with torch.no_grad():
            conc_out = torch.Tensor().float()
            conc_label = torch.Tensor().float()
            for sample_batch in dataloader:
                image_batch_or = sample_batch[0].clone()
                image_batch = transform_noise(sample_batch[0])
                out = net(image_batch)
                conc_out = torch.cat([conc_out, out.cpu()])
                conc_label = torch.cat([conc_label, image_batch_or.cpu()])
            val_loss = loss_fn(conc_out, conc_label)#-limits_loss_noise[0]
            #val_loss *= 100/(limits_loss_noise[1]-limits_loss_noise[0])
            loss_log_noise.append(val_loss.item())
            print('Test loss with a %.1f %% noise level: %.4f' %(noise_level*100, val_loss))
            #fig, axs = plt.subplots(5,3 , figsize=(3,5))
            #for index in range(5):
            #    img = image_batch[index].view(1,-1,28,28)
            #    rec_img  = net(img)
            #    axs[index, 0].imshow(image_batch_or[index].cpu().squeeze().numpy(), cmap='gist_gray')
            #    axs[index,0].axis('off')
            #    axs[index, 1].imshow(image_batch[index].cpu().squeeze().numpy(), cmap='gist_gray')
            #    axs[index,1].axis('off')
            #    axs[index, 2].imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')
            #    axs[index,2].axis('off')
            #plt.tight_layout()
            #plt.show()
            #plt.plot
    #plt.plot(noise_try, loss_log_noise)
    #plt.show()
        
    #Crop
    #limits_loss_crop = compute_limits_loss_crop(net, dataloader, loss_fn)
    print('\n@@@@@@@@@@@@@@@@@@@@@@@   Crop testing:\n')
    loss_log_crop = []
    for scale in scale_try:
        transform_crop = transforms.RandomErasing(p=1, scale=(scale,scale), ratio=(0.5,2))
        with torch.no_grad():
            conc_out = torch.Tensor().float()
            conc_label = torch.Tensor().float()
            for sample_batch in dataloader:
                image_batch_or = sample_batch[0].clone()
                image_batch = sample_batch[0]
                if scale!=0:
                    for i in range(len(image_batch)):
                        image_batch[i] = transform_crop(image_batch[i])
                out = net(image_batch)
                conc_out = torch.cat([conc_out, out.cpu()])
                conc_label = torch.cat([conc_label, image_batch_or.cpu()])
            val_loss = loss_fn(conc_out, conc_label)#-limits_loss_crop[0]
            #val_loss *= 100/(limits_loss_crop[1]-limits_loss_crop[0])
            loss_log_crop.append(val_loss.item())
            print('Test loss with a %.1f %% crop: %.4f' %(scale*100, val_loss))
            #fig, axs = plt.subplots(5, 3, figsize=(3,5))
            #for index in range(5):
            #    img = image_batch[index].view(1,-1,28,28)
            #    rec_img  = net(img)
            #    axs[index, 0].imshow(image_batch_or[index].cpu().squeeze().numpy(), cmap='gist_gray')
            #    axs[index,0].axis('off')
            #    axs[index, 1].imshow(image_batch[index].cpu().squeeze().numpy(), cmap='gist_gray')
            #    axs[index,1].axis('off')
            #    axs[index, 2].imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')
            #    axs[index,2].axis('off')
            #plt.tight_layout()
            #plt.show()
    #plt.plot(scale_try, loss_log_crop)
    #plt.show()
    return loss_log_noise, loss_log_crop