## AAE Params Tuning

In [17]:
from networks.aae import *
import torch
import sys
import torchvision
import numpy as np
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
import os

DEFAULT_ROOT = '/scratch/sagar/slf/train_set/set_harsh_torch_raw_unnormalized/slf_mat'

class SLF(Dataset):
    def __init__(self, root=DEFAULT_ROOT, train=True, download=True, transform=None, total_data=None):
        self.root_dir = root
        if not total_data is None:
            self.num_examples = total_data
        else:
            if train == True:
                self.num_examples = 500000
            else:
                self.num_examples = 2000
    
    def __len__(self):
        return self.num_examples

    def __getitem__(self, idx):
        filename = os.path.join(self.root_dir,
                                str(idx)+'.pt')
        sample = torch.load(filename)
        return sample

import os
# Insert your own path here

variation = "B-VAE"

# ds_path = os.path.join(".", "drive", "My Drive", "Machine Learning", "Datasets")

# configuration = {
#     "dataset": "MNIST",
#     "path": ds_path
# }

ds_path = DEFAULT_ROOT
configuration = {
    "dataset": "SLF",
    "path": ds_path
}


architecture = {
    "conv_layers": 5,
    "conv_channels": [32, 64, 128, 256, 512],
    "conv_kernel_sizes": [(4, 4), (4, 4), (4, 4), (4, 4), (3,3)],
    "conv_strides": [(2, 2), (2, 2), (2, 2), (2, 2), (1,1)],
    "conv_paddings": [(1, 1), (1, 1), (1, 1), (1, 1), (1,1)],
    "z_dimension": 256
}

hyperparameters = {
    "epochs": 10,
    "batch_size": 16,
    "learning_rate": 3e-6,
}

def prepare_dataset(configuration):
    """
    :param dict configuration: The configuration dictionary returned by parse_config_file

    :return:        A dictionary containing information about the dataset used

    Function used to set some values used by the model based on the dataset selected
    """
    dataset_info = {}
    if (configuration["dataset"] == "MNIST"):
        dataset_info["ds_method"] = torchvision.datasets.MNIST
        dataset_info["ds_shape"] = (1, 28, 28)
        dataset_info["ds_path"] = configuration["path"]
    elif (configuration["dataset"] == "CIFAR10"):
        dataset_info["ds_method"] = torchvision.datasets.CIFAR10
        dataset_info["ds_shape"] = (3, 32, 32)
        dataset_info["ds_path"] = configuration["path"]
    elif (configuration["dataset"] == "FashionMNIST"):
        dataset_info["ds_method"] = torchvision.datasets.FashionMNIST
        dataset_info["ds_shape"] = (1, 28, 28)
        dataset_info["ds_path"] = configuration["path"]
    elif configuration['dataset'] == 'SLF':
        dataset_info['ds_method'] = SLF
        dataset_info['ds_shape'] = (1, 51, 51)
        dataset_info['ds_path'] = configuration['path']
    else:
        print("Currently only MNIST & CIFAR10 datasets are supported")
        return None

    return dataset_info

dataset_info = prepare_dataset(configuration)

## Tune AAE parameters

In [19]:
from collections import OrderedDict, namedtuple
import os
from run_manager_adv import RunBuilder, RunManager
from tqdm import tqdm, trange
from IPython.display import clear_output
import time


if torch.cuda.is_available():
    devices = ['cuda']
else:
    devices = ['cpu']
print('starting')

params = OrderedDict(
    lr = [0.0001],
    batch_size = [20],
    device = devices,
    shuffle = [True],
    num_workers = [5]
)

train_set = SLF(root=os.path.join(ROOT, 'slf_mat'), total_data=20)

m = RunManager(epoch_count_print=30)

real_label = 0.9
fake_label = 0

criterion = nn.BCELoss()
l2_loss = nn.MSELoss()
alpha = 0.0001
Tc = 0
Td = 0
T_train = 100

count=0

