# 1) Env setup

## 1.0) Create a new Kaggle Dataset with name `ckpts-<EXPERIMENT_DATASET>-<EXPERIMENT_MODEL>-l<EXPERIMENT_L>`
e.g.: `chkpts-fashionmnist-iwae-l2` (no `_` allowed)

Upload any existing checkpoints there. Then insert this dataset in the current notebook by clicking `+Add data` in the upper right corner. After more checkpoints have been generated, add them to the dataset by navigating to it, then clicking `+ New Version` (in the Data Explorer section of the Data tab) and then adding the new checkpoints.

## 1.1) Clone GitHub repo
Clone achariso/gans-thesis repo into /kaggle/working/code using git clone. For a similar procedure in Colab,
see: https://medium.com/@purba0101/how-to-clone-private-github-repo-in-google-colab-using-ssh-77384cfef18f

In [None]:
import os

# Clean failed attempts
!rm -rf / root /.ssh
!rm -rf / kaggle / working / code
!mkdir -p / kaggle / working / code

git_keys_root = '/kaggle/input/git-keys2'
repo_root = '/kaggle/working/code/iwae-pytorch'
if not os.path.exists(repo_root):
    # Check that ssh keys exist
    id_rsa_abs_drive = f'{git_keys_root}/id_rsa'
    id_rsa_pub_abs_drive = f'{id_rsa_abs_drive}.pub'
    assert os.path.exists(id_rsa_abs_drive)
    assert os.path.exists(id_rsa_pub_abs_drive)
    # On first run: Add ssh key in repo
    if not os.path.exists('/root/.ssh'):
        # Transfer config file
        ssh_config_abs_drive = f'{git_keys_root}/config'
        assert os.path.exists(ssh_config_abs_drive)
        !mkdir -p ~ /.ssh
        !cp -f "$ssh_config_abs_drive" ~ /.ssh /
        # Add github.com to known hosts
        !ssh-keyscan -t rsa github.com >> ~ /.ssh / known_hosts
        # Test ssh connection
        # !ssh -T git@github.com

    # Remove any previous attempts
    !rm -rf "$repo_root"
    !mkdir -p "$repo_root"
    # Clone repo
    !git clone git @ github.com:kth-ml-course-projects / iwae-pytorch.git "$repo_root"

    # Fix issue with duplicated files
    !rm -rf $repo_root / src-clone / dataloaders
    !rm -rf $repo_root / src-clone / utils

## 1.2) Install pip packages
All required files are stored in a requirements.txt files at the repository's root.
Use `pip install -r requirements.txt` from inside the dir to install required packages.

In [None]:
% cd "$repo_root"
!pip install -r requirements.txt

In [None]:
import torch

assert torch.cuda.is_available()
print(torch.__version__)

## 1.3) Add code/, */src/ to path
This is necessary in order to be able to run the modules.

In [None]:
content_root_abs = f'{repo_root}'
src_root_abs = f'{repo_root}/src'
# %env PYTHONPATH="/env/python:$content_root_abs:$src_root_abs"
% set_env PYTHONPATH= / env / python:$content_root_abs:$src_root_abs:$src_clone_root_abs

# 2) Train IWAE model on FashionMNIST Dataset
In this section we run the actual training loop for IWAE network. IWAE consists of a 1 or 2 stochastic layer encoder, and a mirrored decoder, where each stochastic layer consists of FC layers with `Tanh()` activations to produce the distribution parameters.

In [None]:
% cd "$repo_root/src"

import sys
from train import train_and_save_checkpoints
from ifaces import DownloadableDataset

data_path = '/kaggle/working/data'
!mkdir -p $data_path

chkpts_dir_path = '/kaggle/working/checkpoints'
!mkdir -p $chkpts_dir_path
!cp / kaggle / input / chkpts-fashionmnist-iwae-l2 / *.pkl $chkpts_dir_path
!ls -l $chkpts_dir_path

DownloadableDataset.set_data_directory(data_path)
try:
    train_and_save_checkpoints(seed=42,
                               cuda=True,
                               k=5,
                               num_layers=2,
                               dataset='fashion_mnist',
                               model_type='iwae',
                               use_clone=True,
                               batch_size=400,
                               debug=False,
                               dtype=torch.float32,
                               chkpts_dir_path=chkpts_dir_path,
                               use_grad_clip=False)
except RuntimeError as e:
    print('[EXCEPTION] k=5 FAILed: ' + str(e), file=sys.stderr)

torch.cuda.empty_cache()
try:
    train_and_save_checkpoints(seed=42,
                               cuda=True,
                               k=50,
                               num_layers=2,
                               dataset='fashion_mnist',
                               model_type='iwae',
                               use_clone=True,
                               batch_size=400,
                               debug=False,
                               dtype=torch.float32,
                               chkpts_dir_path=chkpts_dir_path,
                               use_grad_clip=False)
except RuntimeError as e:
    print('[EXCEPTION] k=50 FAILed: ' + str(e), file=sys.stderr)

## 2.2) Download checkpoints

In [None]:
!zip -j chkpts.zip $chkpts_dir_path / *.pkl
from IPython.display import FileLink

FileLink('chkpts.zip')