In [0]:
%load_ext autoreload
%autoreload 2

In [0]:
if "DATABRICKS_RUNTIME_VERSION" in os.environ:
  #CUDA = 'cu121' 
  
  import os
  
  !pip install torch==2.1.0  torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
  import torch
  #os.environ['TORCH'] = torch.__version__
  #print(torch.__version__)
  #torch_version = '2.0.0+cu118'
  
  #!pip install pyg_lib torch_scatter torch_sparse torch_cluster -f https://data.pyg.org/whl/torch-2.1.0+${CUDA}.html # torch_spline_conv
  !pip install torch_geometric
  !pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
  #!pip install torch_sparse -f https://data.pyg.org/whl/torch-2.1.0+${CUDA}.html
  #!pip install torch_scatter -f https://data.pyg.org/whl/torch-2.1.0+${CUDA}.html
  #!pip install pyg_lib -f https://data.pyg.org/whl/torch-2.1.0+${CUDA}.html
  !pip install sentence-transformers
  !pip install torcheval
  !pip install matplotlib
  !pip install pandas
  !pip install tensorboard
  ROOT_FOLDER = 'dbfs:/FileStore/GraphNeuralNetworks/'
else:
  ROOT_FOLDER = ''

In [0]:
from learnings_sampler_v1 import get_datasets, uniform_hgt_sampler, get_minibatch_count

train_data, val_data, test_data = get_datasets(get_edge_attr=False)


In [0]:
import torch
from models.TransE import TransE
from models.HGT import HGT

class Model(torch.nn.Module):
    def __init__(self, gnn : torch.nn.Module, head :  torch.nn.Module, node_types, edge_types, ggn_output_dim):
        super().__init__()
        # edge_type onehot lookup table with keys
        # node_type onehot lookup table with keys
        self.node_type_embedding = torch.nn.Embedding(len(node_types), ggn_output_dim) # hidden channels should be the output dim of gnn
        
        self.edge_types = edge_types
        for edge_type in edge_types:
            if edge_type[1].startswith('rev_'):
                self.edge_types.remove(edge_type)
        
        # create edge to int mapping
        self.edgeindex_lookup = {edge_type:torch.tensor(i)  for i, edge_type in enumerate(edge_types)}
            
        # hidden channels should be the output dim of gnn
        if head=='TransE': 
            self.head = TransE(len(node_types), len(edge_types) , ggn_output_dim)  # KGE head with loss function
            self.head.to('cuda:0')
        else:
            raise NotImplementedError
        
        self.gnn = gnn
        
    

    def forward(self, hetero_data1, target_edge_type, edge_label_index, edge_label, hetero_data2=None):
        
        if hetero_data2 is not None:
            assert target_edge_type[0] != target_edge_type[2], 'when passing two data objects, the edge type has to contain two different node types'
            head_embeddings = self.gnn(hetero_data1.x_dict, hetero_data1.edge_index_dict)[target_edge_type[0]][edge_label_index[0,:]]
            tail_embeddings = self.gnn(hetero_data2.x_dict, hetero_data2.edge_index_dict)[target_edge_type[2]][edge_label_index[1,:]]
        else:
            assert target_edge_type[0] == target_edge_type[2], 'when passing one data object, the edge type has to contain the same node types'
            embeddings = self.gnn(hetero_data1.x_dict, hetero_data1.edge_index_dict)
            head_embeddings = embeddings[target_edge_type[0]][edge_label_index[0,:]]
            tail_embeddings = embeddings[target_edge_type[2]][edge_label_index[1,:]]
        
        edgeindex = self.edgeindex_lookup[target_edge_type]
        loss = self.head.loss(head_embeddings, edgeindex.to(next(model.parameters()).device), tail_embeddings, edge_label)
        return loss
    
        
metadata = train_data.metadata()
# add selfloops
for node_type in train_data.node_types:
    metadata[1].append((node_type, 'self_loop', node_type))    
    
