# Colab specific header

In [None]:
from IPython.display import JSON
from google.colab import output
from subprocess import getoutput
import os

def shell(command):
  if command.startswith('cd'):
    path = command.strip().split(maxsplit=1)[1]
    os.chdir(path)
    return JSON([''])
  return JSON([getoutput(command)])
output.register_callback('shell', shell)

In [None]:
#@title Colab Shell
%%html
<div id=term_demo></div>
<script src="https://code.jquery.com/jquery-latest.js"></script>
<script src="https://cdn.jsdelivr.net/npm/jquery.terminal/js/jquery.terminal.min.js"></script>
<link href="https://cdn.jsdelivr.net/npm/jquery.terminal/css/jquery.terminal.min.css" rel="stylesheet"/>
<script>
  $('#term_demo').terminal(async function(command) {
      if (command !== '') {
          try {
              let res = await google.colab.kernel.invokeFunction('shell', [command])
              let out = res.data['application/json'][0]
              this.echo(new String(out))
          } catch(e) {
              this.error(new String(e));
          }
      } else {
          this.echo('');
      }
  }, {
      greetings: 'Welcome to Colab Shell',
      name: 'colab_demo',
      height: 250,
      prompt: 'colab > '
  });

In [None]:
# MOUNT GDRIVE TO COLAB
IN_COLAB = 'google.colab' in str(get_ipython())
if IN_COLAB:
  from google.colab import drive
  drive.mount('/content/drive')

# For all Jupyter-like environments

In [None]:
# UNCOMENT REQUIREMENTS TO BE INSTALLED
# from IPython.display import clear_output
# !pip install -r requirements_colab.txt
# !pip install -r requirements.txt
# clear_output(wait=False)

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config IPCompleter.greedy=True

In [None]:
# MODIFY LIBRARY PATH
import os
from pathlib import Path
lib_path = Path('/content/drive/MyDrive/Colab Notebooks/catlearn/CategoryLearning')
# lib_path = Path('/home/fastuser/CategoryLearning')
os.chdir(lib_path)
print(os.getcwd())
print(os.path.isdir(lib_path))

In [None]:
from typing import Callable, Iterable, Any
import sys
import torch
import random
import numpy as np
import networkx as nx
import collections
from tqdm import (trange, tqdm)
from matplotlib import pyplot as plt
from catlearn.data.dataset import Dataset
from catlearn.tensor_utils import (Tsor, DEFAULT_EPSILON)
from catlearn.graph_utils import (DirectedGraph,
                                    uniform_sample,
                                    random_walk_edge_sample,
                                    random_walk_vertex_sample,
                                    clean_selfloops,
                                    augment_graph,
                                    create_revers_rels)
from catlearn.algebra_models import (Algebra, VectAlgebra, VectMultAlgebra)
from catlearn.composition_graph import CompositeArrow
from catlearn.categorical_model import (TrainableDecisionCatModel, RelationModel,
                                        ScoringModel)

In [None]:
# > VS CODE SPECIFIC <
# Uncoment if run in VS Code embeded notebook
# import warnings
# Needed to show warnings in VS Code integrade jupyter
# warnings.simplefilter(action="default")

In [None]:
# Enable reproducibility
# https://pytorch.org/docs/stable/notes/randomness.html
np.random.seed(42)
torch.manual_seed(42)
# Unable to use below command with True as non-deterministics algos are used by CUDA
torch.set_deterministic(False)

In [None]:
# CHECK GPU
!nvidia-smi
print(torch.cuda.device(0))
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))
print(torch.cuda.is_available())
# CHECK KERNEL PYTHON VERSION
print(sys.version)

For weights and biases integration, run in terminal:  
```wandb login```
On request to enter API key, pasthe the key from website.  
Last valid key ```ca29a13bf2ac0110134723b00696d94b9a6ad032```  
Check that you are logged as ```arreason-labs```

In [None]:
from wandb_logger import log_results, save_params

# IMPORTANT SET BELOW ```default_tensor = torch.cuda.FloatTensor``` FOR GPU


In [None]:
# Select default type: https://pytorch.org/docs/stable/tensors.html
# default_tensor = torch.FloatTensor
default_tensor = torch.cuda.FloatTensor
torch.set_default_tensor_type(default_tensor)
default_tensor.dtype

In [None]:
# Specify dataset path
# ds_path_wn18 = './Datasets/wn18rr/text'
ds_path_wn18 = lib_path/'Datasets/wn18rr/text'
# ds_path_fb15 = './Datasets/fb15k-237'

In [None]:
ds_wn18 = Dataset(path=ds_path_wn18, ds_name='wn18', node_vec_dim=10)

In [None]:
# MODIFY BELOW TO USE RIGHT DATASET
ds = ds_wn18

# Check dataset has multiconnections

In [None]:
# ds_l = list(ds_wn18.train)
# ds_l_edges = [(tpl[0], tpl[1]) for tpl in ds_l]
# cnt_edges = collections.Counter(ds_l_edges)
# cnt_edges = sorted(cnt_edges.items(), key=lambda item: item[1], reverse=True)
# cnt_edges_multi = [tpl for tpl in cnt_edges if tpl[1] > 1]
# ds.train = ds_l
# print(f'Edges with 2 relations: {len(cnt_edges_multi)}')
# print(f'Max number of multirelations: {max([cnt for edge, cnt in cnt_edges_multi])}')

