In [27]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torch
import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np

import importlib
import utils
import attention
import glob

import seaborn as sns
import pandas as pd

In [None]:
importlib.reload(utils)
importlib.reload(attention)

shuffled_transformer = []
for _ in range(20):
    
    net = attention.SimpleViT(image_size = 32,
        patch_size = 4,
        num_classes = 10,
        dim = 1024,
        depth = 1,
        heads = 16,
        mlp_dim = 2048
    ).cuda()
        
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=1e-4)
    trainer = utils.CIFAR_trainer(data_params = dict(
                                    pixel_shuffled = True
                                 ),
                                  train_params = dict(batch_size = 4,
                                      num_epochs = 20),
                model_optim_params = dict(model = net, criterion = criterion, optimizer = optimizer))
    trainer.train()

[1,  2000] loss: 2.066
[1,  4000] loss: 1.919
[1,  6000] loss: 1.880
[1,  8000] loss: 1.858
[1, 10000] loss: 1.827
[1, 12000] loss: 1.799
[2,  2000] loss: 1.746
[2,  4000] loss: 1.741
[2,  6000] loss: 1.723
[2,  8000] loss: 1.711
[2, 10000] loss: 1.700
[2, 12000] loss: 1.666
[3,  2000] loss: 1.648
[3,  4000] loss: 1.636
[3,  6000] loss: 1.618
[3,  8000] loss: 1.610
[3, 10000] loss: 1.576
[3, 12000] loss: 1.596
[4,  2000] loss: 1.559
[4,  4000] loss: 1.556
[4,  6000] loss: 1.554
[4,  8000] loss: 1.554
[4, 10000] loss: 1.540
[4, 12000] loss: 1.531
[5,  2000] loss: 1.505
[5,  4000] loss: 1.514
[5,  6000] loss: 1.509
[5,  8000] loss: 1.502
[5, 10000] loss: 1.498
[5, 12000] loss: 1.498
[6,  2000] loss: 1.459
[6,  4000] loss: 1.468
[6,  6000] loss: 1.466
[6,  8000] loss: 1.473
[6, 10000] loss: 1.473
[6, 12000] loss: 1.469
[7,  2000] loss: 1.424
[7,  4000] loss: 1.427
[7,  6000] loss: 1.447
[7,  8000] loss: 1.430
[7, 10000] loss: 1.447
[7, 12000] loss: 1.441
[8,  2000] loss: 1.400
[8,  4000] 

In [None]:
importlib.reload(utils)

def get_record(model_name, shuffled):
    outdir = "/scratch/gpfs/qanguyen/renorm"
    test_losses = []
    lrs = []
    itrs = []
    train_loss_progs = []
    testlosses_dict = []
    for f in glob.glob(f"{outdir}/{model_name}_shuffled_{shuffled}_*_rep_*"):
        record = utils.load_file_pickle(f)
        itr, train_loss_prog = list(zip(*enumerate(record["train_loss_prog"])))
        if len(itr) > 500:
            test_losses.append(record["test_loss"])
            itrs.extend(itr)
            train_loss_progs.extend(train_loss_prog)
            testlosses_dict.extend([record["test_loss"]] * len(itr))
    print("Number of runs", model_name, shuffled, len(test_losses))
    return test_losses, lrs, pd.DataFrame.from_dict(dict(itrs = itrs, 
                                                         train_loss_progs = train_loss_progs, 
                                                         testlosses_dict = testlosses_dict))
    
    
def plot(model_name):
    model_True_test_losses, _, shuffletrainloss  = get_record(model_name = model_name, shuffled = "True")
    model_False_test_losses, _, non_shuffletrainloss = get_record(model_name = model_name, shuffled = "False")
    model_True_test_losses = [i for i in model_True_test_losses if not np.isnan(i)]
    
    model_True_test_losses = [i for i in model_True_test_losses if i < np.sort(model_True_test_losses)[25]]
    model_False_test_losses = [i for i in model_False_test_losses if i < np.sort(model_False_test_losses)[25]]
    print(model_True_test_losses, len(model_True_test_losses))
    print(model_False_test_losses, len(model_False_test_losses))
    shuffletrainloss = shuffletrainloss[shuffletrainloss["testlosses_dict"] < shuffletrainloss["testlosses_dict"].quantile(0.5)]
    non_shuffletrainloss = non_shuffletrainloss[non_shuffletrainloss["testlosses_dict"] < non_shuffletrainloss["testlosses_dict"].quantile(0.5)]
    sns.lineplot(x = "itrs", y="train_loss_progs", data=shuffletrainloss, label="Shuffled")
    sns.lineplot(x = "itrs", y="train_loss_progs", data=non_shuffletrainloss, label="Non-Shuffled")
    plt.legend()
    plt.ylim(0, 3)
    plt.show()
    print("model_True_test_losses", model_True_test_losses)
    print("model_False_test_losses", model_False_test_losses)
    plt.hist(model_True_test_losses, bins=np.linspace(0, 2, 50), alpha=0.6)
    plt.hist(model_False_test_losses, bins=np.linspace(0, 2, 50), alpha=0.6)
    plt.xlim(0, 2)
    plt.show()

plot("attn")

plot("attn_no_pe")

plot("cnn")

plot("cnn_chan_1-1")

plot("cnn_chan_1-16")

plot("mlp")
