In [2]:
from transformers import AutoTokenizer, EsmModel
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
import os
import torch.nn as nn
import torch.optim as optim

model_name = "esm2_t6_8M_UR50D"
# os.mkdir(f"embeddings/{model_name}")
max_seq_len = 2000
max_tax_len = 40
num_taxonomy_ids = 4118  # Example: Assuming 14,680 possible taxonomy classes
batch_size = 64

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
df = pd.read_csv("taxonomy_data.csv")[:10000]
lens = df['Sequence'].apply(len)
df = df[lens < max_seq_len]
print(df.shape)

(9640, 3)


In [4]:
def list_encoder(s):
    return [int(i) for i in s.replace("[", "").replace("]", "").split(", ")]

# Example list of taxonomy IDs for multiple sequences
taxonomy_ids_list = [list_encoder(tax_id_str) for tax_id_str in df['tree trace'].values]

set_list = [set(t) for t in taxonomy_ids_list]
union_set = set().union(*set_list)
print("Set Len: ", len(union_set))

# Flatten the list of taxonomy IDs and get unique taxonomy IDs
all_taxonomy_ids = set([tax_id for sublist in taxonomy_ids_list for tax_id in sublist])
# Create a mapping from taxonomy ID to index
taxonomy_id_to_idx = {tax_id: idx for idx, tax_id in enumerate(all_taxonomy_ids)}
taxonomy_idx_to_id = {idx: tax_id for idx, tax_id in enumerate(all_taxonomy_ids)}

# Apply the mapping to each list of taxonomy IDs
mapped_taxonomy_ids_list = [[taxonomy_id_to_idx[tax_id] for tax_id in tax_ids] + [0 for _ in range(max_tax_len - len(tax_ids))] for tax_ids in taxonomy_ids_list]

# Character vocabulary for protein sequences (20 amino acids + 1 padding)
vocab = "ACDEFGHIKLMNPQRSTVWY"
char_to_idx = {char: idx + 1 for idx, char in enumerate(vocab)}  # Start index from 1 for padding

# Sequence encoder: Convert the protein sequence into integers
def encode_sequence(sequence):
    return [char_to_idx.get(char, 0) for char in sequence] + [0 for _ in range(max_seq_len - len(sequence))]  # 0 for unknown characters or padding 

def encode_taxonomy(taxonomy):
    return [1 if _ in taxonomy else 0 for _ in range(num_taxonomy_ids)]

Set Len:  4118


In [5]:
class ProteinDataset(Dataset):
    def __init__(self, sequences, taxonomy_ids):
        self.sequences = sequences
        self.taxonomy_ids = taxonomy_ids

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        taxonomy = self.taxonomy_ids[idx]

        # Encode sequence and taxonomy (example encoding for demonstration)
        sequence_encoded = torch.tensor(encode_sequence(sequence), dtype=torch.long)
        taxonomy_encoded = torch.tensor(encode_taxonomy(taxonomy), dtype=torch.long)

        return sequence_encoded, taxonomy_encoded

In [6]:
# Split the dataset into training, validation, and test sets
train_sequences, test_sequences, train_taxonomy, test_taxonomy = train_test_split(df['Sequence'].values, mapped_taxonomy_ids_list, test_size=0.2, random_state=42)
train_sequences, val_sequences, train_taxonomy, val_taxonomy = train_test_split(train_sequences, train_taxonomy, test_size=0.25, random_state=42)

# Create Dataset objects
train_dataset = ProteinDataset(train_sequences, train_taxonomy)
val_dataset = ProteinDataset(val_sequences, val_taxonomy)
test_dataset = ProteinDataset(test_sequences, test_taxonomy)

# Create DataLoader objects for batching
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Example usage of the DataLoader
for sequences_batch, taxonomy_batch in train_loader:
    print(f"Sequences batch shape: {sequences_batch.shape}")
    print(f"Taxonomy batch shape: {taxonomy_batch.shape}")
    break  # Print one batch and exit

Sequences batch shape: torch.Size([64, 2000])
Taxonomy batch shape: torch.Size([64, 4118])


In [None]:
class ProteinDataset(Dataset):
    def __init__(self, dataframe):
        self.data = dataframe

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sequence = self.data.iloc[idx]['Sequence']
        entry = self.data.iloc[idx]['Entry']
        return entry, sequence

# Instantiate dataset and dataloader with batch size
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
dataset = ProteinDataset(train_df)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

