In [1]:
PROTGPS_PARENT_DIR = "/home/shd-sun-lab/protgps"

In [2]:
import sys
import os
sys.path.append(PROTGPS_PARENT_DIR) # append the path of protgps
from argparse import Namespace
import pickle
import copy
import yaml
import requests
from tqdm import tqdm
from p_tqdm import p_map
import numpy as np
import pandas as pd
from collections import defaultdict
import torch 
from captum.attr import LayerIntegratedGradients
from captum.attr import visualization as viz
import protpy
from protpy import amino_acids as protpyAA
from sklearn.metrics import roc_auc_score
from matplotlib import pyplot as plt
from protgps.utils.loading import get_object

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Functions

In [4]:
# WHAT'S IN THE LOCAL CODE
COMPARTMENTS = [
    'transcriptional',
    'chromosome',
    'nuclear_pore_complex',
    'nuclear_speckle', 
    'p-body', 
    'pml-bdoy', 
    'post_synaptic_density',
    'stress_granule',
    'nucleolus',
    'cajal_body',
    'rna_granule',
    'cell_junction'
]

# WHAT'S IN THE COMMITTED CODE
# ORDER WAS CHANGED AT SOME POINT!!! 
# USE THE ONE ON GITHUB
# https://github.com/pgmikhael/nox/blob/1e5b963cbfdad23418a98c7c67a11c6431869cf6/nox/datasets/protein_compartments.py
OLDCOMPS = [
    "nuclear_speckle",
    "p-body",
    "pml-bdoy",
    "post_synaptic_density",
    "stress_granule",
    "chromosome",
    "nucleolus",
    "nuclear_pore_complex",
    "cajal_body",
    "rna_granule",
    "cell_junction",
    "transcriptional"
]

def transform_y(y: torch.Tensor):
    # get indices where y is one in, where y is a pytorch tensor
    indices = torch.nonzero(y)
    # convert indices from list1 to equivalent classes in list2
    new_indices = torch.tensor([OLDCOMPS.index(COMPARTMENTS[i]) for i in indices])
    # return binary tensor based on new indices
    return torch.zeros(len(OLDCOMPS)).scatter_(0, new_indices, 1)

In [5]:
UNIPROT_ENTRY_URL = "https://rest.uniprot.org/uniprotkb/O14983.fasta"

def get_organism(uni):
    response= requests.get(f"https://rest.uniprot.org/uniprotkb/O14983.json").json()
    if 'organism' in response:
        return response['organism']['scientificName']
    else:
        return ""
        
def parse_fasta(f):
    """Parse fasta data

    Args:
        f (str): fasta data

    Returns:
        str: protein sequence
    """
    _seq = ""
    for _line in f.split("\n"):
        if _line.startswith(">"):
            continue
        _seq += _line.strip()
    return _seq


def get_protein_fasta(uniprot):
    """Get protein info from uniprot

    Args:
        uniprot (str): uniprot
    """
    fasta = requests.get(UNIPROT_ENTRY_URL.format(uniprot))
    if fasta.status_code == 200:  # Success
        sequence = parse_fasta(fasta.text)
        return sequence
    return

In [6]:
#Test to see if it gets a sequence
protein_id = "O14983"  # Example UniProt ID
sequence = get_protein_fasta(protein_id)
print(f"Protein sequence for {protein_id}:\n{sequence}")




In [7]:
def load_model(snargs):
    """
    Loads classifier model from args file
    """
    modelpath = snargs.model_path
    model = get_object(snargs.lightning_name, "lightning")(snargs)
    model = model.load_from_checkpoint(
        checkpoint_path = modelpath,
        strict=not snargs.relax_checkpoint_matching,
        **{"args": snargs},
    )
    return model, snargs

In [8]:
def predict_condensates(model, sequences, batch_size, round=True):
    scores = []
    for i in tqdm(range(0, len(sequences), batch_size), ncols=100):
        batch = sequences[ i : (i + batch_size)]
        with torch.no_grad():
            out = model.model({"x": batch})    
        s = torch.sigmoid(out['logit']).to("cpu")
        scores.append(s)
    scores = torch.vstack(scores)
    if round:
        scores = torch.round(scores, decimals=3)
    return scores

