In [2]:
import torch
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from datasets import Dataset
import evaluate
from Bio import motifs
from Bio.Seq import Seq
import random
import requests
import sys
from pathlib import Path
import pandas as pd
import os
import pickle
import requests

src_path = Path("../src")
sys.path.append(str(src_path))

from ts_tf.motifs import fetch_all_motifs, save_to_csv, fetch_all_motif_metadata, save_metadata_to_csv
import ts_tf.protein as prot
from ts_tf.esm import ProteinDNADataset, CustomEsmForPWM

## RETRIEVE DNA MOTIFS

In [3]:
# Retrieve all high-quality motifs for vertebrates
tax_group = "vertebrates"

try:
    print(f"Fetching high-quality motifs for {tax_group}...")
    all_motifs = fetch_all_motifs(tax_group=tax_group)
    print(f"Retrieved {len(all_motifs)} motifs.")

    # Save motifs to a CSV file
    output_file = "../results/high_quality_motifs_with_pfm_pwm.csv"
    save_to_csv(all_motifs, output_file)
    print(f"Saved motifs to {output_file}")
except ValueError as e:
    print(e)

Fetching high-quality motifs for vertebrates...
Fetching motifs from: http://jaspar.genereg.net/api/v1/matrix/


KeyboardInterrupt: 

### alternative: retrieve cached

In [11]:
output_file = "../results/high_quality_motifs_with_pfm_pwm.csv"
motif_df = pd.read_csv(output_file)
motif_ids = list(motif_df["Motif ID"].unique())
print(f'motif_ids retrieved: {len(motif_ids)}')

motif_ids retrieved: 1912


## RETRIEVE UNIPROT ID

In [None]:
# Fetch metadata for all motifs
metadata_list = fetch_all_motif_metadata(motif_ids)

# Save metadata to CSV
save_metadata_to_csv(metadata_list, "../results/motif_metadata.csv")

### alternative: retrieve cached

In [None]:
metadata_df = pd.read_csv("../results/motif_metadata.csv")
metadata_df

Unnamed: 0,Matrix ID,Gene Name,UniProt IDs,Species,Taxonomy ID
0,MA0634.1,ALX3,O95076,Homo sapiens,9606
1,MA0634.2,ALX3,O95076,Homo sapiens,9606
2,MA0007.2,AR,P10275,Homo sapiens,9606
3,MA1463.1,ARGFX,A6NJG6,Homo sapiens,9606
4,MA1463.2,ARGFX,A6NJG6,Homo sapiens,9606
...,...,...,...,...,...
629,MA0046.1,HNF1A,P20823,Vertebrata,7742
630,MA0046.2,HNF1A,P20823,Homo sapiens,9606
631,MA0046.3,HNF1A,P20823,Homo sapiens,9606
632,MA0153.1,HNF1B,P35680,Homo sapiens,9606


## RETRIEVE AA SEQUENCE

In [None]:
metadata_df["AA Sequence"] = None

for i, row in metadata_df.iterrows():

    uniprot_id = row["UniProt IDs"]
    aa_seq = prot.fetch_uniprot_sequence(uniprot_id)
    metadata_df.loc[i, "AA Sequence"] = aa_seq

print(f"N retrieved successfully: {len(metadata_df[~metadata_df["AA Sequence"].isnull()])}")
metadata_df.to_csv("../results/motif_metadata_with_uniprot.csv", index=False)

### alternative: retrieve cached

In [17]:
metadata_df = pd.read_csv("../results/motif_metadata_with_uniprot.csv")
metadata_df

