# Visualize Representations Learned by Matryoshka Representation Learning

In [None]:
from argparse import Namespace

import os

from datasets import load_dataset
from datasets import Dataset

import torch

from torchvision import transforms
import wandb

import pandas as pd
import numpy as np

from sklearn.model_selection import StratifiedKFold

In [None]:
class AnimalClassifierMatryoshka(torch.nn.Module):

    def __init__(self, in_channels: int, dims: int,
                 num_labels: int):
        super().__init__()

        self.conv_1 = torch.nn.Conv2d(
            in_channels, dims, kernel_size=12)
        self.max_pool_1 = torch.nn.MaxPool2d(kernel_size=3)

        self.conv_2 = torch.nn.Conv2d(
            dims, 2*dims, kernel_size=5)
        self.max_pool_2 = torch.nn.MaxPool2d(kernel_size=3)

        self.conv_3 = torch.nn.Conv2d(
            2*dims, 2*dims, kernel_size=3)
        self.max_pool_3 = torch.nn.MaxPool2d(kernel_size=2)

        self.conv_4 = torch.nn.Conv2d(
            2*dims, 2*dims, kernel_size=3)
        self.max_pool_4 = torch.nn.MaxPool2d(kernel_size=2)

        self.flatten = torch.nn.Flatten()
        self.projection = torch.nn.LazyLinear(4*dims)

        num_subsets = int(np.log2(dims)) + 2
        # print(f"Number of subsets: {num_subsets} - Dimensions: {4*dims}")

        self.linear_layers = torch.nn.ModuleList()
        for i in range(3, num_subsets+1):
            # print(f"Number of dimensions: {2**i}")
            self.linear_layers.append(
                torch.nn.Linear(2**i, num_labels))

    def forward(self, x: torch.Tensor):
        """
        Forward pass
        """

        x_ = self.conv_1(x)
        x_ = self.max_pool_1(x_)
        # print(f"Output of conv & max pool 1: {x_.shape}")

        x_ = self.conv_2(x_)
        x_ = self.max_pool_2(x_)
        # print(f"Output of conv & max pool 2: {x_.shape}")

        x_ = self.conv_3(x_)
        x_ = self.max_pool_3(x_)
        # print(f"Output of conv & max pool 3: {x_.shape}")

        x_ = self.conv_4(x_)
        x_ = self.max_pool_4(x_)
        # print(f"Output of conv & max pool 4: {x_.shape}")

        x_ = self.flatten(x_)
        # print(f"Output of flatten: {x_.shape}")

        x_ = self.projection(x_)
        # print(f"Output of projection: {x_.shape}")

        output = []
        for i, layer in enumerate(self.linear_layers):
            # print(f"Getting slice: {2**(3+i)} - {x_[:, 0:2**(3+i)].shape}")
            x__ = layer(x_[:, 0:2**(3+i)])
            # print(f"Subset output: {x__.shape}")
            output.append(x__)

        output = torch.stack(output, dim=1)
        # print(f"Final output: {output.shape}")
        return output, x_

def create_model(in_dimensions: int, dims: int, num_labels: int):
    """
    Create model
    """

    model = AnimalClassifierMatryoshka(in_dimensions, dims, num_labels)
    return model

In [None]:
SEED = 1

DEVICE = torch.device(
    'cuda' if torch.cuda.is_available() \
        else 'mps' if torch.backends.mps.is_available() else 'cpu')
# DEVICE = torch.device('cpu')

CONFIG = Namespace(
    run_name='animal-classifier',
    image_size=256,
    hidden_dims=256,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=1,
    seed=1
    )
CONFIG.device = DEVICE

MODEL_ARTIFACT = './artifacts/animal-classifier-model-v1:v2'
if not os.path.exists(MODEL_ARTIFACT):
    run = wandb.init(project='Animal-Classifier', entity=None,
                     job_type='visualize',
                     name=CONFIG.run_name,
                     config=CONFIG)
    artifact = run.use_artifact('pkthunder/Animal-Classifier/animal-classifier-model-v1:v2', type='model')
    artifact_dir = artifact.download()