for run in RunBuilder.get_runs(params):
    device = torch.device(run.device)
    encoder = Encoder(architecture, dataset_info).to(run.device)
    decoder = Decoder(architecture, dataset_info).to(run.device)
    discriminator = Discriminator(architecture['z_dimension']).to(run.device)

    loader = torch.utils.data.DataLoader(train_set, batch_size=run.batch_size, shuffle=run.shuffle, num_workers=run.num_workers)
    optimizerEncoder = torch.optim.Adam(encoder.parameters(), lr=run.lr)
    optimizerDecoder = torch.optim.Adam(decoder.parameters(), lr=run.lr)
    optimizerDiscriminator = torch.optim.Adam(discriminator.parameters(), lr=run.lr)
    
    m.begin_run(run, encoder, loader)
    
    for epoch in trange(1500):
        m.begin_epoch()    

        total_adv_loss = 0
        total_mse_loss = 0
 
        total_D_real = 0
        total_D_fake = 0
        
        t = 0
        
        for batch in loader:
            t +=1
            # Get data
            in_features = batch
            in_features = in_features.to(run.device)
            b_size = in_features.size(0)
            labels_real = torch.full((b_size,), real_label, device=run.device)
            labels_fake = torch.full((b_size,), fake_label, device=run.device)
            
            # Update Autoencoder using MSE loss
            encoder.zero_grad()
            decoder.zero_grad()
            fake_latent = encoder(in_features)
            x_hat = decoder(fake_latent)
            
            ae_loss = l2_loss(in_features, x_hat)
            ae_loss.backward()
            optimizerEncoder.step()
            optimizerDecoder.step()
 
            # Update GAN using adversarial loss
            encoder.zero_grad()
            discriminator.zero_grad()
            
            fake_pred = discriminator(fake_latent)
            fake_loss = criterion(fake_pred, labels_fake)
            
            real_latent = torch.randn_like(fake_latent)
            real_pred = discriminator(real_latent)
            real_loss = criterion(real_pred, labels_real)
            
            count_real = real_pred.mean().item()
            count_fake = fake_pred.mean().item()
            
            loss_discriminator = 0.5*(real_loss + fake_loss)
            loss_discriminator.backward()
            optimizerEncoder.step()
            optimizerDiscriminator.step()
            
            total_D_real += real_loss.item()
            total_D_fake += fake_loss.item()
        
            total_mse_loss += ae_loss.item()
            total_adv_loss += loss_discriminator.item()
            
            count += 1
            
            m.track_loss(G_adv_loss=total_adv_loss, G_mse_loss=total_mse_loss, D_real_loss=total_D_real, D_fake_loss=total_D_fake, D_real_count=real_count, D_fake_count=fake_count)
        
#         print(epoch, "total_Gloss:",total_Gloss, "total_Dloss:",total_Dloss, "mse:",total_mse_loss, "adv: ", total_adv_loss)           
        m.end_epoch()
    m.end_run()

  0%|          | 0/1500 [00:00<?, ?it/s]

starting


  0%|          | 0/1500 [00:00<?, ?it/s]


FileNotFoundError: Caught FileNotFoundError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/sagar/matlab/venv/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/sagar/matlab/venv/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/sagar/matlab/venv/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "<ipython-input-17-37d2dad62196>", line 31, in __getitem__
    sample = torch.load(filename)
  File "/home/sagar/matlab/venv/lib/python3.8/site-packages/torch/serialization.py", line 581, in load
    with _open_file_like(f, 'rb') as opened_file:
  File "/home/sagar/matlab/venv/lib/python3.8/site-packages/torch/serialization.py", line 230, in _open_file_like
    return _open_file(name_or_buffer, mode)
  File "/home/sagar/matlab/venv/lib/python3.8/site-packages/torch/serialization.py", line 211, in __init__
    super(_open_file, self).__init__(open(name, mode))
FileNotFoundError: [Errno 2] No such file or directory: '/scratch/sagar/slf/train_set/set_harsh_torch_raw_unnormalized/slf_mat/1.pt'
