In [1]:
!pip install evaluate
!pip install transformers
!pip install datasets

Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.6


In [2]:
import torch
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
from transformers import AutoTokenizer, AutoModel, Trainer, TrainingArguments
from transformers import EarlyStoppingCallback
import evaluate

In [3]:
pss_data = pd.read_csv('/content/subset_file_50000_100000.csv')

# Display first few rows of the dataset
pss_data.head()

Unnamed: 0,pdb_id,chain_code,seq,sst8,sst3,len,has_nonstd_aa
0,5XFV,A,MAHHHHHHSAALEVLFQGPGSMSLKVNILGHEFSNPFMNAAGVLCT...,CCCCCCCCCCCCCCCCCCCCCCCCCEEETTEEESSSEEECTTSSCS...,CCCCCCCCCCCCCCCCCCCCCCCCCEEECCEEECCCEEECCCCCCC...,334,False
1,4R9X,A,MLEVIATCLEDVKRIERAGGKRIELISSYTEGGLTPSYAFIKKAVE...,CEEEEESSHHHHHHHHHTTCCEEEECBCGGGTCBCCCHHHHHHHHH...,CEEEEECCHHHHHHHHHCCCCEEEECECHHHCCECCCHHHHHHHHH...,233,False
2,5VAX,A,GSMAHAGRTGYDNREIVMKYIHYKLSQRGYEWDDGDDVEENRTEAP...,CCCCCCCCCCCCHHHHHHHHHHHHHHTTTCCCCCCCCCCCCCCCCC...,CCCCCCCCCCCCHHHHHHHHHHHHHHCCCCCCCCCCCCCCCCCCCC...,168,False
3,5LSJ,D,MGTLQKCFEDSNGKASDFSLEASVAEMKEYITKFSLERQTWDQLLL...,CCCCCCCCCCCCCCCCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHH...,CCCCCCCCCCCCCCCCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHH...,178,False
4,2J5I,B,MSTYEGRWKTVKVEIEDGIAFVILNRPEKRNAMSPTLNREMIDVLE...,CCCCCCCCSSEEEEEETTEEEEEECCGGGTTCBCHHHHHHHHHHHH...,CCCCCCCCCCEEEEEECCEEEEEECCHHHCCCECHHHHHHHHHHHH...,276,False


In [4]:
# Data augmention by oversampling the minority class
# Separate the dataset into different classes
coil_data = pss_data[pss_data['sst3'] == 'C']
sheet_data = pss_data[pss_data['sst3'] == 'E']
helix_data = pss_data[pss_data['sst3'] == 'H']

# Determine the maximum count among the classes (the majority class)
max_class_size = max(coil_data.shape[0], sheet_data.shape[0], helix_data.shape[0])

# Oversample the minority classes to match the size of the majority class
# Check if the dataframes are not empty before sampling
sheet_data_oversampled = sheet_data.sample(max_class_size, replace=True, random_state=42) if not sheet_data.empty else pd.DataFrame(columns=sheet_data.columns)
helix_data_oversampled = helix_data.sample(max_class_size, replace=True, random_state=42) if not helix_data.empty else pd.DataFrame(columns=helix_data.columns)

# Combine the data back into one balanced dataset
balanced_data = pd.concat([coil_data, sheet_data_oversampled, helix_data_oversampled])

# Shuffle the balanced dataset
balanced_data = balanced_data.sample(frac=1, random_state=42).reset_index(drop=True)

# Display the new class distribution
balanced_distribution = balanced_data['sst3'].value_counts()
print(balanced_distribution)

Series([], Name: count, dtype: int64)


In [5]:
pss_data_cleaned = pss_data.dropna(subset=['sst3'])

X = pss_data_cleaned['seq']  # Amino acid sequence
y = pss_data_cleaned['sst3']  # Secondary structure labels (C, E, H)

train_texts, val_texts, train_labels, val_labels = train_test_split(X, y, test_size=0.2, random_state=42)
val_texts, test_texts, val_labels, test_labels = train_test_split(val_texts, val_labels, test_size=0.5, random_state=42)

In [6]:
# Modify the model to use ESM2
model_name = "facebook/esm2_t6_8M_UR50D"  # ESM2 model (choose the appropriate model)

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Tokenization function
def tokenize_data(texts, tokenizer, max_length=128):
    return tokenizer(
        list(texts),
        max_length=max_length,
        truncation=True,
        padding=True,
        return_tensors="pt"  # PyTorch tensors
    )

# Tokenize the datasets
train_encodings = tokenize_data(train_texts, tokenizer)
val_encodings = tokenize_data(val_texts, tokenizer)
test_encodings = tokenize_data(test_texts, tokenizer)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

In [7]:
from datasets import Dataset
import torch

# Ensure labels are correctly mapped to integers
label_map = {'C': 0, 'B': 1, 'H': 2, 'E': 0}  # Assuming 'E' maps to Coil (0)

# Convert string labels (sequences) to a list of numeric labels for each residue and pad/truncate them
def prepare_labels(label_sequences, label_map, max_length, pad_value=-100):
    numeric_labels = []
    for seq in label_sequences:
        numeric_seq = [label_map.get(char, pad_value) for char in seq]  # Use .get with pad_value for unknown chars
        truncated_seq = numeric_seq[:max_length]  # Truncate numeric sequence if longer than max_length
        padded_seq = truncated_seq + [pad_value] * (max_length - len(truncated_seq))  # Pad the numeric sequence
        numeric_labels.append(padded_seq)
    return numeric_labels

