In [1]:
import yaml
from omegaconf import OmegaConf

import torch
from utils import make_model, set_random_seed, save_model, load_model
from trainer import train
from dataset import ShapeDataset, load_data
from dataset_config import DATASET_CONFIG

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import torchvision
import torchvision.transforms as transforms



import torch.nn.functional as F

from sklearn.cluster import KMeans
import fastcluster
from scipy.cluster.hierarchy import fcluster

from sklearn.metrics import adjusted_rand_score
from sklearn.metrics import jaccard_score

import math

import matplotlib.pyplot as plt
from plotting import plot_phases, plot_results, plot_eval, plot_fourier, plot_phases2, plot_masks, plot_slots, build_color_mask, plot_clusters, plot_clusters2

from loss_metrics import get_ar_metrics, compute_pixelwise_accuracy, compute_iou
from loss_metrics import compute_pixelwise_accuracy

import os
import glob
import numpy as np
import imageio
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from IPython.display import display
import ipywidgets as widgets

# Data Paths

In [2]:
# Function to load a YAML file
def load_yaml_file(file_path):
    with open(file_path, 'r') as file:
        return yaml.safe_load(file)['params']

In [3]:
datasets = ["mnist", "new_tetronimoes"]
model_types = ["baseline3"]
seed_folders = ["1", "2", "3"]

extensions = []
for curr1 in datasets:
    for curr2 in model_types:
        for curr3 in seed_folders:
            extensions.append(f"{curr1}/{curr2}/{curr3}/")

In [4]:
extensions

['mnist/baseline3/1/',
 'mnist/baseline3/2/',
 'mnist/baseline3/3/',
 'new_tetronimoes/baseline3/1/',
 'new_tetronimoes/baseline3/2/',
 'new_tetronimoes/baseline3/3/']

In [5]:
#folder = '/n/ba_lab/Everyone/mjacobs/projects/shape-object_store/realign3/'
folder_new = 'experiments/'
hydra_config_file = '.hydra/config.yaml'

In [6]:
fpaths = []
folders = []
for ext in extensions:
    #if 'mnist' in ext or 'hidden' in ext: # re-ran mnist results, also ran hidden results, stored in realign4
    #    ext = folder_new + ext
    #else:
    #    ext = folder + ext
    ext = folder_new + ext
    # Use glob to find the specific file in each directory
    search_pattern = os.path.join(ext, '**', '.hydra', 'config.yaml')
    for file_path in glob.iglob(search_pattern, recursive=True):
        # Check if it's a file
        if os.path.isfile(file_path):
            fpaths.append(file_path)
            folders.append(os.path.dirname(os.path.dirname(file_path)) + "/")

In [7]:
fpaths

