In [None]:
import os

from sa_convlstm import SAConvLSTM
from convlstm import ConvLSTM
from utils import *

import torch
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader

import sys
import pickle
from tqdm import tqdm
import numpy as np
import math
import argparse
import json

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
save_dir = './save_models/saclstm_epoch_50_pinchu_pandamonium/'
# save_dir = './save_models/saclstm_pcrps_epoch_50_pinchu_pandamonium/'

parser = argparse.ArgumentParser()
args = parser.parse_args(args="")
with open(save_dir + 'args.txt', 'r') as f:
    args.__dict__ = json.load(f)

if args.convlstm and not args.prob_crps:
    sam_network = ConvLSTM(args.input_dim, args.hidden_dim, args.output_dim,
                             args.kernel_size, device, dropout=args.dropout).to(device)
elif args.convlstm and args.prob_crps:
    sam_network = ConvLSTM(2*args.input_dim, args.hidden_dim, 2*args.output_dim,
                       args.kernel_size, device, dropout=args.dropout).to(device)
elif args.saconvlstm and not args.prob_crps:   
    sam_network = SAConvLSTM(args.input_dim, args.hidden_dim, args.output_dim, args.attn_dim,
                         args.kernel_size, device, dropout=args.dropout).to(device)
else:
    sam_network = SAConvLSTM(2*args.input_dim, args.hidden_dim, 2*args.output_dim, args.attn_dim,
                         args.kernel_size, device, dropout=args.dropout).to(device)

max_ep = 0
for f in os.listdir(save_dir):
    split = f.split("_")
    if len(split)==2 and split[1] == "checkpoint.chk":
        if max_ep<int(split[0]): max_ep=int(split[0])
chkpnt = str(max_ep) + "_checkpoint.chk" 
chk = torch.load(save_dir + chkpnt)
sam_network.load_state_dict(chk['net'])

In [None]:
save_dir = './save_models/saclstm_orth_2_epoch_50_pinchu_pandamonium/'
parser = argparse.ArgumentParser()
args = parser.parse_args(args="")
with open(save_dir + 'args.txt', 'r') as f:
    args.__dict__ = json.load(f)

if args.convlstm and not args.prob_crps:
    sam_ortho_network = ConvLSTM(args.input_dim, args.hidden_dim, args.output_dim,
                             args.kernel_size, device, dropout=args.dropout).to(device)
elif args.convlstm and args.prob_crps:
    sam_ortho_network = ConvLSTM(2*args.input_dim, args.hidden_dim, 2*args.output_dim,
                       args.kernel_size, device, dropout=args.dropout).to(device)
elif args.saconvlstm and not args.prob_crps:   
    sam_ortho_network = SAConvLSTM(args.input_dim, args.hidden_dim, args.output_dim, args.attn_dim,
                         args.kernel_size, device, dropout=args.dropout).to(device)
else:
    sam_ortho_network = SAConvLSTM(2*args.input_dim, args.hidden_dim, 2*args.output_dim, args.attn_dim,
                         args.kernel_size, device, dropout=args.dropout).to(device)

max_ep = 0
for f in os.listdir(save_dir):
    split = f.split("_")
    if len(split)==2 and split[1] == "checkpoint.chk":
        if max_ep<int(split[0]): max_ep=int(split[0])
chkpnt = str(max_ep) + "_checkpoint.chk" 
chk = torch.load(save_dir + chkpnt)
sam_ortho_network.load_state_dict(chk['net'])

In [None]:
save_dir = './save_models/saclstm_orth_0_epoch_50_pinchu_pandamonium/'
parser = argparse.ArgumentParser()
args = parser.parse_args(args="")
with open(save_dir + 'args.txt', 'r') as f:
    args.__dict__ = json.load(f)

if args.convlstm and not args.prob_crps:
    sam_ortho_full_network = ConvLSTM(args.input_dim, args.hidden_dim, args.output_dim,
                             args.kernel_size, device, dropout=args.dropout).to(device)
