In [1]:
%load_ext autoreload
%autoreload 2

In [6]:
from plaid.datasets import FunctionOrganismDataModule
import torch

datamodule = FunctionOrganismDataModule(
        train_shards="/data/lux70/data/pfam/compressed/j1v1wv6w/train/shard{0000..4423}.tar",
        val_shards="/data/lux70/data/pfam/compressed/j1v1wv6w/val/shard{0000..0863}.tar",
        config_file="/data/lux70/data/pfam/compressed/j1v1wv6w/config.json",
        go_metadata_fpath="/data/lux70/data/pfam/pfam2go.csv",
        organism_metadata_fpath="/data/lux70/data/pfam/organism_counts.csv",
        cache_dir="/homefs/home/lux70/cache/plaid_data_cache/j1v1wv6w",
        train_epoch_num_batches=100_000,
        val_epoch_num_batches=1_000,
        shuffle_buffer=10_000,
        shuffle_initial=10_000,
        num_workers=4,
        batch_size=2048,
)
datamodule.setup()

val_dataloader = datamodule.val_dataloader()
val_dataset = datamodule.val_ds

print(len(val_dataloader))
print(datamodule.batch_size * len(val_dataloader))

1000
2048000


In [7]:
all_go_idxs = []
all_organism_idxs = []
all_local_paths = []
all_sample_ids = []

from tqdm.notebook import tqdm

for batch in tqdm(val_dataloader):
    embedding, mask, go_idx, organism_idx, pfam_id, sample_id, local_path = batch
    
    all_go_idxs.append(go_idx)
    all_organism_idxs.append(organism_idx)
    all_local_paths.append(local_path)
    all_sample_ids.append(sample_id)

  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [15]:
import itertools

all_go_idxs = torch.cat(all_go_idxs)
all_organism_idxs = torch.cat(all_organism_idxs)
all_local_paths = list(itertools.chain.from_iterable(all_local_paths))
all_sample_ids = list(itertools.chain.from_iterable(all_sample_ids))

In [16]:
import pandas as pd

go_metadata_fpath = "/data/lux70/data/pfam/pfam2go.csv"
organism_metadata_fpath = "/data/lux70/data/pfam/organism_counts.csv"

go_df = pd.read_csv(go_metadata_fpath)
org_df = pd.read_csv(organism_metadata_fpath)

In [17]:
import pandas as pd

df = pd.DataFrame(
    {
        "GO_idx": all_go_idxs,
        "organism_index": all_organism_idxs,
        "local_paths": all_local_paths,
        "sample_ids": all_sample_ids
    }
)

In [19]:
df = df.merge(org_df, on="organism_index")

In [20]:
df = df.merge(go_df, on="GO_idx")

In [21]:
df = df.rename({"count": "GO_counts", "counts": "organism_counts"}, axis=1)

In [22]:
print(df[df.GO_term == "GTPase activity"].shape)
df[df.GO_term == "GTPase activity"].value_counts("organism_id")

(0, 11)


Series([], Name: count, dtype: int64)

In [70]:
# df.to_csv("/data/lux70/data/pfam/compressed/j1v1wv6w/val_dataset_metadata.csv")

In [23]:
df.head()

Unnamed: 0,GO_idx,organism_index,local_paths,sample_ids,organism_id,organism_counts,pfam_id,GO_id,GO_term,GO_level,GO_counts
0,63,2260,/data/lux70/data/pfam/compressed/j1v1wv6w/val/...,sample511336,VULDI,2222,PF00145,GO:0008168,methyltransferase activity,function,37
1,63,2260,/data/lux70/data/pfam/compressed/j1v1wv6w/val/...,sample511336,VULDI,2222,PF00590,GO:0008168,methyltransferase activity,function,37
2,63,2260,/data/lux70/data/pfam/compressed/j1v1wv6w/val/...,sample511336,VULDI,2222,PF01234,GO:0008168,methyltransferase activity,function,37
3,63,2260,/data/lux70/data/pfam/compressed/j1v1wv6w/val/...,sample511336,VULDI,2222,PF01795,GO:0008168,methyltransferase activity,function,37
4,63,2260,/data/lux70/data/pfam/compressed/j1v1wv6w/val/...,sample511336,VULDI,2222,PF03141,GO:0008168,methyltransferase activity,function,37


In [24]:
df.local_paths[0]

'/data/lux70/data/pfam/compressed/j1v1wv6w/val/shard0807.tar'

In [25]:
df.shape

(325098587, 11)

In [28]:
len(df.local_paths.unique())

701

In [99]:
def filter_classes(df, GO_idx, org_idx, seed, max_samples=10_000):
    random.seed(seed)
    np.random.seed(seed)
    
    tmp = df[df.GO_idx == GO_idx]
    tmp = tmp[df.organism_index == org_idx]
    n_samples = min(tmp.shape[0], max_samples)
    tmp = tmp.sample(n=n_samples, random_state=seed)
    return tmp

filtered_samples = filter_classes(df, 3, 30, 42)
filtered_samples.shape

  tmp = tmp[df.organism_index == org_idx]


(3129, 11)

In [137]:
filtered_samples = filtered_samples.sort_values("local_paths")

In [104]:
shard_list = filtered_samples.local_paths.unique()

In [98]:
import webdataset as wds
from plaid.datasets import MetadataParser, make_sample

import random
import numpy as np

metadata_parser = MetadataParser(
    go_metadata_fpath="/data/lux70/data/pfam/pfam2go.csv",
    organism_metadata_fpath="/data/lux70/data/pfam/organism_counts.csv",
)

In [130]:
ds = wds.WebDataset(list(shard_list), resampled=False, repeat=False, shardshuffle=False).map(lambda x: make_sample(x, 512, metadata_parser)).batched(2048)

In [201]:
sample = next(iter(ds))

In [202]:
embedding, mask, go_idx, organism_idx, pfam_id, sample_id, local_path = sample

In [203]:
import numpy as np

unique_shards = np.unique(local_path)

In [204]:
tmp = filtered_samples[filtered_samples.local_paths.isin(unique_shards)]

In [205]:
tmp[tmp.sample_ids.isin(sample_id)]

Unnamed: 0,GO_idx,organism_index,local_paths,sample_ids,organism_id,counts,pfam_id,GO_id,GO_term,GO_level,count
