In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, EsmForSequenceClassification, set_seed
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from tqdm import tqdm
import numpy as np
import ast
import pandas as pd
import random
import re
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import os
from transformers import EsmModel, AutoTokenizer
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # checks if CUDA-compatible GPU is available, if yes moves to CUDA
dataset_path = "drive/MyDrive/takehome_dataset.csv"
df = pd.read_csv(dataset_path)
print(df["classification"].unique())
df.head()

['HYDROLASE' 'IMMUNE SYSTEM' 'ISOMERASE' 'LYASE' 'OXIDOREDUCTASE'
 'SIGNALING PROTEIN' 'TRANSCRIPTION' 'TRANSFERASE' 'TRANSPORT PROTEIN'
 'VIRAL PROTEIN']


Unnamed: 0.1,Unnamed: 0,structureId,sequence,classification
0,0,2XJF,MSTSWSDRLQNAADMPANMDKHALKKYRREAYHRVFVNRSLAMEKI...,HYDROLASE
1,1,4XXO,MEASPASGPRHLMDPHIFTSNFNNGIGRHKTYLCYEVERLDNGTSV...,HYDROLASE
2,2,4UOJ,MRSMRALSLSLSIFAGAAVATTESTGKASIHDLALQKWTVTNEYGN...,HYDROLASE
3,3,1LQW,MLTMKDIIRDGHPTLRQKAAELELPLTKEEKETLIAMREFLVNSQD...,HYDROLASE
4,4,2VBO,MNTKYNKEFLLYLAGFVDGDGSIIAQIKPNQSYKFKHQLSLTFQVT...,HYDROLASE


In [4]:
total_classes = df["classification"].astype(str) # getting clasisfication column from dataframe and ensuring type string
classes = np.unique(total_classes) # finding unique classifications

In [5]:
# Set random seeds for reproducibility of your trainings run
def set_seeds(s):
    torch.manual_seed(s)
    np.random.seed(s)
    random.seed(s)
    set_seed(s)


# Set all random seeds
seed= 42      #random seed
set_seeds(seed)

In [6]:
def clean_sequence(seq):
    return re.sub(r'[^ACDEFGHIKLMNPQRSTVWY]', 'X', seq)

df["sequence"] = df["sequence"].apply(clean_sequence) # apply function to all sequences in sequence column


In [7]:
# need labels to be numerical
# converts categorical labels to integer labels using label encoder from scikit-learn, which encodes labels from 0 to n_classes-1
def convert_labels(labels):
  label_encoder = LabelEncoder() #create new label encoder
  int_labels = label_encoder.fit_transform(labels) #fit to data and transform to integers
  return int_labels

In [8]:
#convert classes in "classification" column to integer labels using convert_labels function
int_labels = convert_labels(df["classification"].astype(str).tolist()) # ensure all classes are strings then convert to list for use in the function
df['labels']= int_labels # create new column called 'labels' that contains the integer labels

In [9]:
train_sequences, all_test_sequences, train_labels, all_test_labels = train_test_split(
    df['sequence'], df['labels'], test_size=0.2, random_state=seed)

train_sequences = train_sequences.reset_index(drop=True)
train_labels = train_labels.reset_index(drop=True)

test_sequences, validation_sequences, test_labels, validation_labels = train_test_split(
    all_test_sequences, all_test_labels, test_size=0.5, random_state=seed)

test_sequences = test_sequences.reset_index(drop=True)
test_labels = test_labels.reset_index(drop=True)

validation_sequences = validation_sequences.reset_index(drop=True)
validation_labels = validation_labels.reset_index(drop=True)

In [10]:


tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
esm_encoder = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
esm_encoder.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/724 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/2.61G [00:00<?, ?B/s]

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


