In [None]:
import os, sys

sys.path.append("../") # add base goldiprox-hydra folder to the path, so can import things. 

import torch
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
from src.datamodules.datamodules import CIFAR10DataModule, CIFAR10_100MergedDataModule
from src.datamodules.datasets.sequence_datasets import indices_CIFAR10, indices_CIFAR10_100_merged

In [None]:
datamodule = CIFAR10DataModule(batch_size=100, valset_fraction=0.5, data_dir="~/workspace/data")

In [None]:
print(f"valset_fraction=0.5")
datamodule = CIFAR10DataModule(batch_size=100, valset_fraction=0.5, data_dir="~/workspace/data")
print(f"Intersection of train and validation subset: {len(list(set(datamodule.train_subset) & set(datamodule.val_subset)))} datapoints")
print(f"Size of train subset: {len(list(set(datamodule.train_subset)))} datapoints")
print(f"Size of train subset: {len(list(set(datamodule.val_subset)))} datapoints")

In [None]:
print(f"valset_fraction=0.75")
datamodule = CIFAR10DataModule(batch_size=100, valset_fraction=0.75, data_dir="~/workspace/data")
print(f"Intersection of train and validation subset: {len(list(set(datamodule.train_subset) & set(datamodule.val_subset)))} datapoints")
print(f"Size of train subset: {len(list(set(datamodule.train_subset)))} datapoints")
print(f"Size of train subset: {len(list(set(datamodule.val_subset)))} datapoints")

In [None]:
print(f"valset_fraction=0.25")
datamodule = CIFAR10DataModule(batch_size=100, valset_fraction=0.25, data_dir="~/workspace/data")
print(f"Intersection of train and validation subset: {len(list(set(datamodule.train_subset) & set(datamodule.val_subset)))} datapoints")
print(f"Size of train subset: {len(list(set(datamodule.train_subset)))} datapoints")
print(f"Size of train subset: {len(list(set(datamodule.val_subset)))} datapoints")

In [None]:
len(datamodule.train_subset)

In [None]:
len(datamodule.val_subset)

In [None]:
datamodule.train_subset

# First, setup vision dataset so we can get full information about the dataset

In [None]:
datamodule = CIFAR10_100MergedDataModule(batch_size=100)

In [None]:
datamodule.setup()

In [None]:
datamodule = CIFAR10DataModule(batch_size=100)
CIFAR10_vision_dataset = indices_CIFAR10(root="/Users/mrinank/workspace/data", train=True, transform=datamodule.transform)

In [None]:
def imshow(img, title=None):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    
    if title is not None:
        plt.title(title)

In [None]:
imshow(CIFAR10_vision_dataset.__getitem__(75)[1])

# Load Irreducible Losses – note that these irreducible losses are from a pretrained model. 

In [None]:
irred_losses_dict = torch.load("cifar10_irred_losses/irred_losses_and_checks.pt")

# Check images are consistent

In [None]:
plt.figure(figsize=(6, 3), dpi=300)
plt.subplot(121)
imshow(CIFAR10_vision_dataset.__getitem__(irred_losses_dict["idx_of_control_images"][-1])[1])
plt.subplot(122)
imshow(irred_losses_dict["control_images"][-1])

# Investigate distribution of irreducible losses

In [None]:
irred_losses_all = irred_losses_dict['irreducible_losses'].numpy()
global_indices = np.arange(irred_losses_all.size)
targets = irred_losses_dict["sorted_targets"].numpy()
targets_str = np.array([classes[i] for i in targets])

valid_irred_loss_mask = irred_losses_all != 0

irred_losses = irred_losses_all[valid_irred_loss_mask]
irred_losses_global_idxs = global_indices[valid_irred_loss_mask]
targets = targets[valid_irred_loss_mask]
targets_str = targets_str[valid_irred_loss_mask]

In [None]:
plt.figure(figsize=(4, 3), dpi=300)
sns.histplot(irred_losses)
plt.scatter(irred_losses, -100*np.ones_like(irred_losses), s=3, color="tab:blue", marker='d', alpha=0.3)
plt.xscale("log")

