# Imports, install and mount

In [32]:
# ! pip install cuda
# ! pip install torch_geometric
# ! pip install dgl
# ! pip install nxontology
# ! pip install tensordict
# ! pip install numpy==1.22.1
# ! pip install pandas
# ! pip install tensorflow
# ! pip install scipy
# ! pip install pydantic
# ! pip install matplotlib

import random
import torch
import torch.optim as optim
import torch_geometric
from torch_geometric.nn import ComplEx
from torch_geometric.data import Data
import pandas as pd
from torch_geometric.loader import DataLoader
from nxontology.imports import from_file
import pickle
import wandb
import torch.nn.functional as F
import numpy as np

# Settings 


In [48]:
# Datas
mapped_iric_path = '/home/elliot/Documents/ESL2024/data/little_mapped_Os_to_GO_iric.tsv'
datasets_save_path = '/home/elliot/Documents/ESL2024/data/mapping_datasets_and_model_for_genes_to_phenotypes_iric/little_dataset_'

val_path = datasets_save_path + 'VAL' +  '.pickle'
test_path = datasets_save_path + 'TEST' +  '.pickle'
train_path = datasets_save_path + 'TRAIN' +  '.pickle'

# Model to re-train :
hidden_channels = 5
batch_size = 4096
epochs = 500
lin_factor = 0.5

params_save_name = f"PARAMS_ComplEx_6_times_{hidden_channels}_HC_{epochs}_epochs_{batch_size}_BS_on_full_Os_GO"
model_parameters_path = "/home/elliot/Documents/ESL2024/data/mapping_datasets_and_model_for_genes_to_phenotypes_iric/"+params_save_name

# Ontology
url = "/home/elliot/Documents/ESL2024/data/go-basic.json.gz"

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu' # Tip : Use cpu for debugging
print(device)

cuda


# data loading


In [34]:
mapped_iric = pd.read_csv(mapped_iric_path, sep = '\t')
display(mapped_iric)

GO_to_map = mapped_iric.set_index('object')['mapped_object'].to_dict()
map_to_GO = {value: key for key, value in GO_to_map.items()}


# Checking dict :
looks_ok: bool = True
print(len(list(mapped_iric['object'])))
for i in range(len(list(mapped_iric['object']))):
    if GO_to_map[mapped_iric['object'][i]]!=mapped_iric['mapped_object'][i]:
        looks_ok = False
    
print('Dict looks ok :', looks_ok)
print(map_to_GO)
print(GO_to_map)

Unnamed: 0,subject,predicate,object,mapped_subject,mapped_predicate,mapped_object
0,OsNippo01g010050,gene ontology,GO:0031267,2270,0,2505
1,OsNippo01g010050,gene ontology,GO:0006886,2270,0,2022
2,OsNippo01g010050,gene ontology,GO:0005622,2270,0,2237
3,OsNippo01g010050,gene ontology,GO:0005623,2270,0,1720
4,OsNippo01g010050,gene ontology,GO:0090630,2270,0,76
...,...,...,...,...,...,...
9996,OsNippo01g223000,gene ontology,GO:0005784,1980,0,141
9997,OsNippo01g223050,gene ontology,GO:0005634,1922,0,2657
9998,OsNippo01g223050,gene ontology,GO:0005737,1922,0,160
9999,OsNippo01g223050,gene ontology,GO:0003676,1922,0,2811


