In [None]:
# ==========================================
# Cell 1: Setup and Installation
# ==========================================

# Install PyTorch and transformers
!pip install torch accelerate transformers datasets scikit-learn

# Import libraries
import os
import pickle
import shutil
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, EarlyStoppingCallback
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from torch.utils.data import Dataset

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
# ==========================================
# Cell 2: Mount Drive and Setup Paths
# ==========================================

from google.colab import drive
drive.mount('/content/drive')

# Project configuration
PROJECT_ROOT = '/content/drive/MyDrive/protein_classification'
DATA_DIR = f'{PROJECT_ROOT}/data'
MODELS_DIR = f'{PROJECT_ROOT}/models'
RESULTS_DIR = f'{PROJECT_ROOT}/results'
OUTPUTS_DIR = f'{PROJECT_ROOT}/outputs'

# Change to project directory
os.chdir(PROJECT_ROOT)
print(f"Working directory: {os.getcwd()}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Working directory: /content/drive/MyDrive/protein_classification


In [None]:
# ==========================================
# Cell 3: Load Data Splits
# ==========================================

def load_data_splits():
    """Load the prepared data splits"""
    splits_path = f'{RESULTS_DIR}/data_splits.pkl'

    if not os.path.exists(splits_path):
        print(" Data splits not found!")
        print(" Please run 01_data_preparation.ipynb first.")
        return None

    try:
        with open(splits_path, 'rb') as f:
            data_splits = pickle.load(f)
        print(" Data splits loaded successfully!")

        # Show data summary
        for split_name, split_data in data_splits.items():
            if isinstance(split_data, list):
                print(f"   {split_name}: {len(split_data):,} samples")

        return data_splits
    except Exception as e:
        print(f" Error loading data: {e}")
        return None

# Load data
data_splits = load_data_splits()

 Data splits loaded successfully!
   train_seq: 48,000 samples
   train_labels: 48,000 samples
   val_seq: 16,000 samples
   val_labels: 16,000 samples
   test_seq: 16,000 samples
   test_labels: 16,000 samples


In [None]:
# ==========================================
# Cell 4: Dataset Class and Model Setup
# ==========================================

class ProteinDataset(Dataset):
    """Custom Dataset for protein sequences"""
    def __init__(self, sequences, labels, tokenizer, max_length=1024):
        self.sequences = sequences
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = str(self.sequences[idx])
        label = self.labels[idx]

        encoding = self.tokenizer(
            sequence,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Load ESM-2 model and tokenizer
# model_checkpoint = "facebook/esm2_t12_35M_UR50D"
model_checkpoint = "facebook/esm2_t33_650M_UR50D"
print(f" Loading {model_checkpoint}...")

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)

print(" Model and tokenizer loaded!")


 Loading facebook/esm2_t33_650M_UR50D...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/724 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.61G [00:00<?, ?B/s]

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


 Model and tokenizer loaded!


In [None]:
# ==========================================
# Cell 5: Create Datasets
# ==========================================

if data_splits is not None:
    print(" Creating datasets...")

    train_dataset = ProteinDataset(
        data_splits['train_seq'],
        data_splits['train_labels'],
        tokenizer
    )
    val_dataset = ProteinDataset(
        data_splits['val_seq'],
        data_splits['val_labels'],
        tokenizer
    )
    test_dataset = ProteinDataset(
        data_splits['test_seq'],
        data_splits['test_labels'],
        tokenizer
    )

    print(f" Datasets created:")
    print(f"   Train: {len(train_dataset):,} samples")
    print(f"   Validation: {len(val_dataset):,} samples")
    print(f"   Test: {len(test_dataset):,} samples")

 Creating datasets...
 Datasets created:
   Train: 48,000 samples
   Validation: 16,000 samples
   Test: 16,000 samples


In [None]:
# ==========================================
# Cell 6: Metrics Function
# ==========================================

def compute_metrics(eval_pred):
    """Compute evaluation metrics"""
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average='binary'
    )
    accuracy = accuracy_score(labels, predictions)

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }


In [None]:
# ==========================================
# Cell 7: Stage 1 Training (Classification Head Only)
# ==========================================

print(" Starting Stage 1: Classification Head Training")
print("=" * 60)

# Freeze ESM-2 backbone
frozen_params = 0
trainable_params = 0

for name, param in model.named_parameters():
    if 'classifier' not in name:
        param.requires_grad = False
        frozen_params += param.numel()
    else:
        trainable_params += param.numel()
        print(f"   Training: {name}")

print(f" Frozen parameters: {frozen_params:,}")
print(f" Trainable parameters: {trainable_params:,}")

# Stage 1 training arguments
stage1_args = TrainingArguments(
    output_dir=f'{MODELS_DIR}/esm2_stage1_results',

    num_train_epochs=4,
    learning_rate=2e-5,
    per_device_train_batch_size=32,

    per_device_eval_batch_size=32,
    gradient_accumulation_steps=1,

    warmup_ratio=0.05,
    weight_decay=0.01,

    logging_steps=100,
    eval_strategy="steps",
    eval_steps=250,
    save_strategy="steps",
    save_steps=250,

    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,

    save_total_limit=3,

    dataloader_num_workers=4,
    dataloader_pin_memory=True,
    fp16=True,
    remove_unused_columns=False,

    report_to="none",
    prediction_loss_only=False,
)

