In [None]:
from torch.optim import Adam, SGD
import numpy as np, torch.nn as nn, pandas as pd,\
torch.nn.functional as F, matplotlib.pyplot as plt,\
seaborn as sn
import torch, logging, os, re, random, pickle, logging
from sklearn.metrics import confusion_matrix
from sklearn.metrics import (
    roc_curve, 
    roc_auc_score, 
    precision_recall_fscore_support
)
import argparse
import itertools
from copy import deepcopy
from torch.optim.lr_scheduler import StepLR
from src.pipeline import pipeline
from src.training_utils import training_utils
from torchvision.utils import save_image
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
from torch.utils import data as tdataset
import torch.nn as nn
import torch.nn.functional as F
import torch
import argparse
import os

EXP_HPARAMS = {
    "params": (
        {},
    ),
    "seeds": (3957,),
}

# Read the comment right before cal_linearAcc

In [None]:
class TrainTest():
    def __init__(self, model, optimizer, scheduler, criterion):
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.criterion = criterion
        
    def train(
        self, train_loader,
        num_epochs, device, eval_interval,
        clip=None, model_path=None, save_per_epoch=None,
        results_path=None, defaults=None, **kwargs
    ):
        total_itrs = num_epochs*len(train_loader)
        itr = 0
        self.model.train()
        for epoch in range(num_epochs):
            for i, (real_imgs, labels) in enumerate(train_loader):
                real_imgs, labels = real_imgs.to(device), labels.to(device)
                self.optimizer.zero_grad()
                output = self.model(real_imgs)
                tr_loss = self.criterion(output, labels.view(-1))
                # nn.utils.clip_grad_norm_(model.parameters(), clip)
                print(f'Training: {itr}/{total_itrs} -- loss: {tr_loss.item()}')
                tr_loss.backward()
                self.optimizer.step()
                itr += 1
            self.scheduler.step()

    def save_results(self, results_path, name, results):
        results_dir = '/'.join(results_path.split('/')[:-1])
        if not os.path.exists(results_dir):
            os.makedirs(results_dir)
        with open(os.path.join(results_dir, f'results_{name}.pkl'), 'wb') as save_file:
            pickle.dump(results, save_file)

    def test(self, test_loader, device, all_labels, results_path=None, defaults=None):
        true_labels, pred_labels = [], []
        self.model.eval()
        with torch.no_grad():
            for i, (real_imgs, labels) in enumerate(test_loader):
                real_imgs, labels = real_imgs.to(device), labels.to(device)
                test_output = self.model(real_imgs)
                _, temp_pred = test_output.max(dim=1)
                true_labels.append(labels.view(-1))
                pred_labels.append(temp_pred)
        pred_labels = torch.cat(pred_labels).cpu()
        true_labels = torch.cat(true_labels).cpu()
        test_accuracy = torch.sum(pred_labels == true_labels).item() / true_labels.size()[0]
        prf = precision_recall_fscore_support(
            true_labels,
            pred_labels,
            labels=all_labels,
            average='weighted'
        )
        confm = confusion_matrix(true_labels, pred_labels, labels=all_labels)
        self.ts_metrics = {
            'accuracy':test_accuracy,
            'precision':prf[0],
            'recall':prf[1],
            'f1_score':prf[2],
            'confusion_matrix':confm
        }
        if results_path:
            self.save_results(results_path, f'test', self.ts_metrics)

class LinearAccuracy(nn.Module):
    def __init__(self, encoder, encoder_dim, output_dim):
        super(LinearAccuracy, self).__init__()
        self.encoder = encoder
        self.linear = nn.Linear(encoder_dim, output_dim)

    def forward(self, imgs):
        gen_latents = self.encoder(imgs)
        gen_latents = self.linear(gen_latents)
        latent_softmax = F.log_softmax(gen_latents)
        return latent_softmax

def cal_accuracy(pred_labels, true_labels):
    _, pred_labels = pred_labels.max(dim=1)
    true_labels = true_labels.view(-1)
    return torch.sum(pred_labels == true_labels).item() / true_labels.size()[0]

def img_generator(model, dataloader, real_dir, gen_dir):
    os.makedirs(real_dir, exist_ok=True)
    os.makedirs(gen_dir, exist_ok=True)
    model.eval()
    counter = 0
    device = next(model.parameters()).device
    with torch.no_grad():
        for i, (real_imgs, labels) in enumerate(dataloader):
            labels = labels.to(device)
            gen_imgs, noise = model.generate_imgs(cls=labels)
            for j in range(gen_imgs.size(0)):
                real_path = os.path.join(real_dir, f'real_{counter}.png')
                gen_path = os.path.join(gen_dir, f'gen_{counter}.png')
                save_image(real_imgs[j], real_path)
                save_image(gen_imgs[j], gen_path)
                counter += 1

def get_loader(data_path, image_size, batch_size, train=False):
    dataset = MNIST(
        data_path,
        download=True,
        train=train,
        transform=transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
        ])
    )
    loader = tdataset.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=True,
    )
    return loader

