# Imports, install and mount

<!--  -->

In [1]:
# ! pip install cuda
# ! pip install torch_geometric
# ! pip install nxontology
# ! pip install tensordict
# ! pip install pandas
# ! pip install tensorflow
# ! pip install scipy
# ! pip install matplotlib

# ! pip3 install torch==2.0.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117

import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import torch_geometric
from torch_geometric.nn import ComplEx
from torch_geometric.data import Data
import pandas as pd
from tqdm import tqdm
from torch_geometric.loader import DataLoader

import wandb

import pickle



# Settings


In [2]:
# ComplEx embeddings :

hidden_channels = 10
batch_size = 4096
epochs = 1000

file_path = "/home/elliot/Documents/ESL2024/data/genes_to_phenotypes_iric.tsv"

params_save_name = f"PARAMS_ComplEx_HC_6_times_{hidden_channels}_on_full_Os_GO"
params_save_path = "/home/elliot/Documents/ESL2024/data/mapping_datasets_and_model_for_genes_to_phenotypes_iric/"+params_save_name

mapping_save_path = "/home/elliot/Documents/ESL2024/data/mapping_datasets_and_model_for_genes_to_phenotypes_iric/mapped_Os_to_GO_iric.tsv"

datasets_save_path = '/home/elliot/Documents/ESL2024/data/mapping_datasets_and_model_for_genes_to_phenotypes_iric/dataset_'
val_path = datasets_save_path + 'VAL' +  '.pickle'
test_path = datasets_save_path + 'TEST' +  '.pickle'
train_path = datasets_save_path + 'TRAIN' +  '.pickle'

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