In [9]:
def get_valid_rows(df, cols):
    rows_with_valid_seq_len = []
    for i in range(len(df)):
        if all([ len(df.iloc[i][c]) < 1800 for c in cols]):
            rows_with_valid_seq_len.append(i)
    return rows_with_valid_seq_len

# Predictions on Additional Data

In [21]:
args = Namespace(**pickle.load(open(os.path.join(PROTGPS_PARENT_DIR, 'checkpoints/protgps/32bf44b16a4e770a674896b81dfb3729.args'),'rb')))
args.pretrained_hub_dir = "/home/shd-sun-lab/protgps/checkpoints/esm2"
args.model_path = "/home/shd-sun-lab/protgps/checkpoints/protgps/32bf44b16a4e770a674896b81dfb3729epoch=26.ckpt"  # Ensure this is set


model = load_model(args)
model = model[0]
model.eval()
model = model.to(device)

Using cache found in /home/shd-sun-lab/protgps/checkpoints/esm2/facebookresearch_esm_main
Using cache found in /home/shd-sun-lab/protgps/checkpoints/esm2/facebookresearch_esm_main


Using ESM hidden layers 6
Using ESM hidden layers 6


### Condensate

In [22]:
# Load test dataset
args.dataset_file_path = "/home/shd-sun-lab/protgps/data/dataset.json"
train_dataset = get_object(args.dataset_name, "dataset")(args, "train")
dev_dataset = get_object(args.dataset_name, "dataset")(args, "dev")
test_dataset = get_object(args.dataset_name, "dataset")(args, "test")
train_sequences = set(d['x'] for d in train_dataset.dataset+dev_dataset.dataset)

100%|██████████| 5480/5480 [00:00<00:00, 61739.92it/s]


TRAIN DATASET CREATED FOR PROTEIN_CONDENSATES_COMBINED.
Could not produce summary statement


100%|██████████| 5480/5480 [00:00<00:00, 117443.08it/s]


DEV DATASET CREATED FOR PROTEIN_CONDENSATES_COMBINED.
Could not produce summary statement


100%|██████████| 5480/5480 [00:00<00:00, 108922.83it/s]

TEST DATASET CREATED FOR PROTEIN_CONDENSATES_COMBINED.
Could not produce summary statement





In [23]:
data = pd.read_excel("/home/shd-sun-lab/protgps/notebook/data/Condensate_data_idmapping_2023_11_04.xlsx")


In [24]:
protein_ids = set()
for rowid, row in data.iterrows():
    if isinstance(row['Cluster members'],str):
        entries = row['Cluster members'].split(";")
        entries = [e.split(',')[0].strip() for e in entries]
        protein_ids.update(entries)
        cluster = row['From'].split("_")[0]
        protein_ids.add(cluster)
    elif np.isnan(row['Cluster members']):
        continue 

In [25]:
sequences = p_map(get_protein_fasta, list(protein_ids))
protein2sequence = {p:s for p,s in zip(protein_ids,sequences)}
pickle.dump(protein2sequence, open("Condensate_data_idmapping_sequences.p", "wb"))

  0%|          | 0/788355 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
protein_ids = [ p for p, s in protein2sequence.items() if s is not None ]
sequences = [ protein2sequence[p] for p in protein_ids ]

In [None]:
scores = predict_condensates(model, sequences, batch_size=1)

In [None]:
scores = torch.vstack(scores)
scores_round = torch.round(scores, decimals=3)

In [None]:
scores_round = torch.load("scores_round.pt")
scores_round, protein_ids_scores = scores_round["scores"], scores_round["protein_ids"]

In [None]:
protein_to_scores = {p:s for p,s in zip(protein_ids_scores,scores_round)}

In [None]:
organisms = p_map(get_organism, protein_ids_scores)
protein_ids_scores_to_organisms = {p:o for p,o in zip(protein_ids_scores, organisms)}
pickle.dump(protein_ids_scores_to_organisms, open("protein_ids_scores_to_organisms.p", "wb"))