elif args.convlstm and args.prob_crps:
    sam_ortho_full_network = ConvLSTM(2*args.input_dim, args.hidden_dim, 2*args.output_dim,
                       args.kernel_size, device, dropout=args.dropout).to(device)
elif args.saconvlstm and not args.prob_crps:   
    sam_ortho_full_network = SAConvLSTM(args.input_dim, args.hidden_dim, args.output_dim, args.attn_dim,
                         args.kernel_size, device, dropout=args.dropout).to(device)
else:
    sam_ortho_full_network = SAConvLSTM(2*args.input_dim, args.hidden_dim, 2*args.output_dim, args.attn_dim,
                         args.kernel_size, device, dropout=args.dropout).to(device)

max_ep = 0
for f in os.listdir(save_dir):
    split = f.split("_")
    if len(split)==2 and split[1] == "checkpoint.chk":
        if max_ep<int(split[0]): max_ep=int(split[0])
chkpnt = str(max_ep) + "_checkpoint.chk"
chk = torch.load(save_dir + chkpnt)
sam_ortho_full_network.load_state_dict(chk['net'])

In [None]:
sims_sa_0 = []
sims_sa_1 = []
sims_sa_2 = []
sims_sa_3 = []
for i in range(32):
    for j in range(i, 32):
        if i!=j:
            sims_sa_0.append(torch.nn.functional.cosine_similarity(sam_network.layers[0].sa.conv_z.weight[i,:,0,0], sam_network.layers[0].sa.conv_z.weight[j,:,0,0], dim=0).item())
            sims_sa_1.append(torch.nn.functional.cosine_similarity(sam_network.layers[1].sa.conv_z.weight[i,:,0,0], sam_network.layers[1].sa.conv_z.weight[j,:,0,0], dim=0).item())
            sims_sa_2.append(torch.nn.functional.cosine_similarity(sam_network.layers[2].sa.conv_z.weight[i,:,0,0], sam_network.layers[2].sa.conv_z.weight[j,:,0,0], dim=0).item())
            sims_sa_3.append(torch.nn.functional.cosine_similarity(sam_network.layers[3].sa.conv_z.weight[i,:,0,0], sam_network.layers[3].sa.conv_z.weight[j,:,0,0], dim=0).item())

In [None]:
sims_sa_ortho_0 = []
sims_sa_ortho_1 = []
sims_sa_ortho_2 = []
sims_sa_ortho_3 = []
for i in range(32):
    for j in range(i, 32):
        if i!=j:
            sims_sa_ortho_0.append(torch.nn.functional.cosine_similarity(sam_ortho_network.layers[0].sa.conv_z.weight[i,:,0,0], sam_ortho_network.layers[0].sa.conv_z.weight[j,:,0,0], dim=0).item())
            sims_sa_ortho_1.append(torch.nn.functional.cosine_similarity(sam_ortho_network.layers[1].sa.conv_z.weight[i,:,0,0], sam_ortho_network.layers[1].sa.conv_z.weight[j,:,0,0], dim=0).item())
            sims_sa_ortho_2.append(torch.nn.functional.cosine_similarity(sam_ortho_network.layers[2].sa.conv_z.weight[i,:,0,0], sam_ortho_network.layers[2].sa.conv_z.weight[j,:,0,0], dim=0).item())
            sims_sa_ortho_3.append(torch.nn.functional.cosine_similarity(sam_ortho_network.layers[3].sa.conv_z.weight[i,:,0,0], sam_ortho_network.layers[3].sa.conv_z.weight[j,:,0,0], dim=0).item())