In [None]:
def prepare_dataloader(config: Namespace):
    """
    Prepare dataloader
    """

    preprocess = transforms.Compose(
        [
            transforms.Resize((config.image_size, config.image_size)),  # Resize
            # transforms.RandomHorizontalFlip(p=config.horizontal_flip_prob),
            # transforms.GaussianBlur(kernel_size=config.gaussian_blur_kernel_size),
            transforms.ToTensor(),  # Convert to tensor (0, 1)
            transforms.Normalize([0.5], [0.5]),  # Map to (-1, 1)
        ])
    
    # For pre-processing original image for visualization in W&Bs
    preprocess_original = transforms.Compose(
        [
            transforms.Resize((512, 512)),  # Resize
            transforms.ToTensor(),  # Convert to tensor (0, 1)
        ])


    # Load dataset
    dataset = load_dataset('cats_vs_dogs')
    # Remove images that are 100x100 or below.
    dataset = \
        dataset.filter(
            lambda example: example['image'].size[0] > 100 and example['image'].size[1] > 100)

    def transform(examples):
        images = [preprocess(image.convert('RGB')) for image in examples['image']]
        original_images = [
            preprocess_original(image.convert('RGB')) \
                for image in examples['image']]

        return {'image': images,
                'label': examples['labels'],
                'original-image': original_images
                }

    # Split dataset into train + val. Balance train + val
    num_points = len(dataset['train'])
    labels = dataset['train']['labels']

    split_df = pd.DataFrame()
    split_df['labels'] = labels
    split_df['id'] = list(range(num_points))
    split_df['fold'] = -1

    cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=config.seed)
    for i, (_, test_ids) in enumerate(cv.split(np.zeros(num_points), labels)):
        split_df.loc[test_ids, ['fold']] = i

    split_df['split'] = 'train'
    split_df.loc[split_df.fold == 0, ['split']] = 'val'

    # print(split_df[split_df['split'].str.fullmatch('train')].labels.value_counts())
    # print(split_df[split_df['split'].str.fullmatch('val')].labels.value_counts())

    train_indices = split_df[split_df['split'].str.fullmatch('train')]['id']
    val_indices = split_df[split_df['split'].str.fullmatch('val')]['id']

    def train_generator():
        for idx in train_indices:
            yield dataset['train'][idx]

    def val_generator():
        for idx in val_indices:
            yield dataset['train'][idx]

    train_dataset = Dataset.from_generator(train_generator)
    val_dataset = Dataset.from_generator(val_generator)

    train_dataset.set_transform(transform)
    val_dataset.set_transform(transform)

    train_gen = torch.Generator().manual_seed(config.seed)
    val_gen = torch.Generator().manual_seed(config.seed)

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=config.per_device_train_batch_size,
        shuffle=True, generator=train_gen)
    
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset, batch_size=config.per_device_eval_batch_size,
        shuffle=True, generator=val_gen)

    return train_dataloader, val_dataloader

In [None]:
MODEL = create_model(3, CONFIG.hidden_dims, 2)
MODEL.load_state_dict(torch.load(f"{MODEL_ARTIFACT}/model.pt", map_location=torch.device('cpu')))
MODEL.to(CONFIG.device)
MODEL.eval()

# We are going to visualize the validation data
_, val_dataloader = prepare_dataloader(CONFIG)

In [None]:
@torch.no_grad()
def get_embeddings(model: AnimalClassifierMatryoshka, dataloader):
    """
    Get embeddings from feature encoder backbone
    """

    features_by_subset = {}
    inst_labels = []

    # Initialize subset buckets
    for i in range(len(model.linear_layers)):
        features_by_subset[i] = []

    for _, batch in enumerate(dataloader):
        _, features = model(batch['image'].to(CONFIG.device))
        label = batch['label']

        for i in range(len(model.linear_layers)):
            features_by_subset[i].append(features[:, 0:2**(3+i)].cpu())
        animal_label = 'cat' if label[0].item() == 0 else 'dog'
        inst_labels.append(animal_label)

    for i in range(len(model.linear_layers)):
        features_by_subset[i] = torch.concat(features_by_subset[i], axis=0)
        print(f"Shape of features for subset {i}: {features_by_subset[i].shape}")

    return features_by_subset, inst_labels

In [None]:
from sklearn.preprocessing import normalize

features_by_subset, inst_labels = get_embeddings(MODEL, val_dataloader)
for subset_id, features in features_by_subset.items():
    features_by_subset[subset_id] = normalize(features)

# # L2 Norm the embeddings
# pretrained_embeddings = normalize(pretrained_embeddings)
# finetuned_embeddings = normalize(finetuned_embeddings)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme(font_scale=2.5)
plt.set_cmap('tab20')

from sklearn.decomposition import PCA

In [None]:
color_map = plt.get_cmap('tab20')

unique_labels = []
for label in inst_labels:
    if label not in unique_labels:
        unique_labels.append(label)

# ax = fig.add_subplot(111, projection='3d')
for subset_id, features in features_by_subset.items():

    pca = PCA(n_components=2)
    pca_embedding = pca.fit_transform(features)
    df = pd.DataFrame(pca_embedding, columns=['pca1', 'pca2'])
    df['label'] = inst_labels
    df['subset_id'] = subset_id

    fig = plt.figure(figsize=(13, 10))
    ax = fig.add_subplot(111)
    num_labels = len(unique_labels)
    for i, label in enumerate(unique_labels):
        label_df = df[df.label == label]
        ax.scatter(label_df.pca1, label_df.pca2, label=str(label), c=color_map(i))

    ax.set_xlabel("PCA Dimension 1")
    ax.set_ylabel("PCA Dimension 2")
    # ax.set_zlabel("PCA Dimension 3")
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.7, box.height])
    ax.legend(loc='center left', bbox_to_anchor=(1.00, 0.5))
    fig.savefig(f"subset-{subset_id+1}-features.png", bbox_inches='tight')
    plt.close(fig)