In [1]:
import torch
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from datasets import Dataset
import evaluate
from Bio import motifs
from Bio.Seq import Seq
import random
import requests
import sys
from pathlib import Path
import pandas as pd
import os
import pickle
import requests

src_path = Path("../src")
sys.path.append(str(src_path))

from ts_tf.motifs import fetch_all_motifs, save_to_csv, fetch_all_motif_metadata, save_metadata_to_csv
import ts_tf.protein as prot
import ts_tf.custom_esm as cesm

## RETRIEVE DNA MOTIFS

In [None]:
# Retrieve all high-quality motifs for vertebrates
tax_group = "vertebrates"

try:
    print(f"Fetching high-quality motifs for {tax_group}...")
    all_motifs = fetch_all_motifs(tax_group=tax_group)
    print(f"Retrieved {len(all_motifs)} motifs.")

    # Save motifs to a CSV file
    output_file = "../results/high_quality_motifs_with_pfm_pwm.csv"
    save_to_csv(all_motifs, output_file)
    print(f"Saved motifs to {output_file}")
except ValueError as e:
    print(e)

### alternative: retrieve cached

In [4]:
output_file = "../results/motif/high_quality_motifs_with_pfm_pwm.csv"
motif_df = pd.read_csv(output_file)
motif_ids = list(motif_df["Motif ID"].unique())
print(f'motif_ids retrieved: {len(motif_ids)}')

motif_ids retrieved: 1912


## RETRIEVE UNIPROT ID

In [None]:
# Fetch metadata for all motifs
metadata_list = fetch_all_motif_metadata(motif_ids)

# Save metadata to CSV
save_metadata_to_csv(metadata_list, "../results/motif_metadata_2.csv")

In [13]:
import pickle

with open('outfile', 'wb') as fp:
    pickle.dump(metadata_list, fp)

In [4]:
with open ('outfile', 'rb') as fp:
    metadata_list = pickle.load(fp)

In [None]:
save_metadata_to_csv(metadata_list, "../results/motif/motif_metadata_2.csv")

In [None]:
metadata_list[634]

### alternative: retrieve cached

In [6]:
metadata_df = pd.read_csv("../results/motif/motif_metadata.csv")
metadata_df

Unnamed: 0,Matrix ID,Gene Name,UniProt ID,Species,Taxonomy ID
0,MA0634.1,ALX3,O95076,Homo sapiens,9606
1,MA0634.2,ALX3,O95076,Homo sapiens,9606
2,MA0007.2,AR,P10275,Homo sapiens,9606
3,MA1463.1,ARGFX,A6NJG6,Homo sapiens,9606
4,MA1463.2,ARGFX,A6NJG6,Homo sapiens,9606
...,...,...,...,...,...
1907,MA1630.1,Znf281,Q99LI5,Mus musculus,10090
1908,MA0116.1,Znf423,O08961,Rattus norvegicus,10116
1909,MA0621.1,mix-a,P21711,Xenopus laevis,8355
1910,MA0621.2,mix-a,P21711,Xenopus laevis,8355


## RETRIEVE AA SEQUENCE

In [8]:
metadata_df["AA Sequence"] = None

for i, row in metadata_df.iterrows():

    uniprot_id = row["UniProt ID"]
    aa_seq = prot.fetch_uniprot_sequence(uniprot_id)
    metadata_df.loc[i, "AA Sequence"] = aa_seq

print(f"N retrieved successfully: {len(metadata_df[~metadata_df["AA Sequence"].isnull()])}")
metadata_df.to_csv("../results/motif/motif_metadata_with_uniprot.csv", index=False)

Error fetching sequence for UniProt ID nan: 400
Error fetching sequence for UniProt ID nan: 400
Error fetching sequence for UniProt ID nan: 400
N retrieved successfully: 1909


### alternative: retrieve cached

In [3]:
metadata_df = pd.read_csv("../results/motif/motif_metadata_with_uniprot.csv")
metadata_df