In [None]:
results_df = defaultdict(list)
with tqdm(total=len(data), ncols=100) as tqdm_bar:
    for rowid, row in data.iterrows():
        if isinstance(row['Cluster members'],str):
            entries = row['Cluster members'].split(";")
            entries = [e.split(',')[0].strip() for e in entries]
            for entry in entries:
                if entry in protein_to_scores:
                    sequence = protein2sequence[entry]
                    results_df["ProteinID"].append(entry)
                    results_df["Protein_Split"].append("train" if sequence in train_sequences else "test")
                    results_df["Organism"].append(protein_ids_scores_to_organisms[entry])
                    results_df["original_row"].append(rowid)
                    results_df["gene_names"].append(row["gene_names"])
                    results_df["split"].append(row["split"]) 
                    results_df["labels"].append(row["labels"])
                    results_df["From"].append(row["From"])
                    results_df["Cluster ID"].append(row["Cluster ID"])
                    results_df["Cluster Name"].append(row["Cluster Name"])
                    results_df["Organism IDs"].append(row["Organism IDs"])
                    results_df["Sequence"].append(sequence)
                    score = protein_to_scores[entry]
                    for j,condensate in enumerate(OLDCOMPS):
                        results_df[f"{condensate.upper()}_Score"].append(score[j].item())
        tqdm_bar.update()
               



In [None]:
results_df = pd.DataFrame(results_df)

In [None]:
results_df.to_csv("Condensate_data_idmapping_2023_11_04_preds.csv", index=False)

### Substitutions_set_230130

In [None]:
data = pd.read_excel("substitutions_set_230130.xlsx") 

In [None]:
data.head()

In [None]:
rows_with_valid_seq_len = get_valid_rows(data, ['WT_Sequence', 'Substitution_seq'])

In [None]:
len(data), len(rows_with_valid_seq_len)

In [None]:
data = data.loc[rows_with_valid_seq_len]

In [None]:
sequences = list(data['WT_Sequence'])
scores = predict_condensates(model, sequences, batch_size=10)



In [None]:
for j,condensate in enumerate(OLDCOMPS):
    data[f"WT_Sequence_{condensate.upper()}_Score"] = scores[:, j].tolist()

In [None]:
sequences = list(data['Substitution_seq'])
scores = predict_condensates(model, sequences, batch_size=10)



In [None]:
for j,condensate in enumerate(OLDCOMPS):
    data[f"Substitution_seq_{condensate.upper()}_Score"] = scores[:, j].tolist()

In [None]:
data.to_csv('substitutions_set_230130_preds.csv', index=False)

### termination_set_230129

In [None]:
data = pd.read_excel("termination_set_230129.xlsx")

In [None]:
data.head()

In [None]:
rows_with_valid_seq_len = get_valid_rows(data, ['WT_Sequence', 'Termination_sequence'])

In [None]:
data = data.loc[rows_with_valid_seq_len]

In [None]:
len(data), len(rows_with_valid_seq_len)

In [None]:
sequences = list(data['WT_Sequence'])
scores = predict_condensates(model, sequences, batch_size=5)



In [None]:
for j,condensate in enumerate(OLDCOMPS):
    data[f"WT_Sequence_{condensate.upper()}_Score"] = scores[:, j].tolist()

In [None]:
sequences = list(data['Termination_sequence'])
scores = predict_condensates(model, sequences, batch_size=5)



In [None]:
for j,condensate in enumerate(OLDCOMPS):
    data[f"Termination_sequence_{condensate.upper()}_Score"] = scores[:, j].tolist()

In [None]:
data.to_csv('termination_set_230129_preds.csv', index=False)

### disease_mutations_reference_set

In [None]:
data = pd.read_excel("disease_mutations_reference_set.xlsx")

In [None]:
data.head()

In [None]:
rows_with_valid_seq_len = get_valid_rows(data, ['Sequence'])

In [None]:
len(data), len(rows_with_valid_seq_len)

In [None]:
data = data.loc[rows_with_valid_seq_len]

In [None]:
sequences = [s.upper() for s in list(data['Sequence'])]
scores = predict_condensates(model, sequences, batch_size=5)



In [None]:
for j,condensate in enumerate(OLDCOMPS):
    data[f"{condensate.upper()}_Score"] = scores[:, j].tolist()

In [None]:
data.to_csv('disease_mutations_reference_set_preds.csv', index=False)