# Get max_length from tokenization for each split
train_max_length = train_encodings["input_ids"].shape[1]
val_max_length = val_encodings["input_ids"].shape[1]
test_max_length = test_encodings["input_ids"].shape[1]

train_labels_numeric_padded = prepare_labels(train_labels, label_map, train_max_length)
val_labels_numeric_padded = prepare_labels(val_labels, label_map, val_max_length)
test_labels_numeric_padded = prepare_labels(test_labels, label_map, test_max_length)

# Convert the data into Hugging Face Dataset format with torch tensors
train_data = Dataset.from_dict({
    "input_ids": train_encodings["input_ids"].clone().detach(),
    "attention_mask": train_encodings["attention_mask"].clone().detach(),
    "labels": torch.tensor(train_labels_numeric_padded, dtype=torch.long)  # Use the padded numeric labels
})

val_data = Dataset.from_dict({
    "input_ids": val_encodings["input_ids"].clone().detach(),
    "attention_mask": val_encodings["attention_mask"].clone().detach(),
    "labels": torch.tensor(val_labels_numeric_padded, dtype=torch.long)  # Use the padded numeric labels
})

test_data = Dataset.from_dict({
    "input_ids": test_encodings["input_ids"].clone().detach(),
    "attention_mask": test_encodings["attention_mask"].clone().detach(),
    "labels": torch.tensor(test_labels_numeric_padded, dtype=torch.long)  # Use the padded numeric labels
})

In [8]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer  # Added import here

# Function to define the ESM2 + CNN model
class ESM2_CNN_Model(torch.nn.Module):
    def __init__(self, model_name, num_labels=3):
        super(ESM2_CNN_Model, self).__init__()
        # Load pre-trained ESM2 model
        self.esm2 = AutoModel.from_pretrained(model_name)

        # Add a CNN head on top of the ESM2 model
        self.conv1 = torch.nn.Conv1d(in_channels=self.esm2.config.hidden_size, out_channels=128, kernel_size=3, padding=1)
        self.pool = torch.nn.MaxPool1d(kernel_size=2)
        self.fc1 = torch.nn.Linear(128 * (train_max_length // 2), num_labels)  # Adjusted for max length

    def forward(self, input_ids, attention_mask=None):
        # Get embeddings from ESM2 model
        esm2_output = self.esm2(input_ids, attention_mask=attention_mask)
        hidden_states = esm2_output.last_hidden_state  # (batch_size, seq_len, hidden_size)

        # Apply CNN layer
        x = self.conv1(hidden_states.transpose(1, 2))  # Convert to (batch_size, hidden_size, seq_len)
        x = torch.relu(x)

        # Apply max pooling
        x = self.pool(x)

        # Flatten the output and apply fully connected layer
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

# Define model and training arguments
model = ESM2_CNN_Model(model_name=model_name, num_labels=3)

training_args = TrainingArguments(
    output_dir="./esm2_cnn_output",  # Changed output directory name
    eval_strategy="epoch",
    learning_rate=1e-5,  # Fixed learning rate for simplicity
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    num_train_epochs=4,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",  # Still using F1 as a metric
    greater_is_better=True,
    report_to="none",
)

config.json:   0%|          | 0.00/775 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/31.4M [00:00<?, ?B/s]

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
# Define the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=val_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)

# Train the model
trainer.train()

  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Q3 Accuracy,F1
1,0.2517,0.242331,0.901143,0.901005
2,0.1879,0.239164,0.903382,0.90296
3,0.1913,0.239717,0.903951,0.903541
4,0.1784,0.239832,0.903986,0.903537


TrainOutput(global_step=8996, training_loss=0.20861049732561904, metrics={'train_runtime': 555.6714, 'train_samples_per_second': 258.952, 'train_steps_per_second': 16.189, 'total_flos': 0.0, 'train_loss': 0.20861049732561904, 'epoch': 4.0})

In [20]:
# Delete the conflicting global variable
if 'f1_score' in globals() and not callable(f1_score):
    del f1_score

In [24]:
# Evaluate the model on the test data
eval_results = trainer.evaluate(test_data)

# The eval_results dictionary contains the metrics calculated by compute_metrics, including Q3 accuracy and f1
q3_accuracy = eval_results["eval_q3_accuracy"]
weighted_f1 = eval_results["eval_f1"]

# Print Q3 accuracy and F1 score
print("Q3 Accuracy: {:.4f}".format(q3_accuracy))
print("Weighted F1 Score: {:.4f}".format(weighted_f1))

# Print classification report
predictions = trainer.predict(test_data)
predicted_labels = np.argmax(predictions.predictions, axis=-1)
true_labels = predictions.label_ids

# Flatten and filter out padding (-100)
flat_predicted_labels = predicted_labels.flatten()
flat_true_labels = true_labels.flatten()

mask = flat_true_labels != -100
filtered_predicted_labels = flat_predicted_labels[mask]
filtered_true_labels = flat_true_labels[mask]

# Define target names based on your label_map
target_names = ["Coil", "Sheet", "Helix"]
# Explicitly specify the labels to report on
report_labels = [0, 1, 2]

print("\nClassification Report:")
print(classification_report(filtered_true_labels, filtered_predicted_labels, labels=report_labels, target_names=target_names))

Q3 Accuracy: 0.9032
Weighted F1 Score: 0.9028

Classification Report:
              precision    recall  f1-score   support

        Coil       0.92      0.94      0.93    374566
       Sheet       0.00      0.00      0.00         0
       Helix       0.85      0.83      0.84    168294

    accuracy                           0.90    542860
   macro avg       0.59      0.59      0.59    542860
weighted avg       0.90      0.90      0.90    542860



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