# Stage 1 trainer
stage1_trainer = Trainer(
    model=model,
    args=stage1_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
)

# Train Stage 1
print(" Training Stage 1...")
stage1_trainer.train()

# Evaluate Stage 1
stage1_results = stage1_trainer.evaluate(test_dataset)
print(f"\n Stage 1 Results:")
print(f"   F1-Score: {stage1_results['eval_f1']:.4f}")
print(f"   Accuracy: {stage1_results['eval_accuracy']:.4f}")
print(f"   Precision: {stage1_results['eval_precision']:.4f}")
print(f"   Recall: {stage1_results['eval_recall']:.4f}")


 Starting Stage 1: Classification Head Training
   Training: classifier.dense.weight
   Training: classifier.dense.bias
   Training: classifier.out_proj.weight
   Training: classifier.out_proj.bias
 Frozen parameters: 649,400,981
 Trainable parameters: 1,642,242
 Training Stage 1...


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
250,0.6241,0.533044,0.857313,0.871668,0.838,0.854503
500,0.3675,0.346326,0.886375,0.906925,0.861125,0.883432
750,0.318,0.282263,0.9005,0.906393,0.89325,0.899773
1000,0.2683,0.255654,0.90625,0.909526,0.90225,0.905873
1250,0.2538,0.241503,0.90675,0.90675,0.90675,0.90675
1500,0.2274,0.231112,0.910312,0.917674,0.9015,0.909515
1750,0.2394,0.225857,0.910188,0.914174,0.905375,0.909753
2000,0.2143,0.221609,0.914375,0.927301,0.89925,0.91306
2250,0.2203,0.217804,0.912813,0.918303,0.90625,0.912237
2500,0.2225,0.214757,0.916625,0.927527,0.903875,0.915548


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
250,0.6241,0.533044,0.857313,0.871668,0.838,0.854503
500,0.3675,0.346326,0.886375,0.906925,0.861125,0.883432
750,0.318,0.282263,0.9005,0.906393,0.89325,0.899773
1000,0.2683,0.255654,0.90625,0.909526,0.90225,0.905873
1250,0.2538,0.241503,0.90675,0.90675,0.90675,0.90675
1500,0.2274,0.231112,0.910312,0.917674,0.9015,0.909515
1750,0.2394,0.225857,0.910188,0.914174,0.905375,0.909753
2000,0.2143,0.221609,0.914375,0.927301,0.89925,0.91306
2250,0.2203,0.217804,0.912813,0.918303,0.90625,0.912237
2500,0.2225,0.214757,0.916625,0.927527,0.903875,0.915548



 Stage 1 Results:
   F1-Score: 0.9166
   Accuracy: 0.9179
   Precision: 0.9309
   Recall: 0.9028


In [None]:
# Stage 2 training arguments
stage2_args = TrainingArguments(
    output_dir=f'{MODELS_DIR}/esm2_stage2_results',

    num_train_epochs=2,
    learning_rate=1e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,

    warmup_ratio=0.03,
    weight_decay=0.02,

    logging_steps=50,
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,

    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,

    save_total_limit=3,

    dataloader_num_workers=4,
    dataloader_pin_memory=True,
    fp16=True,
    remove_unused_columns=False,

    gradient_checkpointing=True,

    report_to="none",
    prediction_loss_only=False,
)

# Stage 2 trainer
stage2_trainer = Trainer(
    model=model,
    args=stage2_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=7)]
)

# Train Stage 2
print(" Training Stage 2...")
if not hasattr(stage2_trainer, 'already_trained'):
    stage2_trainer.train()
    stage2_trainer.already_trained = True
else:
    print("Stage 2 already trained. Skipping re-training.")

# Final evaluation
final_results = stage2_trainer.evaluate(test_dataset)


 Training Stage 2...


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
500,0.1034,0.091017,0.972688,0.979217,0.965875,0.9725


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
500,0.1034,0.091017,0.972688,0.979217,0.965875,0.9725
1000,0.0704,0.057799,0.983062,0.990606,0.975375,0.982931
1500,0.0354,0.042537,0.98875,0.992071,0.985375,0.988712
2000,0.0148,0.040521,0.9905,0.99001,0.991,0.990505
2500,0.0235,0.040299,0.991187,0.992851,0.9895,0.991173
3000,0.0283,0.029156,0.992563,0.993488,0.991625,0.992556
3500,0.0239,0.033964,0.993,0.992016,0.994,0.993007
4000,0.0088,0.033766,0.992687,0.994728,0.990625,0.992672
4500,0.013,0.032907,0.993812,0.995236,0.992375,0.993804
5000,0.0011,0.034088,0.993313,0.991653,0.995,0.993324


In [None]:
# ==========================================
# Cell 9: Save Model and Results
# ==========================================