# AUCs

In [7]:
# args
args = Namespace(**pickle.load(open(os.path.join(PROTGPS_PARENT_DIR, 'checkpoints/protgps/32bf44b16a4e770a674896b81dfb3729.args'),'rb')))
args.dataset_file_path = os.path.join(PROTGPS_PARENT_DIR, "data/new_condensate_dataset_m3_c5_mmseqs.json")
# Load test dataset
test_dataset = get_object(args.dataset_name, "dataset")(args, "test")





In [10]:
ys = [d['y'] for d in test_dataset.dataset]

In [None]:
model = load_model(args)
model = model[0]
model.eval()
model = model.to(device)
print()

In [67]:
test_x = [s['x'] for s in test_dataset.dataset]
test_y = [transform_y(s['y']) for s in test_dataset.dataset]
test_id = [s['entry_id'] for s in test_dataset.dataset]

In [70]:
test_preds = predict_condensates(model, test_x, 10, round=False)



In [74]:
test_y = torch.vstack(test_y)

In [78]:
for j,condensate in enumerate(OLDCOMPS):
    auc = roc_auc_score(test_y[:,j], test_preds[:,j])
    print(f"{condensate}:\t{round(auc,3)}")



# Analysis

In [11]:
experimental_sequences = [
    {
        "condensate": "nucleolus",
        "name": "mc2_nuc1",
        "seed": 6,
        "sequence": "FMLVSTLWWKQKRLNNAVRTHTKFLTTINNPWRDFCSHRKKYCQKRKHEHATLKSWGTNNGSRRAAGICSGYGPEHSPDANTVKHCCIDYDSIDPIRCTR"
    },
    {
        "condensate": "nucleolus",
        "name": "mc2_nuc2",
        "seed": 1,
        "sequence": "HFMRIADRKVMHHGCAKQGNSWNHIGQKPCCSKVKKGEQSQKADAVVWGVKCHMKWEARSQCNQSFEKMQLHCPMSCRVQESSHNQHNIQPKANHQAMIH"
    },
    {
        "condensate": "nucleolus",
        "name": "mc2_nuc7",
        "seed": 7,
        "sequence": "HGQNRRRKNIGTLKMHTIRGFFPMFSEIRNNHTFTIHGSKSFNSDFQDQNLHCHDRMMHLQISDSMNNTGEEWMTEKVNSLPRKGKSGGPPYKPKVWSVQ"
    },
    {
        "condensate": "nuclear_speckle",
        "name": "mc2_spk2",
        "seed": 8,
        "sequence": "VNDITDVEMAVGRVPREGGNATERCYACFHHLDDYDLHQQMHGRDAPHMRNNSYKKAAHSEHINEVDHQGLQSDVEEYEGVMNEDTFKYMADERDCSPRN"
    },
    {
        "condensate": "nuclear_speckle",
        "name": "mc2_spk3",
        "seed": 7,
        "sequence": "TKIKKHRSTPNMIQSPVTYPDEDHTNNHAGWKTTKAAAPKFRCAARQINRTAMMRCENFAITIDDMPSQDWPHKDDHGAGDDKKDCMPARYDGHTEETND"
    },
    
]

In [12]:
mcherry = 'LVQLVHAAGGVAALGAFVLFHDGVVLVVGKDVQLDVDVVGAGQLHGLLGLVGGLDLSVVVAAVLQLQPLLDLALQGAVLGVHPLSGGLPAHGTTLHYGAVGGEVGAAQLHLVDELAVLQGGVLGHGHHAAVLEVHHALPIEALGEGQLQVVGDVGGVLHVGLGAVHELRGQDVPGEGQGATLGHLQLGGLGALVGAALALALDLELVAVHGALHVHLEAHELLDDGHVILLALAH'

## Integrated Gradients | Attributions

