#CafChem tools for Finetuning the ESM models

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/MauricioCafiero/CafChem/blob/main/notebooks/ESMTuning_CafChem.ipynb)

## This notebook allows you to:
- Load ESM models and finetune for various tasks
- examples here include
  * location classification
  * token classification

## Requirements:

- Small models run quickly on an L4 GPU

## Install and import libraries

In [None]:
! pip install -q evaluate

Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.6


In [None]:
!apt install git-lfs

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
git-lfs is already the newest version (3.0.2-1ubuntu0.3).
0 upgraded, 0 newly installed, 0 to remove and 41 not upgraded.


In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
from io import BytesIO
import requests
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer
from datasets import Dataset
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer, AutoModelForTokenClassification, DataCollatorForTokenClassification
from evaluate import load
import re

## Choose a model

ESM-2. The citation for this model is [Lin et al, 2022](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1).

There are several ESM-2 checkpoints with differing model sizes.

| Checkpoint name | Num layers | Num parameters |
|------------------------------|----|----------|
| `esm2_t48_15B_UR50D`         | 48 | 15B     |
| `esm2_t36_3B_UR50D`          | 36 | 3B      |
| `esm2_t33_650M_UR50D`        | 33 | 650M    |
| `esm2_t30_150M_UR50D`        | 30 | 150M    |
| `esm2_t12_35M_UR50D`         | 12 | 35M     |
| `esm2_t6_8M_UR50D`           | 6  | 8M      |


In [None]:
model_checkpoint = "facebook/esm2_t12_35M_UR50D"

# Sequence classification - based on location

## Data preparation

In [None]:
#@title Set protein fetching parameters

protein_min_length = 10 #@param {type:"integer"}
protein_max_length = 100 #@param {type:"integer"}

#@markdown Include:
sequence = True #@param {type:"boolean"}
subcellular_location = True #@param {type:"boolean"}
protein_name = True #@param {type:"boolean"}
gene_names = True #@param {type:"boolean"}
organism_name = True #@param {type:"boolean"}
interaction = False #@param {type:"boolean"}
#@markdown ---
#@markdown Only Human proteins?
human_only = False #@param {type:"boolean"}

fields = ''
if subcellular_location:
  fields += '%2Ccc_subcellular_location'
if sequence:
  fields += '%2Csequence'
if protein_name:
  fields += '%2Cprotein_name'
if gene_names:
  fields += '%2Cgene_names'
if organism_name:
  fields += '%2Corganism_name'
if interaction:
  fields += '%2Ccc_interaction'

if human_only:
  include_human = 'organism_id%3A9606%29%20AND%20%28'
else:
  include_human = ''

In [None]:
query_url = f"https://rest.uniprot.org/uniprotkb/stream?compressed=true&fields=accession\
{fields}&format=tsv&query=%28%28{include_human}reviewed%3Atrue%29%20AND%20%28length%3A%5B\
{protein_min_length}%20TO%20{protein_max_length}%5D%29%29"

uniprot_request = requests.get(query_url)

bio = BytesIO(uniprot_request.content)

df = pd.read_csv(bio, compression='gzip', sep='\t')
df

