# 1) Clone GH repo and Mount Google Drive install packages

## 1.1) Mount Drive and define paths
Run provided colab code to mount Google Drive. Then define dataset paths relative to mount point.

In [None]:
from google.colab import drive
mount_root_abs = '/content/drive'
drive.mount(mount_root_abs)
drive_root = f'{mount_root_abs}/MyDrive/IWAE'

data_path = f'{drive_root}/data'
chkpts_dir_path = f'{drive_root}/checkpoints'

## 1.2) Mount Drive and define paths
Clone kth-ml-course-projects/iwae-pytorch repo into /content/code using git clone. For more info see: https://medium.com/@purba0101/how-to-clone-private-github-repo-in-google-colab-using-ssh-77384cfef18f

In [None]:
import os

repo_root = '/content/code/iwae-pytorch'
!rm -rf $repo_root
if not os.path.exists(repo_root):
    # Check that ssh keys exist
    assert os.path.exists(f'{drive_root}/ssh_keys')
    id_rsa_abs_drive = f'{drive_root}/ssh_keys/id_rsa'
    id_rsa_pub_abs_drive = f'{id_rsa_abs_drive}.pub'
    assert os.path.exists(id_rsa_abs_drive)
    assert os.path.exists(id_rsa_pub_abs_drive)
    # On first run: Add ssh key in repo
    if not os.path.exists('/root/.ssh'):
        # Transfer config file
        ssh_config_abs_drive = f'{drive_root}/ssh_keys/config'
        assert os.path.exists(ssh_config_abs_drive)
        !mkdir -p ~/.ssh
        !cp -f "$ssh_config_abs_drive" ~/.ssh/
        # # Add github.com to known hosts
        !ssh-keyscan -t rsa github.com >> ~/.ssh/known_hosts
        # Test: !ssh -T git@github.com

    # Remove any previous attempts
    !rm -rf "$repo_root"
    !mkdir -p "$repo_root"
    # Clone repo
    !git clone git@github.com:kth-ml-course-projects/iwae-pytorch.git "$repo_root"
    
    # Fix issue with duplicated files
    !rm -rf $repo_root/src-clone/dataloaders
    !rm -rf $repo_root/src-clone/utils

## 1.3) Install pip packages
All required files are stored in a requirements.txt files at the repository's root.
Use `pip install -r requirements.txt` from inside the dir to install required packages.

In [None]:
%cd "$repo_root"
!pip install -r requirements.txt

In [None]:
import torch
assert torch.cuda.is_available()
print(torch.__version__)

## 1.5) Add code/, */src/ to path
This is necessary in order to be able to run the modules.

In [None]:
content_root_abs = f'{repo_root}'
src_root_abs = f'{repo_root}/src'
# %env PYTHONPATH="/env/python:$content_root_abs:$src_root_abs"
%set_env PYTHONPATH=/env/python:$content_root_abs:$src_root_abs:$src_clone_root_abs

# 2) Train IWAE model on MNIST Dataset
In this section we run the actual training loop for IWAE network. IWAE consists of a 1 or 2 stochastic layer encoder, and a mirrored decoder, where each stochastic layer consists of FC layers with `Tanh()` activations to produce the distribution parameters.

In [None]:
%cd "$repo_root/src"

from train import train_and_save_checkpoints
from ifaces import DownloadableDataset

DownloadableDataset.set_data_directory(data_path)
train_and_save_checkpoints( seed=42,
                            cuda=True,
                            k=50,
                            num_layers=2,
                            dataset='mnist',
                            model_type='iwae',
                            use_clone=False,
                            batch_size=1000,
                            debug=False,
                            dtype=torch.float64,
                            chkpts_dir_path=chkpts_dir_path)