wandb.init(
    settings=wandb.Settings(start_method="fork"),
    # set the wandb project where this run will be logged
    project="ComplEx on Os_to_GO_iric",
    
    # track hyperparameters and run metadata
    config={
    "architecture": "ComplEx",
    "dataset": "genes_to_phenotypes_iric.tsv",
    "epochs": epochs,
    'hidden_channels' : hidden_channels,
    'batch_size' : batch_size
    }
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


cuda


[34m[1mwandb[0m: Currently logged in as: [33mbutzelliot[0m ([33mesl2024[0m). Use [1m`wandb login --relogin`[0m to force relogin


# DATAS

## Reading and mapping graph

What we want : Create a Data object with all the properties I want to use later

	- x (tensorised and processed node attributes) (Not for now)
	- edge_index (a tensor of shape (num_edges, 2) indicates the source node index and the destination node index)
	- y (desired edge labels - optional, can be defined as node labels if needed) (Not for now)
	- any other things you want to use later

In [3]:
iric = pd.read_csv(file_path, delimiter='\t', names = ['subject', 'predicate','object'])

# Mapping entities and relations to ids
entity_set = set(iric['object']).union(set(iric['subject']))
entity_to_mapping = {entity: int(i) for i, entity in enumerate(entity_set)}
relation_set = set(iric['predicate'])
relation_to_mapping = {relation: int(i) for i, relation in enumerate(relation_set)}

iric['mapped_subject'] = iric['subject'].apply(lambda x: entity_to_mapping[x])
iric['mapped_predicate'] = iric['predicate'].apply(lambda x: relation_to_mapping[x])
iric['mapped_object'] = iric['object'].apply(lambda x: entity_to_mapping[x])

display(iric)

print("Minima in mappings :")
print('subject :', min(iric['mapped_subject']))
print('predicate :',min(iric['mapped_predicate']))
print('object :',min(iric['mapped_object']))

mapping_to_entity = {v: k for k, v in entity_to_mapping.items()}
mapping_to_relation = {v: k for k, v in relation_to_mapping.items()}

iric.to_csv(mapping_save_path, sep='\t', index=False)
display(pd.read_csv(mapping_save_path, sep = '\t'))

Unnamed: 0,subject,predicate,object,mapped_subject,mapped_predicate,mapped_object
0,OsNippo01g010050,gene ontology,GO:0031267,6427,0,14391
1,OsNippo01g010050,gene ontology,GO:0006886,6427,0,26909
2,OsNippo01g010050,gene ontology,GO:0005622,6427,0,5618
3,OsNippo01g010050,gene ontology,GO:0005623,6427,0,3370
4,OsNippo01g010050,gene ontology,GO:0090630,6427,0,18397
...,...,...,...,...,...,...
169243,OsNippo12g248550,gene ontology,GO:0009409,18344,0,20878
169244,OsNippo12g248550,gene ontology,GO:0001666,18344,0,11307
169245,OsNippo12g250550,gene ontology,GO:0008270,25594,0,805
169246,OsNippo12g255100,gene ontology,GO:0005576,21797,0,12546


Minima in mappings :
subject : 0
predicate : 0
object : 5


Unnamed: 0,subject,predicate,object,mapped_subject,mapped_predicate,mapped_object
0,OsNippo01g010050,gene ontology,GO:0031267,6427,0,14391
1,OsNippo01g010050,gene ontology,GO:0006886,6427,0,26909
2,OsNippo01g010050,gene ontology,GO:0005622,6427,0,5618
3,OsNippo01g010050,gene ontology,GO:0005623,6427,0,3370
4,OsNippo01g010050,gene ontology,GO:0090630,6427,0,18397
...,...,...,...,...,...,...
169243,OsNippo12g248550,gene ontology,GO:0009409,18344,0,20878
169244,OsNippo12g248550,gene ontology,GO:0001666,18344,0,11307
169245,OsNippo12g250550,gene ontology,GO:0008270,25594,0,805
169246,OsNippo12g255100,gene ontology,GO:0005576,21797,0,12546


## Building init vars for Data :

In [4]:
# # Initial nodes states :
# x = torch.ones(len(entity_set), 1)  # Chaque nœud a 1 pour état initial
# print('X : \n',x)

# Edges index
heads = list(iric['mapped_subject'])
tails = list(iric['mapped_object'])
edge_index = torch.tensor([heads,tails], dtype=torch.long)
print('\nEDGE INDEX : \n',edge_index)

# edges states
edge_attributes = torch.tensor(iric['mapped_predicate'])
print('\nEDGES ATTRIBUTES : \n',edge_attributes)

iric_pyg = Data(
                num_nodes = len(entity_set),
                edge_index = edge_index,
                edge_attr = edge_attributes
                )

print("\nDataset looks valid ? \n",iric_pyg.validate(raise_on_error=True))


EDGE INDEX : 
 tensor([[ 6427,  6427,  6427,  ..., 25594, 21797, 21797],
        [14391, 26909,  5618,  ...,   805, 12546, 14551]])

EDGES ATTRIBUTES : 
 tensor([0, 0, 0,  ..., 0, 0, 0])

Dataset looks valid ? 
 True


## Setting up datas and model


## Splitting dataset

In [5]:
from torch_geometric.transforms import RandomLinkSplit

transform = RandomLinkSplit(
                            num_val = 0.1,
                            num_test = 0.1,
                            is_undirected=False,
                            add_negative_train_samples=False,
                            )

train_data, val_data, test_data = transform(iric_pyg)

# print(f"Hole Dataset :\n {iric_pyg}\n\nTrain:\n{train_data}\n\nTest :\n{test_data}\n\nValidation :\n{val_data}")
# # Il ne faut pas regarder num_edges parce que RLS cache les arêtes mais ne les sort pas du graph.
# # print(f"Number of edges in datasets : \n  Hole Dataset : {iric_pyg.num_edges}\n\n  Train: {train_data.num_edges}\n\n  Test : {test_data.num_edges}\n\n  Validation : {val_data.num_edges}")
# print(f"Number of edges in datasets : \n  Train: {list(train_data.edge_label.size())[0]}\n\n  Test : {list(test_data.edge_label.size())[0]}\n\n  Validation : {list(val_data.edge_label.size())[0]}")

train_data = train_data.to(device)
val_data = val_data.to(device)
test_data = test_data.to(device)

print(type(train_data))

# print('\n\n',train_data.num_nodes)
# print(train_data.num_edge_types)
# print(train_data.__dict__)
# print(train_data.edge_index[0].size())
# print(train_data.edge_index[1].size())
# print(train_data.edge_attr.size())
# print(train_data.edge_attr)
# print(train_data.num_nodes)
# print(train_data.edge_index.size()[1])

<class 'torch_geometric.data.data.Data'>


In [6]:
def save_dataset(dataset: torch_geometric.data.data.Data, save_path: str):
    with open(save_path, 'wb') as f:
        pickle.dump(dataset.to_dict(), f, pickle.HIGHEST_PROTOCOL)

def load_dataset(dataset_path: str # Should lead to a  '.pickle' file
                 )-> torch_geometric.data.data.Data:
    with open(dataset_path, 'rb') as f:
        return torch_geometric.data.data.Data.from_dict(pickle.load(f))
    
save_dataset(dataset=val_data, save_path=val_path)
save_dataset(dataset=test_data, save_path=test_path)
save_dataset(dataset=train_data, save_path=train_path)


# MODELS


## Iniating models and loaders

In [7]:
# Initiating models

complex_model = ComplEx(
    num_nodes=train_data.num_nodes,
    num_relations = train_data.edge_index.size()[1],
    hidden_channels=hidden_channels,
).to(device)

# Initiaing loader
head_index = train_data.edge_index[0]
tail_index = train_data.edge_index[1]
rel_type = train_data.edge_attr

loader = complex_model.loader(
    head_index = head_index,
    tail_index = tail_index,
    rel_type = rel_type,
    batch_size=batch_size,
    shuffle=True,
)

# initiating optimizers
complex_optimizer = optim.Adam(complex_model.parameters())

print(batch_size)

4096


## Train and test functions

In [8]:
@torch.no_grad()
def test(data, model):
    model.eval()
    return model.test(
        head_index=test_data.edge_index[0],
        tail_index=test_data.edge_index[1],
        rel_type=test_data.edge_attr,
        batch_size=batch_size,
        k=10,
    )

def train(model, optimizer):
    model.train()
    total_loss = total_examples = 0
    for head_index, rel_type, tail_index in loader:
        optimizer.zero_grad()
        loss = model.loss(head_index, rel_type, tail_index)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * head_index.numel()
        total_examples += head_index.numel()
    return total_loss / total_examples

def plot_loss(loss_list):
    plt.plot(loss_list, label='Loss')
    plt.title('Evolution des Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

def plot_loss_log(loss_list):
    plt.semilogy(loss_list, label='Loss')
    plt.title('Evolution des Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

def running_mean(list,
                 half_window: int # Number of elements that the function will consider
                                  # ahead and behind the position X to calculate running mean at X.
                 ):
    running_means = []

    for i in range(0,len(list)):

        left_bound = max(0,i-half_window)
        right_bound = min(len(list)-1, i + half_window)
        sublist = list[left_bound:right_bound+1]
        running_means.append(sum(sublist)/len(sublist))

    return running_means

## Train and test

In [9]:
torch.set_grad_enabled(True)

complex_model.reset_parameters()
complex_model.to(device)

losses = []
for epoch in range(0, epochs+1):
    loss = train(model=complex_model, optimizer=complex_optimizer)
    losses.append(loss)
    wandb.log({"loss": loss})
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')

    if epoch % epochs%1000 == 0:
        rank, mrr, hits = test(val_data, model=complex_model)
        print(f'Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}',
              f'Val MRR: {mrr:.4f}, Val Hits@10: {hits:.4f}')
        wandb.log({"Val Mean Rank" : rank, "Val MRR" : mrr, "hits@10": hits})


torch.set_grad_enabled(False)

Epoch: 000, Loss: 0.6931


100%|██████████| 152324/152324 [04:54<00:00, 516.44it/s]


Epoch: 000, Val Mean Rank: 13261.39 Val MRR: 0.0032, Val Hits@10: 0.0052
Epoch: 001, Loss: 0.6931
Epoch: 002, Loss: 0.6931
Epoch: 003, Loss: 0.6923
Epoch: 004, Loss: 0.6891
Epoch: 005, Loss: 0.6814
Epoch: 006, Loss: 0.6678
Epoch: 007, Loss: 0.6483
Epoch: 008, Loss: 0.6243
Epoch: 009, Loss: 0.5983
Epoch: 010, Loss: 0.5730
Epoch: 011, Loss: 0.5503
Epoch: 012, Loss: 0.5306
Epoch: 013, Loss: 0.5139
Epoch: 014, Loss: 0.5010
Epoch: 015, Loss: 0.4893
Epoch: 016, Loss: 0.4806
Epoch: 017, Loss: 0.4739
Epoch: 018, Loss: 0.4666
Epoch: 019, Loss: 0.4611
Epoch: 020, Loss: 0.4573
Epoch: 021, Loss: 0.4526
Epoch: 022, Loss: 0.4477
Epoch: 023, Loss: 0.4450
Epoch: 024, Loss: 0.4420
Epoch: 025, Loss: 0.4385
Epoch: 026, Loss: 0.4353
Epoch: 027, Loss: 0.4323
Epoch: 028, Loss: 0.4301
Epoch: 029, Loss: 0.4265
Epoch: 030, Loss: 0.4250
Epoch: 031, Loss: 0.4216
Epoch: 032, Loss: 0.4214
Epoch: 033, Loss: 0.4177
Epoch: 034, Loss: 0.4161
Epoch: 035, Loss: 0.4140
Epoch: 036, Loss: 0.4107
Epoch: 037, Loss: 0.4085
Ep

100%|██████████| 152324/152324 [05:03<00:00, 501.22it/s]

Epoch: 1000, Val Mean Rank: 270.13 Val MRR: 0.2643, Val Hits@10: 0.6243





<torch.autograd.grad_mode.set_grad_enabled at 0x7fd598404400>

In [10]:
wandb.finish()
print("WandB finished.")

torch.save(complex_model.state_dict(), params_save_path)
print("Model saved at", params_save_path)

0,1
Val MRR,▁█
Val Mean Rank,█▁
hits@10,▁█
loss,█▅▅▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Val MRR,0.26428
Val Mean Rank,270.13092
hits@10,0.6243
loss,0.04323


WandB finished.
Model saved at /home/elliot/Documents/ESL2024/data/mapping_datasets_and_model_for_genes_to_phenotypes_iric/PARAMS_ComplEx_HC_6_times_10_on_full_Os_GO