Unnamed: 0,Matrix ID,Gene Name,UniProt ID,Species,Taxonomy ID,AA Sequence
0,MA0634.1,ALX3,O95076,Homo sapiens,9606,MDPEHCAPFRVGPAPGPYVASGDEPPGPQGTPAAAPHLHPAPPRGP...
1,MA0634.2,ALX3,O95076,Homo sapiens,9606,MDPEHCAPFRVGPAPGPYVASGDEPPGPQGTPAAAPHLHPAPPRGP...
2,MA0007.2,AR,P10275,Homo sapiens,9606,MEVQLGLGRVYPRPPSKTYRGAFQNLFQSVREVIQNPGPRHPEAAS...
3,MA1463.1,ARGFX,A6NJG6,Homo sapiens,9606,MRNRMAPENPQPDPFINRNYSNMKVIPPQDPASPSFTLLSKLECSG...
4,MA1463.2,ARGFX,A6NJG6,Homo sapiens,9606,MRNRMAPENPQPDPFINRNYSNMKVIPPQDPASPSFTLLSKLECSG...
...,...,...,...,...,...,...
1907,MA1630.1,Znf281,Q99LI5,Mus musculus,10090,MKIGSGFLSGGGGPSSSGGSGSGGSSGSASGGSGGGRRAEMEPTFP...
1908,MA0116.1,Znf423,O08961,Rattus norvegicus,10116,MSRRKQAKPRSVKVEEGEASDFSLAWDSSVAAAGGLEGESECDRKS...
1909,MA0621.1,mix-a,P21711,Xenopus laevis,8355,MDGFSQQLEDLYPSCFSPCPSPLGFSEPVIQPFAMNLAPAAQKDFQ...
1910,MA0621.2,mix-a,P21711,Xenopus laevis,8355,MDGFSQQLEDLYPSCFSPCPSPLGFSEPVIQPFAMNLAPAAQKDFQ...


## CLEAN UP

In [6]:
# Merge motif and sequence dataframes
motif_sequence_df = pd.merge(motif_df, metadata_df, left_on="Motif ID", right_on="Matrix ID", how="inner")

# Identify columns to group by (all except Position and matrix columns)
group_columns = list(motif_sequence_df.columns.difference(['Position', 'A (PFM)', 'C (PFM)', 'G (PFM)', 'T (PFM)', 
                                            'A (PWM)', 'C (PWM)', 'G (PWM)', 'T (PWM)']))

# Group by relevant columns and process
def process_group(group):
    pwm = group[['A (PWM)', 'C (PWM)', 'G (PWM)', 'T (PWM)']].values.tolist()
    pfm = group[['A (PFM)', 'C (PFM)', 'G (PFM)', 'T (PFM)']].values.tolist()
    return pd.Series({'pwm': pwm, 'pfm': pfm})

# Apply the transformation
motif_sequence_df = motif_sequence_df.groupby(group_columns).apply(process_group).reset_index()
print(len(motif_sequence_df[motif_sequence_df["pwm"].isnull()]))
print(len(motif_sequence_df[motif_sequence_df["AA Sequence"].isnull()]))
print(len(motif_sequence_df))
motif_sequence_df = motif_sequence_df.dropna()
print(len(motif_sequence_df))

motif_sequence_df.to_csv("../results/motif/motif_sequence_data.csv", index=False)

0
0
1907
1907


  motif_sequence_df = motif_sequence_df.groupby(group_columns).apply(process_group).reset_index()


## SPLIT DATA

In [7]:
data = pd.read_csv("../results/motif/motif_sequence_data.csv")
train_data = data.sample(frac=0.8, random_state=42)
test_data = data.drop(train_data.index)
train_data.to_csv("../data/esm/train_data.csv", index=False)
test_data.to_csv("../data/esm/test_data.csv", index=False)

## FINE-TUNE 

Fine tune esm using the script esm_finetune.py