# Imports, install and mount

<!--  -->

In [6]:
import torch
import torch_geometric
from torch_geometric.data import Data
import pandas as pd
from torch_geometric.loader import DataLoader
import pickle
from torch_geometric.transforms import RandomLinkSplit

# Settings


In [7]:
file_path = "/home/elliot/Documents/ESL2024/data/genes_to_phenotypes_iric.tsv"

mapping_save_path = "/home/elliot/Documents/ESL2024/data/mapped_Os_to_GO_iric.tsv"
datasets_save_path = '/home/elliot/Documents/ESL2024/data/mapping_datasets_and_model_for_genes_to_phenotypes_iric/dataset_'
val_path = datasets_save_path + 'VAL' +  '.pickle'
test_path = datasets_save_path + 'TEST' +  '.pickle'
train_path = datasets_save_path + 'TRAIN' +  '.pickle'

device = 'cpu'

## Reading and mapping graph

In [8]:
iric = pd.read_csv(file_path, delimiter='\t', names = ['subject', 'predicate','object'])

# Mapping entities and relations to ids
entity_set = set(iric['object']).union(set(iric['subject']))
entity_to_mapping = {entity: int(i) for i, entity in enumerate(entity_set)}
relation_set = set(iric['predicate'])
relation_to_mapping = {relation: int(i) for i, relation in enumerate(relation_set)}

iric['mapped_subject'] = iric['subject'].apply(lambda x: entity_to_mapping[x])
iric['mapped_predicate'] = iric['predicate'].apply(lambda x: relation_to_mapping[x])
iric['mapped_object'] = iric['object'].apply(lambda x: entity_to_mapping[x])

iric.to_csv(mapping_save_path, sep='\t', index=False)
print("Saved mapping :")
display(pd.read_csv(mapping_save_path, sep = '\t'))

Saved mapping :


Unnamed: 0,subject,predicate,object,mapped_subject,mapped_predicate,mapped_object
0,OsNippo01g010050,gene ontology,GO:0031267,8201,0,6566
1,OsNippo01g010050,gene ontology,GO:0006886,8201,0,20154
2,OsNippo01g010050,gene ontology,GO:0005622,8201,0,20826
3,OsNippo01g010050,gene ontology,GO:0005623,8201,0,10373
4,OsNippo01g010050,gene ontology,GO:0090630,8201,0,2733
...,...,...,...,...,...,...
169243,OsNippo12g248550,gene ontology,GO:0009409,20245,0,12440
169244,OsNippo12g248550,gene ontology,GO:0001666,20245,0,4625
169245,OsNippo12g250550,gene ontology,GO:0008270,20383,0,15186
169246,OsNippo12g255100,gene ontology,GO:0005576,29052,0,8295


## Building init vars for Data :

In [9]:
# Edges index
heads = list(iric['mapped_subject'])
tails = list(iric['mapped_object'])
edge_index = torch.tensor([heads,tails], dtype=torch.long)
# edges states
edge_attributes = torch.tensor(iric['mapped_predicate'])

iric_pyg = Data(
                num_nodes = len(entity_set),
                edge_index = edge_index,
                edge_attr = edge_attributes
                )

print(iric_pyg)

print("\nDataset looks valid ? \n",iric_pyg.validate(raise_on_error=True))

Data(edge_index=[2, 169248], edge_attr=[169248], num_nodes=30396)

Dataset looks valid ? 
 True


## Splitting dataset

In [10]:
transform = RandomLinkSplit(
                            num_val = 0.1,
                            num_test = 0.1,
                            is_undirected=False,
                            add_negative_train_samples=False,
                            )

train, val, test = transform(iric_pyg)

torch.save(obj=train, f = train_path)
torch.save(obj=test, f = test_path)
torch.save(obj=val, f = val_path)

print(f'test saved at {test_path}\nval saved at : {val_path}\ntrain saved at : {train_path}\n')

print('Loaded datasets look valid (val, test, train):',
torch.load(val_path).validate(raise_on_error=True),
torch.load(test_path).validate(raise_on_error=True),
torch.load(train_path).validate(raise_on_error=True),'\n')

print('Before :', val)
print(' After :', torch.load(val_path))

test saved at /home/elliot/Documents/ESL2024/data/mapping_datasets_and_model_for_genes_to_phenotypes_iric/dataset_TEST.pickle
val saved at : /home/elliot/Documents/ESL2024/data/mapping_datasets_and_model_for_genes_to_phenotypes_iric/dataset_VAL.pickle
train saved at : /home/elliot/Documents/ESL2024/data/mapping_datasets_and_model_for_genes_to_phenotypes_iric/dataset_TRAIN.pickle

Loaded datasets look valid (val, test, train): True True True 

Before : Data(edge_index=[2, 135400], edge_attr=[135400], num_nodes=30396, edge_label=[33848], edge_label_index=[2, 33848])
 After : Data(edge_index=[2, 135400], edge_attr=[135400], num_nodes=30396, edge_label=[33848], edge_label_index=[2, 33848])