In [None]:
sims_sa_ortho_full_0 = []
sims_sa_ortho_full_1 = []
sims_sa_ortho_full_2 = []
sims_sa_ortho_full_3 = []
for i in range(32):
    for j in range(i, 32):
        if i!=j:
            sims_sa_ortho_full_0.append(torch.nn.functional.cosine_similarity(sam_ortho_full_network.layers[0].sa.conv_z.weight[i,:,0,0], sam_ortho_full_network.layers[0].sa.conv_z.weight[j,:,0,0], dim=0).item())
            sims_sa_ortho_full_1.append(torch.nn.functional.cosine_similarity(sam_ortho_full_network.layers[1].sa.conv_z.weight[i,:,0,0], sam_ortho_full_network.layers[1].sa.conv_z.weight[j,:,0,0], dim=0).item())
            sims_sa_ortho_full_2.append(torch.nn.functional.cosine_similarity(sam_ortho_full_network.layers[2].sa.conv_z.weight[i,:,0,0], sam_ortho_full_network.layers[2].sa.conv_z.weight[j,:,0,0], dim=0).item())
            sims_sa_ortho_full_3.append(torch.nn.functional.cosine_similarity(sam_ortho_full_network.layers[3].sa.conv_z.weight[i,:,0,0], sam_ortho_full_network.layers[3].sa.conv_z.weight[j,:,0,0], dim=0).item())

In [None]:
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter
plt.hist(sims_sa_0, bins=20, color='r', weights=np.ones(len(sims_sa_0))/len(sims_sa_0), alpha=0.5, label='SACLSTM - Cell 1')
plt.hist(sims_sa_ortho_0, bins=30, color='g', weights=np.ones(len(sims_sa_ortho_0))/len(sims_sa_ortho_0), alpha=0.5, label='OSACLSTM - Cell 1')
plt.hist(sims_sa_ortho_full_0, bins=1, color='b', weights=np.ones(len(sims_sa_ortho_full_0))/len(sims_sa_ortho_full_0), alpha=0.5, label='f-OSACLSTM - Cell 1')
plt.gca().yaxis.set_major_formatter(PercentFormatter(1))
plt.title("Cell 1")
# plt.legend()
plt.show()

In [None]:
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter
plt.hist(sims_sa_1, bins=30, color='r', weights=np.ones(len(sims_sa_0))/len(sims_sa_0), alpha=0.5, label='SACLSTM')
plt.hist(sims_sa_ortho_1, bins=20, color='g', weights=np.ones(len(sims_sa_ortho_0))/len(sims_sa_ortho_0), alpha=0.5, label='OSACLSTM')
plt.hist(sims_sa_ortho_full_1, bins=1, color='b', weights=np.ones(len(sims_sa_ortho_full_1))/len(sims_sa_ortho_full_1), alpha=0.5, label='f-OSACLSTM')
plt.gca().yaxis.set_major_formatter(PercentFormatter(1))
plt.title("Cell 2")
plt.legend()
plt.show()

In [None]:
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter
plt.hist(sims_sa_2, bins=25, color='r', weights=np.ones(len(sims_sa_0))/len(sims_sa_0), alpha=0.5, label='SACLSTM - Cell 3')
plt.hist(sims_sa_ortho_2, bins=30, color='g', weights=np.ones(len(sims_sa_ortho_0))/len(sims_sa_ortho_0), alpha=0.5, label='OSACLSTM - Cell 3')
plt.hist(sims_sa_ortho_full_2, bins=1, color='b', weights=np.ones(len(sims_sa_ortho_full_2))/len(sims_sa_ortho_full_2), alpha=0.5, label='f-OSACLSTM - Cell 3')
plt.gca().yaxis.set_major_formatter(PercentFormatter(1))
plt.title("Cell 3")
# plt.legend()
plt.show()

In [None]:
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter
plt.hist(sims_sa_3, bins=25, color='r', weights=np.ones(len(sims_sa_0))/len(sims_sa_0), alpha=0.5, label='SACLSTM - Cell 4')
plt.hist(sims_sa_ortho_3, bins=10, color='g', weights=np.ones(len(sims_sa_ortho_0))/len(sims_sa_ortho_0), alpha=0.5, label='OSACLSTM - Cell 4')
plt.hist(sims_sa_ortho_full_3, bins=1, color='b', weights=np.ones(len(sims_sa_ortho_full_3))/len(sims_sa_ortho_full_3), alpha=0.5, label='f-OSACLSTM - Cell 4')
plt.gca().yaxis.set_major_formatter(PercentFormatter(1))
plt.title("Cell 4")
#plt.legend()
plt.show()