In [None]:
%load_ext autoreload
%autoreload 2

In [21]:
import logging

import pandas as pd
import torch
import torch.nn.functional as F
# from nn_core.common import PROJECT_ROOT
from os import getcwd
import random

from pathlib import Path

# from la.utils.utils import MyDatasetDict
from datasets import load_dataset


try:
    # be ready for 3.10 when it drops
    from enum import StrEnum
except ImportError:
    from backports.strenum import StrEnum
# from pytorch_lightning import seed_everything
# import matplotlib.pyplot as plt
import random
from collections import namedtuple
# import timm
# from transformers import AutoModel, AutoProcessor
# from typing import Sequence, List
# from PIL.Image import Image
from tqdm import tqdm
# import functools
# from timm.data import resolve_data_config
# from datasets import load_dataset, load_from_disk, Dataset, DatasetDict

# from timm.data import create_transform

# Data loading

In [14]:
# PROJECT_ROOT = '/'.join(getcwd().split('/')[:-1])

/Users/amahmed/Desktop/UMass/Thesis/models/latent-aggregation
/Users/amahmed/Desktop/UMass/Thesis/models/latent-aggregation/notebooks


In [25]:
# DATASET_DIR: Path = PROJECT_ROOT / "data" / "cifar100_tasks"
dataset = load_dataset("cifar100")
dataset

DatasetDict({
    train: Dataset({
        features: ['img', 'fine_label', 'coarse_label'],
        num_rows: 50000
    })
    test: Dataset({
        features: ['img', 'fine_label', 'coarse_label'],
        num_rows: 10000
    })
})

### Check that all the tasks only have the desired number of shared classes

In [24]:
num_tasks = dataset["metadata"]["num_tasks"]
for i in range(num_tasks):
    for j in range(i + 1, num_tasks):
        task_i_classes = set(dataset[f"task_{i}_train"]["fine_label"])
        task_j_classes = set(dataset[f"task_{j}_train"]["fine_label"])

        num_shared_classes = len(task_i_classes.intersection(task_j_classes))
        assert num_shared_classes == dataset["metadata"]["num_shared_classes"]

KeyError: 'metadata'

In [None]:
print(dataset["metadata"])
num_shared_samples = dataset["metadata"]["num_train_samples_per_class"] * dataset["metadata"]["num_shared_classes"]
print(num_shared_samples)

### Reconstruct original space

In [None]:
key = f"task_{0}_train"

# shared samples are the same across all tasks
shared_sample_embeddings = dataset[key]["rexnet_100"][0:num_shared_samples]
all_sample_embeddings = [shared_sample_embeddings]

for i in tqdm(range(num_tasks)):
    key = f"task_{i}_train"

    # (num_task_samples, embedding_dim)
    task_i_novel_embeddings = dataset[key]["rexnet_100"][num_shared_samples:]

    all_sample_embeddings.append(task_i_novel_embeddings)

In [None]:
# (num_samples, embedding_dim)
original_space = torch.cat([torch.Tensor(sample_embedding) for sample_embedding in all_sample_embeddings], dim=0)
print(original_space.shape)

# Obtain anchors

### Get shared samples indices
Get the indices of samples from the shared classes, we will sample anchors only from these ones

In [None]:
shared_classes = set(dataset["metadata"]["shared_classes"])

samples = dataset["task_0_train"]
labels = dataset["task_0_train"]["fine_label"]

shared_indices = []

for ind, sample in tqdm(enumerate(samples)):
    if labels[ind] in shared_classes:
        shared_indices.append(ind)

In [None]:
num_shared_samples = 40000
assert shared_indices == list(range(0, num_shared_samples))

### Get non shared samples indices

In [None]:
non_shared_indices = set(range(len(samples))).difference(shared_indices)
print(len(non_shared_indices))

In [None]:
num_shared_samples = 40000
num_novel_samples = 2500
num_samples_per_task = num_shared_samples + num_novel_samples
assert list(non_shared_indices) == list(range(num_shared_samples, num_samples_per_task))

### Sample anchor indices

In [None]:
num_anchors = 512
shared_anchor_indices = random.sample(shared_indices, num_anchors)

### Select the anchors

In [None]:
anchors = []
embeddings = []

for i in tqdm(range(num_tasks)):
    key = f"task_{i}_train"

    # (num_task_samples, embedding_dim)
    task_i_embeddings = torch.Tensor(dataset[key]["rexnet_100"])

    # (num_anchors, embedding_dim)
    task_i_anchors = task_i_embeddings[shared_anchor_indices]

    embeddings.append(task_i_embeddings)
    anchors.append(task_i_anchors)

print(anchors[0].shape)

### Check that the anchors are the same across tasks

In [None]:
for i in range(num_tasks):
    for j in range(i, num_tasks):
        assert torch.all(torch.eq(anchors[i], anchors[j]))

# Map to relative spaces

In [None]:
relatives = []

for i in range(num_tasks):
    key = f"task_{i}_train"

    abs_space = F.normalize(embeddings[i], p=2, dim=-1)
    norm_anchors = F.normalize(anchors[i], p=2, dim=-1)

    rel_space = abs_space @ norm_anchors.T
    relatives.append(rel_space)

### Divide shared samples and novel samples for each space

In [None]:
num_shared_samples = 40000

shared_samples = []
disjoint_samples = []

for relative in relatives:

    task_i_shared = relative[0:num_shared_samples]
    task_i_disjoint = relative[num_shared_samples:]

    shared_samples.append(task_i_shared)
    disjoint_samples.append(task_i_disjoint)

### Check that the shared samples are the same across tasks

In [None]:
for i in range(num_tasks):
    for j in range(i, num_tasks):
        assert torch.all(torch.eq(shared_samples[i], shared_samples[j]))

### Concat the disjoint samples and the shared samples to go to the merged space

In [None]:
all_disjoint_samples = torch.cat(disjoint_samples, dim=0)
all_disjoint_samples.shape

In [None]:
merged_space = torch.cat((shared_samples[0], all_disjoint_samples), dim=0)
merged_space.shape

In [None]:
merged_space

### Project original space to relative


In [None]:
abs_original_space = F.normalize(original_space, p=2, dim=-1)

original_rel_space = abs_original_space @ norm_anchors.T
print(original_rel_space)

In [None]:
assert torch.allclose(original_rel_space, merged_space, rtol=1e-05, atol=1e-08)