In [None]:
import yaml
from omegaconf import OmegaConf

import torch
from utils import make_model, set_random_seed, save_model, load_model
from trainer import train
from dataset import ShapeDataset, load_data
from dataset_config import DATASET_CONFIG

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import torchvision
import torchvision.transforms as transforms

import torch.nn.functional as F

from sklearn.cluster import KMeans
import fastcluster
from scipy.cluster.hierarchy import fcluster

import math

import matplotlib.pyplot as plt
from plotting import plot_phases, plot_results, plot_eval, plot_fourier, plot_phases2, plot_masks, plot_slots, build_color_mask, plot_clusters, plot_clusters2

from loss_metrics import get_ar_metrics, compute_pixelwise_accuracy, compute_iou

import os
import numpy as np
import imageio
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from IPython.display import display
import ipywidgets as widgets

import matplotlib.animation as animation
import matplotlib.pyplot as plt
from IPython.display import HTML

import matplotlib.gridspec as gridspec

import seaborn as sns

In [None]:
sns.set()

# Data Paths

In [None]:
# Function to load a YAML file
def load_yaml_file(file_path):
    with open(file_path, 'r') as file:
        return yaml.safe_load(file)['params']

folders = [
    "ccn8/new_tetronimoes/conv_recurrent2/5/linear_lstm_20iters",
    "ccn8/new_tetronimoes/cornn_model2/9/linear_100iters",
]
folder = 'experiments'
hydra_config_file = '.hydra/config.yaml'
paths = [f"{folder}/{curr}" for curr in folders]

configs = [load_yaml_file(f"{p}/{hydra_config_file}") for p in paths]

In [None]:
# Setup
seed = 1
set_random_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load models

In [None]:
def load_model(cp_folder, config, device, data_config):
    net = make_model(
        device,
        config['model_type'],
        config['num_classes'],
        config['N'],
        config['dt'],
        config['min_iters'],
        config['max_iters'],
        data_config['channels'],
        config['c_mid'],
        config['hidden_channels'],
        config['rnn_kernel'],
        data_config['img_size'],
        config['kernel_init'],
        cell_type=config['cell_type'],
        num_layers=config['num_layers'],
        readout_type=config['readout_type'],
    )
    net.load_state_dict(torch.load(f"{cp_folder}/cp.pt", 
                                   map_location=torch.device('cpu')), 
                                   strict=False)
    net.eval()
    return net.to(device)

In [None]:
models = [load_model(paths[i], configs[i], device, DATASET_CONFIG['new_tetronimoes']) for i in range(len(paths))]

# Forward

In [None]:
def fft_readout(net, y_seq, B, H, W):
    fft_vals = torch.fft.rfft(y_seq, dim=1) # (B, K, c_out, H, W)
    fft_mag = torch.abs(fft_vals) # (B, K, c_out, H, W)
    return fft_mag

def linear_readout(net, y_seq, B, H, W):
    y_seq = y_seq.reshape(B, net.T, net.c_out, -1)
    y_seq = y_seq.transpose(1, 3)
    fft_vals = net.fc_time(y_seq)
    fft_mag = fft_vals.transpose(1, 3) # (B, K, C, H*W)
    fft_mag = fft_mag.reshape(B, fft_mag.size(1), fft_mag.size(2), H, W)
    return fft_mag

# Set up data

In [None]:
# Load data
data_config1 = DATASET_CONFIG['new_tetronimoes']
_, valset, _ = load_data('new_tetronimoes', data_config1)

val_loader = DataLoader(valset, batch_size=16, shuffle=True, drop_last=False)
batch1 = next(iter(val_loader))

testsets = {
    'new_tetronimoes' : batch1,
}

In [None]:
states = []
ffts = []
masks = []
for i, net in enumerate(models):
    config = configs[i]
    dataset = config['dataset']
    batch = testsets[dataset]
    x, x_target = batch
    batch_size = x.size(0)
    x = x.to(device) #torch.Size([16, 2, 3, 40, 40]) 
    logits, y_seq = net(x)
    fft_mag = linear_readout(net.classifier, y_seq, x.size(0), x.size(-2), x.size(-1))
    states.append(y_seq)
    ffts.append(fft_mag)
    masks.append(logits.argmax(dim=1))

