# Necessary

- Create VE
- Create dataset

In [1]:
import sys
import os
sys.path.append(os.path.abspath("/home/gdallagl/myworkdir/ESMSec/utils"))  # Adds current folder to Python path

import utils.my_functions as mf
import utils.models as my_models

import torch
import random
import time
from torch.utils.data import TensorDataset
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForSequenceClassification

print(torch.__version__)

  from .autonotebook import tqdm as notebook_tqdm


2.5.0+cu121


In [None]:
# Configuration / hyperparameters
config = {
    "FLUID": "CSF",         # Dataset name
    "NUM_ITERS": 500,        # Total iterations
    "BATCH_SIZE": 32,       # Batch size
    "LR": 5e-5,             # Learning rate
    "LR_DECAY_GAMMA": 1,    # Learning rate decay
    "LR_DECAY_STEPS": 1000, # Learning rate decay steps
    "EVAL_SIZE": 100,       # Evaluation frequency
    "SEED": 43215,           # Random seed
    "PROTEIN_MAX_LENGTH": 1000, # Max protein length (for ESM2)
    "PRETRAIN_ESM_CHECKPOINT_NAME": "facebook/esm2_t6_8M_UR50D", # ESM2 model name
    "PRETRAIN_ESM_CACHE_DIR": "/home/gdallagl/myworkdir/data/esm2-models", # ESM2 model cache dir
    "DATASET_PATH": "/home/gdallagl/myworkdir/data/ESMSec/protein/CSF_my_dataset.csv", # Path to dataset
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu", # Device to use (cuda or cpu)
    "PATH_TO_SAVE_MODEL": "/home/gdallagl/myworkdir/models/ESMSec/CSF_trained_model.pth" # Path to save the model
}

# Initializations
random.seed(config["SEED"])
np.random.seed(config["SEED"])
torch.manual_seed(config["SEED"])
torch.backends.cudnn.benchmark = True 

### Instantiate the model

In [7]:
# Load pre-trained ESM model
esm_model = AutoModel.from_pretrained(config["PRETRAIN_ESM_CHECKPOINT_NAME"],  cache_dir=config["PRETRAIN_ESM_CACHE_DIR"])
# Checj whcih model has been moded by AutoModel.from_pretrained()
print("\nESM model type", type(esm_model), "\n")

# Load relative tokenizer
tokenizer = AutoTokenizer.from_pretrained(config["PRETRAIN_ESM_CHECKPOINT_NAME"])

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



ESM model type <class 'transformers.models.esm.modeling_esm.EsmModel'> 



In [None]:
# initialise model (ESM+ HEad)
net = my_models.EsmDeepSec(esm_model).to(config["DEVICE"])
# hidden dim of final embeggin of each aa after trnafoerm 
print("\nESM hidden dim", net.ESM_hidden_dim, "\n")


ESM hidden dim 320 



In [9]:
# Block ESM paramters to be trained
for param in net.esm_model.parameters():
    param.requires_grad = False

### Load dataset

In [None]:
# Load dataset
data = pd.read_csv(config["DATASET_PATH"])

# Truncate sequences and add as a new column
data['trunc_sequence'] = data['sequence'].apply(
    lambda seq: mf.truncate_sequence(seq, max_length=config["PROTEIN_MAX_LENGTH"])
)

# Tokenize all sequences at once (vectorized)
tokenized = tokenizer(
    data['trunc_sequence'].tolist(),
    padding='max_length',
    max_length=config["PROTEIN_MAX_LENGTH"],
    truncation=True,
    return_tensors="pt"
)

# Add tokenized input_ids and attention_mask to DataFrame
data['input_ids'] = list(tokenized['input_ids'])
data['attention_mask'] = list(tokenized['attention_mask'])

# Convert labels to tensor
labels_tensor = torch.tensor(data['label'].values)

# Split indices
train_idx = data['set'] == 'train'
valid_idx = data['set'] == 'validation'
test_idx  = data['set'] == 'test'

# Create TensorDatasets
train_dataset = TensorDataset(
    tokenized['input_ids'][train_idx],
    tokenized['attention_mask'][train_idx],
    labels_tensor[train_idx]
)

valid_dataset = TensorDataset(
    tokenized['input_ids'][valid_idx],
    tokenized['attention_mask'][valid_idx],
    labels_tensor[valid_idx]
)

test_dataset = TensorDataset(
    tokenized['input_ids'][test_idx],
    tokenized['attention_mask'][test_idx],
    labels_tensor[test_idx]
)