In [49]:
def visualize_text(datarecords, legend: bool = True) -> "HTML":  # In quotes because this type doesn't exist in standalone mode
    dom = ["<table width: 100%>"]
    rows = [
        #"<tr><th>True Label</th>"
        #"<th>Attribution Label</th>"
        #"<th>Attribution Score</th>"
        "<th>Amino Acid Importance</th>" #"<th>Word Importance</th>"
        "<th>Sample ID</th>"
        "<th>Target (score)</th>"#"<th>Predicted Label</th>"
    ]
    for datarecord in datarecords:
        rows.append(
            "".join(
                [
                    "<tr>",
                    # format_classname(datarecord.true_class),
                    #format_classname(datarecord.attr_class),
                    # format_classname("{0:.2f}".format(datarecord.attr_score)),
                    viz.format_word_importances(
                        datarecord.raw_input_ids, datarecord.word_attributions
                    ),
                    viz.format_classname(datarecord.true_class.split('_')[1]),
                    viz.format_classname(
                        "{0} ({1:.2f})".format(
                            datarecord.pred_class, datarecord.pred_prob
                        )
                    ),
                    
                    "<tr>",
                ]
            )
        )

    if legend:
        dom.append(
            '<div style="border-top: 1px solid; margin-top: 5px; \
            padding-top: 5px; display: inline-block">'
        )
        dom.append("<b>Legend: </b>")

        for value, label in zip([-1, 0, 1], ["Negative", "Neutral", "Positive"]):
            dom.append(
                '<span style="display: inline-block; width: 10px; height: 10px; \
                border: 1px solid; background-color: \
                {value}"></span> {label}  '.format(
                    value=viz._get_color(value), label=label
                )
            )
        dom.append("</div>")

    dom.append("".join(rows))
    dom.append("</table>")
    html = viz.HTML("".join(dom))
    viz.display(html)

    return html

In [50]:
def forward(batch_tokens):
    model.zero_grad()
    result = model.model.encoder.model(batch_tokens, repr_layers = [model.model.encoder.repr_layer], return_contacts=False)
    hidden = result["representations"][model.model.encoder.repr_layer].mean(axis=1)
    scores = torch.sigmoid(model.model.mlp({'x': hidden})["logit"])
    return scores

In [51]:
model = model.to('cpu')

In [52]:
alphabet = model.model.encoder.alphabet

In [53]:
lig = LayerIntegratedGradients(forward, model.model.encoder.model.embed_tokens)

In [54]:
records = []
sequence_dict_copy = copy.deepcopy(experimental_sequences)
for sequence_dict in sequence_dict_copy:
    seq = sequence_dict['sequence']
    input_seq = seq +  mcherry

    # baseline
    baseline = torch.tensor([alphabet.cls_idx] + [alphabet.mask_idx] * len(input_seq) + [alphabet.eos_idx]).unsqueeze(0)

    # inputs 
    fair_x = [(0, input_seq)] 
    _, _, batch_tokens = model.model.encoder.batch_converter(fair_x)

    # get prediction
    with torch.no_grad():
        model.eval()
        out = model.model({'x': [input_seq] })
    probs = torch.sigmoid(out['logit']).detach().cpu()
    pred_class = probs.argmax().item()
    pred_class_name = OLDCOMPS[ pred_class ]

    assert pred_class_name == sequence_dict["condensate"]

    # get attribution
    attributions, delta = lig.attribute(
            inputs=batch_tokens,
            baselines=baseline,
            return_convergence_delta=True,
            target = pred_class,
            n_steps=50,
        )
    A = attributions.sum(-1)[0, 1:-1]
    A = A / torch.norm(A)
    sequence_dict["attributions"] = A.tolist()

    # visualize
    record = viz.VisualizationDataRecord(
            word_attributions = A * 10,
            pred_prob = probs.max().item(),
            pred_class = pred_class_name,
            true_class = sequence_dict["name"],
            attr_class = "-",
            attr_score = attributions[0, 1:-1].sum(),
            raw_input_ids= input_seq,
            convergence_score = delta,
        )
    records.append(record)

In [55]:
html = visualize_text(records)

In [56]:
with open('html_file.html', 'w') as f:
    f.write(html.data)

In [57]:
for sequence_dict in sequence_dict_copy:
    seq = sequence_dict['sequence']
    sequence_dict['full_sequence'] = seq +  mcherry

In [58]:
pd.DataFrame(sequence_dict_copy).to_csv('attributions.csv', index=False)

## Trajectories

