In [1]:
from tqdm import trange, tqdm
import numpy as np
import scipy.sparse as sp

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

from utils.utils import normalize_adjacency, sparse_mx_to_torch_sparse_tensor, load_data

In [2]:
version = "train"
version = "valid"
n_labels = 4888

In [3]:
device = 'cpu'
adj, features, edge_features = load_data()
del edge_features
adj = [normalize_adjacency(A) for A in adj]
adj_shapes = np.array([at.shape[0] for at in adj])
adj = [adj[idx] + sp.identity(adj_shapes[idx]) for idx in range(len(adj))]

adj = np.array(adj)
features = np.array(features)

features = features[:n_labels] if version != "valid" else features[n_labels:]
adj = adj[:n_labels] if version != "valid" else adj[n_labels:]

  features = np.array(features)


In [4]:
len(features), len(adj)

(1223, 1223)

In [27]:
# documents = [] 
# with open('data/sequences.txt', "r") as f1:
#     for line in f1:
#         documents.append(' '.join(list(line[:-1])))

# documents = documents[:n_labels:] if version != "valid" else documents[:n_labels] # String protein sequences

In [9]:
class Dataset(Dataset):
    def __init__(self, path_documents, path_labels, tokenizer, max_len, version):
        self.version = version # "train" (features & labels) / "test" (features & labels) / "valid" (features)
        self.n_labels = 4888 if version != "valid" else 1223
        ## TODO valid

        # Labels
        self.labels = [] # Class labels
        self.valid_id = []
        with open(path_labels, "r") as f1:
            for line in f1:
                s1, s2 = line.strip().split(',')
                if len(s2.strip())>0:
                    self.labels.append(int(s2))
                else :
                    self.valid_id.append(s1)


        # Protein sequences
        self.max_len = max_len # Maximum sequence length threshold, max in train database is 989
        self.tokenizer = tokenizer # Language model tokenizer
        
        documents = [] 
        with open(path_documents, "r") as f1:
            for line in f1:
                documents.append(' '.join(list(line[:-1])))
                
        self.documents = documents[:self.n_labels] if version != "valid" else documents[-self.n_labels:] # String protein sequences

        if version=='valid':
            self.labels = [-1]*self.n_labels #dummy labels for valid

        print(self.n_labels, len(self.labels),len(self.documents))

        assert len(self.labels) == len(self.documents)

        # Graphe features
        # Inplace with getitem's indexs

    def __len__(self):
        return len(self.documents)

    def __getitem__(self, index):
        # print(index)

        # Load protein sequence
        sequence = self.documents[index].split()
        if len(sequence) > self.max_len - 1:
            sequence = sequence[:self.max_len-1]
            
        # Tokenize the sequence
        encoding = self.tokenizer.encode_plus(
            sequence,
            truncation=True,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )

        # Label
        target = [self.labels[index]]

        sample = {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'indexs': index,
            "target": torch.tensor(target),
        }
        return sample

def get_loader(path_documents='data/sequences.txt', path_labels='data/graph_labels.txt', 
                tokenizer=None, max_len=600, batch_size=16, shuffle=False, version=version):
        
    dataset = Dataset(path_documents=path_documents, path_labels=path_labels, tokenizer=tokenizer, max_len=600, version=version)

    data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True, drop_last=False)
    return data_loader

In [10]:
from transformers import BertModel, BertTokenizer

PRE_TRAINED_MODEL_NAME = 'yarongef/DistilProtBert'
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME, do_lower_case=False)

loader = get_loader(tokenizer=tokenizer, version=version)

1223 1223 1223


In [21]:
for e in loader:
    ()
e

{'input_ids': tensor([[ 2, 12, 17,  ...,  0,  0,  0],
         [ 2, 11, 17,  ...,  0,  0,  0],
         [ 2, 10, 13,  ...,  0,  0,  0],
         ...,
         [ 2, 10, 24,  ...,  0,  0,  0],
         [ 2, 23,  6,  ...,  0,  0,  0],
         [ 2, 20, 10,  ...,  0,  0,  0]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'indexs': tensor([4880, 4881, 4882, 4883, 4884, 4885, 4886, 4887]),
 'target': tensor([[11],
         [ 0],
         [14],
         [14],
         [14],
         [ 5],
         [14],
         [14]])}

In [22]:
features = np.array(features)[e['indexs']]
features = np.vstack(features)
features = torch.FloatTensor(features).to(device)

In [23]:
adj = adj[e['indexs']]
adj = sp.block_diag(adj)
adj = sparse_mx_to_torch_sparse_tensor(adj).to(device)