# Save the final model
model_save_path = f'{MODELS_DIR}/esm2_ecm_model_enhanced'
stage2_trainer.save_model(model_save_path)
tokenizer.save_pretrained(model_save_path)
print(f" Model saved to: {model_save_path}")

# Calculate improvement
improvement = final_results['eval_f1'] - stage1_results['eval_f1']

# Print results with better formatting
print("\n" + "=" * 60)
print(" ESM-2 ENHANCED TWO-STAGE TRAINING COMPLETE")
print("=" * 60)

print(f" STAGE COMPARISON:")
print(f"   Stage 1 (Classification Head Only):")
print(f"     ├─ F1-Score:  {stage1_results['eval_f1']:.4f}")
print(f"     ├─ Accuracy:  {stage1_results['eval_accuracy']:.4f}")
print(f"     ├─ Precision: {stage1_results['eval_precision']:.4f}")
print(f"     └─ Recall:    {stage1_results['eval_recall']:.4f}")

print(f"\n   Stage 2 (Full Model Fine-tuning) - FINAL:")
print(f"     ├─ F1-Score:  {final_results['eval_f1']:.4f}")
print(f"     ├─ Accuracy:  {final_results['eval_accuracy']:.4f}")
print(f"     ├─ Precision: {final_results['eval_precision']:.4f}")
print(f"     └─ Recall:    {final_results['eval_recall']:.4f}")

print(f"\n IMPROVEMENT:")
print(f"   Stage 1 → Stage 2: +{improvement:.4f} F1-Score")

# Performance assessment
if final_results['eval_accuracy'] >= 0.99:
    print(f"\n TARGET ACHIEVED: {final_results['eval_accuracy']:.4f} accuracy (≥99%)")
elif final_results['eval_accuracy'] >= 0.98:
    print(f"\n EXCELLENT: {final_results['eval_accuracy']:.4f} accuracy (≥98%)")
elif final_results['eval_accuracy'] >= 0.95:
    print(f"\n VERY GOOD: {final_results['eval_accuracy']:.4f} accuracy (≥95%)")
else:
    print(f"\n ROOM FOR IMPROVEMENT: {final_results['eval_accuracy']:.4f} accuracy")
    print(f"   Gap to 99%: {0.99 - final_results['eval_accuracy']:.4f}")

print("=" * 60)

# Prepare and save results
results = {
    'model_name': 'ESM-2_Two_Stage_Enhanced',
    'model_checkpoint': model_checkpoint,
    'training_summary': {
        'total_epochs': 5 + 3,  # stage1 + stage2 epochs
        'final_accuracy': final_results['eval_accuracy'],
        'final_f1': final_results['eval_f1'],
        'target_achieved': final_results['eval_accuracy'] >= 0.99
    },
    'stage_results': {
        'stage1': stage1_results,
        'stage2_final': final_results
    },
    'improvement': improvement,
    'training_history': {
        'stage1_history': stage1_trainer.state.log_history,
        'stage2_history': stage2_trainer.state.log_history
    },
    'model_config': {
        'model_size': '650M',
        'max_length': 1024,
        'enhanced_classifier': True,
        'fp16_training': True,
        'early_stopping': True,
        'more_data': True
    }
}

# Save results
results_path = f'{RESULTS_DIR}/esm2_enhanced_results.pkl'
with open(results_path, 'wb') as f:
    pickle.dump(results, f)

print(f" Results saved to: {results_path}")

# Optional: Save a simple summary
summary_path = f'{RESULTS_DIR}/training_summary.txt'
with open(summary_path, 'w') as f:
    f.write("ESM-2 Enhanced Two-Stage Training Summary\n")
    f.write("=" * 50 + "\n\n")
    f.write(f"Model: {model_checkpoint}\n")
    f.write(f"Final Accuracy: {final_results['eval_accuracy']:.4f}\n")
    f.write(f"Final F1-Score: {final_results['eval_f1']:.4f}\n")
    f.write(f"Improvement: +{improvement:.4f} F1\n")
    f.write(f"Target (99%) Achieved: {'Yes' if final_results['eval_accuracy'] >= 0.99 else 'No'}\n")

print(f" Summary saved to: {summary_path}")

 Model saved to: /content/drive/MyDrive/protein_classification/models/esm2_ecm_model_enhanced

 ESM-2 ENHANCED TWO-STAGE TRAINING COMPLETE
 STAGE COMPARISON:
   Stage 1 (Classification Head Only):
     ├─ F1-Score:  0.9166
     ├─ Accuracy:  0.9179
     ├─ Precision: 0.9309
     └─ Recall:    0.9028

   Stage 2 (Full Model Fine-tuning) - FINAL:
     ├─ F1-Score:  0.9940
     ├─ Accuracy:  0.9940
     ├─ Precision: 0.9950
     └─ Recall:    0.9930

 IMPROVEMENT:
   Stage 1 → Stage 2: +0.0774 F1-Score

 TARGET ACHIEVED: 0.9940 accuracy (≥99%)
 Results saved to: /content/drive/MyDrive/protein_classification/results/esm2_enhanced_results.pkl
 Summary saved to: /content/drive/MyDrive/protein_classification/results/training_summary.txt
