In [25]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import ast


class ProteinResidueDataset(Dataset):
    def __init__(self, csv_path, label_to_index=None):
        self.df = pd.read_csv(csv_path)
        self.df['fragments'] = self.df['fragments'].apply(ast.literal_eval)
        self.df['family'] = 'PF01370'
        
        # Create label -> index mapping if not given
        if label_to_index is None:
            families = sorted(self.df['family'].unique())
            self.label_to_index = {fam: idx for idx, fam in enumerate(families)}
        else:
            self.label_to_index = label_to_index

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        length = int(row['length'])

        # Per-residue labels: 1 for family domain, 0 for background
        label_mask = torch.zeros(length, dtype=torch.long)
        for fragment in row['fragments']:
	        start = fragment['start']
	        end = fragment['end']
	        label_mask[start:end+1] = 1

        # Later on, we will save the embeddings (and one-hot encodings) and load them instead
        return {
            'accession': row['accession'],
            'label': self.label_to_index[row['family']],
            'residue_labels': label_mask,
            'length': length
        }

In [26]:
# Test the dataset
path = '../data/results_with_sequence.csv'
dataset = ProteinResidueDataset(path)

In [28]:
dataset[0]["residue_labels"]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [4]:
df

Unnamed: 0,accession,length,source_database,fragments,sequence
0,A0A003,340,unreviewed,"[{'start': 15, 'end': 249}]",MSSDTHGTDLADGDVLVTGAAGFIGSHLVTELRNSGRNVVAVDRRP...
1,A0A009GZV8,323,unreviewed,"[{'start': 3, 'end': 208}]",MNVLITGGTGFIGKQIAKEILKAGSLTLDDNKPQSIDKIILFDAFA...
2,A0A009H3J1,335,unreviewed,"[{'start': 2, 'end': 260}]",MILVTGGLGFIGSHIALSLMAQGQEVVIVDNLANSTLQTLERLEFI...
3,A0A009H7U9,338,unreviewed,"[{'start': 4, 'end': 263}]",MAKILVTGGAGYIGSHTCVELLNAGHEVIVFDNLSNSSEESLKRVQ...
4,A0A009HJQ2,301,unreviewed,"[{'start': 5, 'end': 220}]",MNKNVLITGASGFIGTHLIKFLLQKNYNVIAVTRQAGKASDHPALQ...
...,...,...,...,...,...
9995,A0A0D7E5F6,300,unreviewed,"[{'start': 3, 'end': 222}]",MNILLTGGTGLIGRALCRRWLADGHRLWVWSRTPQRVAMLCGAEVQ...
9996,A0A0D7E685,352,unreviewed,"[{'start': 5, 'end': 222}]",MTNQALVVGASGIVGSALSRLLADEGWNVAGLARRPNTDAGVTPIS...
9997,A0A0D7E6N8,325,unreviewed,"[{'start': 9, 'end': 231}]",MARYLNQTIFVAGHRGMVGSAIVRRLRALGYGNILTAERDELNLLD...
9998,A0A0D7E8A6,356,unreviewed,"[{'start': 11, 'end': 268}]",MTQSSQQDTKVLVTGGAGYIGSHTCVELIRAGYGVVIYDNFSNSHR...
