# 1) Prepare workspace
1.1) Clone GitHub repo
Clone achariso/gans-thesis repo into /kaggle/working/code using git clone. For a similar procedure in Colab, see: https://medium.com/@purba0101/how-to-clone-private-github-repo-in-google-colab-using-ssh-77384cfef18f

In [None]:
import os

# Clean failed attempts
!rm -rf /root/.ssh
!rm -rf /kaggle/working/code
!mkdir -p /kaggle/working/code

git_keys_root = '/kaggle/input/git-keys-kth'
repo_root = '/kaggle/working/code/cifar-vit'
if not os.path.exists(repo_root):
    # Check that ssh keys exist
    id_rsa_abs_drive = f'{git_keys_root}/id_rsa'
    id_rsa_pub_abs_drive = f'{id_rsa_abs_drive}.pub'
    assert os.path.exists(id_rsa_abs_drive)
    assert os.path.exists(id_rsa_pub_abs_drive)
    # On first run: Add ssh key in repo
    if not os.path.exists('/root/.ssh'):
        # Transfer config file
        ssh_config_abs_drive = f'{git_keys_root}/config'
        assert os.path.exists(ssh_config_abs_drive)
        !mkdir -p ~/.ssh
        !cp -f "$ssh_config_abs_drive" ~/.ssh/
        # Add github.com to known hosts
        !ssh-keyscan -t rsa gits-15.sys.kth.se >> ~/.ssh/known_hosts
        # Test ssh connection
        # !ssh -T git@github.com

    # Remove any previous attempts
    !rm -rf "$repo_root"
    !mkdir -p "$repo_root"
    # Clone repo
    !git clone git@gits-15.sys.kth.se:thacha/cifar-vit.git "$repo_root"

    # Fix issue with duplicated files
    !rm -rf $repo_root/src-clone/dataloaders
    !rm -rf $repo_root/src-clone/utils

## 1.2) Install pip packages
All required files are stored in a requirements.txt files at the repository's root.
Use `pip install -r requirements.txt` from inside the dir to install required packages.

In [None]:
%cd "$repo_root"
!pip install -r requirements.txt

In [None]:
import torch
assert torch.cuda.is_available()
print(torch.__version__)

## 1.3) Add code/, */src/ to path
This is necessary in order to be able to run the modules.

In [None]:
content_root_abs = f'{repo_root}'
src_root_abs = f'{repo_root}/src'
%env PYTHONPATH="/env/python:$content_root_abs:$src_root_abs:$src_clone_root_abs"

# 2) Train ViT/CVT/CCT models on CIFAR-10 Dataset
In this section we run the actual training loop for our networks.

In [None]:
%cd "$repo_root/src"

import random
import sys
import os
import pathlib
from typing import Tuple

import numpy as np
import torch
import yaml
from torch import nn

import dataset
import evaluate
import model
import train
from visualize import plot_learning_curve
from experiment import *


#----------------------------------#
dl_config_filename = 'dl/cifar10_C'
model_config = 'model/ViT_Lite_7-8'
num_workers = 2
device = 'cuda'
no_plots = False # we want plots
#----------------------------------#


# Load configuration
dl_config = load_config(dl_config_filename)
config = load_config(model_config)
config_filename = f"{model_config.replace('model/', '')}__{dl_config_filename.replace('dl/cifar10_', '')}"

# Choosing pytorch device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# Initialize random number seed
torch.manual_seed(config['seed'])
random.seed(config['seed'])
np.random.seed(config['seed'])

# Instantiate dataloaders
dl_train, dl_val, dl_test, classes = load_data(dl_config, device, num_workers)

# Instantiate model
model_name = config['model']['which']
model_params = config['model']['params']
net, chkpt_filepath = load_or_init_model(config_filename + '.pth', model_name, model_params, device=device)

# Train model
if not net.trained:
    oc = config['optim']
    num_epochs = config['num_epochs']
    train_losses, val_losses = train.train(net, dl_train, dl_val, epochs=num_epochs, device=device,
                                           optim_type=oc['which'], **oc['params'])

    # Save model
    torch.save(net.state_dict(), chkpt_filepath)

    # Save training stats
    stats_path = os.path.join(BASE_PATH, 'stats')
    os.makedirs(stats_path, exist_ok=True)
    stats_filename = f"{stats_path}/{config_filename}_{datetime.now().strftime('%H%M%S-%d%m')}.yaml"
    with open(stats_filename, 'w') as file:
        yaml.dump({'train_losses': train_losses, 'val_losses': val_losses}, file)

    # Plot training stats and save
    if not no_plots:
        fig = plot_learning_curve(train_losses, val_losses)
        graphics_path = os.path.join(BASE_PATH, 'graphics')
        os.makedirs(graphics_path, exist_ok=True)
        fig.savefig(f"{graphics_path}/{config_filename}_{datetime.now().strftime('%H%M%S-%d%m')}")

# Evaluate model
if not no_plots:
    evaluate.predict_and_display_sample(net, classes, dl_test, device=device)
evaluate.calculate_and_display_test_accuracy(net, dl_test, device=device)