In [1]:
import os
import torch
import pandas as pd
from esm import pretrained, Alphabet, BatchConverter

In [5]:
from typing import List, Tuple, Optional, Dict, NamedTuple, Union, Callable
import itertools
import os
import string
from pathlib import Path

import numpy as np
import torch
from scipy.spatial.distance import squareform, pdist, cdist
import matplotlib.pyplot as plt
import matplotlib as mpl
from Bio import SeqIO
from tqdm import tqdm
import pandas as pd

import esm

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7faa713a2f40>

In [9]:
esm2, esm2_alphabet = esm.pretrained.esm2_t6_8M_UR50D()
esm2 = esm2.eval()
esm2_batch_converter = esm2_alphabet.get_batch_converter()

In [8]:
input_fasta = "/Users/davidm/182-final-proj/data/uniref50.fasta"
output_csv = os.path.join(os.path.dirname(input_fasta), "uniprot50_perplexity.csv")

In [13]:
MAX_SEQUENCES = 10
DEVICE = "cpu"

In [None]:

# ========== Perplexity Function ==========
def compute_perplexity(seq: str):
    data = [("protein", seq)]
    _, _, tokens = esm2_batch_converter(data)
    tokens = tokens.to(DEVICE)

    with torch.no_grad():
        logits = esm2(tokens)["logits"]
    
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    token_log_probs = log_probs[0, :-1].gather(1, tokens[0, 1:].unsqueeze(1)).squeeze()
    perplexity = torch.exp(-token_log_probs.mean()).item()
    return perplexity

# ========== Read Sequences ==========
from Bio import SeqIO
sequences = []
perplexities = []
ids = []

with open(input_fasta) as handle:
    for i, record in enumerate(SeqIO.parse(handle, "fasta")):
        if i >= MAX_SEQUENCES:
            break
        seq = str(record.seq)
        if "X" in seq:
            continue
        try:
            ppl = compute_perplexity(seq)
            sequences.append(seq)
            perplexities.append(ppl)
            ids.append(record.id)
            print(f"[{i}] {record.id}: Perplexity = {ppl:.2f}")
        except Exception as e:
            print(f"Error on sequence {record.id}: {e}")

# ========== Save Results ==========
df = pd.DataFrame({
    "sequence_id": ids,
    "sequence": sequences,
    "perplexity": perplexities
})
df.to_csv(output_csv, index=False)
print(f"\nSaved perplexity CSV to: {output_csv}")
