In [1]:
import pandas as pd
import numpy as np

import torch
from torch_geometric.data import Data
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.loader import DataLoader
import torch.optim as optim
from torch_geometric.nn import ComplEx
import random

import wandb
from tqdm import tqdm

# Specifique au dataset iric
import os
import sys
sys.path.append('/home/ebutz/ESL2024/code/utils') # Ajouter le dossier contenant constants.py
import constants as c

In [7]:
# Datas :
iric_csv_path = "/home/ebutz/ESL2024/data/full_iric/iric.csv"
test_ratio = 0.1
val_ratio  = 0.1

# ComplEx embeddings :
hidden_channels = 1
batch_size = 4096
epochs = 3
neg_per_pos = 1 #Number of negatives per positive during training
K = 10 #K from Hit@K

device = 'cuda' if torch.cuda.is_available() else 'cpu'
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [9]:
def triples_from_csv(path_to_csv, columns_to_use = c.IricNode.features.value):
    """
    Creates triples from a CSV.

    Parameters:
    - path_to_csv (str): Filepath or buffer to the input CSV file.
    - columns_to_use (list): A list of the columns to consider.

    Returns:
    - triples (pandas.DataFrame): Output DataFrame in triple format.
                                  Subjects are index items, predicates are column names from columns_to_use, objects are non-NaN values in columns.
    """

    df = pd.read_csv(filepath_or_buffer=path_to_csv, sep = ',', index_col = 0)
    df.columns = df.columns.str.lower()
    
    # Create a list of triples
    triples = []
    # Drop feature columns
    columns_to_drop = [col for col in columns_to_use if col in df.columns]
    df.drop(columns=columns_to_drop, inplace=True)
    df = df.replace({np.nan:None})
    
    for index, row in df.iterrows():
        for column in df.columns:
            if row[column] is not None:
                for predicate in row[column].split('|'):
                    triples.append([index, column, predicate])

    # Create a dataframe from the list of triples
    return pd.DataFrame(triples, columns=['subject', 'predicate', 'object'])

# Extracting triples from original csv :
iric_triples = triples_from_csv(path_to_csv = iric_csv_path)

# Mapping entities and relations to integers ids :
entity_set = set(iric_triples['object']).union(set(iric_triples['subject']))
entity_to_mapping = {entity: int(i) for i, entity in enumerate(entity_set)}
relation_set = set(iric_triples['predicate'])
relation_to_mapping = {relation: int(i) for i, relation in enumerate(relation_set)}

# Triples to mapped triples :
iric_triples['mapped_subject'] = iric_triples['subject'].apply(lambda x: entity_to_mapping[x])
iric_triples['mapped_predicate'] = iric_triples['predicate'].apply(lambda x: relation_to_mapping[x])
iric_triples['mapped_object'] = iric_triples['object'].apply(lambda x: entity_to_mapping[x])
display(iric_triples)
print(min(iric_triples['mapped_subject']))

  df = pd.read_csv(filepath_or_buffer=path_to_csv, sep = ',', index_col = 0)


Unnamed: 0,subject,predicate,object,mapped_subject,mapped_predicate,mapped_object
0,GO:0000001,is_a,GO:0048311,49137,7,52855
1,GO:0000001,is_a,GO:0048308,49137,7,74528
2,GO:0000002,is_a,GO:0007005,1676,7,46182
3,GO:0000003,is_a,GO:0008150,49508,7,46766
4,GO:0000006,is_a,GO:0005385,7766,7,58756
...,...,...,...,...,...,...
1452516,OsNippo12g255000,interacts_with,OsNippo07g025800,1963,8,65147
1452517,OsNippo12g255000,interacts_with,OsNippo07g207000,1963,8,68417
1452518,OsNippo12g255000,interacts_with,OsNippo07g207600,1963,8,49405
1452519,OsNippo12g255000,interacts_with,OsNippo10g150350,1963,8,34132


0


In [16]:
ontology = iric_triples[iric_triples['predicate']=='is_a']
GO_terms = list(set(list(ontology['mapped_subject']+list(ontology['mapped_object']))))
GO_terms[:100]

[131073,
 131088,
 131089,
 131096,
 131099,
 131103,
 131105,
 131110,
 131113,
 131114,
 131115,
 131120,
 131122,
 131131,
 131132,
 131134,
 131145,
 131149,
 131151,
 131160,
 131163,
 131166,
 131170,
 131173,
 131174,
 131178,
 131181,
 131182,
 131186,
 131189,
 131201,
 131204,
 131207,
 131212,
 131219,
 131222,
 131225,
 131226,
 131231,
 131236,
 131237,
 131239,
 131241,
 131246,
 131249,
 131253,
 131254,
 131256,
 131260,
 131266,
 131270,
 131279,
 131280,
 131284,
 131295,
 131300,
 131310,
 131313,
 131314,
 131319,
 131320,
 131322,
 131324,
 131325,
 131327,
 131332,
 131336,
 131337,
 131345,
 131346,
 131354,
 131355,
 131356,
 131357,
 131359,
 131360,
 131361,
 292,
 131366,
 131367,
 131369,
 131370,
 131371,
 131377,
 131378,
 131379,
 131381,
 131388,
 131394,
 131395,
 131397,
 131398,
 131400,
 131402,
 131407,
 131410,
 131411,
 131412,
 131414,
 131419]

