# Fine-tuning protein language models

**Q6.** Describe the problem of **predicting the subcellular location** of (prokaryotic) proteins as described in [Moreno et al., 2024](https://doi.org/10.1093/bioinformatics/btae677)? Think of a biological question one could answer with proteome-wide predictions of subcellular location, and potential follow-up experiments?

In [1]:
import numpy as np, pandas as pd, sklearn.preprocessing
import datasets, evaluate, transformers # Hugging Face libraries https://doi.org/10.18653/v1/2020.emnlp-demos.6
import Bio.SeqIO.FastaIO # Biopython for reading fasta files
from IPython.display import display, HTML
random_number = 4 # https://xkcd.com/221/

In [2]:
# Uncomment & execute once to download data from https://services.healthtech.dtu.dk/services/DeepLocPro-1.0/
!mkdir -p data
!curl https://services.healthtech.dtu.dk/services/DeepLocPro-1.0/data/graphpart_set.fasta -o data/graphpart_set.fasta

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 5492k  100 5492k    0     0  10.6M      0 --:--:-- --:--:-- --:--:-- 10.6M


**Q7.** Look at the contents of `df_data`, how was the column `fold_id` defined in the paper? What exact set of sequences are in this data set (check number of rows)? 

- **`fold_id`**: 5-fold **homology-partition** label created with **GraphPart**; sequences in different folds have **≤30% Needleman–Wunsch identity** and folds are **balanced** for organism group + location (64 sequences removed to achieve separation).
- **Dataset sequences**: experimentally verified prokaryotic proteins (UniProt release 2023_03 + PSORTdb 4.0), deduplicated, **len ≥ 40 aa**, **single-label only**; after GraphPart the training dataset has **11,906 sequences** (**= number of rows in `df_data`**).

In [3]:
with open('data/graphpart_set.fasta') as handle:
    fasta_cols = ['header', 'sequence']
    df_data = pd.DataFrame.from_records([values for values in Bio.SeqIO.FastaIO.SimpleFastaParser(handle)], columns=fasta_cols)
header_cols = ['uniprot_id', 'subcellular_location', 'organism_group', 'fold_id']
df_data[header_cols] = df_data['header'].str.split('|', expand=True)
final_cols = ['uniprot_id']
df_data = df_data[['uniprot_id', 'subcellular_location', 'organism_group', 'fold_id', 'sequence']].astype({'fold_id': int}).sort_values('fold_id')
df_data

Unnamed: 0,uniprot_id,subcellular_location,organism_group,fold_id,sequence
8560,Q8A0Z3,Cytoplasmic,negative,0,MAVTMADITKLRKMTGAGMMDCKNALTEAEGDYDKAMEIIRKKGQA...
8568,Q8A2N8,Cytoplasmic,negative,0,MIMSKETLIKSIREIPDFPIPGILFYDVTTLFKDPWCLQELSNIMF...
2593,P32709,CYtoplasmicMembrane,negative,0,MTQTSAFHFESLVWDWPIAIYLFLIGISAGLVTLAVLLRRFYPQAG...
61,E7FHF8,Cytoplasmic,archaea,0,MKLGVFELTDCGGCALNLLFLYDKLLDLLEFYEIAEFHMATSKKSR...
63,E7FHU4,Cytoplasmic,archaea,0,MGKVRIGFYALTSCYGCQLQLAMMDELLQLIPNAEIVCWFMIDRDS...
...,...,...,...,...,...
5239,Q97F85,Cytoplasmic,positive,4,MRKLFTSESVTEGHPDKICDQISDAILDAILEKDPNGRVACETTVT...
5226,P33656,Cytoplasmic,positive,4,MKNKTEVKNGGEKKNSKKVSKEESAKEKNEKMKIVKNLIDKGKKSG...
11894,P13949,OuterMembrane,negative,4,MCALDRRERPLNSQSVNKYILNVQNIYRNSPVPVCVRNKNRKILYA...
11890,P42185,Extracellular,negative,4,MRLRFSVPLFFFGCVFVHGVFAGPFPPPGMSLPEYWGEEHVWWDGR...


In [4]:
# List the 6 unique subcellular locations
unique_locations = df_data['subcellular_location'].unique()
print(f"Your 6 labels are: {unique_locations}")

# See the distribution (how many proteins per location)
print("\nCounts per location:")
print(df_data['subcellular_location'].value_counts())

Your 6 labels are: <ArrowStringArray>
[        'Cytoplasmic', 'CYtoplasmicMembrane',            'Cellwall',
       'OuterMembrane',       'Extracellular',         'Periplasmic']
Length: 6, dtype: str

Counts per location:
subcellular_location
Cytoplasmic            6885
CYtoplasmicMembrane    2535
Extracellular          1077
OuterMembrane           756
Periplasmic             566
Cellwall                 87
Name: count, dtype: int64


In [5]:
# Encode subcellular location as numerical labels
subcellular_location_encoder = sklearn.preprocessing.LabelEncoder()
subcellular_location_encoder.fit(df_data['subcellular_location'])
df_data['label'] = subcellular_location_encoder.transform(df_data['subcellular_location'])
df_data

Unnamed: 0,uniprot_id,subcellular_location,organism_group,fold_id,sequence,label
8560,Q8A0Z3,Cytoplasmic,negative,0,MAVTMADITKLRKMTGAGMMDCKNALTEAEGDYDKAMEIIRKKGQA...,2
8568,Q8A2N8,Cytoplasmic,negative,0,MIMSKETLIKSIREIPDFPIPGILFYDVTTLFKDPWCLQELSNIMF...,2
2593,P32709,CYtoplasmicMembrane,negative,0,MTQTSAFHFESLVWDWPIAIYLFLIGISAGLVTLAVLLRRFYPQAG...,0
61,E7FHF8,Cytoplasmic,archaea,0,MKLGVFELTDCGGCALNLLFLYDKLLDLLEFYEIAEFHMATSKKSR...,2
63,E7FHU4,Cytoplasmic,archaea,0,MGKVRIGFYALTSCYGCQLQLAMMDELLQLIPNAEIVCWFMIDRDS...,2
...,...,...,...,...,...,...
5239,Q97F85,Cytoplasmic,positive,4,MRKLFTSESVTEGHPDKICDQISDAILDAILEKDPNGRVACETTVT...,2
5226,P33656,Cytoplasmic,positive,4,MKNKTEVKNGGEKKNSKKVSKEESAKEKNEKMKIVKNLIDKGKKSG...,2
11894,P13949,OuterMembrane,negative,4,MCALDRRERPLNSQSVNKYILNVQNIYRNSPVPVCVRNKNRKILYA...,4
11890,P42185,Extracellular,negative,4,MRLRFSVPLFFFGCVFVHGVFAGPFPPPGMSLPEYWGEEHVWWDGR...,3


**Q8.** How were the data partitioned during training and evaluation in the paper? What does the code below do, and how does it compare to the approach taken in the paper?

- **Paper partitioning:** 5-fold **GraphPart** homology split (≤30% identity across folds) + **nested CV**.  
  Outer loop: 1 fold = **test**, remaining 4 folds used for **train/val** in **4 inner-loop** combinations ⇒ **20 models** total; test predictions per fold are **averaged over the 4 inner models**.

- **Your code:** a **single fixed split** of folds:  
  train = {0,1,2} (7336), eval = {3} (2579), test = {4} (1991); then prints class counts.

- **Difference:** your approach is **not nested CV** and does **not rotate** the test fold or average across inner models; it’s a **one-shot hold-out** split (still homology-separated).

In [6]:
train_id = {0, 1, 2}
eval_id = {3}
test_id = {4}

df_train = df_data.query('fold_id in @train_id')#.groupby('subcellular_location').sample(n=50, random_state=random_number)
df_eval = df_data.query('fold_id in @eval_id')
df_test = df_data.query('fold_id in @test_id')
print(len(df_train), 'records in training data:')
print(df_train['subcellular_location'].value_counts())
print()
print(len(df_eval), 'records in eval data:')
print(df_eval['subcellular_location'].value_counts())
print()
print(len(df_test), 'records in test data:')
print(df_test['subcellular_location'].value_counts())

7336 records in training data:
subcellular_location
Cytoplasmic            4345
CYtoplasmicMembrane    1534
Extracellular           645
OuterMembrane           433
Periplasmic             320
Cellwall                 59
Name: count, dtype: int64

2579 records in eval data:
subcellular_location
Cytoplasmic            1568
CYtoplasmicMembrane     476
Extracellular           224
OuterMembrane           159
Periplasmic             136
Cellwall                 16
Name: count, dtype: int64

1991 records in test data:
subcellular_location
Cytoplasmic            972
CYtoplasmicMembrane    525
Extracellular          208
OuterMembrane          164
Periplasmic            110
Cellwall                12
Name: count, dtype: int64


In [None]:
print(len(df_train), 'records in training data:')
print(df_train['subcellular_location'].value_counts())
print(df_train['organism_group'].value_counts())
print()
print(len(df_eval), 'records in eval data:')
print(df_eval['subcellular_location'].value_counts())
print(df_eval['organism_group'].value_counts())
print()
print(len(df_test), 'records in test data:')
print(df_test['subcellular_location'].value_counts())
print(df_test['organism_group'].value_counts())

In [7]:
# Prepare train/eval/test data sets for ESM2 model
model_checkpoint = 'facebook/esm2_t6_8M_UR50D'
tokenizer = transformers.AutoTokenizer.from_pretrained(model_checkpoint)

train_tokenized = tokenizer(df_train['sequence'].tolist(), truncation=True, max_length=1024)
eval_tokenized = tokenizer(df_eval['sequence'].tolist(), truncation=True, max_length=1024)
test_tokenized = tokenizer(df_test['sequence'].tolist(), truncation=True, max_length=1024)

train_dataset = datasets.Dataset.from_dict(train_tokenized).add_column('labels', df_train['label'].tolist())
eval_dataset = datasets.Dataset.from_dict(eval_tokenized).add_column('labels', df_eval['label'].tolist())
test_dataset = datasets.Dataset.from_dict(test_tokenized).add_column('labels', df_test['label'].tolist())

**Q9.** There's a warning about uninitialized weights when loading the ESM-2 model using `AutoModelForSequenceClassification` below. Describe the part of the network that has the uninitialized weights. How does the new part connect to the rest?

In [8]:
model = transformers.AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=df_data['label'].nunique())
model

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