In [27]:
paths = []
for root,_,files in os.walk(os.path.join(esm_directory, "trajectories")):
    paths.extend([ os.path.join(root, f) for f in files if f.endswith('.txt')])


idr2scores = defaultdict(list)
preds = []
for p in paths:
    config = os.path.join(os.path.dirname(p), ".hydra/config.yaml")
    with open(config, 'r') as file:
        config = yaml.safe_load(file)
    condensate = p.split('/')[-1].split('.')[0]
    with open(p, 'r') as f:
        preds = f.readlines()
    preds = [p.strip('\n') for p in preds]
    idr2scores[f"{condensate}_{config['seed']}"].extend([p.split('\t') for p in preds])

In [28]:
idr2scores.keys()

In [30]:
trajectories = defaultdict(list)

In [47]:
for generated, traj in idr2scores.items():
    
    target, seed = generated[:-2], generated[-1]
    trajectory = [(int(i.split(':')[0]), s, float(p)) for i,s,p in traj ]
    trajectory = sorted(trajectory, key = lambda x: x[0])
    
    steps = [s[0] for s in trajectory]
    seqs = [s[1] for s in trajectory]
    scores = [s[2] for s in trajectory]
    
    trajectories["Target Compartment"].extend( [target] * len(steps) )
    trajectories["Seed"].extend([seed] * len(steps))
    trajectories["Step"].extend(steps)
    trajectories["IDR Sequence"].extend(seqs)
    trajectories["Localization Score"].extend(scores)

In [48]:
pd.DataFrame(trajectories).to_csv('trajectories.csv', index=False)

In [271]:
plt.plot( [i[0] for i in trajectory], [i[-1] for i in trajectory])
plt.show()

# Revisions

In [6]:
# model
from sklearn.ensemble import RandomForestClassifier 
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.multioutput import MultiOutputClassifier, ClassifierChain
from sklearn.metrics import roc_auc_score

In [7]:
def make_features(sequence):
    features = [protpy.amino_acid_composition(sequence)]
    for key in ["hydrophobicity", "polarity", "charge", "solvent_accessibility", "polarizability"]:
        features.extend([
            protpy.ctd_composition(sequence, property=key),
            protpy.ctd_transition(sequence, property=key),
            protpy.ctd_distribution(sequence, property=key)
        ])  
    features = pd.concat(features, axis=1)
    features = np.array(features)
    return features

## Classic Model

In [3]:
args = Namespace(**pickle.load(open(os.path.join(PROTGPS_PARENT_DIR, 'checkpoints/protgps/32bf44b16a4e770a674896b81dfb3729.args'),'rb')))
args.pretrained_hub_dir = "/home/protgps/esm_models/esm2"
args.dataset_file_path = os.path.join(PROTGPS_PARENT_DIR, "data/new_condensate_dataset_m3_c5_mmseqs.json")

In [4]:
train_dataset = get_object(args.dataset_name, "dataset")(args, "train")
dev_dataset = get_object(args.dataset_name, "dataset")(args, "dev")
test_dataset = get_object(args.dataset_name, "dataset")(args, "test")















In [55]:
train_data_classic = []
for sample in tqdm(train_dataset.dataset, ncols=100):
    if any(k not in protpyAA for k in sample['x']):
        continue
    train_data_classic.append({
        "x": make_features(sample['x']),
        'y': sample['y'],
    })



In [64]:
trainX = np.concatenate([d['x'] for d in train_data_classic])
trainY = np.stack([d['y'] for d in train_data_classic])

In [68]:
test_data_classic = []
for sample in tqdm(test_dataset.dataset, ncols=100):
    if any(k not in protpyAA for k in sample['x']):
        continue
    test_data_classic.append({
        "x": make_features(sample['x']),
        'y': sample['y'],
    })



In [69]:
len(test_data_classic), len(test_dataset.dataset)

In [70]:
testX = np.concatenate([d['x'] for d in test_data_classic])
testY = np.stack([d['y'] for d in test_data_classic])

In [109]:
# RANDOM FOREST
rf =  RandomForestClassifier(
    n_estimators=100,
    max_depth=400, 
    random_state=0, 
)

multi_target_rf = ClassifierChain(rf)
multi_target_rf.fit(trainX, trainY)
predY = multi_target_rf.predict_proba(testX)