tokenizer = AutoTokenizer.from_pretrained(f"facebook/{model_name}")
model = EsmModel.from_pretrained(f"facebook/{model_name}").to("cuda:1")

# Iterate through batches of data
for batch_idx, batch in enumerate(dataloader):
    entries, sequences = batch
    
    # Check if the embeddings already exist, and if not, process them
    # if f"{batch_idx}.pt" not in os.listdir(f"embeddings/{model_name}"):
    print(f"Processing: {batch_idx}")

    # Tokenize the batch of sequences and move inputs to GPU
    inputs = tokenizer(sequences, return_tensors="pt", padding=True).to("cuda:1")

    # Forward pass through the model
    outputs = model(**inputs).last_hidden_state

    # Compute mean across the sequence dimension (or any other pooling method)
    output_embeddings = outputs.mean(dim=1).cpu()  # Move back to CPU
    print(output_embeddings.shape)

    # Save the embeddings
    torch.save(output_embeddings, f"embeddings/{model_name}/{batch_idx}.pt")
    
    # Clear cache after each batch to avoid memory overflow
    torch.cuda.empty_cache()

    print(f"Batch {batch_idx + 1}/{len(dataloader)} processed.")

In [7]:
# Create the model with an Attention Layer
class SimpleAttentionClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_taxonomy_ids):
        super(SimpleAttentionClassifier, self).__init__()
        # Embedding layer for sequences
        self.sequence_embedding = nn.Embedding(vocab_size, embedding_dim)
        # Attention mechanism
        self.attention = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=4, batch_first=True)
        # Fully connected layer for predicting taxonomy
        self.fc = nn.Linear(embedding_dim, num_taxonomy_ids)
        
        self.sigmoid = nn.Sigmoid()

    def forward(self, sequences):
        # Embed the input sequences
        embedded_seq = self.sequence_embedding(sequences)  # (batch_size, seq_len, embed_dim)
        
        # Attention mechanism (self-attention here)
        attn_output, _ = self.attention(embedded_seq, embedded_seq, embedded_seq)  # (batch_size, seq_len, embed_dim)
        
        # Mean pooling across the sequence length dimension
        attn_output = attn_output.mean(dim=1)  # (batch_size, embed_dim)
        
        # Pass through a fully connected layer to predict taxonomy IDs
        output = self.fc(attn_output)  # (batch_size, num_taxonomy_ids)
        
        return output

In [40]:
# Hyperparameters
vocab_size = len(vocab) + 1  # +1 for padding
embedding_dim = 16
hidden_dim = 64
num_taxonomy_ids = 4118  # Example: Assuming 14,680 possible taxonomy classes
num_epochs = 10
device = "cuda:1"

# Initialize the model, optimizer, and loss function
model = SimpleAttentionClassifier(vocab_size=vocab_size, embedding_dim=embedding_dim, hidden_dim=hidden_dim, num_taxonomy_ids=num_taxonomy_ids).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss()  # Multi-label classification requires BCEWithLogitsLoss

In [41]:
from tqdm import tqdm

for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0

    for batch_idx, (sequences, taxonomy_ids) in tqdm(enumerate(train_loader)):
        # print("Running Batch idx:", batch_idx, len(train_loader))
        
        sequences = sequences.to(device)
        taxonomy_ids = taxonomy_ids.to(device)

        outputs = model(sequences)
        
        # Calculate the loss
        loss = criterion(outputs, taxonomy_ids.float())

        # Backpropagation: Zero the gradients, compute the backward pass, and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Track the loss
        running_loss += loss.item()

    # Print loss for this epoch
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {epoch_loss:.4f}")
    
    # Evaluate the model on the test set
    val_loss, val_accuracy, val_f1 = evaluate(model, val_loader, criterion)
    print(f"val Loss: {val_loss:.4f}, val Accuracy: {val_accuracy:.4f}, val F1 Score: {val_f1:.4f}")

91it [00:15,  5.85it/s]


Epoch [1/10], Train Loss: 0.3698
2 33154 33208 33213 131567 2759 6072
2 131567 2157 2258 2259 2260 183968 28890 53953 70601
val Loss: 0.0319, val Accuracy: 0.9949, val F1 Score: 0.3289


91it [00:15,  6.00it/s]


Epoch [2/10], Train Loss: 0.0202
2 33154 33208 33213 131567 2759 6072
2 131567 2157 2258 2259 2260 183968 28890 53953 70601
val Loss: 0.0168, val Accuracy: 0.9949, val F1 Score: 0.3376