out_channels = 64
gnn = HGT(hidden_channels=64, out_channels=out_channels, num_heads=2, num_layers=2, node_types=train_data.node_types, data_metadata=metadata)

model = Model(gnn, head='TransE', node_types=metadata[0], edge_types=metadata[1], ggn_output_dim=out_channels)
model.to('cuda:0')

In [0]:
# get cuda device names
import torch
print(torch.cuda.device_count())
torch.cuda.get_device_name(0)

In [0]:
model.to('cuda:0')


In [0]:
# test training

In [0]:
# for each node type get the data type of .x
for node_type in train_data.node_types:
    print(node_type, train_data[node_type].x.dtype)
    # get min and max
    print(train_data[node_type].x.min(), train_data[node_type].x.max())

In [0]:
batch_size = 32
num_relationships = len(train_data.edge_types)
one_hop_neighbors = (25 * batch_size)//num_relationships # per relationship type
two_hop_neighbors = (25 * 10 * batch_size)//num_relationships # per relationship type
num_neighbors = [one_hop_neighbors, two_hop_neighbors]
print('num_neighbors', num_neighbors)
device = 'cuda:0'
sampler = uniform_hgt_sampler(train_data, batch_size, True, 'binary', 1, num_neighbors)


optimizer = torch.optim.Adam(model.parameters(), lr=2e-7) #2e-15
model.train()

for i,(same_nodetype, target_edge_type, batch) in enumerate(sampler):
    optimizer.zero_grad() 
    # batching is different depending on if node types in edge are same or different
    print(target_edge_type)
    if same_nodetype:
        minibatch, edge_label_index, edge_label, input_edge_ids = batch
        loss = model(minibatch.to(device), target_edge_type, edge_label_index.to(device), edge_label.to(device))
    else:
        minibatchpart1, minibatchpart2, edge_label_index, edge_label, input_edge_id = batch
        loss = model(minibatch1.to(device), target_edge_type, edge_label_index.to(device), edge_label.to(device), minibatch2.to(device))
        
    optimizer.step()
    print(loss)
    

    
    

In [0]:
import os
import torch
from torch_geometric.data import HeteroData


from typing import Tuple, List
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.sampler import NegativeSampling
from torch_geometric.data import HeteroData
import gc
import multiprocessing as mp

# watch -n 1 df -h /dev/shm
gc.collect()


from torcheval.metrics import BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score, BinaryAUPRC, BinaryAUROC
from tqdm.auto import tqdm
import gc
from torch.utils.tensorboard import SummaryWriter
import os
from pathlib import Path
from datetime import datetime
import torch
import numpy as np
from sklearn.metrics import f1_score


