In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False
print(f"Are we using Colab now? {IN_COLAB}")

Are we using Colab now? False


In [4]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.utils.data
import torchvision
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
from icecream import ic
from collections import Counter

In [14]:
import sys
import time, os
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    GOOGLE_DRIVE_PATH_AFTER_MYDRIVE = "Out-of-Distribution-GANs"
    GOOGLE_DRIVE_PATH = os.path.join('drive', 'My Drive', GOOGLE_DRIVE_PATH_AFTER_MYDRIVE)
    print(os.listdir(GOOGLE_DRIVE_PATH))
    sys.path.append(GOOGLE_DRIVE_PATH)
    %pip install icecream tensorboard
    %pip install umap-learn pandas matplotlib datashader bokeh holoviews colorcet scikit-image
else:
#     ic(sys.prefix)
#     ic(sys.path)
#     pass
    # ic(sys.path)
    sys.path.insert(0, '../../')
os.environ["TZ"] = "US/Eastern"
time.tzset()

In [21]:
from config import *
from dataset import MNIST,CIFAR10, MNIST_SUB, SVHN
# from models.mnist_cnn import MNISTCNN
from models.hparam import HParam
from models.gans import *
from models.dc_gan_model import *
from utils import *
from models.ood_gan_backbone import *
from ood_gan import *

### Load datasets

In [13]:
bsz_tri, bsz_val = 256, 128
cifartri_set, cifarval_set, cifar_tri_loader, cifar_val_loader = CIFAR10(bsz_tri, bsz_val)

ood_bsz_tri = 64
ood_bsz_val = 128
svhn_tri_set, svhn_val_set, svhn_triloader, svhn_valloader = SVHN(ood_bsz_tri, ood_bsz_val)
ood_img_batch, ood_img_label = next(iter(svhn_triloader))
ic(ood_img_batch.shape)


Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./Datasets/SVHN/train_32x32.mat
Using downloaded and verified file: ./Datasets/SVHN/test_32x32.mat


ic| ood_img_batch.shape: torch.Size([64, 3, 32, 32])


torch.Size([64, 3, 32, 32])

### Pretraining

In [28]:
D = Discriminator()
pretrain_writer = SummaryWriter("PreTrain")
model = D
if IN_COLAB:
    pretrain_addr = GOOGLE_DRIVE_PATH + '/checkpoint/CIFAR-SVHN/pretrainedD.pt'
else:
    pretrain_addr = 'checkpoint/CIFAR-SVHN/pretrainedD.pt'
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.5, 0.999))
criterion = nn.CrossEntropyLoss()
num_epoch=8
# Simple training loop
iter_count_train = 0
iter_count_val = 0
for epoch in tqdm(range(num_epoch)):
    # Training
    model.train()
    train_loss, train_acc = [], []
    for idx, (img, label) in enumerate(cifar_tri_loader):
        img = img.to(DEVICE)
        label = label.to(DEVICE)
        optimizer.zero_grad()
        logits = model(img)
        loss = criterion(logits, label)
        loss.backward()
        optimizer.step()
        # Append training statistics
        acc = (torch.argmax(logits, dim=1) == label).sum().item() / label.shape[0]
        train_acc.append(acc)
        train_loss.append(loss.detach().item())
        pretrain_writer.add_scalar("Training/Accuracy", acc, iter_count_train)
        pretrain_writer.add_scalar("Training/Loss", loss.detach().item(), iter_count_train)
        iter_count_train += 1

    pretrain_writer.add_scalar("Training/Accuracy (Epoch)", np.mean(train_acc), epoch)
    pretrain_writer.add_scalar("Training/Loss (Epoch)", np.mean(train_loss), epoch)
    print(f"Epoch  # {epoch + 1} | training loss: {np.mean(train_loss)} \
            | training acc: {np.mean(train_acc)}")
    # Evaluation
    model.eval()
    with torch.no_grad():
        val_loss, val_acc = [], []
        for idx, (img, label) in enumerate(cifar_val_loader):
            img, label = img.to(DEVICE), label.to(DEVICE)
            logits = model(img)
            loss = criterion(logits, label)
            acc = (torch.argmax(logits, dim=1) == label).sum().item() / label.shape[0]
            val_acc.append(acc)
            val_loss.append(loss.detach().item())
            pretrain_writer.add_scalar("Training/Accuracy", acc, iter_count_val)
            pretrain_writer.add_scalar("Training/Loss", loss.detach().item(), iter_count_val)
            iter_count_val += 1

        pretrain_writer.add_scalar("Training/Accuracy (Epoch)", np.mean(val_acc), epoch)
        pretrain_writer.add_scalar("Training/Loss (Epoch)", np.mean(val_loss), epoch)
        print(f"Epoch  # {epoch + 1} | validation loss: {np.mean(val_loss)} \
            | validation acc: {np.mean(val_acc)}")

  0%|          | 0/8 [00:00<?, ?it/s]

KeyboardInterrupt: 

### Initialize trainers

In [23]:
if IN_COLAB:
    ckpt_dir = GOOGLE_DRIVE_PATH + '/checkpoint/CIFAR-SVHN/'
else:
    ckpt_dir = 'checkpoint/CIFAR-SVHN/'

ckpt_name = f'CIFAR-SVHN[{ood_bsz_tri}]'
hp = HParam(ce=1,wass=0.1, dist=0.8)
max_epoch = 1
writer_name = ckpt_name
n_steps_log = 1
noise_dim=96
# Model setup
D = Discriminator()
G = Generator(noise_dim)
D_solver = torch.optim.Adam(D.parameters(), lr=1e-3, betas=(0.5, 0.999))
G_solver = torch.optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.999))
# Training dataset
ind_loader = cifar_tri_loader


In [22]:
trainer = OOD_GAN_TRAINER(D=D, G=G, 
                        noise_dim=noise_dim, 
                        bsz_tri=bsz_tri, 
                        gd_steps_ratio=1, 
                        hp=hp, 
                        max_epochs=max_epoch, 
                        writer_name=writer_name, 
                        ckpt_name=ckpt_name,
                        ckpt_dir=ckpt_dir, 
                        n_steps_log=1)

In [None]:
trainer.train(ind_loader, ood_img_batch, D_solver, G_solver, pretrainedD=None, checkpoint=None)