Unnamed: 0,Matrix ID,Gene Name,UniProt IDs,Species,Taxonomy ID,AA Sequence
0,MA0634.1,ALX3,O95076,Homo sapiens,9606,MDPEHCAPFRVGPAPGPYVASGDEPPGPQGTPAAAPHLHPAPPRGP...
1,MA0634.2,ALX3,O95076,Homo sapiens,9606,MDPEHCAPFRVGPAPGPYVASGDEPPGPQGTPAAAPHLHPAPPRGP...
2,MA0007.2,AR,P10275,Homo sapiens,9606,MEVQLGLGRVYPRPPSKTYRGAFQNLFQSVREVIQNPGPRHPEAAS...
3,MA1463.1,ARGFX,A6NJG6,Homo sapiens,9606,MRNRMAPENPQPDPFINRNYSNMKVIPPQDPASPSFTLLSKLECSG...
4,MA1463.2,ARGFX,A6NJG6,Homo sapiens,9606,MRNRMAPENPQPDPFINRNYSNMKVIPPQDPASPSFTLLSKLECSG...
...,...,...,...,...,...,...
629,MA0046.1,HNF1A,P20823,Vertebrata,7742,MVSKLSQLQTELLAALLESGLSKEALIQALGEPGPYLLAGEGPLDK...
630,MA0046.2,HNF1A,P20823,Homo sapiens,9606,MVSKLSQLQTELLAALLESGLSKEALIQALGEPGPYLLAGEGPLDK...
631,MA0046.3,HNF1A,P20823,Homo sapiens,9606,MVSKLSQLQTELLAALLESGLSKEALIQALGEPGPYLLAGEGPLDK...
632,MA0153.1,HNF1B,P35680,Homo sapiens,9606,MVSKLTSLQQELLSALLSSGVTKEVLVQALEELLPSPNFGVKLETL...


## CLEAN UP

In [18]:
# Merge motif and sequence dataframes
motif_sequence_df = pd.merge(motif_df, metadata_df, left_on="Motif ID", right_on="Matrix ID", how="inner")

# Identify columns to group by (all except Position and matrix columns)
group_columns = list(motif_sequence_df.columns.difference(['Position', 'A (PFM)', 'C (PFM)', 'G (PFM)', 'T (PFM)', 
                                            'A (PWM)', 'C (PWM)', 'G (PWM)', 'T (PWM)']))

# Group by relevant columns and process
def process_group(group):
    pwm = group[['A (PWM)', 'C (PWM)', 'G (PWM)', 'T (PWM)']].values.tolist()
    pfm = group[['A (PFM)', 'C (PFM)', 'G (PFM)', 'T (PFM)']].values.tolist()
    return pd.Series({'pwm': pwm, 'pfm': pfm})

# Apply the transformation
motif_sequence_df = motif_sequence_df.groupby(group_columns).apply(process_group).reset_index()

motif_sequence_df.to_csv("../results/motif_sequence_data.csv", index=False)

  motif_sequence_df = motif_sequence_df.groupby(group_columns).apply(process_group).reset_index()


In [1]:
import argparse
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import KFold
import numpy as np
import ast
import torch
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from datasets import Dataset
import evaluate
from Bio import motifs
from Bio.Seq import Seq
import random
import requests
import sys
from pathlib import Path
import pandas as pd
import os
import pickle
import requests

src_path = Path("../src")
sys.path.append(str(src_path))

from ts_tf.esm import ProteinDNADataset, CustomEsmForPWM

def fine_tune_esm(csv_file: str, epochs: int = 3, lr: float = 2e-5, k_folds: int = 5):
    data = pd.read_csv(csv_file)
    sequences = data['AA Sequence']
    pwms = data['pwm']

    example_dataset = ProteinDNADataset(sequences, pwms)
    model = CustomEsmForPWM(output_shape=(example_dataset.max_rows, example_dataset.max_cols))
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    kfold = KFold(n_splits=k_folds, shuffle=True)
    for fold, (train_ids, val_ids) in enumerate(kfold.split(sequences)): # take out test ids before
        print(f'Fold {fold + 1}/{k_folds}')
        train_sequences = sequences.iloc[train_ids]
        train_pwms = pwms.iloc[train_ids]
        val_sequences = sequences.iloc[val_ids]
        val_pwms = pwms.iloc[val_ids]

        train_dataset = ProteinDNADataset(train_sequences, train_pwms)
        val_dataset = ProteinDNADataset(val_sequences, val_pwms)
        train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True) # increase batch size for more stable gradient descent
        val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False)

        model.train()
        for epoch in range(epochs):
            epoch_loss = 0
            for batch in train_dataloader:
                sequences, pwms = batch
                sequences = list(sequences)
                pwms = torch.tensor(pwms, dtype=torch.float).to(next(model.parameters()).device)

                optimizer.zero_grad()
                outputs = model(sequences) # predicted pwms
                print(outputs.shape, pwms.shape)
                loss = torch.nn.functional.mse_loss(outputs, pwms) ###
                print(loss)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()

            print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss / len(train_dataloader)}")
            
            model.eval()
            val_loss = 0


