We're going to do a visual comparison of the ProtHash and ESMC embeddings as a sanity check. We're expecting that, if ProtHash successfully learned from its ESMC teacher, the two embeddings will be nearly identical in terms of their 2D plots. Now it's not exactly comparing apples to apples when you have to take two high-dimensional embeddings of likely different dimensions and reduce them both down to only two dimensions - but this is just a sanity check.

Let's kick this party off by defining some configuration variables.

In [1]:
from torch.cuda import is_available as cuda_is_available
from torch.backends.mps import is_available as mps_is_available

min_sequence_length=1
max_sequence_length=2048
num_samples=1000
batch_size=32

teacher_model_name="esmc_300m"

checkpoint_path="checkpoints/checkpoint.pt"

device="cuda" if cuda_is_available() else "mps" if mps_is_available() else "cpu"

Then, we'll load the ESM protein sequence tokenizer and the SwissProt dataset.

In [2]:
from esm.tokenization import EsmSequenceTokenizer

from data import SwissProt

from torch.utils.data import Subset, DataLoader

tokenizer = EsmSequenceTokenizer()

dataset = SwissProt(
    tokenizer=tokenizer,
    min_sequence_length=min_sequence_length,
    max_sequence_length=max_sequence_length,
)

dataset = Subset(dataset, range(num_samples))

dataloader = DataLoader(
    dataset, batch_size=batch_size, shuffle=False, collate_fn=dataset.dataset.collate_pad_right
)

Next we'll load the teacher model, ESMC, from its pretrained weights.

In [3]:
from esm.models.esmc import ESMC

teacher = ESMC.from_pretrained(teacher_model_name)

teacher = teacher.to(device)

teacher.eval()

print("Teacher model loaded successfully")

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Teacher model loaded successfully


Now you've made it this far it's time for some fun. Let's go down and dirty and load one of our ProtHash model checkpoints into memory.

In [None]:
import torch

from src.prothash.model import ProtHash

checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)

student = ProtHash(**checkpoint["model_args"])



student.load_state_dict(checkpoint["model"])

student.remove_adapter_head()

student = student.to(device)

student.eval()

print("Model checkpoint loaded successfully")

RuntimeError: Error(s) in loading state_dict for ProtHash:
	Unexpected key(s) in state_dict: "head.linear.weight", "head.linear.bias". 

You've made it this far there's no turning back. It's literally life or death from here on out. Next we'll be embedding a subset of the SwissProt dataset with both models. I'll know if you turned back from here.

In [None]:
student_embeddings = []
teacher_embeddings = []

for x in dataloader:
    x = x.to(device)

    with torch.no_grad():
        out_teacher = teacher.forward(x)
        y_teacher = out_teacher.hidden_states[-1][:, 0, :]

    y_student = student.forward(x)

    student_embeddings.append(y_student.cpu())
    teacher_embeddings.append(y_teacher.cpu())

assert len(student_embeddings) == len(teacher_embeddings)

student_embeddings = torch.cat(student_embeddings, dim=0)
teacher_embeddings = torch.cat(teacher_embeddings, dim=0)

In [None]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

pca = PCA(n_components=32)

tsne = TSNE(n_components=2)

s = student_embeddings.numpy()
t = teacher_embeddings.numpy()

s = pca.fit_transform(s)
t = pca.fit_transform(t)

s = tsne.fit_transform(s)
t = tsne.fit_transform(t)

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 6))

axes[0].scatter(s[:, 0], s[:, 1], s=5, alpha=0.7)
axes[0].set_title('Student embeddings (2D)')
axes[1].scatter(t[:, 0], t[:, 1], s=5, alpha=0.7, color='orange')
axes[1].set_title('Teacher embeddings (2D)')

for ax in axes:
    ax.set_xlabel('dim 1')
    ax.set_ylabel('dim 2')

plt.tight_layout()
plt.show()