EsmModel(
  (embeddings): EsmEmbeddings(
    (word_embeddings): Embedding(33, 1280, padding_idx=1)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): EsmEncoder(
    (layer): ModuleList(
      (0-32): 33 x EsmLayer(
        (attention): EsmAttention(
          (self): EsmSelfAttention(
            (query): Linear(in_features=1280, out_features=1280, bias=True)
            (key): Linear(in_features=1280, out_features=1280, bias=True)
            (value): Linear(in_features=1280, out_features=1280, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (rotary_embeddings): RotaryEmbedding()
          )
          (output): EsmSelfOutput(
            (dense): Linear(in_features=1280, out_features=1280, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (LayerNorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        )
        (intermediate): EsmIntermediate(
          (dense): Linear(in_features=1280, out_feature

In [23]:
def embed_dataset(model, sequences, tokenizer, device, max_length=1026):
    embeddings = []
    with torch.no_grad():
        for seq in sequences:
            encoded = tokenizer(
                seq,
                max_length=max_length,
                truncation=True,
                padding='max_length',
                return_tensors='pt'
            )

            input_ids = encoded['input_ids'].to(device)
            attention_mask = encoded['attention_mask'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
            # Take CLS token embedding from the last hidden layer
            cls_emb = outputs.hidden_states[-1][:, 0, :]  # shape: (1, hidden_size)
            embeddings.append(cls_emb.squeeze(0).cpu())
    return torch.stack(embeddings)  # shape: (num_samples, hidden_size)

train_embeddings = embed_dataset(esm_encoder, train_sequences, tokenizer, device)
val_embeddings = embed_dataset(esm_encoder, validation_sequences, tokenizer, device)
test_embeddings = embed_dataset(esm_encoder, test_sequences, tokenizer, device)


In [14]:
from torch.utils.data import Dataset

class ProteinEmbeddingDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings
        self.labels = torch.tensor(labels, dtype=torch.long)
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        return {
            'embedding': self.embeddings[idx],
            'label': self.labels[idx]
        }

train_dataset = ProteinEmbeddingDataset(train_embeddings, train_labels.values)
val_dataset = ProteinEmbeddingDataset(val_embeddings, validation_labels.values)
test_dataset = ProteinEmbeddingDataset(test_embeddings, test_labels.values)


In [48]:
class FlexibleClassifier(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_layers=2, hidden_units=128, dropout_rate=0.2):
        """
        A flexible fully connected neural network classifier.

        Parameters:
        - input_dim (int): Number of input features (e.g., the size of the embedding).
        - output_dim (int): Number of output classes.
        - hidden_layers (int): Number of hidden layers in the classifier.
        - hidden_units (int): Number of units in each hidden layer.
        - dropout_rate (float): Dropout rate to prevent overfitting.
        """
        super(FlexibleClassifier, self).__init__()
        self.output_dim=output_dim
        # Create a list of layers
        layers = []
        # First hidden layer
        layers.append(nn.Linear(input_dim, hidden_units))
        layers.append(nn.ReLU())
        layers.append(nn.Dropout(dropout_rate))

        # Additional hidden layers (if more than 1)
        for _ in range(hidden_layers - 1):
            layers.append(nn.Linear(hidden_units, hidden_units))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))

        # Output layer
        layers.append(nn.Linear(hidden_units, output_dim))

        # Combine the layers into a Sequential model
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


In [85]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight)
        nn.init.constant_(m.bias, 0)



In [96]:
# Training parameters
input_dim = esm_encoder.config.hidden_size
output_dim = 10
classifier = FlexibleClassifier(input_dim, output_dim, hidden_layers=3, hidden_units=128, dropout_rate=0.1).to(device)
classifier.apply(init_weights)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(classifier.parameters(), lr=1e-3, weight_decay=0.1)
epochs = 20
train_batch_size = 32
eval_batch_size = 32
gradient_accumulation_steps = 1



In [97]:
# Compute total training steps
total_steps = (len(train_dataset) // train_batch_size // gradient_accumulation_steps) * epochs
warmup_steps = int(0.1 * total_steps)
# Optional: Use linear warmup instead of ReduceLROnPlateau
from transformers import get_linear_schedule_with_warmup
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)


In [98]:
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=eval_batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False)

In [99]:
# Save best model
best_f1 = 0
save_path = f"best_model_{seed}.pt"
# Initialize GradScaler for AMP here:
#scaler = torch.cuda.amp.GradScaler()

In [100]:
scaler = torch.cuda.amp.GradScaler()  # Initialize before training loop

best_f1 = 0.0

for epoch in range(epochs):
    classifier.train()
    total_loss = 0.0
    optimizer.zero_grad()

    for step, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
        embeddings = batch['embedding'].to(device)
        labels = batch['label'].to(device)

        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            outputs = classifier(embeddings)
            loss = criterion(outputs, labels)
            loss = loss / gradient_accumulation_steps

        scaler.scale(loss).backward()

        if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(train_loader):
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(classifier.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()

        total_loss += loss.item() * gradient_accumulation_steps

    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} - Training Loss: {avg_train_loss:.4f}")

    # Validation
    classifier.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            embeddings = batch['embedding'].to(device)
            labels = batch['label'].to(device)

            with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                outputs = classifier(embeddings)

            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='macro')
    print(f"Validation - Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

    # Save best model
    if f1 > best_f1:
        best_f1 = f1
        torch.save(classifier.state_dict(), save_path)
        print(f"New best model saved with F1: {best_f1:.4f}")

# Load best model before testing
classifier.load_state_dict(torch.load(save_path))
classifier.eval()

# Testing
all_preds, all_labels = [], []
with torch.no_grad():
    for batch in test_loader:
        embeddings = batch['embedding'].to(device)
        labels = batch['label'].to(device)

        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            outputs = classifier(embeddings)

        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

test_acc = accuracy_score(all_labels, all_preds)
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='macro')
print(f"Test - Accuracy: {test_acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")


  scaler = torch.cuda.amp.GradScaler()  # Initialize before training loop
Epoch 1: 100%|██████████| 50/50 [00:00<00:00, 266.97it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 1 - Training Loss: 2.3372
Validation - Accuracy: 0.2350, Precision: 0.1957, Recall: 0.2274, F1: 0.1777
New best model saved with F1: 0.1777


Epoch 2: 100%|██████████| 50/50 [00:00<00:00, 276.11it/s]


Epoch 2 - Training Loss: 2.1808


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation - Accuracy: 0.2950, Precision: 0.3253, Recall: 0.2880, F1: 0.2322
New best model saved with F1: 0.2322


Epoch 3: 100%|██████████| 50/50 [00:00<00:00, 282.36it/s]


Epoch 3 - Training Loss: 2.0152


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation - Accuracy: 0.3700, Precision: 0.3063, Recall: 0.3470, F1: 0.3040
New best model saved with F1: 0.3040


Epoch 4: 100%|██████████| 50/50 [00:00<00:00, 279.96it/s]


Epoch 4 - Training Loss: 1.8823
Validation - Accuracy: 0.3950, Precision: 0.3573, Recall: 0.3858, F1: 0.3616
New best model saved with F1: 0.3616


Epoch 5: 100%|██████████| 50/50 [00:00<00:00, 279.40it/s]


Epoch 5 - Training Loss: 1.7620
Validation - Accuracy: 0.4350, Precision: 0.3833, Recall: 0.4064, F1: 0.3766
New best model saved with F1: 0.3766


Epoch 6: 100%|██████████| 50/50 [00:00<00:00, 283.14it/s]


Epoch 6 - Training Loss: 1.6624
Validation - Accuracy: 0.4200, Precision: 0.4768, Recall: 0.4084, F1: 0.3621


Epoch 7: 100%|██████████| 50/50 [00:00<00:00, 269.91it/s]


Epoch 7 - Training Loss: 1.5917
Validation - Accuracy: 0.4250, Precision: 0.4104, Recall: 0.4021, F1: 0.3849
New best model saved with F1: 0.3849


Epoch 8: 100%|██████████| 50/50 [00:00<00:00, 277.61it/s]


Epoch 8 - Training Loss: 1.5168
Validation - Accuracy: 0.5250, Precision: 0.4886, Recall: 0.5026, F1: 0.4690
New best model saved with F1: 0.4690


Epoch 9: 100%|██████████| 50/50 [00:00<00:00, 274.38it/s]


Epoch 9 - Training Loss: 1.4577
Validation - Accuracy: 0.5100, Precision: 0.4957, Recall: 0.4822, F1: 0.4527


Epoch 10: 100%|██████████| 50/50 [00:00<00:00, 276.81it/s]


Epoch 10 - Training Loss: 1.4297
Validation - Accuracy: 0.5100, Precision: 0.5228, Recall: 0.5004, F1: 0.4777
New best model saved with F1: 0.4777


Epoch 11: 100%|██████████| 50/50 [00:00<00:00, 282.57it/s]


Epoch 11 - Training Loss: 1.3905
Validation - Accuracy: 0.5050, Precision: 0.5179, Recall: 0.4965, F1: 0.4770


Epoch 12: 100%|██████████| 50/50 [00:00<00:00, 279.32it/s]


Epoch 12 - Training Loss: 1.3357
Validation - Accuracy: 0.5350, Precision: 0.5385, Recall: 0.5269, F1: 0.5125
New best model saved with F1: 0.5125


Epoch 13: 100%|██████████| 50/50 [00:00<00:00, 277.01it/s]


Epoch 13 - Training Loss: 1.3143
Validation - Accuracy: 0.5600, Precision: 0.5497, Recall: 0.5484, F1: 0.5373
New best model saved with F1: 0.5373


Epoch 14: 100%|██████████| 50/50 [00:00<00:00, 281.10it/s]


Epoch 14 - Training Loss: 1.2791
Validation - Accuracy: 0.5500, Precision: 0.5448, Recall: 0.5366, F1: 0.5260


Epoch 15: 100%|██████████| 50/50 [00:00<00:00, 280.90it/s]


Epoch 15 - Training Loss: 1.2586
Validation - Accuracy: 0.5650, Precision: 0.5412, Recall: 0.5470, F1: 0.5382
New best model saved with F1: 0.5382


Epoch 16: 100%|██████████| 50/50 [00:00<00:00, 281.52it/s]


Epoch 16 - Training Loss: 1.2253
Validation - Accuracy: 0.5950, Precision: 0.5754, Recall: 0.5804, F1: 0.5678
New best model saved with F1: 0.5678


Epoch 17: 100%|██████████| 50/50 [00:00<00:00, 280.27it/s]


Epoch 17 - Training Loss: 1.2010
Validation - Accuracy: 0.5700, Precision: 0.5713, Recall: 0.5547, F1: 0.5447


Epoch 18: 100%|██████████| 50/50 [00:00<00:00, 278.67it/s]


Epoch 18 - Training Loss: 1.1744
Validation - Accuracy: 0.5700, Precision: 0.5520, Recall: 0.5561, F1: 0.5464


Epoch 19: 100%|██████████| 50/50 [00:00<00:00, 278.63it/s]


Epoch 19 - Training Loss: 1.1731
Validation - Accuracy: 0.5900, Precision: 0.5784, Recall: 0.5764, F1: 0.5689
New best model saved with F1: 0.5689


Epoch 20: 100%|██████████| 50/50 [00:00<00:00, 275.41it/s]


Epoch 20 - Training Loss: 1.1445
Validation - Accuracy: 0.5850, Precision: 0.5681, Recall: 0.5705, F1: 0.5611
Test - Accuracy: 0.4900, Precision: 0.4901, Recall: 0.4999, F1: 0.4820
