# Imports, install and mount

<!--  -->

In [1]:
import torch
from torch_geometric.data import Data
import pandas as pd
import pickle
from torch_geometric.transforms import RandomLinkSplit

# Settings


In [2]:
file_path = "/home/ebutz/ESL2024/data/iric.tsv"

mapping_save_path = "/home/ebutz/ESL2024/data/altailed_mapped_iric.tsv"
datasets_save_path = '/home/elliot/Documents/ESL2024/data/mapping_datasets_and_model_for_genes_to_phenotypes_iric/little_dataset_'
val_path = datasets_save_path + 'VAL'
test_path = datasets_save_path + 'TEST'
train_path = datasets_save_path + 'TRAIN'

device = 'cpu'

## Reading and mapping graph

In [7]:
iric = pd.read_csv(file_path, delimiter='\t')
display(iric)

Unnamed: 0,subject,predicate,object
0,GO:0000001,is_a,GO:0048311
1,GO:0000001,is_a,GO:0048308
2,GO:0000002,is_a,GO:0007005
3,GO:0000003,is_a,GO:0008150
4,GO:0000006,is_a,GO:0005385
...,...,...,...
1452516,OsNippo12g255000,interacts_with,OsNippo07g025800
1452517,OsNippo12g255000,interacts_with,OsNippo07g207000
1452518,OsNippo12g255000,interacts_with,OsNippo07g207600
1452519,OsNippo12g255000,interacts_with,OsNippo10g150350


In [22]:
iric.groupby(['object'])

<pandas.core.groupby.generic.DataFrameGroupBy object at 0x7f85918afad0>

In [23]:
iric = pd.read_csv(file_path, delimiter='\t')

# Mapping entities and relations to ids
entity_set = set(iric['object']).union(set(iric['subject']))
entity_to_mapping = {entity: int(i) for i, entity in enumerate(entity_set)}
relation_set = set(iric['predicate'])
relation_to_mapping = {relation: int(i) for i, relation in enumerate(relation_set)}

iric['mapped_subject'] = iric['subject'].apply(lambda x: entity_to_mapping[x])
iric['mapped_predicate'] = iric['predicate'].apply(lambda x: relation_to_mapping[x])
iric['mapped_object'] = iric['object'].apply(lambda x: entity_to_mapping[x])

display(iric)

from tqdm import tqdm
tqdm.pandas()

iric['mapped_alt_tails'] = iric.groupby(['subject']).progress_apply(
                                             lambda row: 
                                             iric.loc[(iric['mapped_subject'] == row['mapped_subject']) & 
                                                      (iric['mapped_predicate'] == row['mapped_predicate']) & 
                                                      (iric['mapped_object'] != row['mapped_object']), 
                                                                                                               'mapped_object'].values, 
                                            #  axis=1
                                             )
display(iric)

iric.to_csv(mapping_save_path, sep='\t', index=False)
print("Saved mapping :")
display(pd.read_csv(mapping_save_path, sep = '\t'))



Unnamed: 0,subject,predicate,object,mapped_subject,mapped_predicate,mapped_object
0,GO:0000001,is_a,GO:0048311,75577,7,36695
1,GO:0000001,is_a,GO:0048308,75577,7,5784
2,GO:0000002,is_a,GO:0007005,34930,7,34589
3,GO:0000003,is_a,GO:0008150,62056,7,81714
4,GO:0000006,is_a,GO:0005385,31254,7,57013
...,...,...,...,...,...,...
1452516,OsNippo12g255000,interacts_with,OsNippo07g025800,45303,4,20123
1452517,OsNippo12g255000,interacts_with,OsNippo07g207000,45303,4,9055
1452518,OsNippo12g255000,interacts_with,OsNippo07g207600,45303,4,19684
1452519,OsNippo12g255000,interacts_with,OsNippo10g150350,45303,4,17834


  0%|          | 2/73396 [00:00<1:55:19, 10.61it/s]


TypeError: <lambda>() got an unexpected keyword argument 'axis'

In [19]:
display(iric)

Unnamed: 0,subject,predicate,object,mapped_subject,mapped_predicate,mapped_object,mapped_alt_tails
0,GO:0000001,is_a,GO:0048311,75577,7,36695,66091
1,GO:0000001,is_a,GO:0048308,75577,7,5784,40112
2,GO:0000002,is_a,GO:0007005,34930,7,34589,54300
3,GO:0000003,is_a,GO:0008150,62056,7,81714,79634
4,GO:0000006,is_a,GO:0005385,31254,7,57013,41857
...,...,...,...,...,...,...,...
1452516,OsNippo12g255000,interacts_with,OsNippo07g025800,45303,4,20123,24221
1452517,OsNippo12g255000,interacts_with,OsNippo07g207000,45303,4,9055,71014
1452518,OsNippo12g255000,interacts_with,OsNippo07g207600,45303,4,19684,28614
1452519,OsNippo12g255000,interacts_with,OsNippo10g150350,45303,4,17834,18104


KeyError: 'mapped_subject'

## Building init vars for Data :

In [4]:
# Edges index
heads = list(iric['mapped_subject'])
tails = list(iric['mapped_object'])
edge_index = torch.tensor([heads,tails], dtype=torch.long)
# edges states
edge_attributes = torch.tensor(iric['mapped_predicate'])

iric_pyg = Data(
                num_nodes = len(entity_set),
                edge_index = edge_index,
                edge_attr = edge_attributes
                )

print(iric_pyg)

print("\nDataset looks valid ? \n",iric_pyg.validate(raise_on_error=True))

Data(edge_index=[2, 10001], edge_attr=[10001], num_nodes=3343)

Dataset looks valid ? 
 True


## Splitting dataset

In [5]:
transform = RandomLinkSplit(
                            num_val = 0.1,
                            num_test = 0.1,
                            is_undirected=False,
                            add_negative_train_samples=False,
                            )

train, val, test = transform(iric_pyg)

torch.save(obj=train, f = train_path)
torch.save(obj=test, f = test_path)
torch.save(obj=val, f = val_path)

print(f'test saved at {test_path}\nval saved at : {val_path}\ntrain saved at : {train_path}\n')

print('Loaded datasets look valid (val, test, train):',
torch.load(val_path).validate(raise_on_error=True),
torch.load(test_path).validate(raise_on_error=True),
torch.load(train_path).validate(raise_on_error=True),'\n')

print('Before :', val)
print(' After :', torch.load(val_path))

test saved at /home/elliot/Documents/ESL2024/data/mapping_datasets_and_model_for_genes_to_phenotypes_iric/little_dataset_TEST.pickle
val saved at : /home/elliot/Documents/ESL2024/data/mapping_datasets_and_model_for_genes_to_phenotypes_iric/little_dataset_VAL.pickle
train saved at : /home/elliot/Documents/ESL2024/data/mapping_datasets_and_model_for_genes_to_phenotypes_iric/little_dataset_TRAIN.pickle

Loaded datasets look valid (val, test, train): True True True 

Before : Data(edge_index=[2, 8001], edge_attr=[8001], num_nodes=3343, edge_label=[2000], edge_label_index=[2, 2000])
 After : Data(edge_index=[2, 8001], edge_attr=[8001], num_nodes=3343, edge_label=[2000], edge_label_index=[2, 2000])