In [None]:
plt.figure(figsize=(8, 8), dpi=300)

for p_i, idx in enumerate(np.argsort(-irred_losses)[:25]):
    plt.subplot(5, 5, p_i+1)
    global_index = irred_losses_global_idxs[idx] 
    target = targets_str[idx]
    irred_loss = irred_losses[idx]
    imshow(CIFAR10_vision_dataset.__getitem__(global_index)[1], f"{target}\n: Irred Loss: {irred_loss:.2f}")
    
plt.suptitle("Points with highest irred loss")
plt.tight_layout()

In [None]:
plt.figure(figsize=(8, 8), dpi=300)

for p_i, idx in enumerate(np.argsort(irred_losses)[:25]):
    plt.subplot(5, 5, p_i+1)
    global_index = irred_losses_global_idxs[idx] 
    target = targets_str[idx]
    irred_loss = irred_losses[idx]
    imshow(CIFAR10_vision_dataset.__getitem__(global_index)[1], f"{target}\n: Irred Loss: {irred_loss:.2g}")
    
plt.suptitle("Points with lowest irred loss")
plt.tight_layout()

In [None]:
plt.figure(figsize=(4, 5), dpi=300)
sns.violinplot(x=np.log(irred_losses), y=targets_str, cut=0, inner=None)
plt.title("Irreducible Losses by Class: All")
plt.xlabel("Log irreducible loss")

In [None]:
def percentile_plot(percentile, irred_losses, targets_str):
    q = np.percentile(irred_losses, percentile)
    mask = irred_losses>q
    
    classes, counts = np.unique(targets_str[mask], return_counts=True)
    plt.bar(classes, 100*counts/np.sum(counts))
    plt.title(f"{percentile*100:.0f}th Percentile")
    plt.ylabel("%")
    plt.xticks(rotation="-90")

In [None]:
classes, counts = np.unique(targets_str, return_counts=True)

In [None]:
percentiles = [0, 0.9, 0.95, 0.99]
plt.figure(figsize=(8, 8), dpi=300)

for p_i, p in enumerate(percentiles):
    plt.subplot(2, 2, p_i+1)
    percentile_plot(p, irred_losses, targets_str)

plt.tight_layout()

# Load Irreducible Losses – note that these irreducible losses are from a model trained on the validation set. 

In [None]:
irred_losses_dict = torch.load("cifar10_irred_losses/irred_losses_and_checks_valtrain.pt")

# Check images are consistent

In [None]:
plt.figure(figsize=(6, 3), dpi=300)
plt.subplot(121)
imshow(CIFAR10_vision_dataset.__getitem__(irred_losses_dict["idx_of_control_images"][-1])[1])
plt.subplot(122)
imshow(irred_losses_dict["control_images"][-1])

# Investigate distribution of irreducible losses

In [None]:
irred_losses_all = irred_losses_dict['irreducible_losses'].numpy()
global_indices = np.arange(irred_losses_all.size)
targets = irred_losses_dict["sorted_targets"].numpy()
targets_str = np.array([classes[i] for i in targets])

valid_irred_loss_mask = irred_losses_all != 0

irred_losses = irred_losses_all[valid_irred_loss_mask]
irred_losses_global_idxs = global_indices[valid_irred_loss_mask]
targets = targets[valid_irred_loss_mask]
targets_str = targets_str[valid_irred_loss_mask]

In [None]:
plt.figure(figsize=(4, 3), dpi=300)
sns.histplot(irred_losses)
plt.scatter(irred_losses, -100*np.ones_like(irred_losses), s=3, color="tab:blue", marker='d', alpha=0.3)
plt.xscale("log")

In [None]:
plt.figure(figsize=(8, 8), dpi=300)

for p_i, idx in enumerate(np.argsort(-irred_losses)[:25]):
    plt.subplot(5, 5, p_i+1)
    global_index = irred_losses_global_idxs[idx] 
    target = targets_str[idx]
    irred_loss = irred_losses[idx]
    imshow(CIFAR10_vision_dataset.__getitem__(global_index)[1], f"{target}\n: Irred Loss: {irred_loss:.2f}")
    