In [None]:
print(states[0].shape, states[1].shape)

In [None]:
print(ffts[0].shape, ffts[1].shape)

In [None]:
print(masks[0].shape, masks[1].shape)

# Plot masks

In [None]:
def plot_masks(masks, title):
    masks = masks.detach().cpu().numpy()
    fig, axes = plt.subplots(1, 16, figsize=(16, 1))
    for i in range(16):
        axes[i].imshow(masks[i])
        axes[i].set_xticks([])
        axes[i].set_yticks([])
    axes[0].set_title(title)
    plt.show()

In [None]:
for i, net in enumerate(models):
    plot_masks(masks[i], title=configs[i]['model_type'])

# Look at gifs and choose timesteps we want to plot

In [None]:
def plot_hidden_state_video(y_seq, sample_idx=0, interval=200):
    """
    Given y_seq of shape (T,B,H,W), animate the hidden state for the sample
    `sample_idx` across timesteps T.
    
    - `interval` controls the animation speed (milliseconds between frames).
    - returns: HTML object that, when displayed in Jupyter, shows the animation.
    """
    T, B, H, W = y_seq.shape
    assert 0 <= sample_idx < B, f"sample_idx must be in [0..{B-1}]"
    
    # Subsample to 100 frames if sequence is too long
    if T > 100:
        indices = np.linspace(0, T-1, 100, dtype=int)
        y_seq = y_seq[indices]
        T = 100
    
    # We'll animate frames across t=0..T-1
    #  shape => (T,H,W)
    y_seq_np = y_seq[:, sample_idx].cpu().numpy()  # -> (T,H,W)
    
    # We can pick vmin/vmax across the entire timeseries for a stable color scale
    vmin = y_seq_np.min()
    vmax = y_seq_np.max()
    
    fig, ax = plt.subplots()
    im = ax.imshow(y_seq_np[0], cmap='bwr', vmin=vmin, vmax=vmax)
    ax.set_title(f"Hidden state evolution (sample={sample_idx})")
    plt.colorbar(im, ax=ax)
    
    def animate(t):
        im.set_array(y_seq_np[t])
        ax.set_xlabel(f"t = {t}")
        return [im]
    
    ani = animation.FuncAnimation(
        fig, animate, 
        frames=T, 
        interval=interval, 
        blit=True
    )
    plt.close(fig)  # so that we don't get a duplicate static plot
    return HTML(ani.to_jshtml())

def plot_hidden(y, sample, channel, interval=200):
    y = torch.transpose(y, 0, 1).detach()
    return plot_hidden_state_video(y[:,:,channel], sample_idx=sample, interval=200)

In [None]:
sample = 5

In [None]:
plot_hidden(states[0], sample=sample, channel=1)

In [None]:
plot_hidden(states[1], sample=sample, channel=1)

In [None]:
ts2 = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]

# Generate sequential code

In [None]:
#  torch.Size([5, 151, 2, 64, 64]))
def plot_fft(the_f, sample):
    f_plot = the_f.detach().cpu().numpy()[sample] # K x C x N x N
    #f_plot = np.transpose(f_plot, (1, 0, 2, 3)) # C x K x N x N
    K, C, H, W = f_plot.shape
    fig, axes = plt.subplots(K, C, figsize=(C, K))
    for i in range(K):
        for j in range(C):
            axes[i][j].imshow(f_plot[i][j], cmap='gray')
            axes[i][j].set_xticks([])
            axes[i][j].set_yticks([])
            axes[i][j].set_title(f"K{i}, C{j}")
    fig.tight_layout()
    plt.show()

def plot_fft_channel(fft, bins, sample, channel, fpath):
    fft = fft.detach().cpu().numpy()[sample][:,channel]

    num_plots = len(bins)

    fig, axes = plt.subplots(1, num_plots, figsize=(4 * num_plots, 4))
    for i, bin in enumerate(bins):
        axes[i].imshow(fft[bin], cmap='gray', interpolation='bilinear')
        axes[i].set_xticks([])
        axes[i].set_yticks([])
    fig.tight_layout()
    plt.savefig(fpath)
    plt.show()