Unnamed: 0,Entry,Subcellular location [CC],Sequence,Protein names,Gene Names,Organism
0,A0A068B6Q6,SUBCELLULAR LOCATION: Secreted {ECO:0000305}.,PDGRNAAAKAFDLITPTVRKGCCSNPACILNNPNQCG,Conotoxin Bt1.8,,Conus betulinus (Beech cone)
1,A0A0A1I6E7,SUBCELLULAR LOCATION: Secreted {ECO:0000269|Pu...,MEIKYLLTVFLVLLIVSDHCQAFLFSLIPHAISGLISAFKGRRKRD...,Antimicrobial peptide AcrAP1,,Androctonus crassicauda (Arabian fat-tailed sc...
2,A0A0A1I6N9,SUBCELLULAR LOCATION: Secreted {ECO:0000269|Pu...,MEIKYLLTVFLVLLIVSDHCQAFLFSLIPNAISGLLSAFKGRRKRN...,Antimicrobial peptide AcrAP2,,Androctonus crassicauda (Arabian fat-tailed sc...
3,A0A0B4J1N3,SUBCELLULAR LOCATION: Secreted {ECO:0000250|Un...,MRLLALSGLLCMLLLCFCIFSSEGRRHPAKSLKLRRCCHLSPRSKL...,Protein GPR15LG (Protein GPR15 ligand) (Protei...,Gpr15lg Gpr15l,Mus musculus (Mouse)
4,A0A0B4J2F0,SUBCELLULAR LOCATION: Mitochondrion outer memb...,MFRRLTFAQLLFATVLGIAGGVYIFQPVFEQYAKDQKELKEKMQLV...,Protein PIGBOS1 (PIGB opposite strand protein 1),PIGBOS1,Homo sapiens (Human)
...,...,...,...,...,...,...
57461,Q9ZZW4,SUBCELLULAR LOCATION: Mitochondrion {ECO:00003...,MTGSGTPPSREVNTYYMTMTMTMTMIMIMTMTMNIHFNNNNNNNIN...,"Putative uncharacterized protein Q0142, mitoch...",Q0142 ORF9,Saccharomyces cerevisiae (strain ATCC 204508 /...
57462,Q9ZZX7,SUBCELLULAR LOCATION: Mitochondrion {ECO:00003...,MLMMYMLFIMMKTYPMLSYHMMSYHIMLYTIMWYMKYSTYMRLWLL...,"Putative uncharacterized protein Q0032, mitoch...",Q0032 ORF8,Saccharomyces cerevisiae (strain ATCC 204508 /...
57463,Q9ZZX8,SUBCELLULAR LOCATION: Mitochondrion {ECO:00003...,MCATYMFNITVIITHPTPTLRTRGPGFVRNRDLYIYKYKSNLINNL...,"Putative uncharacterized protein Q0017, mitoch...",Q0017 ORF7,Saccharomyces cerevisiae (strain ATCC 204508 /...
57464,V5QPS4,,MTATIGFRPTEKDEQIINAAMRSGERKSDVIRRALQLLEREVWIKQ...,Putative antitoxin Rv3098B/RVBD_3098B,Rv3098B RVBD_3098B P425_03228,Mycobacterium tuberculosis (strain ATCC 25618 ...


Now we'll make one dataframe of proteins that contain `cytosol` or `cytoplasm` in their subcellular localization column, and a second that mentions the `membrane` or `cell membrane`. To ensure we don't get overlap, we ensure each dataframe only contains proteins that don't match the other search term.

In [None]:
cytosolic = df['Subcellular location [CC]'].str.contains("Cytosol") | df['Subcellular location [CC]'].str.contains("Cytoplasm")
membrane = df['Subcellular location [CC]'].str.contains("Membrane") | df['Subcellular location [CC]'].str.contains("Cell membrane")

In [None]:
cytosolic_df = df[cytosolic & ~membrane]
cytosolic_df

Unnamed: 0,Entry,Subcellular location [CC],Sequence,Protein names,Gene Names,Organism
9,A0A0U1RRE5,"SUBCELLULAR LOCATION: Cytoplasm, P-body {ECO:0...",MGDQPCASGRSTLPPGNAREAKPPKKRCLLAPRWDYPEGTPNGGST...,Negative regulator of P-body association (P-bo...,NBDY LINC01420,Homo sapiens (Human)
17,A0A2R8VHR8,SUBCELLULAR LOCATION: Nucleus {ECO:0000250|Uni...,MLKMSGWQRQSQNNSRNLRRECSRRKCIFIHHHT,DDIT3 upstream open reading frame protein (Alt...,Ddit3,Mus musculus (Mouse)
42,A8DYH2,SUBCELLULAR LOCATION: Nucleus {ECO:0000250|Uni...,MSKVTFKITLTSDPKLPFKVLSVPEGTPFTAVLKFASEEFKVPAET...,Ubiquitin-fold modifier 1,Ufm1 CG34191,Drosophila melanogaster (Fruit fly)
45,A8MTZ0,"SUBCELLULAR LOCATION: Cell projection, cilium ...",MLKAAAKRPELSGKNTISNNSDMAEVKSMFREVLPKQGPLFVEDIM...,BBSome-interacting protein 1 (BBSome-interacti...,BBIP1 BBIP10 NCRNA00081,Homo sapiens (Human)
96,C9JLW8,SUBCELLULAR LOCATION: Nucleus {ECO:0000269|Pub...,MTSSPVSRVVYNGKRTSSPRSPPSSSEIFTPAHEENVRFIYEAWQG...,Mapk-regulated corepressor-interacting protein...,MCRIP1 FAM195B GRAN2,Homo sapiens (Human)
...,...,...,...,...,...,...
57239,Q9KAD6,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000255|H...,MLSKQKIERINELAKRAKTTGLTEDELREQKKLREEYIQQFRQSFK...,UPF0291 protein BH2353,BH2353,Halalkalibacterium halodurans (strain ATCC BAA...
57259,Q9KVY3,SUBCELLULAR LOCATION: Cell inner membrane {ECO...,MATPLSPFSWLAIGIVKLYQWFISPLIGPRCRFTPTCSTYAIEALR...,Putative membrane protein insertion efficiency...,VC_0005,Vibrio cholerae serotype O1 (strain ATCC 39315...
57281,Q9PJS0,SUBCELLULAR LOCATION: Cell inner membrane {ECO...,MKTSWIKIFFQGMIHLYRWTISPLLGSPCRFFPSCSEYALVALKKH...,Putative membrane protein insertion efficiency...,TC_0758,Chlamydia muridarum (strain MoPn / Nigg)
57380,Q9X1H3,SUBCELLULAR LOCATION: Cell inner membrane {ECO...,MKKLLIMLIRFYQRYISPLKPPTCRFTPTCSNYFIQALEKHGLLKG...,Putative membrane protein insertion efficiency...,TM_1462,Thermotoga maritima (strain ATCC 43589 / DSM 3...


In [None]:
membrane_df = df[membrane & ~cytosolic]
membrane_df

Unnamed: 0,Entry,Subcellular location [CC],Sequence,Protein names,Gene Names,Organism
12,A0A1P8AQ95,SUBCELLULAR LOCATION: Cell membrane {ECO:00002...,MTKNMTKKKMGLMSPNIAAFVLPMLLVLFTISSQVEVVESTGRKLS...,Secreted transmembrane peptide 4 (Phytocytokin...,STMP4 At1g65486 F5I14,Arabidopsis thaliana (Mouse-ear cress)
22,A0JQ18,SUBCELLULAR LOCATION: Cell membrane {ECO:00002...,MAAKTSNLVALLLSLFLLLLSISSQVGLGEAKRNLRNNLRLDCVSH...,Serine rich endogenous peptide 14 (AtSCOOP14) ...,PROSCOOP14 SCOOP14 STMP2 At1g22890 F19G10.22,Arabidopsis thaliana (Mouse-ear cress)
36,A4IFH6,SUBCELLULAR LOCATION: Endoplasmic reticulum me...,MDKVQYLTRSAIRRASTIEMPQQARQNLQNLFINFCLISICLLLIC...,Phospholamban (PLB),PLN,Bos taurus (Bovine)
46,B0L3A2,SUBCELLULAR LOCATION: Cell membrane {ECO:00002...,MTMFKGSNEMKSRWNWGSITCIICFTCVGSQLSMSSSKASNFSGPL...,Dual endothelin-1/VEGF signal peptide receptor...,FBXW7-AS1 DEAR DEspR,Homo sapiens (Human)
57,B2RUZ4,SUBCELLULAR LOCATION: Cell membrane {ECO:00002...,MQPQESHVHYSRWEDGSRDGVSLGAVSSTEEASRCRRISQRLCTGK...,Small integral membrane protein 1 (Vel blood g...,SMIM1,Homo sapiens (Human)
...,...,...,...,...,...,...
57398,Q9Y068,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...,MPITCGDIPRLICSVIIPPVGVFFQVGCTKDLAINCLLTVLGYIPG...,Protein Ric1,RIC1,Phytophthora infestans (Potato late blight age...
57435,Q9ZDN9,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...,MLKSLKFLLVFIILAQLLSCTPSAPYEIKSPCVSVDIDDNSSLSIN...,Uncharacterized protein RP288,RP288,Rickettsia prowazekii (strain Madrid E)
57440,Q9ZDZ4,SUBCELLULAR LOCATION: Cell membrane {ECO:00003...,MIILHLIHRSLNMLINTSNNLLITTIHLLSSIGAINWGLVGLFNFN...,Uncharacterized protein RP169,RP169,Rickettsia prowazekii (strain Madrid E)
57445,Q9ZE49,SUBCELLULAR LOCATION: Cell membrane {ECO:00003...,MFKHVLLSIIIFLGINQNVYSINSNSYKTDDIIKIVIILGIVILIF...,Uncharacterized protein RP098,RP098,Rickettsia prowazekii (strain Madrid E)


In [None]:
cytosolic_sequences = cytosolic_df["Sequence"].tolist()
cytosolic_labels = [0 for protein in cytosolic_sequences]

In [None]:
membrane_sequences = membrane_df["Sequence"].tolist()
membrane_labels = [1 for protein in membrane_sequences]

In [None]:
sequences = cytosolic_sequences + membrane_sequences
labels = cytosolic_labels + membrane_labels

# Quick check to make sure we got it right
len(sequences) == len(labels)

True

In [None]:
train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.25, shuffle=True)

## Tokenizing the data

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
print(f"size of Tokenizer: {len(tokenizer)}")

size of Tokenizer: 33


In [None]:
train_tokenized = tokenizer(train_sequences)
test_tokenized = tokenizer(test_sequences)

## Dataset creation

In [None]:
train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

train_dataset

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 6247
})

In [None]:
train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels)
train_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 6247
})

## Model loading

In [None]:
num_labels = max(train_labels + test_labels) + 1  # Add 1 since 0 can be a label
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model_name = model_checkpoint.split("/")[-1]
batch_size = 8

args = TrainingArguments(
    f"{model_name}-finetuned-localization",
    eval_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=True,
)

Next, we define the metric we will use to evaluate our models and write a `compute_metrics` function. We can load this from the `evaluate` library.

In [None]:
metric = load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

## Model Training

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

  trainer = Trainer(


In [None]:
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 2}.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.1383,0.080703,0.978397
2,0.0352,0.05941,0.985598
3,0.0196,0.060139,0.986078


TrainOutput(global_step=2343, training_loss=0.05616307594406773, metrics={'train_runtime': 149.968, 'train_samples_per_second': 124.967, 'train_steps_per_second': 15.623, 'total_flos': 372044720641122.0, 'train_loss': 0.05616307594406773, 'epoch': 3.0})

In [None]:
test_seqs = [
"YMLLLLLVLTLGETLLLGVAILLLFRFLLLLKGGNSLFLLKYLAAALQLL",
"GKYSHLLQDELLLPLNQNYFLGSAPCLCTCKLATGASESVALSGLILLLA",
"RSRNEENDQHGRTTRLAAQGAEGNFVPDPQKPSYVLLSLAAFLLSKLLED",
"RMILLLLLLLRSLLLGLLYSRLLLLLLLRNIRALLSELLVVVSLELILHH",
"EMKNLTILLLLLLLLLLLLLALLSALVSLSYCLCLCGAAGSVSHNLAASK",
"KCSDPRKAADPPKLDSTALSEESPSCGVGLLLLLDAGTTEKIELRPQLQS",
"IRLSLLLLLLLLLLLLLLLLEGTALLVLLLRLSLLLSSALLQAELLQYPI",
"AFSQLLSSLQQLKLQSLLLLAEYKEAYAVLLLLLLLLTTAVLLLLLLLLV",
"CPQIILLLLLETLLLLLTVLAEALLKTVILLLELLLLSSLLVRRLVDLLN",
"CLFLLGLLEPPKCCNLLLNGSELLLLALHVLLLALLLACKL",
"CIGGAALLVSALTGLLSAALLLLLLLVPCRLLLLLFLLGLLLLLLLLLHL",
"VLLLLLLLLLLALLLLLLLLLLASLLLLLLLLALSCLLLLLEGNIPRLLL",
"MTVVLVVDGLLVLLLLTLLSLVSLLLAELDGLLAAAPEARRAFLAIQELL",
"SPPKSLLLALLALLLLKDLLGLLLLLNRFTPVNGCHLLAQLLSQLLFLLL",
"LILPLLLLFTAPPEYFLLLLLLGKELLALLLACAVKPDKEKLTEPETIFC",
"VYVPPACCNTEPKPPC",
"YVVVGLLKLPLNEREEDLLLLRNGAIAALL",
"EHKEVVAVRLLRYLAALLTLLVPWLLLNLRLLLVLLLKLKLLAIFLPVLL",
"FIQPTAGFLLTVLGALEGLLCPQVATEELLCAPICCVKLISAFAPTALLL",
"FKLTSLLLLLLLLLLLLLKLGLLDLRLLIRLMLLATARCLLSLNRGNVDL",
]

In [None]:
test_dataset = Dataset.from_dict(tokenizer(test_seqs))
res = trainer.predict(test_dataset)

In [None]:
which_idx = 17

In [None]:
label_hash = {'0': 'cystolic', '1': 'membrane'}

for seq, seq_res in zip(test_seqs[which_idx],res[0][which_idx]):
  max_label = np.argmax(seq_res)
  print(f"{token} : {label_hash[str(max_label)]}")

N : cystolic
N : cystolic


***
# Token classification

categorize each token (amino acid, in this case!) into one or more categories.

## Data preparation

In [None]:
#@title Set protein fetching parameters

protein_min_length = 10 #@param {type:"integer"}
protein_max_length = 100 #@param {type:"integer"}

#@markdown Include:
sequence = True #@param {type:"boolean"}
helix = True #@param {type:"boolean"}
strand = True #@param {type:"boolean"}
subcellular_location = False #@param {type:"boolean"}
protein_name = True #@param {type:"boolean"}
gene_names = True #@param {type:"boolean"}
organism_name = True #@param {type:"boolean"}
interaction = False #@param {type:"boolean"}
#@markdown ---
#@markdown Only Human proteins?
human_only = True #@param {type:"boolean"}

fields = ''
if subcellular_location:
  fields += '%2Ccc_subcellular_location'
if sequence:
  fields += '%2Csequence'
if protein_name:
  fields += '%2Cprotein_name'
if gene_names:
  fields += '%2Cgene_names'
if organism_name:
  fields += '%2Corganism_name'
if interaction:
  fields += '%2Ccc_interaction'
if helix:
  fields += '%2Cft_helix'
if strand:
  fields += '%2Cft_strand'

if human_only:
  include_human = 'organism_id%3A9606%29%20AND%20%28'
else:
  include_human = ''

In [None]:
query_url = f"https://rest.uniprot.org/uniprotkb/stream?compressed=true&fields=accession\
{fields}&format=tsv&query=%28%28{include_human}reviewed%3Atrue%29%20AND%20%28length%3A%5B\
{protein_min_length}%20TO%20{protein_max_length}%5D%29%29"

uniprot_request = requests.get(query_url)

bio = BytesIO(uniprot_request.content)

df = pd.read_csv(bio, compression='gzip', sep='\t')
df

Unnamed: 0,Entry,Sequence,Protein names,Gene Names,Organism,Helix,Beta strand
0,A0A0B4J2F0,MFRRLTFAQLLFATVLGIAGGVYIFQPVFEQYAKDQKELKEKMQLV...,Protein PIGBOS1 (PIGB opposite strand protein 1),PIGBOS1,Homo sapiens (Human),,
1,A0A0C5B5G6,MRWQEMGYIFYPRKLR,Mitochondrial-derived peptide MOTS-c (Mitochon...,MT-RNR1,Homo sapiens (Human),,
2,A0A0U1RRE5,MGDQPCASGRSTLPPGNAREAKPPKKRCLLAPRWDYPEGTPNGGST...,Negative regulator of P-body association (P-bo...,NBDY LINC01420,Homo sapiens (Human),,
3,A1L190,MDDADPEERNYDNMLKMLSDLNKDLEKLLEEMEKISVQATWMAYDM...,Synaptonemal complex central element protein 3...,SYCE3 C22orf41 THEG2,Homo sapiens (Human),,
4,A8MT69,MEGAGAGSGFRKELVSRLLHLHFKDDKTKVSGDALQLMVELLKVFV...,Centromere protein X (CENP-X) (FANCM-associate...,CENPX FAAP10 MHF2 STRA13,Homo sapiens (Human),"HELIX 12..20; /evidence=""ECO:0007829|PDB:4NE3""...","STRAND 28..30; /evidence=""ECO:0007829|PDB:7R5S..."
...,...,...,...,...,...,...,...
757,Q9UI25,MEEMSYGENSGTHVGSFSCSPQPSQQMKVLFVGNSFLLTPVLHRQP...,Putative uncharacterized protein PRO0461,PRO0461,Homo sapiens (Human),,
758,Q9UI54,MESPKCLYSRITVNTAFGTKFSHISFIILFKVFLFPRITISKKTKL...,Putative uncharacterized protein PRO0628,PRO0628,Homo sapiens (Human),,
759,Q9UI72,MGMALELYWLCGFRSYWPLGTNAENEGNRKENRRQMQSRNERGCNV...,Putative uncharacterized protein PRO0255,PRO0255,Homo sapiens (Human),,
760,Q9Y3F1,MSLLWTPQILTISFVSYILSLFPSPFPSCYTSCWFETSITTEKELN...,Putative TAP2-associated 6.5 kDa polypeptide,,Homo sapiens (Human),,


In [None]:
no_structure_rows = df["Beta strand"].isna() & df["Helix"].isna()
df = df[~no_structure_rows]
df

Unnamed: 0,Entry,Sequence,Protein names,Gene Names,Organism,Helix,Beta strand
4,A8MT69,MEGAGAGSGFRKELVSRLLHLHFKDDKTKVSGDALQLMVELLKVFV...,Centromere protein X (CENP-X) (FANCM-associate...,CENPX FAAP10 MHF2 STRA13,Homo sapiens (Human),"HELIX 12..20; /evidence=""ECO:0007829|PDB:4NE3""...","STRAND 28..30; /evidence=""ECO:0007829|PDB:7R5S..."
11,L0R8F8,MAPWSREAVLSLYRALLRQGRQLRYTDRDFYFASIRREFRKNQKLE...,Mitochondrial ribosome and complex I assembly ...,MIEF1 AltMiD51 AltMIEF1,Homo sapiens (Human),"HELIX 6..21; /evidence=""ECO:0007829|PDB:7OF0"";...","STRAND 24..26; /evidence=""ECO:0007829|PDB:7QH7..."
12,O00168,MASLGHILVFCVGLLTMAKAESPKEHDPFTYDYQSLQIGGLVIAGI...,Phospholemman (FXYD domain-containing ion tran...,FXYD1 PLM,Homo sapiens (Human),"HELIX 23..26; /evidence=""ECO:0007829|PDB:2JO1""...","STRAND 69..71; /evidence=""ECO:0007829|PDB:2JO1"""
13,O00198,MCPCPLHRGRGPPAVCACSAGRLGLRSSAAQLTAARLKALGDELHQ...,Activator of apoptosis harakiri (BH3-interacti...,HRK BID3,Homo sapiens (Human),"HELIX 31..49; /evidence=""ECO:0007829|PDB:7P0U""...",
14,O00244,MPKHEFSVDMTCGGCAEAVSRVLNKLGGVKYDIDLPNKKVCIESEH...,Copper transport protein ATOX1 (Metal transpor...,ATOX1 HAH1,Homo sapiens (Human),"HELIX 13..26; /evidence=""ECO:0007829|PDB:3IWL""...","STRAND 3..8; /evidence=""ECO:0007829|PDB:3IWL"";..."
...,...,...,...,...,...,...,...
455,Q8N6N7,MALQADFDRAAEDVRKLKARPDDGELKELYGLYKQAIVGDINIACP...,Acyl-CoA-binding domain-containing protein 7,ACBD7,Homo sapiens (Human),"HELIX 2..13; /evidence=""ECO:0007829|PDB:3EPY"";...",
462,Q969E1,MWHLKLCAVLMIFLLLLGQIDGSPIPEVSSAKRRPRRMTPFWRGVS...,Liver-expressed antimicrobial peptide 2 (LEAP-2),LEAP2,Homo sapiens (Human),"HELIX 57..59; /evidence=""ECO:0007829|PDB:2L1Q""","STRAND 46..48; /evidence=""ECO:0007829|PDB:2L1Q"""
470,Q9BQ48,MAVLAGSLLGPTSRSAALLGGRWLQPRAWLGFPDAWGLPTPQQARG...,Large ribosomal subunit protein bL34m (39S rib...,MRPL34,Homo sapiens (Human),"HELIX 57..64; /evidence=""ECO:0007829|PDB:7OF0""...",
653,Q9P1F3,MNVDHEVNLLVEEIHRLGSKNADGKLSVKFGVLFRDDKCANLFEAL...,Costars family protein ABRACL (ABRA C-terminal...,ABRACL C6orf115 HSPC280 PRO2013,Homo sapiens (Human),"HELIX 3..17; /evidence=""ECO:0007829|PDB:2L2O"";...","STRAND 24..29; /evidence=""ECO:0007829|PDB:2L2O..."


In [None]:
#@title Make clean label columns

def build_labels(sequence, strands, helices):
    # Start with all 0s
    labels = np.zeros(len(sequence), dtype=np.int64)

    if isinstance(helices, float): # Indicates missing (NaN)
        found_helices = []
    else:
        found_helices = re.findall(helix_re, helices)
    for helix_start, helix_end in found_helices:
        helix_start = int(helix_start) - 1
        helix_end = int(helix_end)
        assert helix_end <= len(sequence)
        labels[helix_start: helix_end] = 1  # Helix category

    if isinstance(strands, float): # Indicates missing (NaN)
        found_strands = []
    else:
        found_strands = re.findall(strand_re, strands)
    for strand_start, strand_end in found_strands:
        strand_start = int(strand_start) - 1
        strand_end = int(strand_end)
        assert strand_end <= len(sequence)
        labels[strand_start: strand_end] = 2  # Strand category
    return labels

strand_re = r"STRAND\s(\d+)\.\.(\d+)\;"
helix_re = r"HELIX\s(\d+)\.\.(\d+)\;"

re.findall(helix_re, df.iloc[0]["Helix"])

sequences = []
labels = []

for row_idx, row in df.iterrows():
    row_labels = build_labels(row["Sequence"], row["Beta strand"], row["Helix"])
    sequences.append(row["Sequence"])
    labels.append(row_labels)

for seqs, labs in zip(sequences[:5], labels[:5]):
    print(f"{labs[:10]} : {seqs[:10]}")

[0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 2 2 2 0 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 2 2 2 0 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 0] : MEGAGAGSGF
[0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 2 2 2 0 1 1 1 1 1 1 1 1 1 1
 1 1 1 0 0 2 2 2 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0] : MAPWSREAVL
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 1 1 1 0 0 0 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 2 2 2 0 0 0
 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 0 0 0] : MASLGHILVF
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0] : MCPCPLHRGR
[0 0 2 2 2 2 2 2 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 2 2 2 2 2 2 2 0 0 0
 0 2 2 2 2 2 2 2 2 0 1 1 1 1 1 1 1 1 1 0 0 0 0 0 2 2 2 2 2 0 0] : MPKHEFSVDM


## Creating our dataset

In [None]:
train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.25, shuffle=True)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

train_tokenized = tokenizer(train_sequences)
test_tokenized = tokenizer(test_sequences)

In [None]:
train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels)

## Model loading

In [None]:
num_labels = 3
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
data_collator = DataCollatorForTokenClassification(tokenizer)

In [None]:
model_name = model_checkpoint.split("/")[-1]
batch_size = 8

args = TrainingArguments(
    f"{model_name}-finetuned-secondary-structure",
    eval_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.001,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=True,
)

In [None]:
metric = load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    labels = labels.reshape((-1,))
    predictions = np.argmax(predictions, axis=2)
    predictions = predictions.reshape((-1,))
    predictions = predictions[labels!=-100]
    labels = labels[labels!=-100]
    return metric.compute(predictions=predictions, references=labels)

Downloading builder script: 0.00B [00:00, ?B/s]

## Train model

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
)

trainer.train()

  trainer = Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 2}.


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.69879,0.747515
2,No log,0.577248,0.792372
3,No log,0.562489,0.795864


TrainOutput(global_step=51, training_loss=0.6677172791724112, metrics={'train_runtime': 4.9382, 'train_samples_per_second': 78.976, 'train_steps_per_second': 10.328, 'total_flos': 7748576912496.0, 'train_loss': 0.6677172791724112, 'epoch': 3.0})

## Test Model

In [None]:
test_seqs = [
"YMLLLLLVLTLGETLLLGVAILLLFRFLLLLKGGNSLFLLKYLAAALQLL",
"GKYSHLLQDELLLPLNQNYFLGSAPCLCTCKLATGASESVALSGLILLLA",
"RSRNEENDQHGRTTRLAAQGAEGNFVPDPQKPSYVLLSLAAFLLSKLLED",
"RMILLLLLLLRSLLLGLLYSRLLLLLLLRNIRALLSELLVVVSLELILHH",
"EMKNLTILLLLLLLLLLLLLALLSALVSLSYCLCLCGAAGSVSHNLAASK",
"KCSDPRKAADPPKLDSTALSEESPSCGVGLLLLLDAGTTEKIELRPQLQS",
"IRLSLLLLLLLLLLLLLLLLEGTALLVLLLRLSLLLSSALLQAELLQYPI",
"AFSQLLSSLQQLKLQSLLLLAEYKEAYAVLLLLLLLLTTAVLLLLLLLLV",
"CPQIILLLLLETLLLLLTVLAEALLKTVILLLELLLLSSLLVRRLVDLLN",
"CLFLLGLLEPPKCCNLLLNGSELLLLALHVLLLALLLACKL",
"CIGGAALLVSALTGLLSAALLLLLLLVPCRLLLLLFLLGLLLLLLLLLHL",
"VLLLLLLLLLLALLLLLLLLLLASLLLLLLLLALSCLLLLLEGNIPRLLL",
"MTVVLVVDGLLVLLLLTLLSLVSLLLAELDGLLAAAPEARRAFLAIQELL",
"SPPKSLLLALLALLLLKDLLGLLLLLNRFTPVNGCHLLAQLLSQLLFLLL",
"LILPLLLLFTAPPEYFLLLLLLGKELLALLLACAVKPDKEKLTEPETIFC",
"VYVPPACCNTEPKPPC",
"YVVVGLLKLPLNEREEDLLLLRNGAIAALL",
"EHKEVVAVRLLRYLAALLTLLVPWLLLNLRLLLVLLLKLKLLAIFLPVLL",
"FIQPTAGFLLTVLGALEGLLCPQVATEELLCAPICCVKLISAFAPTALLL",
"FKLTSLLLLLLLLLLLLLKLGLLDLRLLIRLMLLATARCLLSLNRGNVDL",
]

In [None]:
test_dataset = Dataset.from_dict(tokenizer(test_seqs))
res = trainer.predict(test_dataset)

In [None]:
which_idx = 8

In [None]:
label_hash = {'0': 'unclassified', '1': 'helix', '2': 'strand'}

for seq, seq_res in zip(test_seqs[which_idx],res[0][which_idx]):
  for token, label in zip(seq, seq_res):
    #print(seq_res)
    max_label = np.argmax(seq_res)
    print(f"{token} : {label_hash[str(max_label)]}")

C : unclassified
P : unclassified
Q : unclassified
I : unclassified
I : helix
L : helix
L : helix
L : helix
L : helix
L : helix
E : helix
T : helix
L : helix
L : helix
L : helix
L : helix
L : helix
T : helix
V : helix
L : helix
A : helix
E : helix
A : helix
L : helix
L : helix
K : helix
T : helix
V : helix
I : helix
L : helix
L : helix
L : helix
E : helix
L : helix
L : helix
L : helix
L : helix
S : helix
S : helix
L : helix
L : helix
V : helix
R : helix
R : unclassified
L : unclassified
V : unclassified
D : unclassified
L : unclassified
L : unclassified
N : unclassified


In [None]:
res[0][0]

array([[ 1.1182994 ,  0.02025404, -1.0024948 ],
       [ 0.66464204,  0.91449344, -1.212727  ],
       [ 0.23383358,  1.3879018 , -1.187534  ],
       [ 0.03748567,  1.5216719 , -1.1994661 ],
       [-0.09473807,  1.6444285 , -1.1305642 ],
       [-0.12965043,  1.6437149 , -1.0896218 ],
       [-0.18493234,  1.5203499 , -1.1650407 ],
       [ 0.06917883,  1.3380501 , -1.1928856 ],
       [ 0.21839479,  0.97747463, -1.1504468 ],
       [ 0.48156598,  0.9707075 , -1.0550798 ],
       [ 0.49752265,  0.7604192 , -1.0267966 ],
       [ 0.46026355,  0.76377964, -0.9962858 ],
       [ 0.29584888,  0.87685394, -1.0875291 ],
       [ 0.09484729,  1.1764625 , -1.2005545 ],
       [-0.15599872,  1.4452455 , -1.10801   ],
       [-0.04869396,  1.565366  , -1.2679691 ],
       [-0.14631969,  1.7230221 , -1.1689494 ],
       [-0.18232056,  1.7943643 , -0.9927362 ],
       [-0.25862083,  1.6866018 , -0.91176635],
       [-0.36547112,  1.7737633 , -1.0943445 ],
       [-0.435158  ,  1.7539719 , -0.941