## Create training graphg

In [None]:
graph_labels = DirectedGraph(ds.train)

## Clean graph

In [None]:
clean_selfloops(graph_labels)

In [None]:
print(nx.info(graph_labels))

# Augment graph

In [None]:
relation_revers = {
    '_hypernym': None,
    '_derivationally_related_form': '_derivationally_related_form',
    '_member_meronym': None,
    '_has_part': None,
    '_synset_domain_topic_of': None,
    '_instance_hypernym' : None,
    '_also_see': '_also_see',
    '_verb_group': '_verb_group',
    '_member_of_domain_usage': None,
    '_member_of_domain_region': None,
    '_similar_to': '_similar_to',
}

In [None]:
relation2id_augmented, relation_id2vec_augmented, revers_rels = create_revers_rels(relation_revers, ds.relation2id)

In [None]:
augment_graph(graph_labels, revers_rels)

# Print graphs stats

In [None]:
print(nx.info(graph_labels))

# Define Relation model

In [None]:
class CustomRelation(RelationModel):
    """ 
    """
    def __init__(self, nb_features: int, nb_labels: int, algebra: Algebra) -> None:
        self.linear = torch.nn.Linear(2 * nb_features + nb_labels, algebra.flatdim)

    def named_parameters(self, recurse: bool = True) -> Iterable[Tsor]:
        return self.linear.named_parameters(recurse=recurse)

    def __call__(self, x: Tsor, y: Tsor, l: Tsor) -> Tsor:
        """ Compute x R y """
        return self.linear(torch.cat((x, y, l), -1))

# Define Score model

In [None]:
class CustomScore(ScoringModel):
    """ Must be defined. depends on algebra and 
    in the scope of definition of the project. """
    def __init__(
            self,
            nb_features: int,
            nb_scores: int,
            algebra: Algebra) -> None:
        self.linear = torch.nn.Linear(
            2 * nb_features + algebra.flatdim, nb_scores + 1)
        self.softmax = torch.nn.Softmax(dim=-1)

    def named_parameters(self, recurse: bool = True) -> Iterable[Tsor]:
        return self.linear.named_parameters(recurse=recurse)

    def __call__(self, src: Tsor, dst: Tsor, rel: Tsor) -> Tsor:
        """ Compute S(src, dst, rel) """
        cat_input = torch.cat((src, dst, rel), -1)
        return self.softmax(self.linear(cat_input))[..., :-1]

# Create training models

In [None]:
algebra = VectMultAlgebra(ds.entity_vec_dim)

In [None]:
relation_model = CustomRelation(
    nb_features=ds.entity_vec_dim,
    nb_labels=len(ds.relation2id),
    algebra=algebra
)

In [None]:
# check if realtions set is modified after augmentation
assert relation2id_augmented == ds.relation2id

In [None]:
scoring_model = CustomScore(
    nb_features=ds.entity_vec_dim,
    nb_scores=len(relation2id_augmented),
    algebra=algebra
)

In [None]:
scaler = torch.cuda.amp.GradScaler()

In [None]:
model = TrainableDecisionCatModel(
    relation_model=relation_model,
    label_universe=relation_id2vec_augmented,
    scoring_model=scoring_model,
    algebra_model=algebra,
    optimizer=torch.optim.Adam,
    epsilon=DEFAULT_EPSILON
)

In [None]:
# DEBUG NOTE: datatype comparaison
# Dataset interable format [src:int, tgt:int, lbl: {id:int: None}]
# CompositeArrow data format:  [[src: int, tgt: int], [label: int]]

# Create training loop

In [None]:
def graph_to_nodes_edges(graph: DirectedGraph):
    nodes = ((src, dst) for src, dst in graph.edges(data=False))
    edges = (list(rel.keys()) for _, _, rel in graph.edges(data=True))
    return nodes, edges

In [None]:
def nodes_edges_to_arrows(nodes, edges):
    return [CompositeArrow(nodes=node_pair, arrows=[edge]) for node_pair, edges in zip(nodes, edges) for edge in edges]

In [None]:
def plot_subgraphs(subgraph, plot_nth=5, graph_info=True):
    """Plot each n'th subgraph for debug"""
    if not i % plot_nth:
        nx.draw_networkx(subgraph)
        plt.show()
        if graph_info: print(nx.info(subgraph))

In [None]:
# NOTE: for large graphs, random_walk functions family can be used to sub-sample graph
# while preserving its topology 
for i in trange(100, desc='1st epoque'):
    sampled_subgraph = uniform_sample(graph=graph_labels, sample_vertices_size=256, rng=random.Random(), with_edges=True)
    plot_subgraphs(sampled_subgraph, plot_nth=20)
    nodes, edges = graph_to_nodes_edges(sampled_subgraph)
    # !Numbers of arrows == N unique labels. It's >= to number of edges because some edges have multiple edges.
    arrows = nodes_edges_to_arrows(nodes, edges)
    cache, matches = model.train(
        data_points = ds.entity_id2vec,
        relations = arrows,
        # NOTE: Labels could be a couplete graph, a subgraph from random_walk or a sub-sub-graph used to create a batch
        labels = sampled_subgraph,
        step = True,
        match_negatives=False
    )
    log_results(cache, matches)

In [None]:
save_params(model)