plt.suptitle("Points with highest irred loss")
plt.tight_layout()

In [None]:
plt.figure(figsize=(8, 8), dpi=300)

for p_i, idx in enumerate(np.argsort(irred_losses)[:25]):
    plt.subplot(5, 5, p_i+1)
    global_index = irred_losses_global_idxs[idx] 
    target = targets_str[idx]
    irred_loss = irred_losses[idx]
    imshow(CIFAR10_vision_dataset.__getitem__(global_index)[1], f"{target}\n: Irred Loss: {irred_loss:.2g}")
    
plt.suptitle("Points with lowest irred loss")
plt.tight_layout()

In [None]:
plt.figure(figsize=(4, 5), dpi=300)
sns.violinplot(x=np.log(irred_losses), y=targets_str, cut=0, inner=None)
plt.title("Irreducible Losses by Class: All")
plt.xlabel("Log irreducible loss")

In [None]:
def percentile_plot(percentile, irred_losses, targets_str):
    q = np.percentile(irred_losses, percentile)
    mask = irred_losses>q
    
    classes, counts = np.unique(targets_str[mask], return_counts=True)
    plt.bar(classes, 100*counts/np.sum(counts))
    plt.title(f"{percentile*100:.0f}th Percentile")
    plt.ylabel("%")
    plt.xticks(rotation="-90")

In [None]:
classes, counts = np.unique(targets_str, return_counts=True)

In [None]:
percentiles = [0, 0.9, 0.95, 0.99]
plt.figure(figsize=(8, 8), dpi=300)

for p_i, p in enumerate(percentiles):
    plt.subplot(2, 2, p_i+1)
    percentile_plot(p, irred_losses, targets_str)

plt.tight_layout()

# Compare Irreducible Losses

In [None]:
def filter_irred_losses(irred_losses_dict):
    irred_losses_all = irred_losses_dict['irreducible_losses'].numpy()
    global_indices = np.arange(irred_losses_all.size)
    targets = irred_losses_dict["sorted_targets"].numpy()
    targets_str = np.array([classes[i] for i in targets])

    valid_irred_loss_mask = np.logical_and(irred_losses_all != 0, ~np.isnan(irred_losses_all))

    irred_losses = irred_losses_all[valid_irred_loss_mask]
    irred_losses_global_idxs = global_indices[valid_irred_loss_mask]
    targets = targets[valid_irred_loss_mask]
    targets_str = targets_str[valid_irred_loss_mask]
    
    return irred_losses, irred_losses_global_idxs, targets, targets_str

irred_losses_dict_1 = torch.load("cifar10_irred_losses/irred_losses_and_checks.pt")
irred_losses_dict_2 = torch.load("cifar10_irred_losses/irred_losses_and_checks_valtrain.pt")

irred_losses_1, _, _, targets_str_1 = filter_irred_losses(irred_losses_dict_1)
irred_losses_2, _, _, targets_str_2 = filter_irred_losses(irred_losses_dict_2)

assert np.all(targets_str_1 == targets_str_2)

In [None]:
np.corrcoef(irred_losses_1, irred_losses_2)

In [None]:
plt.figure(figsize=(6, 4), dpi=300)
sns.scatterplot(irred_losses_1[:2500], irred_losses_2[:2500], hue=targets_str_1[:2500], s=8, alpha=0.25)
plt.xlabel("Pretrained irreducible loss")
plt.ylabel("Valset irreducible loss")
plt.xscale("log")
plt.yscale("log")
plt.legend(bbox_to_anchor=(1.01, 0.99), loc="upper left", title="classes")
plt.xlim([10**-4, 10**1])
plt.ylim([10**-4, 10**1])
plt.plot([10**-4, 10**1], [10**-4, 10**1], "--k", linewidth=1)
plt.title("Correlation: {np.c}