# Create DataLoaders
train_dl = DataLoader(train_dataset, batch_size=config["BATCH_SIZE"], shuffle=True, pin_memory=True)
valid_dl = DataLoader(valid_dataset, batch_size=config["BATCH_SIZE"], shuffle=False, pin_memory=True)
test_dl  = DataLoader(test_dataset, batch_size=config["BATCH_SIZE"], shuffle=False, pin_memory=True)

# Optional: inspect the DataFrame
display(data.head(5))
print(data.loc[0, "input_ids"])

Unnamed: 0,protein,sequence,label,set,trunc_sequence,input_ids,attention_mask
0,P22694,MGNAATAKKGSEVESVKEFLAKAKEDFLKKWENPTQNNAGLEDFER...,1,train,MGNAATAKKGSEVESVKEFLAKAKEDFLKKWENPTQNNAGLEDFER...,"[tensor(0), tensor(20), tensor(6), tensor(17),...","[tensor(1), tensor(1), tensor(1), tensor(1), t..."
1,Q8NEV1,MSGPVPSRARVYTDVNTHRPREYWDYESHVVEWGNQDDYQLVRKLG...,1,validation,MSGPVPSRARVYTDVNTHRPREYWDYESHVVEWGNQDDYQLVRKLG...,"[tensor(0), tensor(20), tensor(8), tensor(6), ...","[tensor(1), tensor(1), tensor(1), tensor(1), t..."
2,Q6P3V2,MPANWTSPQKSSALAPEDHGSSYEGSVSFRDVAIDFSREEWRHLDP...,1,train,MPANWTSPQKSSALAPEDHGSSYEGSVSFRDVAIDFSREEWRHLDP...,"[tensor(0), tensor(20), tensor(14), tensor(5),...","[tensor(1), tensor(1), tensor(1), tensor(1), t..."
3,Q8N0Y7,MAAYKLVLIRHGESTWNLENRFSCWYDADLSPAGHEEAKRGGQALR...,1,train,MAAYKLVLIRHGESTWNLENRFSCWYDADLSPAGHEEAKRGGQALR...,"[tensor(0), tensor(20), tensor(5), tensor(5), ...","[tensor(1), tensor(1), tensor(1), tensor(1), t..."
4,Q9NY65,MRECISVHVGQAGVQIGNACWELFCLEHGIQADGTFDAQASKINDD...,1,train,MRECISVHVGQAGVQIGNACWELFCLEHGIQADGTFDAQASKINDD...,"[tensor(0), tensor(20), tensor(10), tensor(9),...","[tensor(1), tensor(1), tensor(1), tensor(1), t..."


tensor([ 0, 20,  6, 17,  5,  5, 11,  5, 15, 15,  6,  8,  9,  7,  9,  8,  7, 15,
         9, 18,  4,  5, 15,  5, 15,  9, 13, 18,  4, 15, 15, 22,  9, 17, 14, 11,
        16, 17, 17,  5,  6,  4,  9, 13, 18,  9, 10, 15, 15, 11,  4,  6, 11,  6,
         8, 18,  6, 10,  7, 20,  4,  7, 15, 21, 15,  5, 11,  9, 16, 19, 19,  5,
        20, 15, 12,  4, 13, 15, 16, 15,  7,  7, 15,  4, 15, 16, 12,  9, 21, 11,
         4, 17,  9, 15, 10, 12,  4, 16,  5,  7, 17, 18, 14, 18,  4,  7, 10,  4,
         9, 19,  5, 18, 15, 13, 17,  8, 17,  4, 19, 20,  7, 20,  9, 19,  7, 14,
         6,  6,  9, 20, 18,  8, 21,  4, 10, 10, 12,  6, 10, 18,  8,  9, 14, 21,
         5, 10, 18, 19,  5,  5, 16, 12,  7,  4, 11, 18,  9, 19,  4, 21,  8,  4,
        13,  4, 12, 19, 10, 13,  4, 15, 14,  9, 17,  4,  4, 12, 13, 21, 16,  6,
        19, 12, 16,  7, 11, 13, 18,  6, 18,  5, 15, 10,  7, 15,  6, 10, 11, 22,
        11,  4, 23,  6, 11, 14,  9, 19,  4,  5, 14,  9, 12, 12,  4,  8, 15,  6,
        19, 17, 15,  5,  7, 13, 22, 22, 

### Train

In [None]:
results = mf.train(net, train_dl, valid_dl, test_dl, config)

### Plot

In [None]:
mf.summarize_training(*results)

### Save Model

In [None]:
torch.save(net, config["PATH_TO_SAVE_MODEL"])

### Load Model

In [None]:
model = torch.load(config["PATH_TO_SAVE_MODEL"])
model.to(config["DEVICE"])
model.eval()

### Umaps

### A-scanning