fine_tune_esm("../results/motif_sequence_data.csv", epochs=3, lr=2e-5, k_folds=5)

Using cache found in /home/anton.thieme/.cache/torch/hub/facebookresearch_esm_main


Output shape: (35, 4)
<esm.data.Alphabet object at 0x7f75a4676bd0>
<class 'esm.data.Alphabet'>
Fold 1/5


  pwms = torch.tensor(pwms, dtype=torch.float).to(next(model.parameters()).device)


torch.Size([8, 35, 4]) torch.Size([8, 35, 4])
tensor(41.7219, grad_fn=<MseLossBackward0>)
torch.Size([8, 35, 4]) torch.Size([8, 35, 4])
tensor(44.7803, grad_fn=<MseLossBackward0>)
torch.Size([8, 35, 4]) torch.Size([8, 35, 4])
tensor(35.0697, grad_fn=<MseLossBackward0>)
torch.Size([8, 35, 4]) torch.Size([8, 35, 4])
tensor(39.8811, grad_fn=<MseLossBackward0>)
torch.Size([8, 35, 4]) torch.Size([8, 35, 4])
tensor(51.9166, grad_fn=<MseLossBackward0>)
torch.Size([8, 35, 4]) torch.Size([8, 35, 4])
tensor(49.9725, grad_fn=<MseLossBackward0>)
torch.Size([8, 35, 4]) torch.Size([8, 35, 4])
tensor(40.9698, grad_fn=<MseLossBackward0>)


KeyboardInterrupt: 

In [1]:
import torch
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from datasets import Dataset
import evaluate
from Bio import motifs
from Bio.Seq import Seq
import random
import requests
import sys
from pathlib import Path
import pandas as pd
import os
import pickle
import requests

src_path = Path("../scripts")
sys.path.append(str(src_path))

import esm_finetune

In [3]:
torch.version.cuda

'12.4'

In [5]:
esm_finetune.run_fine_tune_esm(csv_file="../results/motif_sequence_data.csv")

2025-01-11 17:09:10,543 - INFO - Starting fine-tuning process.
Using cache found in /home/anton.thieme/.cache/torch/hub/facebookresearch_esm_main
2025-01-11 17:09:11,676 - INFO - Starting Fold 1/5


Output shape: (35, 4)
<esm.data.Alphabet object at 0x7fca19167110>
<class 'esm.data.Alphabet'>


  pwms = torch.tensor(pwms, dtype=torch.float).to(next(model.parameters()).device)


torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size([8, 35, 4])
torch.Size(

2025-01-11 17:21:43,713 - INFO - Epoch 1/3, Loss: 47.8605
  pwms = torch.tensor(pwms, dtype=torch.float).to(next(model.parameters()).device)
  val_loss += torch.nn.functional.mse_loss(outputs, pwms).item()
2025-01-11 17:21:47,701 - ERROR - Error during validation batch: The size of tensor a (35) must match the size of tensor b (22) at non-singleton dimension 1
2025-01-11 17:21:47,707 - ERROR - An unexpected error occurred: The size of tensor a (35) must match the size of tensor b (22) at non-singleton dimension 1
2025-01-11 17:21:47,709 - CRITICAL - Fine-tuning process failed: The size of tensor a (35) must match the size of tensor b (22) at non-singleton dimension 1