91it [00:15,  6.03it/s]


Epoch [3/10], Train Loss: 0.0162
2 33154 33208 33213 131567 2759 6072
2 131567 2157 2258 2259 2260 183968 28890 53953 70601
val Loss: 0.0158, val Accuracy: 0.9949, val F1 Score: 0.3406


91it [00:14,  6.14it/s]


Epoch [4/10], Train Loss: 0.0156
2 33154 33208 33213 131567 2759 6072
2 131567 2157 2258 2259 2260 183968 28890 53953 70601
val Loss: 0.0155, val Accuracy: 0.9949, val F1 Score: 0.3408


91it [00:15,  6.03it/s]


Epoch [5/10], Train Loss: 0.0153
2 33154 33208 33213 131567 2759 6072
2 131567 2157 2258 2259 2260 183968 28890 53953 70601
val Loss: 0.0154, val Accuracy: 0.9949, val F1 Score: 0.3381


91it [00:15,  5.94it/s]


Epoch [6/10], Train Loss: 0.0152
2 33154 33208 33213 131567 2759 6072
2 131567 2157 2258 2259 2260 183968 28890 53953 70601
val Loss: 0.0153, val Accuracy: 0.9949, val F1 Score: 0.3420


91it [00:15,  5.97it/s]


Epoch [7/10], Train Loss: 0.0151
2 33154 33208 33213 131567 2759 6072
2 131567 2157 2258 2259 2260 183968 28890 53953 70601
val Loss: 0.0152, val Accuracy: 0.9949, val F1 Score: 0.3415


91it [00:15,  6.05it/s]


Epoch [8/10], Train Loss: 0.0151
2 33154 33208 33213 131567 33511 2759 117570 117571 6072 89593 7711 7742 7776 32523
2 131567 2157 2258 2259 2260 183968 28890 53953 70601
val Loss: 0.0153, val Accuracy: 0.9951, val F1 Score: 0.4784


91it [00:15,  6.06it/s]


Epoch [9/10], Train Loss: 0.0150
2 33154 33208 33213 131567 2759 6072
2 131567 2157 2258 2259 2260 183968 28890 53953 70601
val Loss: 0.0152, val Accuracy: 0.9949, val F1 Score: 0.3414


91it [00:15,  5.97it/s]


Epoch [10/10], Train Loss: 0.0150
2 33154 33208 33213 131567 2759 6072
2 131567 2157 2258 2259 2260 183968 28890 53953 70601
val Loss: 0.0152, val Accuracy: 0.9949, val F1 Score: 0.3422


In [38]:
from sklearn.metrics import accuracy_score, f1_score

def evaluate(model, test_loader, criterion, device='cuda:1'):
    model.eval()  # Set model to evaluation mode
    running_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():  # Disable gradient computation during evaluation
        for batch_idx, (sequences, taxonomy_ids) in enumerate(test_loader):
            sequences = sequences.to(device)
            taxonomy_ids = taxonomy_ids.to(device)

            outputs = model(sequences)

            # Calculate loss
            loss = criterion(outputs, taxonomy_ids.float())
            running_loss += loss.item()
            
            # Convert model outputs to binary predictions (e.g., threshold = 0.5)
            preds = torch.sigmoid(outputs) > 0.6  # Binary predictions

            all_preds.append(preds.cpu())
            all_labels.append(taxonomy_ids.cpu())
    
    # Concatenate all batches into single tensors
    all_preds = torch.cat(all_preds).int()
    # print(all_preds[0])
    all_labels = torch.cat(all_labels)
    # print(all_labels[0])
    
    for p, l in zip(all_preds, all_labels):
        print(*[taxonomy_idx_to_id[i] for i, p_ in enumerate(p) if p_])
        print(*[taxonomy_idx_to_id[i] for i, l_ in enumerate(l) if l_])
        break
    
    # Compute evaluation metrics (example: accuracy, F1 score)
    accuracy = accuracy_score(all_labels.numpy(), all_preds.numpy())
    
    accuracies = [accuracy_score(p, l) for p, l in zip(all_preds, all_labels)]
    accuracy = np.mean(accuracies)
    
    f1 = f1_score(all_labels.numpy(), all_preds.numpy(), average='micro')  # F1-score for multi-label classification
    avg_loss = running_loss / len(test_loader)
    
    return avg_loss, accuracy, f1