In [1]:
%cd ..
%pwd

/home/julius/projects/University/2023W/PP2/Project


'/home/julius/projects/University/2023W/PP2/Project'

In [14]:
from pathlib import Path
from typing import Union

import polars as pl
from Bio import SeqIO

from src.dataset.trizod_scores.parse import read_score_csv
from src.dataset.clustering.parse import read_cluster_assignments

from typing import Set
import h5py
import torch
import numpy as np
from collections import defaultdict
from tqdm import tqdm

In [3]:
def filter_scores(score_csv: pl.DataFrame, ids: Set[str]) -> pl.DataFrame:
    return score_csv.filter(score_csv["ID"].is_in(ids))[
        ["ID", "pscores"]
    ].group_by("ID").agg(pl.col("pscores"))

def to_file(data: pl.DataFrame, embs: h5py.File, name: str):
    with h5py.File(f"data/{name}.h5", 'w') as f:
        embedding = f.create_group("embedding")
        trizod = f.create_group("trizod")
        cluster_group: Group = f.create_group("cluster")
        cluster = defaultdict(list)

        for row in data.rows():
            cluster[row[2]].append(row[0])
            trizod[row[0]] = torch.tensor(np.array(row[1], dtype=np.float32))
            embedding[row[0]] = torch.tensor(np.array(embs[row[0]]))

        for key in cluster:
            cluster_group[key] = cluster[key]
        

In [15]:
datasets = ["unfiltered", "tolerant", "moderate", "strict"]
clusters = {
    dataset: read_cluster_assignments(f"data/clusters/{dataset}_rest_clu.tsv")
    for dataset in datasets
}
score_csv = {dataset: read_score_csv(f"data/{dataset}.csv") for dataset in datasets}
embs = h5py.File("data/embeddings/unfiltered_all_esm2_3b.h5", "r")


In [5]:
ids = {dataset: set(score_csv[dataset]["ID"]) for dataset in score_csv}
test_ids = set(rec.id for rec in SeqIO.parse(f"data/TriZOD_test_set.fasta", "fasta"))

In [6]:
train_ids = {dataset: ids[dataset] - test_ids for dataset in ids}

In [7]:
test_data = filter_scores(score_csv["strict"], test_ids)
training_data = {
    dataset: filter_scores(score_csv[dataset], train_ids[dataset])
    for dataset in datasets
}


In [8]:
training_data = {
    dataset: training_data[dataset].join(clusters[dataset], left_on="ID", right_on="sequence_id")
    for dataset in datasets
}

test_data = test_data.with_columns(test_data["ID"].alias("cluster_representative_id"))

In [9]:
training_data

{'unfiltered': shape: (13_290, 3)
 ┌──────────────┬────────────────────────┬───────────────────────────┐
 │ ID           ┆ pscores                ┆ cluster_representative_id │
 │ ---          ┆ ---                    ┆ ---                       │
 │ str          ┆ list[f64]              ┆ str                       │
 ╞══════════════╪════════════════════════╪═══════════════════════════╡
 │ 27821_1_1_1  ┆ [null, null, … null]   ┆ 27821_1_1_1               │
 │ 4620_1_1_1   ┆ [null, 0.5436, … null] ┆ 15845_1_1_1               │
 │ 16574_1_2_2  ┆ [null, 0.2467, … null] ┆ 16574_1_2_2               │
 │ 19879_1_1_1  ┆ [null, 0.0998, … null] ┆ 19879_1_1_1               │
 │ …            ┆ …                      ┆ …                         │
 │ 5156_2_1_1   ┆ [null, null, … null]   ┆ 5156_2_1_1                │
 │ 6836_1_1_1   ┆ [null, null, … null]   ┆ 6836_1_1_1                │
 │ 50525_1_1_1  ┆ [null, null, … null]   ┆ 50525_1_1_1               │
 │ 50438_22_1_1 ┆ [null, null, … null]   ┆ 

In [10]:
%timeit -n 1 -r 1
to_file(test_data, embs, "test")

for dataset in datasets:
    to_file(training_data[dataset], embs, f"train_{dataset}")

In [12]:
read_score_csv("data/unfiltered.csv")

Unnamed: 0_level_0,ID,entryID,stID,entity_assemID,entityID,seq_index,seq,k,zscores,pscores,C,CA,CB,HA,H,N,HB
i64,str,i64,i64,i64,i64,i64,str,i64,f64,f64,f64,f64,f64,f64,f64,f64,f64
0,"""36025_1_1_1""",36025,1,1,1,1,"""F""",2,,,,,,,,,
0,"""36025_1_1_1""",36025,1,1,1,2,"""I""",5,7.1754,0.0594,,,,3.888,,,1.893
0,"""36025_1_1_1""",36025,1,1,1,3,"""H""",8,8.7949,0.0733,,,,4.434,8.585,,3.156
0,"""36025_1_1_1""",36025,1,1,1,4,"""H""",9,8.643,0.1185,,,,4.63,8.419,,3.254
0,"""36025_1_1_1""",36025,1,1,1,5,"""I""",9,8.1347,0.156,,,,3.907,8.06,,1.94
0,"""36025_1_1_1""",36025,1,1,1,6,"""I""",8,6.6576,0.2457,,,,3.731,8.055,,1.906
0,"""36025_1_1_1""",36025,1,1,1,7,"""G""",7,4.5009,0.4333,,,,3.8725,8.304,,
0,"""36025_1_1_1""",36025,1,1,1,8,"""G""",7,6.2927,0.241,,,,3.868,8.089,,
0,"""36025_1_1_1""",36025,1,1,1,9,"""L""",8,8.5087,0.0931,,,,4.044,8.236,,1.885
0,"""36025_1_1_1""",36025,1,1,1,10,"""F""",9,10.629,0.0,,,,4.317,8.373,,3.234
