In [None]:
import pathlib
import torch
from esm import Alphabet, FastaBatchedDataset, pretrained
import random
from collections import Counter
from tqdm import tqdm

import torch
!pip install -q fair-esm

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import esm
import pathlib
import torch

from esm import FastaBatchedDataset, pretrained

def extract_embeddings(model_name, fasta_file, output_dir, tokens_per_batch=4096, seq_length=1022,repr_layers=[33]):
    
    model, alphabet = pretrained.load_model_and_alphabet(model_name)
    model.eval()

    if torch.cuda.is_available():
        model = model.cuda()
        
    dataset = FastaBatchedDataset.from_file(fasta_file)
    batches = dataset.get_batch_indices(tokens_per_batch, extra_toks_per_seq=1)

    data_loader = torch.utils.data.DataLoader(
        dataset, 
        collate_fn=alphabet.get_batch_converter(seq_length), 
        batch_sampler=batches
    )

    output_dir.mkdir(parents=True, exist_ok=True)
    
    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):

            print(f'Processing batch {batch_idx + 1} of {len(batches)}')

            if torch.cuda.is_available():
                toks = toks.to(device="cuda", non_blocking=True)

            out = model(toks, repr_layers=repr_layers, return_contacts=False)

            logits = out["logits"].to(device="cpu")
            representations = {layer: t.to(device="cpu") for layer, t in out["representations"].items()}
            
            for i, label in enumerate(labels):
                entry_id = label.split()[0]
                
                filename = output_dir / f"{entry_id}.pt"
                truncate_len = min(seq_length, len(strs[i]))

                result = {"entry_id": entry_id}
                result["mean_representations"] = {
                        layer: t[i, 1 : truncate_len + 1].mean(0).clone()
                        for layer, t in representations.items()
                    }

                torch.save(result, filename)

model_name = 'esm2_t33_650M_UR50D'
fasta_file = pathlib.Path('../../first_run/BLAST/train_all_toxinpred3.fa')
output_dir1 = pathlib.Path('train_embeddings')

extract_embeddings(model_name, fasta_file, output_dir1)

model_name = 'esm2_t33_650M_UR50D'
fasta_file = pathlib.Path('../../first_run/BLAST//test_all_toxinpred3.fa')
output_dir2 = pathlib.Path('test_embeddings')

extract_embeddings(model_name, fasta_file, output_dir2)

In [1]:
###embeddings to feature_dataframe

In [None]:
#training

TRAIN_FASTA_PATH = "../../../Downloads/TOXIC/first_run/BLAST/train_all_toxinpred3.fa" # Path to P62593.fasta
TRAIN_EMB_PATH = "../../../Downloads/embeddings/train_embeddings_esm" # Path to directory of embeddings for P62593.fasta
EMB_LAYER = 33

Xs_train = []
for header, _seq in esm.data.read_fasta(TRAIN_FASTA_PATH):
    fn = f'{TRAIN_EMB_PATH}/{header}.pt'
    embs = torch.load(fn)
    Xs_train.append(embs['mean_representations'][EMB_LAYER])
Xs_train = torch.stack(Xs_train, dim=0).numpy()

ys_train = pd.read_csv('train_seq1.csv')#['cnrci']
print(len(ys_train))
print(Xs_train.shape)

#testing

TEST_FASTA_PATH = "../../../Downloads/TOXIC/first_run/BLAST/test_all_toxinpred3.fa" # Path to P62593.fasta
TEST_EMB_PATH = "../../../Downloads/embeddings/test_embeddings_esm" # Path to directory of embeddings for P62593.fasta
EMB_LAYER = 33

Xs_test = []
for header, _seq in esm.data.read_fasta(TEST_FASTA_PATH):
    fn = f'{TEST_EMB_PATH}/{header}.pt'
    embs = torch.load(fn)
    Xs_test.append(embs['mean_representations'][EMB_LAYER])
Xs_test = torch.stack(Xs_test, dim=0).numpy()

ys_test = pd.read_csv('test_seq.csv')#['cnrci']
print(len(ys_test))
print(Xs_test.shape)