def run_experiments(data_path, dataset='MNIST', model_architecture='bigbigan',include_sx=True, include_sz=True, include_sxz=True):
    for hparams_overwrite_list, seed in itertools.product(EXP_HPARAMS["params"], EXP_HPARAMS["seeds"]):
        config = training_utils.get_config(dataset)
        hparams_str = ""
        for k, v in hparams_overwrite_list.items():
            config[k] = v
            hparams_str += str(k) + "-" + str(v) + "_"
        config["model_architecture"] = model_architecture
        config["hparams_str"] = hparams_str.strip("_")
        config["seed"] = seed
        config['has_sx'] = include_sx
        config['has_sz'] = include_sz
        config['has_sxz'] = include_sxz
        training_utils.set_random_seed(seed=config['seed'], device=config['device'])
        training_pipeline = pipeline.BigBiGANPipeline.from_config(data_path=data_path, config=config)
        training_pipeline.train_model()
        return training_pipeline

# This funciton will train a classifier based on BigGiGAN encoder and a linear layer
# Set path to train and test files
def cal_linearAcc(model):
    cuda_flag = True if torch.cuda.is_available() else False
    lr = .1
    num_epochs = 15
    eval_interval = 40
    save_model = True
    device = torch.device('cuda' if cuda_flag else 'cpu')
    encoder = model.encoder
    linear_model = LinearAccuracy(encoder, 100, 10).to(device)
    for name, p in linear_model.named_parameters():
        if "encoder" in name:
            p.requires_grad = False
    optimizer = SGD(linear_model.parameters(), lr=lr, momentum=0.9)
    scheduler = StepLR(optimizer, step_size=3, gamma=0.1)
    # Set path (directory) to train and test files here (Note: files retrieved from torchvision)
    train_loader = get_loader('./mnist/', 32, 256, train=True)
    traintest = TrainTest(linear_model, optimizer, scheduler, F.nll_loss)
    traintest.train(
        train_loader,
        num_epochs,
        device,
        eval_interval,
        model_path='./models/',
        save_per_epoch=4,
        results_path='./results/',
        clip=5
    )
    traintest.test(
        test_loader, 
        device, 
        list(range(10)),
        results_path='./results/'
    )
    return traintest

# Choose which loss do you want to ignore

In [None]:
ignore_loss = 1 # include_all = 0, ignore_sx = 1, ignore_sz = 2, ignore_sx_and_sz = 3
if ignore_loss == 0:
    print('all losses included')
    pipline = run_experiments('./mnist/')
elif ignore_loss == 1:
    print('sx loss ignored')
    pipline = run_experiments('./mnist/', include_sx=False)
elif ignore_loss == 2:
    print('sz loss ignored')
    pipline = run_experiments('./mnist/', include_sz=False)
elif ignore_loss == 3:
    print('sx and sz losses ignored')
    pipline = run_experiments('./mnist/', include_sx=False, include_sz=False)
else:
    print('all losses included')
    pipline = run_experiments('./mnist/')

In [None]:
model = pipline.model
model.eval()
# base_model = deepcopy(model)

# Set path (directory) to train and test files here (Note: files retrieved from torchvision)

In [None]:
test_loader = get_loader('./mnist/', 32, 256, train=False)
img_generator(model, test_loader, './real_images/', './gen_images/')

In [None]:
!zip -r /content/sx_gen.zip /content/gen_images
!zip -r /content/sx_real.zip /content/real_images

In [None]:
train_tester = cal_linearAcc(model)
print(train_tester.ts_metrics)

In [None]:
!pip install pytorch-fid

In [None]:
!python -m pytorch_fid -h

## Set path to images zip files if they are not in a seperate directory. Otherwise, set the direstory path in the next cell

In [None]:
!unzip sx_gen.zip -d ./sx_gen
!unzip sz_gen.zip -d ./sz_gen
!unzip sxsz_gen.zip -d ./sxsz_gen
!unzip all_gen.zip -d ./all_gen
!unzip all_real.zip -d ./all_real

In [None]:
!python -m pytorch_fid sxsz_gen/content/gen_images/ all_real/content/real_images/ --device cuda:0

# Set path to saved pickle results file

In [None]:
with open('./results/sxsz_results_test.pkl', 'rb') as results_file:
    test_results = pickle.load(results_file)
classes = [f'digit: {x}' for x in list(range(10))]
plt.figure()
test_confm = pd.DataFrame(test_results['confusion_matrix'], classes, classes)
sn.set(font_scale=1)
sn.heatmap(
    test_confm, 
    annot= False, 
    annot_kws = {"size": 10}
)
plt.autoscale(True)
plt.savefig(os.path.join('./results/', 'sxsz-test-confusion-matrix.png'), dpi=300, bbox_inches="tight")
print(f'{"*"*20} Test Metrics: {"*"*20}\n'
      f'Accuracy: {test_results["accuracy"]:.3f}\n'
      f'Weighted Precision: {test_results["precision"]:.3f}\n'
      f'Weighted Recall: {test_results["recall"]:.3f}\n'
      f'Weighted F1-score: {test_results["f1_score"]:.3f}\n'
      f'{"*"*55}')