10001
Dict looks ok : True
{2505: 'GO:0031267', 2022: 'GO:0006886', 2237: 'GO:0005622', 1720: 'GO:0005623', 76: 'GO:0090630', 515: 'GO:0043087', 1419: 'GO:0005096', 1161: 'GO:0020037', 2316: 'GO:0016705', 2918: 'GO:0055114', 1094: 'GO:0004497', 3128: 'GO:0005506', 2524: 'GO:0009055', 2490: 'GO:0016722', 1806: 'GO:0005507', 2287: 'GO:0016491', 2415: 'GO:0005886', 1736: 'GO:0009506', 132: 'GO:0046658', 3092: 'GO:0016020', 2811: 'GO:0003676', 863: 'GO:0006412', 1710: 'GO:0019843', 2678: 'GO:0003723', 340: 'GO:0042542', 832: 'GO:0000028', 2190: 'GO:0022627', 2458: 'GO:0009651', 212: 'GO:0003735', 1821: 'GO:0009414', 997: 'GO:0009737', 2511: 'GO:0050832', 1178: 'GO:0003729', 1730: 'GO:0005840', 1120: 'GO:0015935', 181: 'GO:0005783', 2768: 'GO:0016829', 653: 'GO:0006888', 2364: 'GO:0051788', 3144: 'GO:0006629', 1913: 'GO:0006635', 1542: 'GO:0043161', 13: 'GO:0009751', 1022: 'GO:0030149', 2554: 'GO:0030170', 1261: 'GO:0009407', 2418: 'GO:0005789', 2876: 'GO:0003824', 2126: 'GO:0016831', 903: 

In [35]:
from tqdm import tqdm
tqdm.pandas()

mapped_iric['mapped_alt_tails'] = mapped_iric.progress_apply(
                                             lambda row: 
                                             mapped_iric.loc[(mapped_iric['mapped_subject'] == row['mapped_subject']) & 
                                                             (mapped_iric['mapped_predicate'] == row['mapped_predicate']) & 
                                                             (mapped_iric['mapped_object'] != row['mapped_object']), 
                                                             'mapped_object'].values, 
                                             axis=1)
display(mapped_iric)

100%|██████████| 10001/10001 [00:04<00:00, 2026.39it/s]


Unnamed: 0,subject,predicate,object,mapped_subject,mapped_predicate,mapped_object,mapped_alt_tails
0,OsNippo01g010050,gene ontology,GO:0031267,2270,0,2505,"[2022, 2237, 1720, 76, 515, 1419]"
1,OsNippo01g010050,gene ontology,GO:0006886,2270,0,2022,"[2505, 2237, 1720, 76, 515, 1419]"
2,OsNippo01g010050,gene ontology,GO:0005622,2270,0,2237,"[2505, 2022, 1720, 76, 515, 1419]"
3,OsNippo01g010050,gene ontology,GO:0005623,2270,0,1720,"[2505, 2022, 2237, 76, 515, 1419]"
4,OsNippo01g010050,gene ontology,GO:0090630,2270,0,76,"[2505, 2022, 2237, 1720, 515, 1419]"
...,...,...,...,...,...,...,...
9996,OsNippo01g223000,gene ontology,GO:0005784,1980,0,141,"[3024, 3140, 181, 2022, 3092, 2472, 100, 2418,..."
9997,OsNippo01g223050,gene ontology,GO:0005634,1922,0,2657,"[160, 2811, 362]"
9998,OsNippo01g223050,gene ontology,GO:0005737,1922,0,160,"[2657, 2811, 362]"
9999,OsNippo01g223050,gene ontology,GO:0003676,1922,0,2811,"[2657, 160, 362]"


In [None]:
print

In [36]:
mapped_alt_tails = {}

for index, row in mapped_iric.iterrows():
    key = (row['mapped_subject'], row['mapped_predicate'])
    if key not in mapped_alt_tails:
        mapped_alt_tails[key] = set()
    mapped_alt_tails[key].update(set(row['mapped_alt_tails']))

print(mapped_alt_tails)

for key, value in mapped_alt_tails.items():
    mapped_alt_tails[key]=np.array(list(value))

{(2270, 0): {515, 2022, 2505, 1419, 76, 1720, 2237}, (439, 0): {1094, 2918, 1161, 2316, 3128, 2524}, (1667, 0): {132, 2918, 1736, 1806, 2287, 2415, 2490}, (2502, 0): set(), (644, 0): set(), (3258, 0): set(), (3271, 0): {832, 1730, 2190, 2511, 340, 212, 1178, 2458, 1821, 863, 1120, 997, 1710, 2678, 2237}, (3266, 0): {1542, 903, 653, 13, 3092, 932, 2734, 181, 2364, 2876, 3144, 2126, 2511, 2768, 1261, 1520, 496, 2418, 368, 1913, 2554, 1022}, (1223, 0): {160, 2657, 2678, 2938}, (1635, 0): {2772, 2678, 343, 2745, 314, 2811}, (177, 0): {2657, 3138, 1835, 2192, 1140, 181, 1751, 1564, 2047}, (308, 0): set(), (1488, 0): {3138, 2498, 196, 1158, 3165, 1821, 3230, 286, 2549, 2876}, (310, 0): {2668, 1744, 1047, 2455, 3102}, (2704, 0): {3265, 2253, 1493, 2838, 347, 160, 2657, 2728, 1263, 1394, 2237}, (814, 0): {3138, 550, 1127, 2823, 1744, 1588}, (2791, 0): {3092, 181}, (753, 0): {3138, 1235, 3092, 343, 2520, 798, 3102, 418, 1960, 235, 1588, 247, 2168}, (2416, 0): {1826, 100, 1157, 1416, 3120, 538},

In [37]:
val_data = torch.load(val_path)
test_data = torch.load(test_path)
train_data = torch.load(train_path)

print("Datatsets look OK ? (val, train, test) :",
val_data.validate(),
test_data.validate(),
train_data.validate())

print(val_data)
print(test_data)
print(train_data)

Datatsets look OK ? (val, train, test) : True True True
Data(edge_index=[2, 8001], edge_attr=[8001], num_nodes=3343, edge_label=[2000], edge_label_index=[2, 2000])
Data(edge_index=[2, 9001], edge_attr=[9001], num_nodes=3343, edge_label=[2000], edge_label_index=[2, 2000])
Data(edge_index=[2, 8001], edge_attr=[8001], num_nodes=3343, edge_label=[8001], edge_label_index=[2, 8001])


In [38]:
subs = set(list(mapped_iric['mapped_subject']))
objs = set(list(mapped_iric['mapped_object']))

heads = set(train_data.edge_index[0].tolist())
tails = set(train_data.edge_index[1].tolist())

print('Number of heads in dataset  :', len(heads))
print('difference heads - subjects :', len(heads-subs))
print('difference heads - object   :', len(heads-objs))

print('Numer of tails in dataset   :', len(tails))
print('Difference tails - subjects :', len(tails-subs))
print('Differece tails - objects   :', len(tails-objs))

Number of heads in dataset  : 1312
difference heads - subjects : 0
difference heads - object   : 1312
Numer of tails in dataset   : 1769
Difference tails - subjects : 1769
Differece tails - objects   : 0


In [39]:
# Here I create a batch on wich i will test my losses. I need a loader and a model to create it.

complex_model = ComplEx(
    num_nodes=train_data.num_nodes,
    num_relations = train_data.edge_index.size()[1],
    hidden_channels=hidden_channels,
).to(device)

loader = complex_model.loader(
    head_index = train_data.edge_index[0],
    tail_index = train_data.edge_index[1],
    rel_type = train_data.edge_attr,
    batch_size=batch_size,
    shuffle=False,)

batchy = next(iter(loader))
hi, rt, ti = batchy[0], batchy[1], batchy[2]
print('\n',hi,'\n',rt,'\n', ti)


 tensor([2227,  902, 3106,  ..., 2998,  543, 2821]) 
 tensor([0, 0, 0,  ..., 0, 0, 0]) 
 tensor([2823, 1806, 2353,  ..., 3215, 2728, 1394])


# Defining LinLoss() functions

In [40]:
nxo = from_file('/home/elliot/Documents/ESL2024/data/go-basic.json.gz')
nxo.freeze()
print(nxo)

<nxontology.ontology.NXOntology object at 0x7f8a35412520>


In [41]:
def lin_sim_on_mapped_terms(mapped_term1, mapped_term2):
    term1 = map_to_GO[mapped_term1]
    term2 = map_to_GO[mapped_term2]
    if (term1 in nxo.graph._node and term2 in nxo.graph._node):
        sim = nxo.similarity(term1, term2).lin 
        return sim
    else:
        return 0  

def best_lim_sim_for_triple(head, rel, tail)-> torch.Tensor:
    max_lin_sim=0
    for alt_tail in mapped_alt_tails[(head, rel)]:

        if (map_to_GO[tail] in nxo.graph._node
            and
            map_to_GO[alt_tail] in nxo.graph._node):
                
                sim = nxo.similarity(map_to_GO[tail], map_to_GO[alt_tail]).lin # Pourrait être amélioré : actuellement, on calcule plein de similarités différentes.
                
                if max_lin_sim < sim < 1:
                    max_lin_sim = sim
    
    return max_lin_sim


def best_lin_sims_for_batch(head_index:torch.Tensor, rel_type:torch.Tensor, tail_index:torch.Tensor):

    batch = pd.DataFrame(torch.transpose(torch.stack((head_index,rel_type,tail_index)),
                                         0,1)
                        )
    
    return torch.Tensor(batch.apply(lambda row : best_lim_sim_for_triple(head=row[0],
                                                                         rel=row[1],
                                                                         tail=row[2]),
                                    axis=1))

def lin_sims_for_batch(term1: torch.Tensor, term2: torch.Tensor)->torch.Tensor:

    batch = pd.DataFrame(torch.transpose(torch.stack((term1, term2)),
                                         0,1)
                        )
    
    display(batch)
    
    return torch.Tensor(batch.apply(lambda row : best_lim_sim_for_triple(head=row[0],
                                                                         rel=row[1],
                                                                         tail=row[2]),
                                    axis=1))

In [42]:
def shuffle_tensor(t: torch.Tensor):
    '''
    Shuffles elments of a tensor.
    WARNING :
    shuffle_tensor(torch.tensor([[0,1,2,3,4,5],[6,7,8,9,0,1]]))
    returns :
    tensor([[0, 1, 2, 3, 4, 5],  OR tensor([[6, 7, 8, 9, 0, 1],
            [6, 7, 8, 9, 0, 1]])            [0, 1, 2, 3, 4, 5]])
    '''
    idx = torch.randperm(t.shape[0])
    return t[idx].view(t.size())

shuffle_tensor(torch.tensor([[0,1,2,3,4,5],[6,7,8,9,0,1]]))

tensor([[0, 1, 2, 3, 4, 5],
        [6, 7, 8, 9, 0, 1]])

In [43]:
class tail_only_ComplEx(ComplEx):

    '''
    Overwritting random_sample() to make negative triples by setting a random tail to each triple,
    instead of setting a random head or tail.
    '''
    @torch.no_grad()
    def random_sample(
        self,
        head_index: torch.Tensor,
        rel_type: torch.Tensor,
        tail_index: torch.Tensor,
        ) -> torch.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

        """
        Randomly samples negative triplets by replacing the tail.
        Args:
            head_index (torch.Tensor): The head indices.
            rel_type (torch.Tensor): The relation type.
            tail_index (torch.Tensor): The tail indices.
        """

        tail_index = shuffle_tensor(tail_index.clone())

        return head_index, rel_type, tail_index

to = tail_only_ComplEx(
    num_nodes=train_data.num_nodes,
    num_relations = train_data.edge_index.size()[1],
    hidden_channels=hidden_channels,
).to(device)



print("Heads are not touched :")
print(batchy[0])
print(to.random_sample(*batchy)[0])

print("\nBut tails are :")
print(batchy[2])
print(to.random_sample(*batchy)[2])

def tens_to_set(t):
    return set(t.tolist())

t, f = batchy[2], to.random_sample(*batchy)[2]
st, sf = tens_to_set(t), tens_to_set(f)
print('Shuffle ok :', not bool(len(st - sf)))



Heads are not touched :
tensor([2227,  902, 3106,  ..., 2998,  543, 2821])
tensor([2227,  902, 3106,  ..., 2998,  543, 2821])

But tails are :
tensor([2823, 1806, 2353,  ..., 3215, 2728, 1394])
tensor([2524,  997,  160,  ...,  863, 1288, 2309])
Shuffle ok : True


In [44]:
class LinSim_ComplEx(tail_only_ComplEx):
  def loss(
            self,
            head_index: torch.Tensor,
            rel_type: torch.Tensor,
            tail_index: torch.Tensor,
            ) -> torch.Tensor:
            
        '''
        tail_only_ComplEx.loss() modified to account a LinSim term : one simply withdraw mean(similarities(batch)) to the loss.
        '''

        pos = head_index, rel_type, tail_index

        false_head_index, false_rel_type, false_tail_index = self.random_sample(head_index, rel_type, tail_index)
        neg = false_head_index, false_rel_type, false_tail_index

        pos_score = self(*pos)
        neg_score = self(*neg)
        scores = torch.cat([pos_score, neg_score], dim=0)

        pos_target = torch.ones_like(pos_score) 
        neg_target = torch.zeros_like(neg_score)
        target = torch.cat([pos_target, neg_target], dim=0)

        # stacking true and falses tails in df :
        pos_and_neg_tails = pd.DataFrame(torch.stack((tail_index,false_tail_index)).transpose(0,1)).astype("int")

        # Calculating LinSim(positive_head, negative_head) : 
        similarities = torch.tensor(pos_and_neg_tails.apply(lambda row : lin_sim_on_mapped_terms(row[0], row[1]),
                                                      axis = 1).values
                                    )


        return F.binary_cross_entropy_with_logits(scores, target) - torch.mean(similarities)
  
class best_LinSim_ComplEx(tail_only_ComplEx):
  def loss(
            self,
            head_index: torch.Tensor,
            rel_type: torch.Tensor,
            tail_index: torch.Tensor,
            ) -> torch.Tensor:
            
        '''
        tail_only_ComplEx.loss() modified to account a LinSim term :
        one withdraw the mean(bests similarities between each false tail of a triple to its possible tails) to the loss.
        '''

        pos = head_index, rel_type, tail_index

        false_head_index, false_rel_type, false_tail_index = self.random_sample(head_index, rel_type, tail_index)
        neg = false_head_index, false_rel_type, false_tail_index

        pos_score = self(*pos)
        neg_score = self(*neg)
        scores = torch.cat([pos_score, neg_score], dim=0)

        pos_target = torch.ones_like(pos_score) 
        neg_target = torch.zeros_like(neg_score)
        target = torch.cat([pos_target, neg_target], dim=0)

        # Calculating LinSim(positive_head, negative_head) : 
        similarities = best_lin_sims_for_batch(head_index, rel_type, false_tail_index)

        return F.binary_cross_entropy_with_logits(scores, target) - torch.mean(similarities)

In [45]:
def getBack(var_grad_fn):
    print(var_grad_fn)
    for n in var_grad_fn.next_functions:
        if n[0]:
            try:
                tensor = getattr(n[0], 'variable')
                print(n[0])
                print('Tensor with grad found:\n', tensor)
                print(' - gradient:\n', tensor.grad)
                print()
            except AttributeError as e:
                getBack(n[0])

In [46]:
batchy

(tensor([2227,  902, 3106,  ..., 2998,  543, 2821]),
 tensor([0, 0, 0,  ..., 0, 0, 0]),
 tensor([2823, 1806, 2353,  ..., 3215, 2728, 1394]))

In [47]:
c = ComplEx(
    num_nodes=train_data.num_nodes,
    num_relations = train_data.edge_index.size()[1],
    hidden_channels=hidden_channels,
).to(device)

loss = c.loss(*batchy)
print(loss)

print('Tracing back tensors:')
loss.backward()
getBack(loss.grad_fn)

tensor(0.6931)
Tracing back tensors:


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [None]:
blc = best_LinSim_ComplEx(
    num_nodes=train_data.num_nodes,
    num_relations = train_data.edge_index.size()[1],
    hidden_channels=hidden_channels,
).to(device)

loss = blc.loss(*batchy)
print(loss)

print('Tracing back tensors:')
loss.backward()
getBack(loss.grad_fn)


tensor(0.3401, grad_fn=<SubBackward0>)
Tracing back tensors:
<SubBackward0 object at 0x7f8a4c0cd070>
<BinaryCrossEntropyWithLogitsBackward0 object at 0x7f8a8105cdf0>
<CatBackward0 object at 0x7f8a82252eb0>
<SubBackward0 object at 0x7f8aa7ae4880>
<AddBackward0 object at 0x7f8b782bef70>
<AddBackward0 object at 0x7f8b782be3d0>
<SumBackward1 object at 0x7f8a35405790>
<MulBackward0 object at 0x7f8a35405100>
<MulBackward0 object at 0x7f8a354059d0>
<EmbeddingBackward0 object at 0x7f8a35405130>
<AccumulateGrad object at 0x7f8a39c24190>
Tensor with grad found:
 Parameter containing:
tensor([[ 0.0163, -0.0224, -0.0084, -0.0395,  0.0165],
        [ 0.0298, -0.0268, -0.0156, -0.0270, -0.0048],
        [ 0.0292,  0.0388, -0.0003,  0.0113,  0.0225],
        ...,
        [ 0.0061,  0.0179,  0.0037, -0.0202,  0.0255],
        [ 0.0111, -0.0351,  0.0092,  0.0308,  0.0270],
        [-0.0116, -0.0379,  0.0056,  0.0030,  0.0159]], requires_grad=True)
 - gradient:
 tensor([[ 0.0000e+00,  0.0000e+00,  0.000

In [None]:
lc = LinSim_ComplEx(
    num_nodes=train_data.num_nodes,
    num_relations = train_data.edge_index.size()[1],
    hidden_channels=hidden_channels,
).to(device)

loss = lc.loss(*batchy)
print(loss)

print('Tracing back tensors:')
loss.backward()
getBack(loss.grad_fn)



tensor(0.5985, dtype=torch.float64, grad_fn=<SubBackward0>)
Tracing back tensors:
<SubBackward0 object at 0x7f8b7bb44340>
<BinaryCrossEntropyWithLogitsBackward0 object at 0x7f8aa7ad0af0>
<CatBackward0 object at 0x7f8aa7ad0a30>
<SubBackward0 object at 0x7f8aa7ad0cd0>
<AddBackward0 object at 0x7f8a39c246d0>
<AddBackward0 object at 0x7f8a39c24550>
<SumBackward1 object at 0x7f8a39c243d0>
<MulBackward0 object at 0x7f8a39c245b0>
<MulBackward0 object at 0x7f8a39c24610>
<EmbeddingBackward0 object at 0x7f8a39c24730>
<AccumulateGrad object at 0x7f8a39c24b20>
Tensor with grad found:
 Parameter containing:
tensor([[-0.0071,  0.0253,  0.0166,  0.0281,  0.0310],
        [ 0.0306,  0.0367,  0.0304, -0.0285,  0.0047],
        [ 0.0310,  0.0292,  0.0267,  0.0079, -0.0229],
        ...,
        [ 0.0230, -0.0340,  0.0070,  0.0414, -0.0046],
        [ 0.0206,  0.0354, -0.0172, -0.0133,  0.0063],
        [-0.0079,  0.0200, -0.0147,  0.0316,  0.0062]], requires_grad=True)
 - gradient:
 tensor([[ 0.0000e+00

# Defining training functions

In [72]:
def train(loader, model, optimizer, device):
    model.train()
    total_loss = total_examples = 0
    for head_index, rel_type, tail_index in loader:

        head_index, rel_type, tail_index = head_index.to(device), rel_type.to(device), tail_index.to(device)
        
        optimizer.zero_grad()
        loss = model.loss(head_index, rel_type, tail_index)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * head_index.numel()
        total_examples += head_index.numel()
    return total_loss / total_examples

@torch.no_grad()
def test(data, model,device):
    model.eval()
    return model.test(
        head_index=data.edge_index[0].to(device),
        tail_index=data.edge_index[1].to(device),
        rel_type=data.edge_attr.to(device),
        batch_size=batch_size,
        k=10, #The k in Hit@k
    )

def get_test_loss(loader, model, device):
    model.eval()
    total_loss = total_examples = 0
    for head_index, rel_type, tail_index in loader:

        head_index, rel_type, tail_index = head_index.to(device), rel_type.to(device), tail_index.to(device)
        loss = model.loss(head_index, rel_type, tail_index)
        total_loss += float(loss) * head_index.numel()
        total_examples += head_index.numel()

    return total_loss / total_examples


def train_and_test_complex(
                           model,
                           train_data: torch_geometric.data.data.Data,
                           test_data : torch_geometric.data.data.Data,
                           xp_name = '',
                           epochs: int = 1000, 
                           eval_period = 500,
                           reset_parameters = False, save_params = True,
                           use_wandb = False,
                           params_save_path = '',
                           device = 'cpu',
                           dataset_name = 'iric'
                           ):
    
    
    # ----------------------------------------------- Reset parameters
    if reset_parameters :
        model.reset_parameters()

    # ----------------------------------------------- Loader
    print('Init loader...')
    loader = model.loader(
                          head_index = train_data.edge_index[0],
                          tail_index = train_data.edge_index[1],
                          rel_type   = train_data.edge_attr,
                          batch_size = batch_size,
                          shuffle    = True)
    test_loader = model.loader(
                          head_index = test_data.edge_index[0],
                          tail_index = test_data.edge_index[1],
                          rel_type   = test_data.edge_attr,
                          batch_size = batch_size,
                          shuffle    = True)

    # ----------------------------------------------- Optimizer
    print('Init optimizer...')
    optimizer = optim.Adam(model.parameters())

    # ----------------------------------------------- WandB
    if use_wandb:
        print('Init wandb...')
        wandb.init(
            settings=wandb.Settings(start_method="fork"),
            project=xp_name,
            
            config={
            "architecture": str(type(model)),
            "dataset": dataset_name,
            "epochs": epochs,
            'hidden_channels' : hidden_channels,
            'batch_size' : batch_size
            }
        )

    # ----------------------------------------------- Train and eval
    print('Train...')

    torch.set_grad_enabled(True)
    model.to(device)

    train_losses = []
    test_losses = []
    for epoch in range(0, epochs+1):

        loss = train(loader = loader,
                     model  = model,
                     optimizer = optimizer,
                     device = device)
        test_loss = get_test_loss(
                            loader = test_loader,
                            model  = model,
                            device = device)
        
        train_losses.append(loss)
        test_losses.append(test_loss)
        

        if use_wandb : 
            wandb.log({"loss": loss, "loss on test": test_loss})
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Loss on test : {test_loss:.4f}')

    # ----------------------------------------------- Periodic Evaluation
        if eval_period:
            if epoch%eval_period == 0:
                print('Test...')
                rank, mrr, hits = test(test_data, model=model, device=device)

                print(f'Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}',
                    f'Val MRR: {mrr:.4f}, Val Hits@10: {hits:.4f}')
                if wandb:
                    wandb.log({"Val Mean Rank" : rank,
                            "Val MRR" : mrr,
                            "hits@10": hits})

    # ----------------------------------------------- End WandB
    if use_wandb:
        print('End wandb...')
        wandb.finish()
        print("WandB finished.")

    # ----------------------------------------------- Save model
    print('Save parameters...')
    if save_params:
        torch.save(model.state_dict(), params_save_path)
        print("Model saved at", params_save_path)

    print("End")

    return model, train_losses, train_losses

# Initiating models

In [63]:
# TODO Envoyer le calcul sur GPU.
# TODO Calculer sur le serveur.

print(f"Device : {device}, batch size : {batch_size}, epochs : {epochs}, hidden channels : {hidden_channels}")
print("DataSets :", '\n',train_data,'\n', test_data,'\n', val_data)

print("\nModel types :\n  -ComplEx ;\n  -Tail only Complex ;\n  -ComplEx with LinSim ;\n  -ComplEx with best possible LinSim")


usual_complex = ComplEx(
    num_nodes=train_data.num_nodes,
    num_relations = 1,#train_data.edge_index.size()[1],
    hidden_channels=hidden_channels,
)

to_ComplEx = tail_only_ComplEx(
    num_nodes=train_data.num_nodes,
    num_relations = train_data.edge_index.size()[1],
    hidden_channels=hidden_channels,
)

lin_complex = LinSim_ComplEx(
    num_nodes=train_data.num_nodes,
    num_relations = train_data.edge_index.size()[1],
    hidden_channels=hidden_channels,
)

best_lin_complex = best_LinSim_ComplEx(
    num_nodes=train_data.num_nodes,
    num_relations = train_data.edge_index.size()[1],
    hidden_channels=hidden_channels,
)

Device : cuda, batch size : 4096, epochs : 500, hidden channels : 5
DataSets : 
 Data(edge_index=[2, 8001], edge_attr=[8001], num_nodes=3343, edge_label=[8001], edge_label_index=[2, 8001]) 
 Data(edge_index=[2, 9001], edge_attr=[9001], num_nodes=3343, edge_label=[2000], edge_label_index=[2, 2000]) 
 Data(edge_index=[2, 8001], edge_attr=[8001], num_nodes=3343, edge_label=[2000], edge_label_index=[2, 2000])

Model types :
  -ComplEx ;
  -Tail only Complex ;
  -ComplEx with LinSim ;
  -ComplEx with best possible LinSim


# Train and test

In [73]:
usual_complex, usual_complex_losses = train_and_test_complex(model=usual_complex,
                                             train_data=train_data, test_data=test_data,
                                             xp_name='First try pipeline with tail only complex',
                                             epochs=1000,
                                             eval_period=500,
                                             reset_parameters=True,
                                             save_params=False,
                                             device=device,
                                             use_wandb = False
                                             )

Init loader...
Init optimizer...
Train...
Epoch: 000, Loss: 0.6932, Loss on test : 0.6931
Test...


100%|██████████| 9001/9001 [00:03<00:00, 2548.16it/s]


Epoch: 000, Val Mean Rank: 1612.94 Val MRR: 0.0027, Val Hits@10: 0.0037
Epoch: 001, Loss: 0.6931, Loss on test : 0.6931
Epoch: 002, Loss: 0.6931, Loss on test : 0.6931
Epoch: 003, Loss: 0.6931, Loss on test : 0.6930
Epoch: 004, Loss: 0.6930, Loss on test : 0.6930
Epoch: 005, Loss: 0.6930, Loss on test : 0.6930
Epoch: 006, Loss: 0.6930, Loss on test : 0.6930
Epoch: 007, Loss: 0.6929, Loss on test : 0.6929
Epoch: 008, Loss: 0.6929, Loss on test : 0.6929
Epoch: 009, Loss: 0.6929, Loss on test : 0.6928
Epoch: 010, Loss: 0.6928, Loss on test : 0.6928
Epoch: 011, Loss: 0.6928, Loss on test : 0.6928
Epoch: 012, Loss: 0.6927, Loss on test : 0.6927
Epoch: 013, Loss: 0.6927, Loss on test : 0.6927
Epoch: 014, Loss: 0.6926, Loss on test : 0.6926
Epoch: 015, Loss: 0.6925, Loss on test : 0.6925
Epoch: 016, Loss: 0.6925, Loss on test : 0.6925
Epoch: 017, Loss: 0.6924, Loss on test : 0.6924
Epoch: 018, Loss: 0.6923, Loss on test : 0.6923
Epoch: 019, Loss: 0.6922, Loss on test : 0.6922
Epoch: 020, Loss

100%|██████████| 9001/9001 [00:03<00:00, 2542.30it/s]


Epoch: 500, Val Mean Rank: 112.27 Val MRR: 0.1456, Val Hits@10: 0.3683
Epoch: 501, Loss: 0.2329, Loss on test : 0.2745
Epoch: 502, Loss: 0.2355, Loss on test : 0.2706
Epoch: 503, Loss: 0.2289, Loss on test : 0.2728
Epoch: 504, Loss: 0.2302, Loss on test : 0.2745
Epoch: 505, Loss: 0.2295, Loss on test : 0.2704
Epoch: 506, Loss: 0.2326, Loss on test : 0.2765
Epoch: 507, Loss: 0.2278, Loss on test : 0.2693
Epoch: 508, Loss: 0.2323, Loss on test : 0.2744
Epoch: 509, Loss: 0.2268, Loss on test : 0.2752
Epoch: 510, Loss: 0.2271, Loss on test : 0.2731
Epoch: 511, Loss: 0.2267, Loss on test : 0.2691
Epoch: 512, Loss: 0.2203, Loss on test : 0.2705
Epoch: 513, Loss: 0.2256, Loss on test : 0.2680
Epoch: 514, Loss: 0.2202, Loss on test : 0.2672
Epoch: 515, Loss: 0.2235, Loss on test : 0.2652
Epoch: 516, Loss: 0.2236, Loss on test : 0.2714
Epoch: 517, Loss: 0.2190, Loss on test : 0.2670
Epoch: 518, Loss: 0.2278, Loss on test : 0.2656
Epoch: 519, Loss: 0.2205, Loss on test : 0.2674
Epoch: 520, Loss:

100%|██████████| 9001/9001 [00:03<00:00, 2449.45it/s]


Epoch: 1000, Val Mean Rank: 85.98 Val MRR: 0.1841, Val Hits@10: 0.4478
Save parameters...
End


ValueError: too many values to unpack (expected 2)

In [67]:
To, To_losses = train_and_test_complex(model=to_ComplEx,
                                             train_data=train_data, test_data=test_data,
                                             xp_name='First try pipeline with tail only complex',
                                             epochs=1000,
                                             eval_period=500,
                                             reset_parameters=True,
                                             save_params=False,
                                             device=device,
                                             use_wandb = False
                                             )

Init loader...
Init optimizer...
Train...
Epoch: 000, Loss: 0.6931
Test...


100%|██████████| 9001/9001 [00:34<00:00, 263.53it/s]


Epoch: 000, Val Mean Rank: 1607.70 Val MRR: 0.0033, Val Hits@10: 0.0040
Epoch: 001, Loss: 0.6931
Epoch: 002, Loss: 0.6931
Epoch: 003, Loss: 0.6931
Epoch: 004, Loss: 0.6931
Epoch: 005, Loss: 0.6931
Epoch: 006, Loss: 0.6931
Epoch: 007, Loss: 0.6931
Epoch: 008, Loss: 0.6931
Epoch: 009, Loss: 0.6931
Epoch: 010, Loss: 0.6931
Epoch: 011, Loss: 0.6931
Epoch: 012, Loss: 0.6931
Epoch: 013, Loss: 0.6931
Epoch: 014, Loss: 0.6931
Epoch: 015, Loss: 0.6931
Epoch: 016, Loss: 0.6931
Epoch: 017, Loss: 0.6931
Epoch: 018, Loss: 0.6931
Epoch: 019, Loss: 0.6931
Epoch: 020, Loss: 0.6930
Epoch: 021, Loss: 0.6930
Epoch: 022, Loss: 0.6930
Epoch: 023, Loss: 0.6930
Epoch: 024, Loss: 0.6930
Epoch: 025, Loss: 0.6929
Epoch: 026, Loss: 0.6929
Epoch: 027, Loss: 0.6929
Epoch: 028, Loss: 0.6928
Epoch: 029, Loss: 0.6928
Epoch: 030, Loss: 0.6927
Epoch: 031, Loss: 0.6927
Epoch: 032, Loss: 0.6926
Epoch: 033, Loss: 0.6926
Epoch: 034, Loss: 0.6925
Epoch: 035, Loss: 0.6924
Epoch: 036, Loss: 0.6923
Epoch: 037, Loss: 0.6923
Epo

100%|██████████| 9001/9001 [00:34<00:00, 264.32it/s]


Epoch: 500, Val Mean Rank: 237.64 Val MRR: 0.1484, Val Hits@10: 0.3301
Epoch: 501, Loss: 0.4124
Epoch: 502, Loss: 0.4090
Epoch: 503, Loss: 0.4060
Epoch: 504, Loss: 0.4127
Epoch: 505, Loss: 0.4097
Epoch: 506, Loss: 0.4074
Epoch: 507, Loss: 0.4107
Epoch: 508, Loss: 0.4070
Epoch: 509, Loss: 0.4096
Epoch: 510, Loss: 0.4067
Epoch: 511, Loss: 0.4076
Epoch: 512, Loss: 0.4009
Epoch: 513, Loss: 0.4013
Epoch: 514, Loss: 0.4078
Epoch: 515, Loss: 0.3988
Epoch: 516, Loss: 0.4064
Epoch: 517, Loss: 0.4047
Epoch: 518, Loss: 0.4056
Epoch: 519, Loss: 0.4017
Epoch: 520, Loss: 0.4022
Epoch: 521, Loss: 0.4005
Epoch: 522, Loss: 0.4017
Epoch: 523, Loss: 0.3995
Epoch: 524, Loss: 0.3987
Epoch: 525, Loss: 0.3998
Epoch: 526, Loss: 0.4005
Epoch: 527, Loss: 0.3946
Epoch: 528, Loss: 0.3980
Epoch: 529, Loss: 0.3959
Epoch: 530, Loss: 0.3966
Epoch: 531, Loss: 0.3957
Epoch: 532, Loss: 0.3940
Epoch: 533, Loss: 0.3979
Epoch: 534, Loss: 0.3979
Epoch: 535, Loss: 0.3871
Epoch: 536, Loss: 0.3904
Epoch: 537, Loss: 0.3970
Epoc

100%|██████████| 9001/9001 [00:33<00:00, 264.79it/s]

Epoch: 1000, Val Mean Rank: 337.21 Val MRR: 0.0318, Val Hits@10: 0.0678
Save parameters...
End





In [69]:
blc, losses_blc= train_and_test_complex(model=best_lin_complex,
                                             train_data=train_data, test_data=test_data,
                                             xp_name='First try pipeline with best lin sim complex',
                                             epochs=1000,
                                             eval_period=500,
                                             reset_parameters=True,
                                             save_params=False,
                                             device=device,
                                             use_wandb = False
                                             )

Init loader...
Init optimizer...
Train...


TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.