class GNNTrainer():
    def __init__(self, model, criterion, optimizer, device, log_folder):
        #super().__init__(model, criterion, optimizer, device, ) # metrics=['f1','accuracy','precision','recall', 'aucpr']
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.train_batch_size = 0  # for tqdm, for logging
        self.train_n_mini_in_batch = 0  # for tqdm
        self.val_n_mini_in_batch = 0  # for tqdm
        self.log_folder = log_folder
    
    def free_memory(self):
        """Clears the GPU cache and triggers garbage collection, to reduce OOMs."""
        torch.cuda.empty_cache()
        gc.collect()
        
    
            
    
        
    def write_calc_metrics(self, split_name:str, y_hat, y, y_per_edgetype, y_hat_per_edgetype, epoch:int, is_epoch:bool, minibatch:int=0, loss:int=0, loss_per_edgetype=None, print_=True):
        assert epoch>=1, 'Epoch must be >= 1'
        assert minibatch >= 1, "minibatch must be >=1"
        
        if split_name=='train':
            assert self.train_batch_size != 0
            assert self.train_n_mini_in_batch != 0
            assert loss != 0, "loss can't be 0"
            if is_epoch:
                loss_per_edge_type = {}
                y_per_edgetype = {}
                y_hat_per_edgetype = {}
            assert loss_per_edgetype is not None
            
        else:
            assert loss_per_edgetype is None
            
        if is_epoch:
            split_name = 'epoch_'+split_name
            index = epoch
        else:
            split_name = 'samples_'+split_name
            no_minibatches = (epoch-1)*self.train_n_mini_in_batch + minibatch
            approx_no_samples = no_minibatches*self.train_batch_size
            index = approx_no_samples
        
        y_per_edgetype[('','all','')]= y
        y_hat_per_edgetype[('','all','')]= y_hat
        if 'train' in split_name:
            loss_per_edgetype[('','all','')] = loss
        
        
        for edgetype in y_per_edgetype.keys():
            y = y_per_edgetype[edgetype].to(torch.int).detach().cpu()
            y_hat = y_hat_per_edgetype[edgetype].detach().cpu()
            edge_name = '-'.join(list(edgetype)).strip('-')
        
            
            def calculate_f1(y_hat, y, thresholds):
                
                return np.array([f1_score(y_hat>threshold, y, average='binary') for threshold in thresholds])
            
            # get best f1 threshold
            thresholds = np.arange(0.001, 1, 0.001)
            a, b = y_hat.numpy(), y.numpy()
            f1s = calculate_f1(a,b , thresholds)
            optimal_threshold = thresholds[np.argmax(f1s)]
            f1=max(f1s)
    
            acc, prec, rec = BinaryAccuracy(threshold=optimal_threshold).update(y_hat, y).compute().item(), BinaryPrecision(threshold=optimal_threshold).update(y_hat, y).compute().item(), BinaryRecall(threshold=optimal_threshold).update(y_hat, y).compute().item()
            self.writer.add_scalar(f'{split_name}_f1threshold_{edge_name}', optimal_threshold, index)
            self.writer.add_scalar(f'{split_name}_accuracy_{edge_name}', acc, index)
            self.writer.add_scalar(f'{split_name}_precision_{edge_name}', prec, index)
            self.writer.add_scalar(f'{split_name}_recall_{edge_name}', rec, index)
            self.writer.add_scalar(f'{split_name}_f1_{edge_name}', f1, index)
            
        
            auprc = BinaryAUPRC().update(y_hat, y).compute().item()
            auroc = BinaryAUROC().update(y_hat, y).compute().item()
            
            self.writer.add_scalar(f'{split_name}_auprc_{edge_name}', auprc, index)
            self.writer.add_scalar(f'{split_name}_auroc_{edge_name}', auroc, index)
            if 'train' in split_name:
                loss = loss_per_edgetype[edgetype]
                self.writer.add_scalar(f'{split_name}_loss_{edge_name}', loss, index)
            self.writer.flush()
            
            if print_ and edgetype==('','all',''):
                out_of = f'/{self.train_n_mini_in_batch:06d}' if 'train' in split_name else ''
                no_samples = f'|samples:{index}' if 'train' in split_name else ''
                loss_to_show = f'loss:{loss:.4f},' if 'train' in split_name else ''
                print(f'{split_name}|{int(minibatch):04d}{out_of}|{epoch:04d}{no_samples}|{loss_to_show} F1: {f1:.6f}, AUC-PR: {auprc:.6f}, (auroc: {auroc:.6f}, acc: {acc:.6f}, prec: {prec:.6f}, rec: {rec:.6f})')

    def create_logfolders(self, run_folder=None):
        if run_folder is None:
            run_folder = datetime.now().strftime('run_%d%m%Y_%H%M%S')
            
        self.writer = SummaryWriter(log_dir=Path(self.log_folder)/(run_folder+'_tensorboard'))
        self.checkpoint_folder = Path(self.log_folder)/(run_folder+'_checkpoints')
        if not os.path.exists(self.checkpoint_folder):
            os.makedirs(self.checkpoint_folder)
        
        if not os.path.exists(Path(self.log_folder)/(run_folder+'_tensorboard')):
            os.makedirs(Path(self.log_folder)/(run_folder+'_tensorboard'))
            
        print(f'run folder is {run_folder}')
        
    
    def train(self, train_iterator, val_iterator, start_epoch, n_epochs, run_folder=None, save_metrics_after_n_batches=100):
        self.free_memory()
        self.create_logfolders(run_folder)

        self.model.train()
        
        print(f'Number of parameters: {sum(p.numel() for p in model.parameters())}')
        print(f'Number of learnable parameters: {sum(p.numel() for p in model.parameters()  if p.requires_grad)}')
        
        assert start_epoch >= 1, "Epoch must be >= 1"
        
        for epoch in range(start_epoch, start_epoch+n_epochs):
            
            epoch_loss = 0
            for batch_idx, edge_batches in tqdm(enumerate(train_iterator()), total=self.train_n_mini_in_batch, desc='train epoch'):
                batch_idx+=1  # start from 1
                self.optimizer.zero_grad()  # empty gradients
                minibatch_loss = 0
                loss_per_edgetype = {}
                y_hat, y = [], []
                y_hat_per_edgetype, y_per_edgetype = {}, {}
                for supervision_edge_type, batch in edge_batches:  # each "batch" here is one edge type, since we want to learn for all edge types
                    batch = batch.to(self.device)
                    hetero_out = model(batch.x_dict, batch.edge_index_dict, batch.edge_weight_dict, batch.num_sampled_edges_dict, batch.num_sampled_nodes_dict)  # get model output

                    # evaluate, calculate cosine sim and compute cross-entropy loss
                    src_type, dst_type = supervision_edge_type[0], supervision_edge_type[2]
                    edge_label = batch[supervision_edge_type].edge_label
                    edge_label_index = batch[supervision_edge_type].edge_label_index
                    src_node_embeddings = hetero_out[src_type][edge_label_index[0]]
                    dst_node_embeddings = hetero_out[dst_type][edge_label_index[1]]
                    
                    loss, y_pred = self.criterion(src_node_embeddings, dst_node_embeddings, edge_label)
                    
                    minibatch_loss += loss
                    # collect data for metrics
                    loss_per_edgetype[supervision_edge_type] = loss.detach().item()
                    y_hat_per_edgetype[supervision_edge_type] = y_pred.detach().cpu()
                    y_per_edgetype[supervision_edge_type] = edge_label.to(torch.int).detach().cpu()
                    y_hat.append(y_hat_per_edgetype[supervision_edge_type])
                    y.append(y_per_edgetype[supervision_edge_type])
                    
                minibatch_loss.backward()
                self.optimizer.step()
                minibatch_loss = minibatch_loss.detach().item()
                epoch_loss += minibatch_loss
             
                y_hat = torch.cat(y_hat)
                y = torch.cat(y)
                # create metrics and write to tensorboard writer
                if batch_idx%save_metrics_after_n_batches==1:
                    self.write_calc_metrics('train', y_hat, y, y_per_edgetype, y_hat_per_edgetype,  epoch=epoch, minibatch=batch_idx, loss=minibatch_loss, loss_per_edgetype=loss_per_edgetype, is_epoch=False, print_=False)
                    self.validate(val_iterator, epoch, batch_idx, is_epoch=False)
                    self.model.train()  # back to training, just in case
           
            self.save_checkpoint(epoch, batch_idx)
            self.write_calc_metrics('train', y_hat, y, y_per_edgetype={}, y_hat_per_edgetype={}, loss_per_edgetype={}, epoch=epoch, minibatch=batch_idx, loss=epoch_loss, is_epoch=True, print_=True)
            self.validate(val_iterator, epoch, batch_idx, is_epoch=True)
            self.model.train()  # back to training, just in case
            
    def validate(self, val_iterator, epoch, batch_idx, is_epoch):
        self.model.eval()
        with torch.no_grad():
            y_hat, y = [], []
            y_hat_per_edgetype, y_per_edgetype = {}, {}
            for edge_batches in val_iterator():
                for supervision_edge_type, batch in edge_batches:  # each "batch" here is one edge type, since we want to learn for all edge types
                    batch = batch.to(self.device)
                    hetero_out = model(batch.x_dict, batch.edge_index_dict, batch.edge_weight_dict, batch.num_sampled_edges_dict, batch.num_sampled_nodes_dict)

                    # evaluate, calculate cosine sim and compute cross-entropy loss
                    src_type, dst_type = supervision_edge_type[0], supervision_edge_type[2]
                    edge_label = batch[supervision_edge_type].edge_label
                    edge_label_index = batch[supervision_edge_type].edge_label_index
                    if src_type not in  hetero_out.keys() or dst_type not in hetero_out.keys():
                        print('eval failed on one minibatch part, skipping')
                        print('Supervision edge type:',supervision_edge_type)
                        print('one type is missing in model output',src_type, dst_type)
                        hetero_out = model(batch.x_dict, batch.edge_index_dict, batch.edge_weight_dict, batch.num_sampled_edges_dict, batch.num_sampled_nodes_dict)
                        print(batch.x_dict)
                        print(hetero_out.keys())
                        print(batch)
                        continue
                        
                    src_node_embeddings = hetero_out[src_type][edge_label_index[0]]
                    dst_node_embeddings = hetero_out[dst_type][edge_label_index[1]]
                    
                    #logits = F.cosine_similarity(src_node_embeddings, dst_node_embeddings, dim=-1)
                    
                    _, y_pred = self.criterion(src_node_embeddings, dst_node_embeddings, edge_label)
                    # y_hat.append(y_pred.detach())
                    # y.append(edge_label.to(torch.int).detach())

                    # collect data for metrics
                    if supervision_edge_type not in y_hat_per_edgetype.keys():
                        y_hat_per_edgetype[supervision_edge_type] = []
                        y_per_edgetype[supervision_edge_type] = []
                        
                    y_hat_per_edgetype[supervision_edge_type].append(y_pred.detach().cpu())
                    y_per_edgetype[supervision_edge_type].append(edge_label.to(torch.int).detach().cpu())
                   
                    
            # save metrics
            for key in y_hat_per_edgetype:
                y_hat_per_edgetype[key] = torch.cat(y_hat_per_edgetype[key])
                y_per_edgetype[key] = torch.cat(y_per_edgetype[key])
                y_hat.append(y_hat_per_edgetype[key])
                y.append(y_per_edgetype[key])
                
            y_hat = torch.cat(y_hat)
            y = torch.cat(y)
            self.write_calc_metrics('val', y_hat, y, y_per_edgetype, y_hat_per_edgetype, epoch=epoch, is_epoch=is_epoch, print_=True, minibatch=batch_idx)
    
    def save_checkpoint(self, epoch, batch_idx):
        print(f'save checkpoint {self.checkpoint_folder}/checkpoint_ep{epoch}_{batch_idx}.pt')
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
           # 'metrics_history': self.story,
        }, f'{self.checkpoint_folder}/checkpoint_ep{epoch}.pt')

    def load_checkpoint(self, load_path):
        checkpoint = torch.load(load_path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        
import os
import torch


# from models.WeightedSkillGAT import weightedSkillGAT_lr_2emin6_16hiddenchannels_8heads_128out_2layers_edgeweights_checkpoints
# model = weightedSkillGAT_lr_2emin6_16hiddenchannels_8heads_128out_2layers_edgeweights_checkpoints()
from models.WeightedSkillSAGE import weightedSkillSAGE_lr_2emin7_1lin_1lin_256dim_edgeweight_noskillskillpred_checkpoints
from models.WeightedSkillSAGE import weightedSkillSAGE_lr_2emin7_0lin_256dim_edgeweight_prelu_batchnorm_checkpoints
from models.WeightedSkillSAGE import weightedSkillSAGE_lr_2emin7_0lin_132dim_edgeweight_prelu_batchnorm_checkpoints
from models.WeightedSkillSAGE import skillsage_388_prelu_batchnorm_edgeweight
# model = weightedSkillSAGE_lr_2emin7_1lin_1lin_256dim_edgeweight_noskillskillpred_checkpoints()
# model = weightedSkillSAGE_lr_2emin7_0lin_256dim_edgeweight_prelu_batchnorm_checkpoints()
model = skillsage_388_prelu_batchnorm_edgeweight()
#os.environ["TOKENIZERS_PARALLELISM"] = "true"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# torch._dynamo.config.verbose=True
# torch._dynamo.config.suppress_errors = True

optimizer = torch.optim.Adam(model.parameters(), lr=2e-7) #2e-15
def graphSAGE_loss(u, v, y_label):
    y_neg = (y_label-1)
    y_label = (y_neg + y_label).squeeze()  # has -1 for neg and 1 for pos
    # loss= -1* log(sig(u,v_pos)) - Q*E*log(sig(-1*u,v_neg)) where Q is number of neg, E is expected value
    logits = torch.sum(torch.mul(u, v), dim=-1)
    
    mul = torch.sigmoid(torch.mul(logits, y_label))
    loss = -1*torch.sum(torch.log(mul)) # sum across all examples
    y_hat = torch.sigmoid(logits) # just for metrics in later step
    return loss, y_hat.detach()

criterion = graphSAGE_loss
#criterion = torch.nn.CrossEntropyLoss()
trainer = GNNTrainer(model, criterion, optimizer, device , log_folder='runs')
#trainer.load_checkpoint('./checkpoints/checkpoint_0_300.pt')


# for tqdm
trainer.train_batch_size = batch_size
trainer.train_n_mini_in_batch = train_batch_len
trainer.val_n_mini_in_batch = val_batch_len

trainer.train(
    train_iterator, 
    val_iterator, 
    start_epoch=1, 
    n_epochs=200, 
    run_folder=f'skillsage_388_prelu_batchnorm_edgeweight_jsssjj_fulldsv2', # temp
    save_metrics_after_n_batches=1000) # graphconv_v0_lr_2emin6_2lin_1lin_256dim
#weightedSkillSAGE_lr_2emin7_0lin_256dim_edgeweight_prelu_batchnorm_checkpoints
# trainer.validate(val_dataloader)
# trainer.plot_losses()
# trainer.load_checkpoint('./checkpoints/checkpoint_100.pt')
# trainer.validate(val_iterator,1)



In [0]:
batch_size = 32
num_relationships = len(train_data.edge_types)
one_hop_neighbors = (25 * batch_size)//num_relationships # per relationship type
two_hop_neighbors = (25 * 10 * batch_size)//num_relationships # per relationship type
num_neighbors = [one_hop_neighbors, two_hop_neighbors]
print('num_neighbors', num_neighbors)

sampler = uniform_hgt_sampler(train_data, batch_size, True, 'binary', 1, num_neighbors)
start = datetime.datetime.now()
print(start)
print()
for i,(same_nodetype, target_edge_type, batch) in enumerate(sampler):
    
    # batching is different depending on if node types in edge are same or different
    edge_type = batch[-1]
    if same_nodetype:
        minibatch, edge_label_index, edge_label, input_edge_ids = batch
        print(minibatch)
    else:
        minibatchpart1, minibatchpart2, edge_label_index, edge_label, input_edge_id = batch
        print(minibatchpart1)
        
    print(i,target_edge_type)
    
    break
    time.sleep(5)
    
end = datetime.datetime.now()
print()
print(end-start)