EsmForSequenceClassification(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 320, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 320, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-5): 6 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=320, out_features=320, bias=True)
              (key): Linear(in_features=320, out_features=320, bias=True)
              (value): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((320,), eps=1e-05,

- **Uninitialized weights:** the **new classification head** (`classifier.dense.*` and `classifier.out_proj.*`) is **MISSING** from the pretrained ESM2 checkpoint, so it’s randomly initialized.
- **Connection:** pretrained **ESM encoder** outputs a **sequence embedding (dim=320)** → `classifier.dense (320→320)` → `classifier.out_proj (320→num_labels=6)` to produce class logits.

In [9]:
# Track accuracy and macro F1 score throughout the training
# https://huggingface.co/docs/transformers/en/training#evaluate
metric_accuracy = evaluate.load('accuracy')
metric_f1 = evaluate.load('f1')

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=1)
    return {
        'accuracy': metric_accuracy.compute(predictions=predictions, references=labels)['accuracy'],
        'f1_macro': metric_f1.compute(predictions=predictions, references=labels, average='macro')['f1'],
    }

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

**Q10.** The fine-tuning may fail by running out of GPU memory. Look up `per_device_train_batch_size` and `per_device_eval_batch_size` in the [TrainingArguments docs](https://huggingface.co/docs/transformers/v4.49.0/en/main_classes/trainer#transformers.TrainingArguments). How would you adjust these parameters to use less GPU memory?

- To use **less GPU memory**, **decrease** `per_device_train_batch_size` and `per_device_eval_batch_size` (e.g., `8 → 4 → 2 → 1`).
- If you still want a similar **effective batch size**, keep batches small and **increase** `gradient_accumulation_steps`.

In [10]:
import sys, accelerate, transformers, torch
print("python:", sys.executable)
print("accelerate:", accelerate.__version__)
print("transformers:", transformers.__version__)
print("torch:", torch.__version__)
print("mps available:", torch.backends.mps.is_available())

python: /home/course/bc_deep_learning_in_biology/.venv/bin/python
accelerate: 1.12.0
transformers: 4.48.3
torch: 2.10.0+cu128
mps available: False


In [11]:
# Set up fine-tuning
trainer_args = transformers.TrainingArguments(
    output_dir=f'{model_checkpoint}-subcellular_location',
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    # Adjust if needed
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
)

trainer = transformers.Trainer(
    model=model,                         # Named
    args=trainer_args,                   # Named (was 'trainer_args' in your snippet)
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    processing_class=tokenizer,          # Use 'processing_class' or 'tokenizer'
    compute_metrics=compute_metrics,
)

retrained = trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro
1,0.6067,0.741129,0.863125,0.61861
2,0.4546,0.736177,0.87088,0.63857
3,0.2267,0.627858,0.895308,0.760889


In [12]:
# Use fine-tuned model to predict on held-out test data
retrained_predict = trainer.predict(test_dataset=test_dataset)
retrained_predict.metrics

{'test_loss': 0.8143479824066162,
 'test_accuracy': 0.8699146157709694,
 'test_f1_macro': 0.7421056811121813,
 'test_runtime': 11.5419,
 'test_samples_per_second': 172.502,
 'test_steps_per_second': 172.502}

In [13]:
# Convert probabilities into discrete predictions by taking the max probability
test_labels = np.argmax(retrained_predict.predictions, axis=-1)
# Sanity-check by manualy calculating the accuracy
print(sum(test_labels == test_dataset['labels']) / len(test_dataset))
test_labels

0.8699146157709694


array([0, 2, 4, ..., 2, 5, 4])

**Q11.** Adjust the code below to re-produce the cross-validation results as in Table 2 of the paper. Fine-tune+test separately on every fold, and gather the results in `predicted_labels`.

In [16]:
fold_id = set(df_data.fold_id)
predicted_labels_all = [None] * len(df_data)

for test_id in sorted(fold_id):
    eval_id = (test_id + 1) % 5
    train_id = fold_id - {eval_id, test_id}

    df_train = df_data.query('fold_id in @train_id')
    df_eval  = df_data.query('fold_id == @eval_id')
    df_test  = df_data.query('fold_id == @test_id')
    print(train_id, eval_id, test_id, len(df_train), len(df_eval), len(df_test))

    # --- Tokenize this fold's splits ---
    train_tokenized = tokenizer(df_train['sequence'].tolist(), truncation=True, max_length=1024)
    eval_tokenized  = tokenizer(df_eval['sequence'].tolist(),  truncation=True, max_length=1024)
    test_tokenized  = tokenizer(df_test['sequence'].tolist(),  truncation=True, max_length=1024)

    train_dataset = datasets.Dataset.from_dict(train_tokenized).add_column('labels', df_train['label'].tolist())
    eval_dataset  = datasets.Dataset.from_dict(eval_tokenized).add_column('labels',  df_eval['label'].tolist())
    test_dataset  = datasets.Dataset.from_dict(test_tokenized).add_column('labels',  df_test['label'].tolist())

    # --- Fresh model for each fold ---
    model = transformers.AutoModelForSequenceClassification.from_pretrained(
        model_checkpoint, num_labels=df_data['label'].nunique()
    )

    # --- Minimal TrainingArguments matching your working setup ---
    trainer_args = transformers.TrainingArguments(
        output_dir=f'{model_checkpoint}-fold{test_id}',
        eval_strategy='epoch',
        save_strategy='epoch',
        load_best_model_at_end=True,
        metric_for_best_model='accuracy',
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
    )

    trainer = transformers.Trainer(
        model=model,
        args=trainer_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        processing_class=tokenizer,
        compute_metrics=compute_metrics,
    )

    trainer.train()

    # --- Predict and store in correct positions ---
    retrained_predict = trainer.predict(test_dataset=test_dataset)
    fold_preds = list(np.argmax(retrained_predict.predictions, axis=-1))

    for idx, pred in zip(df_test.index, fold_preds):
        predicted_labels_all[df_data.index.get_loc(idx)] = pred

predicted_labels = predicted_labels_all

{2, 3, 4} 1 0 6668 2684 2554


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro
1,0.6692,0.590921,0.887481,0.629164
2,0.5413,0.605567,0.891952,0.690194
3,0.2869,0.591977,0.899404,0.7232


{0, 3, 4} 2 1 7124 2098 2684


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro
1,0.573,0.79508,0.868446,0.627092
2,0.3973,0.706297,0.884175,0.727017
3,0.3868,0.662465,0.885129,0.743165


{0, 1, 4} 3 2 7229 2579 2098


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro
1,0.5666,0.990078,0.799147,0.589099
2,0.3878,0.699573,0.873982,0.642722
3,0.3051,0.745617,0.870105,0.698404


{0, 1, 2} 4 3 7336 1991 2579


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro
1,0.6764,0.855345,0.844299,0.630254
2,0.4707,0.869421,0.863887,0.708826
3,0.2773,0.872674,0.857358,0.722181


{1, 2, 3} 0 4 7361 2554 1991


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro
1,0.5764,0.6169,0.884886,0.6258
2,0.2589,0.819566,0.846515,0.660137
3,0.2878,0.65025,0.89585,0.714506


**Q12.** Discuss differences between the methodology and the resulting performance of the approach taken in the paper, and the reproduction.

In [17]:
# Show table with performance metrics split by organism to match Table 2
def calculate_stats_(df):
    accuracy = metric_accuracy.compute(predictions=df.predicted_labels.values, references=df.label.values)['accuracy']
    f1_macro = metric_f1.compute(predictions=df.predicted_labels.values, references=df.label.values, average='macro')['f1']
    return pd.Series({
        'size': '{:d}'.format(len(df)),
        'accuracy': '{:.2f}'.format(accuracy),
        'f1_macro': '{:.2f}'.format(f1_macro),
    })

df_data['predicted_labels'] = predicted_labels
pd.concat([
    calculate_stats_(df_data),
    calculate_stats_(df_data.query('organism_group == "archaea"')),
    calculate_stats_(df_data.query('organism_group == "positive"')),
    calculate_stats_(df_data.query('organism_group == "negative"')),
], axis=1).set_axis(['Overall', 'Archaea' , 'Gram pos', 'Gram neg'], axis=1)

Unnamed: 0,Overall,Archaea,Gram pos,Gram neg
size,11906.0,283.0,3206.0,8417.0
accuracy,0.89,0.85,0.91,0.88
f1_macro,0.71,0.46,0.51,0.69


## Q12 — Methodology & Performance: Paper vs. Reproduction

### Methodology Differences

| Aspect | Paper (DeepLocPro) | Reproduction |
|---|---|---|
| **ESM-2 backbone** | 650M parameters | 8M parameters |
| **Pooling** | Learned attention pooling (per-residue weights) | Simple CLS-token / mean pooling |
| **CV scheme** | Nested 5-fold: **20 models** trained (4 inner models × 5 outer folds), test predictions **averaged over 4 inner models per fold** | Simplified 5-fold: **5 models** (1 per outer fold, single fixed val fold = `(test_id + 1) % 5`), no ensemble averaging |
| **Training length** | Up to **60 epochs** + dynamic LR reduction (×0.1 after 5 stagnant epochs) | **3 epochs**, default LR schedule |
| **Batch size** | 8–64 (hyperparameter searched) | **1** (memory constraint) |
| **Hyperparameter tuning** | Grid search over LR, batch size, dropout per fold | None |
| **Regularization** | Dropout in classification head | None explicitly |

---

### Performance Differences

| Metric | Paper | Reproduction |
|---|---|---|
| **Overall Accuracy** | **0.92 ± 0.01** | ~0.87–0.90 |
| **Macro F1** | **0.80 ± 0.02** | ~0.70–0.76 |

Despite the gap, **qualitative patterns are preserved**: archaea remain the hardest group to classify, and cytoplasmic proteins remain the easiest — suggesting the fundamental biological signal is captured even by the smaller model.

---

### Why the Gap?

The performance difference is attributable to several compounding factors:

1. **Smaller backbone** — the 8M vs. 650M ESM-2 model encodes far less evolutionary and biochemical information
2. **No attention pooling** — the paper's residue-level attention mechanism can up-weight informative regions (e.g. signal peptides); the reproduction cannot
3. **Undertraining** — 3 epochs is likely insufficient relative to 60 epochs with adaptive LR decay
4. **No ensembling** — averaging over 4 inner-loop models substantially reduces variance in the paper; the reproduction uses a single model per fold