# Pairwise cosine similarities between embeddings throughout training

Fair warning, my kernel kept dying randomly so... beware.

In [2]:
# download embeds. only run this if you don't already have them
from huggingface_hub import snapshot_download
snapshot_download(repo_id="amphora/pythia-12b-weights", 
                  local_dir='pythia-12b-weights',
                  local_dir_use_symlinks=False,
                  allow_patterns="embed_only_0-29000.pkl")

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

embed_only_0-29000.pkl:   0%|          | 0.00/41.5G [00:00<?, ?B/s]

'/home/pyn/pythia-embedding-analysis/notebooks/pythia-12b-weights'

In [1]:
import torch
import torch.nn as nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
import matplotlib.pyplot as plt
import numpy as np

cuda


Here we actually do the pairwise processing of cos sim for everything at once using pytorch GPU which is faster than iteratively sklearn computing it.

In [2]:
def cos_sim(embed_matrix):
    """
    Takes in embedding matrix at any step. ensure it is on cuda for MAXIMUM EFFICIENCY.
    Expected that the output cosine similarity matrix is of torch.Size([50688, 50688]) and norms matrix is torch.Size([50688])
    """
    weights = embed_matrix
    norms = weights.norm(dim=1)
    embedding_matrix_norm = weights / norms[:, None]
    cosine_similarity = torch.mm(embedding_matrix_norm, embedding_matrix_norm.transpose(0,1))
    return cosine_similarity, norms

In [None]:
import pickle
from pathlib import Path

Path("./pairwise-sims/").mkdir(parents=True, exist_ok=True)

files = [
    './pythia-12b-weights/embed_only_0-29000.pkl',
    #'./pythia-12b-weights/embed_only_30000-69000.pkl',
    #'./pythia-12b-weights/embed_only_70000-109000.pkl',
    #'./pythia-12b-weights/embed_only_110000-143000.pkl'
]

for filename in files:
    with open(filename, 'rb') as in_file:
        data = pickle.load(in_file)
        for step_idx, embeddings in data.items():
            sim_matrix, norms = cos_sim(torch.from_numpy(embeddings).to(device))
            sim_matrix, norms = sim_matrix.to('cpu'), norms.to('cpu')
            outfile_name = f"./pairwise-sims/{step_idx}.pkl"
            with open(outfile_name, "wb") as outfile:
                pickle.dump((sim_matrix, norms), outfile)

Helper function to read the pickle file we dumped.

In [None]:
import pickle
import os
dir = './pairwise-sims/'

for filename in os.listdir(dir):
    filepath = os.path.join(dir, filename)
    print(filepath)
    with open(filepath, 'rb') as infile:
        sim_matrix, norms = pickle.load(infile)
        print(sim_matrix, norms)