def plot_states_channel(states, ts, fpath, global_scale=False, sample=0, channel=0, plot_name=None, cmap='gray'):
    states = states.detach().cpu().numpy()
    states = states[sample][:,channel]

    # SUBSAMPLE
    T = len(states)
    if T > 100:
        indices = np.linspace(0, T-1, 100, dtype=int)
        states = states[indices]

    num_plots = len(ts)

    if global_scale:
        vmin, vmax = np.min(states), np.max(states)
        #vmin = y_seq_np.min()
        #vmax = y_seq_np.max()
    else:
        vmin, vmax = None, None

    fig, axes = plt.subplots(1, num_plots, figsize=(4 * num_plots, 4))
    for i, ts in enumerate(ts):
        state_to_plot = states[ts].copy()
        state_to_plot = state_to_plot[:32, 32:]
        axes[i].imshow(state_to_plot, cmap=cmap, vmin=vmin, vmax=vmax, interpolation='bilinear')
        axes[i].set_xticks([])
        axes[i].set_yticks([])
    fig.tight_layout()
    plt.savefig(fpath, dpi=300)
    plt.show()

"""
mask: batch x n x n (filled with index values)
num_slots: integer
"""
def build_color_mask(mask, num_slots=6):
    img_size = mask.shape[-1]
    mask_colors = np.array([
        [0, 0, 0],        # Black
        [255, 0, 0],      # Red
        [255, 127, 0],    # Orange
        [255, 255, 0],    # Yellow
        [0, 255, 0],      # Green   
        [0, 0, 255],      # Blue
    ])
    colored_mask = np.zeros((img_size, img_size, 3), dtype=np.uint8)
    # Assign colors to each pixel based on class
    for i in range(num_slots):
        colored_mask[mask == i] = mask_colors[i]
    return colored_mask

def plot_masks(masks, sample, fpath):
    masks = masks.detach().cpu().numpy()[sample]
    masks = build_color_mask(masks, 6)
    fig = plt.figure()
    plt.imshow(masks)
    plt.xticks([])
    plt.yticks([])
    plt.savefig(fpath)
    plt.show()

# LSTM - LOOK AT FULL FOURIER

In [None]:
plot_fft(ffts[0], sample=sample)

In [None]:
channel = 1
b1 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
fp = f"results/fig_lstm_nwm/lstm_{sample}_{channel}"
plot_fft_channel(ffts[0], b1, fpath=fp + "_fft.pdf", sample=sample, channel=channel)

In [None]:
channel = 1
ts1 = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
fp = f"results/fig_lstm_nwm/lstm_{sample}_{channel}"
plot_states_channel(states[0], ts1, fpath=fp + "_state.pdf", sample=sample, channel=channel, cmap='bwr', global_scale=True)

# CORNN - LOOK AT FULL FOURIER

In [None]:
plot_fft(ffts[1][:,0:51], sample=sample)

In [None]:
channel = 1
b2 = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45]
ts2 = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]

fp = f"results/fig_lstm_nwm/cornn_{sample}_{channel}"
plot_fft_channel(ffts[1], b2, fpath=fp + "_fft.pdf", sample=sample, channel=channel)

fp = f"results/fig_lstm_nwm/cornn_{sample}_{channel}"
plot_states_channel(states[1], ts2, fpath=fp + "_state.pdf", sample=sample, channel=channel, cmap='bwr', global_scale=True)

# Save GT Image


In [None]:
gt_image = testsets['new_tetronimoes'][0][sample]

In [None]:
gt_image = torch.permute(gt_image, (1, 2, 0))
gt_image = gt_image.cpu().numpy()

In [None]:
fig = plt.figure(figsize=(10, 10))
plt.imshow(gt_image)
plt.xticks([])
plt.yticks([])
plt.savefig(f"results/fig_lstm_nwm/sample{sample}_gt.pdf")