for i, c in enumerate(OLDCOMPS):
    auc = roc_auc_score(testY[:,i], predY[:,i])
    print(f"ROC-AUC {c}: {auc}")



In [113]:
# Logistic Regression
logreg =  LogisticRegression(solver="liblinear", random_state=0)

multi_target_lr = ClassifierChain(logreg)
multi_target_lr.fit(trainX, trainY)
predY = multi_target_lr.predict_proba(testX)

for i, c in enumerate(OLDCOMPS):
    auc = roc_auc_score(testY[:,i], predY[:,i])
    print(f"ROC-AUC {c}: {auc}")



## MMSeqs

In [9]:
args = Namespace(**pickle.load(open(os.path.join(PROTGPS_PARENT_DIR, 'checkpoints/protgps/7c4853cd22080b250ef89af2a1b25102.args'),'rb')))
args.from_checkpoint = True
args.checkpoint_path = os.path.join(PROTGPS_PARENT_DIR,"checkpoints/protgps/7c4853cd22080b250ef89af2a1b25102epoch=3.ckpt")
args.model_path = args.checkpoint_path
args.pretrained_hub_dir = "/home/protgps/esm_models/esm2"
args.dataset_file_path = os.path.join(PROTGPS_PARENT_DIR, "data/new_condensate_dataset_m3_c5_mmseqs.json")

In [None]:
model = load_model(args)
model = model[0]
model.eval()
print()

In [269]:
test_dataset = get_object(args.dataset_name, "dataset")(args, "test")





In [270]:
test_x = [s['x'] for s in test_dataset.dataset]
test_y = [s['y'] for s in test_dataset.dataset]
test_id = [s['entry_id'] for s in test_dataset.dataset]
test_y = torch.vstack(test_y)

In [280]:
test_preds = predict_condensates(model, test_x, 1, round=False)



In [283]:
for j,condensate in enumerate(OLDCOMPS):
    auc = roc_auc_score(test_y[:,j], test_preds[:,j])
    print(f"{condensate}:\t{round(auc,3)}")



### classical models

In [12]:
train_dataset = get_object(args.dataset_name, "dataset")(args, "train")
dev_dataset = get_object(args.dataset_name, "dataset")(args, "dev")
test_dataset = get_object(args.dataset_name, "dataset")(args, "test")













In [13]:
train_data_classic = []
for sample in tqdm(train_dataset.dataset, ncols=100):
    if any(k not in protpyAA for k in sample['x']):
        continue
    train_data_classic.append({
        "x": make_features(sample['x']),
        'y': sample['y'],
    })

trainX = np.concatenate([d['x'] for d in train_data_classic])
trainY = np.stack([d['y'] for d in train_data_classic])



In [14]:
test_data_classic = []
for sample in tqdm(test_dataset.dataset, ncols=100):
    if any(k not in protpyAA for k in sample['x']):
        continue
    test_data_classic.append({
        "x": make_features(sample['x']),
        'y': sample['y'],
    })

testX = np.concatenate([d['x'] for d in test_data_classic])
testY = np.stack([d['y'] for d in test_data_classic])



In [15]:
len(test_data_classic), len(test_dataset.dataset)

In [19]:
# RANDOM FOREST
rf =  RandomForestClassifier(
    n_estimators=100,
    max_depth=400, 
    random_state=0, 
)

multi_target_rf = ClassifierChain(rf)
multi_target_rf.fit(trainX, trainY)
predY = multi_target_rf.predict_proba(testX)

for i, c in enumerate(OLDCOMPS):
    auc = roc_auc_score(testY[:,i], predY[:,i])
    print(f"ROC-AUC {c}: {auc}")



In [20]:
# Logistic Regression
logreg =  LogisticRegression(solver="liblinear", random_state=0)

multi_target_lr = ClassifierChain(logreg)
multi_target_lr.fit(trainX, trainY)
predY = multi_target_lr.predict_proba(testX)

for i, c in enumerate(OLDCOMPS):
    auc = roc_auc_score(testY[:,i], predY[:,i])
    print(f"ROC-AUC {c}: {auc}")



In [30]:
protpy.ctd_distribution(sample['x']).to_dict()