<a href="https://colab.research.google.com/github/Cognition-And-Vision-Amsterdam-CAVA/UvA2024NeuroAI/blob/main/TutorialDay2_part2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Tutorial Day 2 - SPoSE

by Giacomo Aldegheri

\
\
In this tutorial, we will look at **SPoSE** (Sparse Positive Object Similarity Embeddings) a model proposed by Martin Hebart and colleagues in a [2020 paper](https://www.nature.com/articles/s41562-020-00951-3) and then in a [follow-up paper](https://elifesciences.org/articles/82580) to estimate the dimensions underlying human judgments of image similarities. What does that mean?

They ran a large-scale online behavioral study, in which they simply showed, on each trial, 3 images, and asked participants to indicate which was the odd one out (the image most dissimilar to the others):

<img src='https://drive.google.com/uc?id=1zONI4kRZmsP0jlO-tNgPYy46EhVsHXVL' width=400>

They did this for 1854 object concepts (from the [THINGS database](https://things-initiative.org/)) and 1.46 million trials.

Using this data, they tried to understand the dimensions underlying subjects' similarity judgments. They used the following procedure:

\
<img src='https://drive.google.com/uc?id=1zPHHATd117y1YWGSs4FbRoD9FI5TWGjh' width=750>

\
In essence, they pre-specified a number of embedding dimensions (e.g. 40) and created a simple linear model with a `1854 x 40` (n. concepts x n. dimensions) weight matrix. They fed each concept as a 1854-dimensional one-hot vector to the model, and transformed it into a 40-dimensional embedding.

The embeddings were extracted for each of the three items in a triplet (**step 1** in the figure above), and their pairwise similarity was computed with a dot product (**step 2**). These dot products were turned into choice probabilities using the [softmax function](https://en.wikipedia.org/wiki/Softmax_function) (**step 3**), and the model's choice (pair with the highest probability) was compared to a human subject's choice (**step 4**).

The idea is that among the three pairs, the highest probability (pairwise similarity) should be assigned to the pair that does *not* include the odd-one-out. For example, if from the triplet `glass, bottle, car` I judge the car to be the odd-one-out, the model should predict the pair `glass, bottle` to have a higher similarity than `glass, car` and `bottle, car`. The distance ([cross-entropy loss](https://en.wikipedia.org/wiki/Cross-entropy)) between the model's choice and the correct choice was then [backpropagated](https://en.wikipedia.org/wiki/Backpropagation) to the weights, making them increasingly informative about subjects' judgments.

They also put two additional constraints on the weights:

- They had to be **positive**: the intuition is that each weight should reflect the presence, or absence, of a given feature. E.g. an animal can be more or less furry, but it can't be negative furry.

- They had to be **sparse**: (for any given input, most features' activations should be 0) when a lot of features are present for each object, they are usually not very interpretable. Only a few features should be active for any given object.

Both of these constraints were informed using special loss functions, as we will see below.

The resulting features turn out to reflect interpretable concepts, that reflect the dimensions along which people's internal representations of objects are organized.

\
With all that in mind, time to dive into the code! This tutorial closely follows the [official implementation](https://github.com/ViCCo-Group/SPoSE).



## Import libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
import numpy as np
import os
import re
from tqdm.notebook import tqdm
from typing import Tuple
import random

rand_seed = 123
random.seed(rand_seed)
np.random.seed(rand_seed)
torch.manual_seed(rand_seed)

## Set device

If possible, use the GPU for much faster computation!

In [None]:
device = 'cuda' # @param ['cuda', 'cpu']
if device == 'cuda':
  assert torch.cuda.is_available(), 'GPU not available! Please select a GPU runtime or use CPU.'
device = torch.device(device)

## Mount Google Drive

You should have already added a shortcut to the data folder in your Google Drive during yesterday's tutorial. If not, this is how you do it:

<img src='https://drive.google.com/uc?id=15TNjV__sWCcnBRlxbXNbJfpidx-C6nrk' width=500>

Go to the [folder](https://drive.google.com/drive/folders/1AjDOejWLjfXGkr-hK07SZJ_4ni1nypjw?usp=sharing), right click on its name, and select `Organize -> Add shortcut`. It will add a shortcut to your own Google Drive without the need to copy any data or occupy any storage.

In [None]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)
data_dir = '/content/drive/MyDrive/UvA_encodingtutorial/SPoSE/triplet_dataset/'

## Specify directory to save trained model parameters

After we have trained the model (as well as during training, to see how the weights evolve) we want to save the model's weights.

This should be on your own Google Drive, not in the shared folder, where you don't have writing permissions.

In [None]:
log_dir = 'SPoSE_weights' #@param {type: 'string'}
log_dir = f'/content/drive/MyDrive/{log_dir}/'
if not os.path.isdir(log_dir):
  os.makedirs(log_dir)

In [None]:
# @title Utility functions

def load_triplets(partition:str, data_dir=data_dir):
  if partition == 'train':
    fname = 'trainset.txt'
    n_items = 500000
  elif partition == 'val':
    fname = 'validationset.txt'
    n_items = 20000
  elif partition == 'test':
    fname = 'testset1.txt'
    n_items = 10000

  triplets = np.loadtxt(os.path.join(data_dir, fname))

  return torch.from_numpy(triplets[:n_items]).type(torch.LongTensor)

def accuracy_(probas:np.ndarray) -> float:
    choices = np.where(probas.mean(axis=1) == probas.max(axis=1), -1, np.argmax(probas, axis=1))
    acc = np.where(choices == 0, 1, 0).mean()
    return acc

def choice_accuracy(anchor:torch.Tensor, positive:torch.Tensor, negative:torch.Tensor) -> float:
    similarities  = compute_similarities(anchor, positive, negative)
    probas = F.softmax(torch.stack(similarities, dim=-1), dim=1).detach().cpu().numpy()
    return accuracy_(probas)

def filter_nonneg(W, threshold=0.1):
  W = W*(W>threshold)
  return W

## Dataset

Here, to speed up computations, we will only use a subset of the available. Specifically, this dataset contains 4.12M triplets, but we will only use 500K for training. It's less than 1/8 of the data, but we will see that it works quite well! For validation, we will use 20K.

In [None]:
train_triplets = load_triplets('train')
val_triplets = load_triplets('val')

print('Number of training samples:', len(train_triplets))
print('Number of validation samples:', len(val_triplets))

The data contains 1854 concepts, as we can verify:

In [None]:
n_items = len(torch.unique(train_triplets))
print('N. unique concepts:', n_items)

Let's create a dataset to feed the triplets to our model. First, the unique concept IDs need to be coded as one-hot vectors (1854-dimensional, with 1 for the current concept's entry and 0 elsewhere).

In [None]:
def encode_as_onehot(I:torch.Tensor, triplets:torch.Tensor) -> torch.Tensor:
    """encode item triplets as one-hot-vectors"""
    return I[triplets.flatten(), :]


class TripletDataset(Dataset):

    def __init__(self, n_items:int, dataset:torch.Tensor):
        self.I = torch.eye(n_items)
        self.dataset = dataset

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, idx:int) -> torch.Tensor:
        sample = encode_as_onehot(self.I, self.dataset[idx])
        return sample

## Model

Finally, time to code the actual model! As you can see, it's really simple. Just a single linear layer (only weights, no biases), with the number of concepts as input size and the number of embedding dimensions as output size.

**EXERCISE:** fill in the code for the single linear layer in the model.

In [None]:
class SPoSE(nn.Module):

    def __init__(
                self,
                in_size:int,
                out_size:int,
                init_weights:bool=True,
                ):
        super(SPoSE, self).__init__()
        self.in_size = in_size
        self.out_size = out_size
        # YOUR CODE HERE
        #self.fc = ...

        if init_weights:
            self._initialize_weights()

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        return self.fc(x)

    def _initialize_weights(self) -> None:
        mean, std = .1, .01
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.weight.data.normal_(mean, std)

## Loss functions

And now one of the most important ingredients: the loss functions. First we code the two regularizers: the L1 regularization to enforce sparsity, and the positivity penalty to enforce positive weights.

- For **sparsity**, we use the [L1 loss](https://towardsdatascience.com/intuitions-on-l1-and-l2-regularisation-235f2db4c261): $\sum_{i=1}^n|\textbf{W}_i|$ for each weight $\textbf{W}_i$.

- For **positivity**, we use the sum of all weights less than 0 as a loss term: $\sum_{i=1}^{n}\operatorname{ReLU}(-\textbf{W}_i)$ for each weight $\textbf{W}_i$.

\
**EXERCISE:** fill in the code for the positivity loss.

In [None]:
def l1_regularization(model, device=device) -> torch.Tensor:
    l1_reg = torch.tensor(0., requires_grad=True)
    for n, p in model.named_parameters():
        if re.search(r'weight', n):
            l1_reg = l1_reg + torch.norm(p, 1)
    return l1_reg.to(device)

def pos_penalty(model) -> torch.Tensor:
  W = model.fc.weight
  # YOUR CODE HERE:
  #return ... # positivity constraint to enforce non-negative values in embedding matrix

Now, we code our main loss function, `trinomial_loss`.

It's based on the [cross-entropy](https://en.wikipedia.org/wiki/Cross-entropy) loss:

$$
H(p, q) = -\sum_{x \in \mathcal{X}}p(x)\log{q(x)}
$$

Where $p$ and $q$ are two probability distributions, corresponding to the ground-truth distribution and the model's outputs.

Basically, it makes sure the similarities between the triplets correspond to the choices in the odd-one-out task.

First, we need to compute the similarities between the embeddings (using the dot product).

Then, since in the dataset the chosen pair (that doesn't include the odd-one-out) is always the first in the triplet, we just need to ensure that the model's estimated choice probability/similarity for the first pair is always the highest.

**EXERCISE:** implement the dot product computation for the three pairs: `pos_sim` (anchor, positive), `neg_sim` (anchor, negative) and `neg_sim_2` (positive, negative)

In [None]:
def compute_similarities(anchor:torch.Tensor, positive:torch.Tensor, negative:torch.Tensor) -> Tuple:
    # YOUR CODE HERE
    # pos_sim = ...
    # neg_sim = ...
    # neg_sim_2 = ...

    return pos_sim, neg_sim, neg_sim_2

def weighted_softmax(sims: tuple, t:float) -> torch.Tensor:
  return torch.exp(sims[0] / t) / torch.sum(torch.stack([torch.exp(sim / t) for sim in sims]), dim=0)

def cross_entropy_loss(sims:tuple, t:float) -> torch.Tensor:
    return torch.mean(-torch.log(weighted_softmax(sims, t)))

def trinomial_loss(anchor:torch.Tensor, positive:torch.Tensor, negative:torch.Tensor, t:float) -> torch.Tensor:
  sims = compute_similarities(anchor, positive, negative)
  return cross_entropy_loss(sims, t)

## Hyperparameters

Now we're all set! We just need to set a few hyperparameters:

- `lmbda`: this is the weight of the l1-regularization (we can't use the name `lambda` as it's a Python keyword 😁)
- `temperature`: the temperature for the softmax function.
- `lr`: the learning rate.
- `batch_size`: the batch size for training.
- `embed_dim`: the embedding's dimensionality.
- `n_epochs`: number of training epochs.

In [None]:
lmbda = 0.02
temperature = 1.
lr = 0.001
batch_size = 100
embed_dim = 40
n_epochs = 16

Let's define our model, datasets, dataloaders and the optimizer.

In [None]:
model = SPoSE(in_size=n_items, out_size=embed_dim, init_weights=True)
model = model.to(device)

# Training/validation datasets and dataloaders:
trainset = TripletDataset(n_items, train_triplets)
validationset = TripletDataset(n_items, val_triplets)
trainloader = DataLoader(dataset=trainset, batch_size=batch_size, shuffle=True)
valloader = DataLoader(dataset=validationset, batch_size=batch_size, shuffle=False)

# Optimizer (Adam with default settings)
optim = Adam(model.parameters(), lr=lr)

## Training loop

Ok, now it's time to train the model! With a GPU backend, it should take around 10 minutes.

In [None]:
crossentropies = []
complexity_losses = []
train_losses = []
train_accs = []

for epoch in tqdm(range(n_epochs)):

  model.train()

  # To keep track of the losses
  batch_crossentropies = torch.zeros(len(trainloader))
  batch_complosses = torch.zeros(len(trainloader))
  batch_losses_train = torch.zeros(len(trainloader))
  batch_accs_train = torch.zeros(len(trainloader))

  for i, batch in enumerate(trainloader):
    optim.zero_grad()
    batch = batch.to(device)
    logits = model(batch)

    # separate the three embeddings:
    anchor, positive, negative = torch.unbind(logits, dim=1)

    c_entropy = trinomial_loss(anchor, positive, negative, temperature)
    l1_pen = (lmbda/n_items) * l1_regularization(model, device=device)
    pos_pen = pos_penalty(model)

    # Sum everything into one big loss:
    loss = c_entropy + 0.01 * pos_pen + l1_pen
    loss.backward()
    optim.step()

    batch_losses_train[i] += loss.item()
    batch_crossentropies[i] += c_entropy.item()
    batch_complosses[i] += l1_pen.item()
    batch_accs_train[i] += choice_accuracy(anchor, positive, negative)

  avg_crossentropy = torch.mean(batch_crossentropies).item()
  avg_comploss = torch.mean(batch_complosses).item()
  avg_train_loss = torch.mean(batch_losses_train).item()
  avg_train_acc = torch.mean(batch_accs_train).item()


  ####################################
  # Validation
  ####################################

  val_accs = torch.zeros(len(valloader))
  val_losses = torch.zeros(len(valloader))

  model.eval()

  with torch.no_grad():
    for i, batch in enumerate(valloader):

      batch = batch.to(device)
      logits = model(batch)
      anchor, positive, negative = torch.unbind(logits, dim=1)

      val_loss = trinomial_loss(anchor, positive, negative, temperature)
      val_acc = choice_accuracy(anchor, positive, negative)

      val_losses[i] += val_loss.item()
      val_accs[i] += val_acc.item()

  avg_val_loss = torch.mean(val_losses).item()
  avg_val_acc = torch.mean(val_accs).item()


  print('\n==========================================================')
  print(f'Epoch: {epoch+1}, Train acc: {avg_train_acc:.5f}, Train loss: {avg_train_loss:.5f}, Val acc: {avg_val_acc:.5f}, Val loss: {avg_val_loss:.5f}')
  print('==========================================================\n')

  if (epoch + 1) % 2 == 0:
    # Save the weights every other epoch:
    W = model.fc.weight.detach().cpu().numpy().T
    np.savetxt(os.path.join(log_dir, f'weights_epoch{epoch+1:04d}.txt'), W)

# Save final model
W = model.fc.weight.detach().cpu().numpy().T
np.savetxt(os.path.join(log_dir, 'weights_final.txt'), W)

If all went well, you should have reached ~64% accuracy on the validation set! That's quite good, given that we are using less than 1/8 of the original dataset for training! Now, let's see what the embedding dimensions we have learned look like. Are they interpretable?

## Inspect the learned dimensions

Let's code some visualization utilities.

In [None]:
# @title Visualization utilities

import matplotlib.pyplot as plt
from matplotlib import gridspec
import matplotlib.image as mpimg
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pandas as pd
from glob import glob
from scipy.io import loadmat
from scipy.stats import pearsonr

def plot_weights_across_time(log_dir=log_dir, n_epochs=n_epochs, n_rows=100):

    all_epochs = np.arange(2, n_epochs+2, 2)

    fig = plt.figure(figsize=(15, 5))
    gs = gridspec.GridSpec(1, len(all_epochs)+1, width_ratios=[1]*len(all_epochs) + [0.05])

    allweights = []
    for epoch in all_epochs:
        thisfile = os.path.join(log_dir, f'weights_epoch{epoch:04d}.txt')
        allweights.append(filter_nonneg(np.loadtxt(thisfile)[:n_rows]))
    allweights = np.dstack(allweights)

    vmin = allweights.min()
    vmax = allweights.max()

    axes = []
    for i, epoch in enumerate(all_epochs):
        ax = fig.add_subplot(gs[0, i])
        cax = ax.matshow(allweights[:,:,i], cmap='viridis', vmin=vmin, vmax=vmax)
        ax.set_title(f'Epoch {epoch}')

        # Remove ticks and labels
        ax.set_xticks([])
        ax.set_yticks([])
        ax.tick_params(labelbottom=False, labelleft=False)

        axes.append(ax)

    # Add a colorbar to the right of the last subplot
    cbar_ax = fig.add_subplot(gs[0, -1])
    fig.colorbar(cax, cax=cbar_ax)

    plt.tight_layout()
    plt.show()


def get_top_k(W, dim, k):

  sorted_indices = np.argsort(W[:, dim])
  topk_indices = list(sorted_indices[-k:][::-1])

  return topk_indices

def show_top_concepts(W, dim, k, concept_list, img_dir):

  fig, axes = plt.subplots(1, k, figsize=(k*5,5))
  axes = axes.flatten()

  topk_indices = get_top_k(W, dim, k)
  for ax, i in zip(axes, topk_indices):
    this_concept = concept_list[i]
    concept_dir = glob(os.path.join(img_dir, this_concept+'*'))
    this_img = random.choice(glob(os.path.join(img_dir, this_concept, '*.jpg')))
    this_img = mpimg.imread(this_img)
    ax.imshow(this_img)
    ax.set_title(this_concept.replace('_', ' '), fontsize=22)
    ax.axis('off')

  plt.tight_layout()
  plt.show()

def check_rdm_size(rdmA, rdmB):

  assert rdmA.shape == rdmB.shape, 'RDMs must have the same size!'
  assert rdmA.shape[0] == rdmA.shape[1], 'RDMs must be square!'

def plot_rdms(realrdm, modelrdm):

  fig, axs = plt.subplots(1, 2, figsize=(12, 6))

  # Plot human RDM
  im1 = axs[0].imshow(realrdm, cmap='viridis')
  divider1 = make_axes_locatable(axs[0])
  cax1 = divider1.append_axes("right", size="5%", pad=0.1)
  fig.colorbar(im1, cax=cax1)
  axs[0].set_title('Human RDM', fontsize=22, pad=15)

  # Plot model RDM
  im2 = axs[1].imshow(modelrdm, cmap='viridis')
  divider2 = make_axes_locatable(axs[1])
  cax2 = divider2.append_axes("right", size="5%", pad=0.1)
  fig.colorbar(im2, cax=cax2)
  axs[1].set_title('Model RDM', fontsize=22, pad=15)

  plt.tight_layout()
  plt.show()

def compute_rdm_correlation(rdmA, rdmB):

  check_rdm_size(rdmA, rdmB)

  loweridx = np.tril_indices(rdmA.shape[0], k=-1)

  rdmA = rdmA[loweridx]
  rdmB = rdmB[loweridx]

  return pearsonr(rdmA, rdmB)[0]


## Plot weights across time

We plot, across training epochs, what the weights of our model look like. Can you notice some structure emerging?

In [None]:
plot_weights_across_time(n_epochs=n_epochs)

## Show dimensions

For any of our 40 model dimensions, we want to see examples of concepts that maximally activate them. Do the images have anything in common? You can try to make sense of what the different dimensions correspond to, and perhaps name them.

In [None]:
# Uncomment this if you need to load the model's final weights again
# (e.g. the runtime got disconnected)
#W = np.loadtxt(os.path.join(log_dir, 'weights_final.txt'))

In [None]:
# Get list of concept names and image directory path

things_concepts = pd.read_csv(os.path.join(data_dir, 'things_concepts.tsv'), sep='\t')
concept_list = list(things_concepts['uniqueID'].values)
img_dir = os.path.join(data_dir, 'images')

Pick a few dimensions and plot some example concepts/images from each:

In [None]:
dim = 0 # @param {type:"slider", min:0, max:39, step:1}
show_top_concepts(W, dim=dim, k=5, concept_list=concept_list, img_dir=img_dir)

In [None]:
dim = 29 # @param {type:"slider", min:0, max:39, step:1}
show_top_concepts(W, dim=dim, k=5, concept_list=concept_list, img_dir=img_dir)

In [None]:
dim = 18 # @param {type:"slider", min:0, max:39, step:1}
show_top_concepts(W, dim=dim, k=5, concept_list=concept_list, img_dir=img_dir)

## Compare model and human RDM

As a final test of the model we have learned, we check how well it can predict an RDM obtained from human behavioral judgments. After all, this is an RSA tutorial...

Why is this a non-trivial task? Because in our training data, only a subset of possible concept pairs was "seen" by the model. The model, then, needs to reconstruct the full RDM from a sample. Here, we test it on an RDM of 48x48 concepts.

**NOTE:** we are actually turning the Representational **DIS**similarity Matrix into a Representational **Similarity** Matrix, to directly compare it with the dot products (similarities) generated by our model. We still call it an RDM just to make it more confusing.

**EXERCISE:** from the onehot encodings of the 48 concepts in the RDM, and the weight matrix, compute the embeddings. Then, from the embeddings, compute the matrix of pairwise dot products.

In [None]:
# Load the RDM and turn it into an "RSM":
rdm48 = loadmat(os.path.join(data_dir, 'RDM48_triplet.mat'))['RDM48_triplet']
rdm48 = 1. - rdm48

# Load the list of 48 concepts used for this RDM, so we can feed them to the model:
words48 = [w[0][0] for w in loadmat(os.path.join(data_dir, 'words48.mat'))['words48']]
ids48 = []

for w in words48:
  w = w.replace(' ', '_')
  if w not in concept_list:
    pattern = re.compile(f"^{re.escape(w)}\d+$")
    theseconcepts = [c for c in concept_list if pattern.match(c)]
    w = random.choice(theseconcepts)
  ids48.append(concept_list.index(w))

# Encode them as one-hot vectors:
onehot48 = np.eye(n_items)[np.array(ids48)]

# Get the embeddings by feeding the one-hot vectors
# to the network:
# YOUR CODE HERE
# embeddings48 = ...

# From the 48 x 40 embeddings matrix, get the 40 x 40 matrix
# of rowwise dot products:
# YOUR CODE HERE
# rdm48_model = ...

In [None]:
# Plot the two RDMs side by side
plot_rdms(rdm48, rdm48_model)

They look quite similar! This is promising... let's check how correlated they are.

In [None]:
compute_rdm_correlation(rdm48, rdm48_model)

In conclusion, we can say that:

- SPoSE can learn to accurately predict human choices in a triplet odd-one-out task.

- It does so by generating interpretable dimensions underlying human judgments.

- It is able to generalize and predict full pairwise similarity matrices from a separate experiment.

Really not bad for only having 1854 x 40 linear weights!