In [5]:
# Triples to pyg framework :

# Edges index :
heads = list(iric_triples['mapped_subject'])
tails = list(iric_triples['mapped_object'])
edge_index = torch.tensor([heads,tails], dtype=torch.long)
edge_attributes = torch.tensor(iric_triples['mapped_predicate'])

iric_pyg = Data(
                num_nodes = len(entity_set),
                edge_index = edge_index,
                edge_attr = edge_attributes
                )

print(iric_pyg)

print("\nDataset looks valid :",iric_pyg.validate(raise_on_error=True))

transform = RandomLinkSplit(
                            num_val = val_ratio,
                            num_test = test_ratio,
                            is_undirected=False,
                            add_negative_train_samples=False,
                            )

train_data, val_data, test_data = transform(iric_pyg)
print("Train, test, val sets look valid :",train_data.validate(raise_on_error=True), test_data.validate(raise_on_error=True), val_data.validate(raise_on_error=True))

Data(edge_index=[2, 1427], edge_attr=[1427], num_nodes=1684)

Dataset looks valid : True
Train, test, val sets look valid : True True True


In [6]:
# Initiating model :
to_complex = ComplEx(
    num_nodes=train_data.num_nodes,
    num_relations = train_data.edge_index.size()[1],
    hidden_channels=hidden_channels,
).to(device)
to_complex.reset_parameters()
to_complex.to(device)

# Initiaing loader :
head_index = train_data.edge_index[0]
tail_index = train_data.edge_index[1]
rel_type = train_data.edge_attr

loader = to_complex.loader(
    head_index = head_index,
    tail_index = tail_index,
    rel_type = rel_type,
    batch_size=batch_size,
    shuffle=True,
)
print("Loader type :", type(loader))

# initiating optimizers :
complex_optimizer = optim.Adam(to_complex.parameters())

# Defining test and train functions :
@torch.no_grad()
def test(data, model):
    model.eval()
    return model.test(
        head_index=data.edge_index[0],
        tail_index=data.edge_index[1],
        rel_type=data.edge_attr,
        batch_size=batch_size, # No need for Tail_Only_ComplEx because one use only 1000 random sample instead of the full dataset.
        k=K, #The k in Hit@k
    )

def train(loader, model, optimizer):
    model.train()
    total_loss = total_examples = 0
    for head_index, rel_type, tail_index in loader:
        optimizer.zero_grad()
        loss = model.loss(head_index, rel_type, tail_index)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * head_index.numel()
        total_examples += head_index.numel()
    return total_loss / total_examples

Loader type : <class 'torch_geometric.nn.kge.loader.KGTripletLoader'>


In [7]:
# Running XP :

wandb.init(
    settings=wandb.Settings(start_method="fork"),
    # set the wandb project where this run will be logged
    project="ComplEx on Iric",
    
    # track hyperparameters and run metadata
    config={
    "architecture": "Tail_Only_ComplEx",
    "dataset": "Iric",
    "epochs": epochs,
    'hidden_channels' : hidden_channels,
    'batch_size' : batch_size
    }
)

losses = []
for epoch in range(1, epochs+1):
    loss = train(model=to_complex, loader = loader, optimizer=complex_optimizer)
    losses.append(loss)
    wandb.log({"loss": loss})

    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')

    rank, mrr, hit = test(val_data, model=to_complex)
    print(f'Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}', f'Val MRR: {mrr:.4f}, Val Hits@10: {hit:.4f}')

    wandb.log({"Val Mean Rank" : rank, "Val MRR" : mrr, "hits@10": hit})


wandb.finish()


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mbutzelliot[0m ([33mesl2024[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011301602925070459, max=1.0…

Epoch: 001, Loss: 0.6931


100%|██████████| 1143/1143 [00:00<00:00, 1241.49it/s]


Epoch: 001, Val Mean Rank: 808.90 Val MRR: 0.0055, Val Hits@10: 0.0105
Epoch: 002, Loss: 0.6931


100%|██████████| 1143/1143 [00:00<00:00, 1990.98it/s]


Epoch: 002, Val Mean Rank: 793.60 Val MRR: 0.0060, Val Hits@10: 0.0114
Epoch: 003, Loss: 0.6931


100%|██████████| 1143/1143 [00:00<00:00, 1818.26it/s]


Epoch: 003, Val Mean Rank: 778.51 Val MRR: 0.0066, Val Hits@10: 0.0122


VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Val MRR,▁▄█
Val Mean Rank,█▄▁
hits@10,▁▅█
loss,█▃▁

0,1
Val MRR,0.00656
Val Mean Rank,778.50745
hits@10,0.01225
loss,0.69314