['/n/ba_lab/Everyone/mjacobs/projects/shape-object_store/realign4/mnist/baseline3/1/lstm_100/.hydra/config.yaml',
 '/n/ba_lab/Everyone/mjacobs/projects/shape-object_store/realign4/mnist/baseline3/1/rnn_100/.hydra/config.yaml',
 '/n/ba_lab/Everyone/mjacobs/projects/shape-object_store/realign4/mnist/baseline3/2/lstm_100/.hydra/config.yaml',
 '/n/ba_lab/Everyone/mjacobs/projects/shape-object_store/realign4/mnist/baseline3/2/rnn_100/.hydra/config.yaml',
 '/n/ba_lab/Everyone/mjacobs/projects/shape-object_store/realign4/mnist/baseline3/3/lstm_100/.hydra/config.yaml',
 '/n/ba_lab/Everyone/mjacobs/projects/shape-object_store/realign4/mnist/baseline3/3/rnn_100/.hydra/config.yaml',
 '/n/ba_lab/Everyone/mjacobs/projects/shape-object_store/realign4/new_tetronimoes/baseline3/1/lstm_100/.hydra/config.yaml',
 '/n/ba_lab/Everyone/mjacobs/projects/shape-object_store/realign4/new_tetronimoes/baseline3/1/rnn_100/.hydra/config.yaml',
 '/n/ba_lab/Everyone/mjacobs/projects/shape-object_store/realign4/new_te

In [8]:
configs = [load_yaml_file(p) for p in fpaths]

In [9]:
from dataset import MNISTSegmentationDataset, load_new_tetrominoes, ShapeDataset
from torchvision import datasets as torchvision_datasets

def load_data(dataset, data_config, num_train=None, num_test=None, scale_min=0.7, transform_set='set1', normalize=True):
    if dataset == 'mnist':
        testset = torchvision_datasets.MNIST(data_config['test_path'], train=False, download=False,
                                             transform=transforms.ToTensor())
        testset = MNISTSegmentationDataset(testset, data_config['img_size'])
        return testset
    elif dataset == 'new_tetronimoes':
        x, y = load_new_tetrominoes(data_config['x_test_path'], data_config['y_test_path'])
        return ShapeDataset(x, y)

In [10]:
 # Setup
seed = 7
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_random_seed(seed)

# Load data
data_config1 = DATASET_CONFIG['new_tetronimoes']
data_config2 = DATASET_CONFIG['mnist']
testset1 = load_data('new_tetronimoes', data_config1)
testset2 = load_data('mnist', data_config2)

# Load models

In [11]:
# Load saved models
def load_model(cp_folder, config, device, data_config):
    data_config = data_config[config['dataset']]
    if 'cell_type' in config.keys():
        cell_type = config['cell_type']
    else:
        cell_type = None
    if 'dt1' in config.keys():
        dt = config['dt1']
    else:
        dt = config['dt']
    # Make model
    net = make_model(
        device,
        config['block_type'],
        config['model_type'],
        config['oscillator_type'],
        config['num_classes'],
        config['N'],
        config['M'],
        dt,
        config['min_iters'],
        config['max_iters'],
        data_config['channels'],
        config['hidden_channels'],
        config['rnn_kernel'],
        config['num_blocks'],
        config['num_slots'],
        config['num_iters'],
        data_config['img_size'],
        config['kernel_init'],
        cell_type=cell_type,
        num_layers=config['num_layers'],
    )

    net.load_state_dict(torch.load(cp_folder + "cp.pt"), strict=False)
    net.eval()
    return net.to(device)

In [12]:
models = [load_model(folders[i], configs[i], device, DATASET_CONFIG) for i in range(len(folders))]

# More score functions

In [13]:
def calculate_acc(predictions, ground_truth, ignore_class=0):
    """
    Calculate the average pixel-wise accuracy for each image in the batch.

    Args:
        predictions (torch.Tensor): Predicted segmentation masks, shape (B, H, W)
        ground_truth (torch.Tensor): Ground truth segmentation masks, shape (B, H, W)
        ignore_class (int, optional): Class to ignore in accuracy calculation. Default is 0.

    Returns:
        torch.Tensor: Average pixel-wise accuracy for each image, shape (B,)
    """
    # Ensure the tensors are of the same shape
    if predictions.shape != ground_truth.shape:
        raise ValueError("Shape of predictions and ground_truth must match.")

    # Create a mask for pixels that are NOT the ignore_class
    mask = ground_truth != ignore_class  # Shape: (B, H, W)

    # Calculate correct predictions where mask is True
    correct = (predictions == ground_truth) & mask  # Shape: (B, H, W)

    # Sum correct predictions per image
    correct_per_image = correct.view(correct.size(0), -1).sum(dim=1).float()  # Shape: (B,)

    # Sum valid (non-ignored) pixels per image
    total_per_image = mask.view(mask.size(0), -1).sum(dim=1).float()  # Shape: (B,)

    # Calculate accuracy per image
    accuracy = correct_per_image / total_per_image  # Shape: (B,)
    accuracy = accuracy.cpu().numpy()
    return np.sum(accuracy)

In [14]:
def calc_iou(predictions, ground_truth, ignore_class=0, num_classes=None):
    """
    Calculate the mean Intersection over Union (IoU) for each image in the batch.

    Args:
        predictions (torch.Tensor): Predicted segmentation masks, shape (B, H, W)
        ground_truth (torch.Tensor): Ground truth segmentation masks, shape (B, H, W)
        ignore_class (int, optional): Class to ignore in IoU calculation. Default is 0.
        num_classes (int, optional): Total number of classes. If None, inferred from data.

    Returns:
        torch.Tensor: Mean IoU for each image, shape (B,)
    """
    if predictions.shape != ground_truth.shape:
        raise ValueError("Shape of predictions and ground_truth must match.")

    if num_classes is None:
        num_classes = int(max(predictions.max(), ground_truth.max()) + 1)

    batch_size = predictions.size(0)

    # Create a mask to ignore the specified class
    mask = ground_truth != ignore_class  # Shape: (B, H, W)
    mask = mask.unsqueeze(1)

    # Expand dimensions to (B, C, H, W) for one-hot encoding
    predictions_one_hot = torch.nn.functional.one_hot(predictions, num_classes=num_classes).permute(0, 3, 1, 2)  # (B, C, H, W)
    ground_truth_one_hot = torch.nn.functional.one_hot(ground_truth, num_classes=num_classes).permute(0, 3, 1, 2)  # (B, C, H, W)

    # APPLY MASK
    predictions_one_hot = predictions_one_hot * mask
    ground_truth_one_hot = ground_truth_one_hot * mask
    intersection = (predictions_one_hot & ground_truth_one_hot) # B x C x H x W
    union = (predictions_one_hot | ground_truth_one_hot) # B x C x H x W

    intersection = intersection.sum((1, 2, 3))
    union = union.sum((1, 2, 3))
    iou = intersection / union
    iou = iou.cpu().numpy()
    return np.sum(iou)
    

    ious = []
    for i in range(batch_size):
        # Exclude the ignore_class
        curr_preds = predictions_one_hot[i][mask[i]]
        curr_gt = ground_truth_one_hot[i][mask[i]]
        
        # Compute intersection and union
        intersection = (curr_preds & curr_gt).float().sum()
        union = (curr_preds | curr_gt).float().sum()

        # Avoid division by zero
        #eps = 1e-6
        #iou = intersection / (union + eps)
        iou = intersection / union
        ious.append(iou.item())

    iou = iou.cpu().numpy()
    return np.sum(iou)

In [15]:
def calc_ari(predictions, ground_truth, ignore_class=0):
    """
    Calculate the Adjusted Rand Index (ARI) for each image in the batch.

    Args:
        predictions (torch.Tensor): Predicted segmentation masks, shape (B, H, W)
        ground_truth (torch.Tensor): Ground truth segmentation masks, shape (B, H, W)
        ignore_class (int, optional): Class to ignore in ARI calculation. Default is 0.

    Returns:
        torch.Tensor: Adjusted Rand Index for each image, shape (B,)
    """
    from sklearn.metrics import adjusted_rand_score

    B = predictions.shape[0]
    ari_scores = []

    # Iterate over the batch
    for i in range(B):
        pred = predictions[i].flatten()
        gt = ground_truth[i].flatten()

        # Create mask to ignore the specified class
        mask = gt != ignore_class
        pred_masked = pred[mask]
        gt_masked = gt[mask]

        # Convert tensors to numpy arrays for sklearn
        pred_np = pred_masked.cpu().numpy()
        gt_np = gt_masked.cpu().numpy()

        # Calculate ARI using sklearn
        ari = adjusted_rand_score(gt_np, pred_np)
        ari_scores.append(ari)

    return np.sum(ari_scores)

# Evaluate Scores

In [16]:
def eval_scores(net, testset, device, batch_size, num_classes, ignore_class=0):
    loss_func = nn.CrossEntropyLoss()
    net.eval()
    total_loss = 0
    total_acc = 0
    total_iou = 0
    total_ari = 0

    testloader = DataLoader(testset, batch_size=batch_size, shuffle=True, drop_last=False)
    with torch.no_grad():
        for x in testloader:
            x, x_target = x
            x_target = x_target.to(device).type(torch.long)
            x = x.to(device)
            b, c, h, w = x.size()
            x_pred_classifier = net(x)
            # LOSS
            loss = loss_func(x_pred_classifier, x_target)
            loss = loss.item()
            # CLASSIFIER ARI
            x_pred_classifier = torch.argmax(x_pred_classifier, dim=1)
            
            acc = calculate_acc(x_pred_classifier, x_target, ignore_class=ignore_class)
            iou = calc_iou(x_pred_classifier, x_target, num_classes=num_classes, ignore_class=ignore_class)
            #ari = calc_ari(x_pred_classifier, x_target, ignore_class=ignore_class)

            # Store
            total_loss += loss * b
            total_acc += acc
            total_iou += iou
            #total_ari += ari

    num_samples = len(testset)
    loss = total_loss / num_samples
    #ari = total_ari / num_samples
    iou = total_iou / num_samples
    acc = total_acc / num_samples

    return loss, iou, acc
    #return loss, ari, iou, acc

In [17]:
testsets = {
    'new_tetronimoes' : testset1,
    'mnist' : testset2,
}
ignore_class = 0

In [18]:
scores = []
for i, net in enumerate(models):
    print(f"{i + 1}/{len(models)}")
    config = configs[i]
    score = eval_scores(net, testsets[config['dataset']], device, batch_size=64, num_classes=config['num_classes'], ignore_class=ignore_class)
    scores.append(score)
    if i == 24:
        np.save('scores_halfway_forgot.npy', np.array(scores))
scores = np.array(scores)
np.save("scores_forgot.npy", scores)

1/12
2/12
3/12
4/12
5/12
6/12
7/12
8/12
9/12
10/12
11/12
12/12


In [19]:
all_scores = {}
for i, net in enumerate(models):
    config = configs[i]
    score = scores[i]
    dataset = config['dataset']
    model_type = config['model_type']
    if dataset not in all_scores:
        all_scores[dataset] = {}

    # handle special case where we have lstm vs rnn
    full_model_type = model_type
    if model_type == 'baseline3_fft' or model_type == 'baseline3':
        cell_type = config['cell_type']
        full_model_type = f"{model_type}-{cell_type}"
    if model_type == 'baseline1_flexible':
        num_layers = config['num_layers']
        full_model_type = f"{model_type}-{num_layers}"
    if full_model_type not in all_scores[dataset]:
        all_scores[dataset][full_model_type] = {
            'loss' : [],
            #'ari' : [],
            'iou' : [],
            'acc' : []
        }
    all_scores[dataset][full_model_type]['loss'].append(score[0])
    #all_scores[dataset][full_model_type]['ari'].append(score[1])
    all_scores[dataset][full_model_type]['iou'].append(score[1])
    all_scores[dataset][full_model_type]['acc'].append(score[2])

In [20]:
score_stats = {}
for dataset in all_scores:
    score_stats[dataset] = {}
    for model_type in all_scores[dataset]:
        score_stats[dataset][model_type] = {}
        curr_scores = all_scores[dataset][model_type]
        for metric in curr_scores:
            mean = np.mean(curr_scores[metric])
            std = np.std(curr_scores[metric])
            score_stats[dataset][model_type][metric] = mean, std

In [21]:
score_stats

{'mnist': {'baseline3-lstm': {'loss': (0.6948826508522034,
    7.92936117696532e-05),
   'iou': (0.0, 0.0),
   'acc': (0.0, 0.0)},
  'baseline3-rnn': {'loss': (0.6949432949384055, 3.315571287538999e-05),
   'iou': (0.0, 0.0),
   'acc': (0.0, 0.0)}},
 'new_tetronimoes': {'baseline3-lstm': {'loss': (0.6459340534210205,
    0.00031856562923701377),
   'iou': (0.0, 0.0),
   'acc': (0.0, 0.0)},
  'baseline3-rnn': {'loss': (0.6455641514460245, 0.00013795588264729538),
   'iou': (0.0, 0.0),
   'acc': (0.0, 0.0)}}}

In [22]:
import json

In [23]:
with open("score_stats_forgot.json", 'w') as json_file:
    json.dump(score_stats, json_file, indent=4)