In [None]:
import sys
from pathlib import Path

if 'google.colab' in sys.modules:
  from google.colab import drive
  colab_path = '/content/'
  drive.mount('/content/drive',force_remount=True)
  DRIVE_FOLDER = Path('/content/drive/MyDrive/DataExplorationProject/Skill_Ontology_GNN')
  colab = True
else:
  colab_path = ''
  colab = False
  
  
import torch
from matplotlib import pyplot as plt
import gc

class Trainer:
    def __init__(self, model, criterion, optimizer, device, metrics=[]):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.metrics_history = self.create_metrics_history(metrics)
        self.epoch = 0

    def create_metrics_history(self, metrics):
        metrics = set(metrics)
        metrics.add('epoch')
        metrics.add('minibatch')
        metrics.add('accuracy')
        metrics.add('loss')

        metrics = list(metrics)
        metrics_history={}
        for split in ['train','val']:
            metrics_history[split]={}
            for metric in metrics:
                metrics_history[split][metric]=[]
        return metrics_history

    def free_memory(self):
        """Clears the GPU cache and triggers garbage collection, to reduce OOMs."""
        torch.cuda.empty_cache()
        gc.collect()

    def train(self, dataloader, n_epochs, save_interval, save_path):
        self.free_memory()
        self.model.train()
        for epoch in range(self.epoch, self.epoch+n_epochs):
            print(f'=============== Epoch {epoch} ===============')
            for batch_idx, (data, target) in enumerate(dataloader):
                data, target = data.to(self.device), target.to(self.device)
                self.optimizer.zero_grad()
                output = self.model(data)
                loss = self.criterion(output, target)
                loss.backward()
                self.optimizer.step()
                self.train_losses.append(loss.item())
                if batch_idx % save_interval == 0:
                    self.save_checkpoint(epoch, batch_idx, save_path)

                print(f'Mini-Batch {batch_idx}, Loss: {loss}')

    def validate(self, dataloader):
        self.model.eval()
        with torch.no_grad():
            for data, target in dataloader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, target)
                self.val_losses.append(loss.item())




    def save_checkpoint(self, epoch, batch_idx, save_path):
        print('save')
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'metrics_history': self.metrics_history,
        }, f'{save_path}/checkpoint_{epoch}_{batch_idx}.pt')

    def load_checkpoint(self, load_path):
        checkpoint = torch.load(load_path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
       

    def plot_losses(self):
        plt.figure(figsize=(10,5))
        plt.title("Training and Validation Loss")
        plt.plot(self.train_losses,label="train")
        plt.plot(self.val_losses,label="val")
        plt.xlabel("iterations")
        plt.ylabel("Loss")
        plt.legend()
        plt.show()
        
        
if colab:
    # Install required packages.
    import os
    import torch
    os.environ['TORCH'] = torch.__version__
    print(torch.__version__)
    # !pip install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
    # !pip install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
    # !pip install torch-cluster -f https://data.pyg.org/whl/torch-{torch.__version__}.html
    # !pip install git+https://github.com/pyg-team/pytorch_geometric.git
    !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
    !pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
    !pip install torch_geometric
    !pip install sentence-transformers
    !pip install torcheval
    !pip install matplotlib
    !pip install pandas
    # unpack datasets
    if not 'unzipped' in globals():
        !unzip /content/drive/MyDrive/DataExplorationProject/Skill_Ontology_GNN/neo4jgraph.zip
        unzipped =True
        

In [None]:
import os
import torch


filename = 'Job_Skill_HeteroData_v1.pt'
if os.path.exists('./'+filename):
    data = torch.load('./'+filename)
    print('loading saved heterodata object')
else:
    torch.save(data, './'+filename)

In [None]:
from torch_geometric import seed_everything
import torch_geometric.transforms as T
seed_everything(4)

transform = T.RandomLinkSplit(
    is_undirected=True,
    edge_types=[
        ('Job', 'REQUIRES', 'Skill'),
        ('Skill', 'IS_SIMILAR_SKILL', 'Skill'),
        ('Job', 'IS_SIMILAR_JOB', 'Job')
        ],
    rev_edge_types=[
        ('Skill', 'rev_REQUIRES', 'Job'),
        ('Skill', 'rev_IS_SIMILAR_SKILL', 'Skill'),
        ('Job', 'rev_IS_SIMILAR_JOB', 'Job')
    ],
    num_val=0.008,
    num_test=0.80,
    add_negative_train_samples=False, # only adds neg samples for val and test, neg train are added by LinkNeighborLoader. This means for each train batch, negs. are different, for val and train they stay the same
    neg_sampling_ratio=1.0,
    disjoint_train_ratio=0, #  training edges are shared for message passing and supervision
    

    )
train_data, val_data, test_data = transform(data)

In [None]:
from typing import Tuple, List, Union
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.sampler import NegativeSampling
from torch_geometric.data import HeteroData

def create_loader(data:HeteroData, edge_type:Tuple[str,str,str], num_neighbors:List[int], negative_sampling_amount:int, batch_size:int, is_training:bool)->LinkNeighborLoader:

    #print('create mini-batches for', edge)

    negative_sampling = NegativeSampling(
        mode='binary',
        amount=negative_sampling_amount  # ratio, like Graphsage
        #weight=  # "Probabilities" of nodes to be sampled: Node degree follows power law distribution
        )

    loader = LinkNeighborLoader(
        data,
        num_neighbors=num_neighbors,
        # {
        #     ('Job', 'REQUIRES', 'Skill'):num_neighbors,
        #     ('Skill', 'rev_REQUIRES', 'Job'):num_neighbors,
        #     ('Skill', 'IS_SIMILAR_SKILL', 'Skill'):num_neighbors, # In this example, index 0 will never be used, since neighboring edge to a job node can't be a skill-skill edge
        #     ('Skill', 'rev_IS_SIMILAR_SKILL', 'Skill'):num_neighbors,
        #     ('Job', 'IS_SIMILAR_JOB', 'Job'):num_neighbors,
        #     ('Job', 'rev_IS_SIMILAR_JOB', 'Job'):num_neighbors,
        # },
        edge_label_index=(edge_type, data[edge_type].edge_label_index), # if (edge, None), None means all edges are considered
        #  =train_data[edge].edge_label,
        neg_sampling=negative_sampling, # adds negative samples
        batch_size=batch_size,
        shuffle=is_training,
        #drop_last=True,
        #num_workers=0,
        directed=True,  # contains only edges which are followed, False: contains full node induced subgraph
        #disjoint=True # sampled seed node creates its own, disjoint from the rest, subgraph, will add "batch vector" to loader output
        pin_memory=True # faster data transfer to gpu
    )

    return loader


batch_size=64
num_neighbors = [5,4]

def create_iterator(data, is_training:bool):
    loaders = []
    for edge_type in [train_data.edge_types[0]]:
        if 'rev_' in edge_type[1]:
            continue    # we dont need rev_ target edges, since they are the same
                        # rev edges are only needed in the later step for the gnn traversal
        # create mini-batches for each edge type, because LinkNeighborLoader only allows one target edge type
     
        loader = create_loader(
            data=data,
            edge_type=edge_type,
            num_neighbors=num_neighbors,
            batch_size=batch_size,
            is_training=is_training,
            negative_sampling_amount=(20 if is_training else 1)
        )
        loaders.append(loader)
    
    
    # creates an iterator which has as many elements as the longest iterable
    # other iterables will be repeated until the longest is done
    length = 0
    index = 0
    for i, iterable in enumerate(loaders):
        l = len(iterable)
        if l>length:
            length = l
            index = i

    longest_loader = loaders.pop(index)
    
    
    # create a list of iterators
    iterators = [iter(loader) for loader in loaders]
    
    def iterator():
        for batch in longest_loader:
            batches = [batch]
            for i in range(len(iterators)):
                try:
                    batches.append(next(iterators[i]))
                except StopIteration:
                    iterators[i] = iter(loaders[i])
                    batches.append(next(iterators[i]))
            yield tuple(batches)

    return iterator, len(longest_loader)
    
    

# watch -n 1 df -h /dev/shm
gc.collect()
train_iterator, train_batch_len = create_iterator(train_data, is_training=True)
val_iterator, val_batch_len = create_iterator(val_data, is_training=False)
test_iterator, test_batch_len = create_iterator(test_data, is_training=False)

In [None]:
from typing import Tuple, Union
from torch import Tensor
from torch_geometric.nn import to_hetero, HeteroDictLinear, Linear
from torch_geometric.nn.conv import GraphConv, SAGEConv, SimpleConv
import torch.nn.functional as F
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size




# PyG does not implement the exact max pooling aggregation as in the GraphSage paper
# with GraphConvWithPool we manually extend it by adding a linear layer on x before .propagate
# as our activation function is monotonically increasing, this modification corresponds to the max pooling aggregation

class GraphConvWithPool(GraphConv):
    def __init__(self, in_channels, out_channels: int, aggr: str = 'add', bias: bool = True, **kwargs):
        super().__init__(in_channels, out_channels, aggr, bias, **kwargs)
        self.linear = torch.nn.Linear(in_channels, in_channels, bias=False)

    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
                edge_weight: OptTensor = None, size: Size = None) -> Tensor:

        if isinstance(x, Tensor):
            x: OptPairTensor = (x, x)

        x = self.linear(x) # added this

        out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
                             size=size)
        out = self.lin_rel(out)

        x_r = x[1]
        if x_r is not None:
            out = out + self.lin_root(x_r)

        return out



class WeightedSkillSage(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        #self.linear1 = Linear(-1,-1)
        #self.conv1 = SimpleConv(aggr='sum')
        self.linear1 = Linear(-1,hidden_channels)
        self.linear2 = Linear(-1,hidden_channels)
        self.conv1 = GraphConv(in_channels=hidden_channels, out_channels=hidden_channels)
        self.conv2 = GraphConv(in_channels=hidden_channels, out_channels=hidden_channels)
        self.linear3 = Linear(hidden_channels,out_channels)

    def forward(self, x: HeteroData, edge_index, edge_weight):
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        x = F.relu(x)
        x = self.conv1(x, edge_index, edge_weight=edge_weight)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_weight=edge_weight)
        x = F.relu(x)
        x = self.linear3(x)
        return x

model = WeightedSkillSage(hidden_channels=64, out_channels=64)
model = to_hetero(model, train_data.metadata(), aggr='sum')


In [None]:
from torcheval.metrics import BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score, BinaryAUPRC
from tqdm.auto import tqdm

class GNNTrainer(Trainer):
    def __init__(self, model, criterion, optimizer, device):
        super().__init__(model, criterion, optimizer, device, metrics=['f1','accuracy','precision','recall', 'aucpr'])
        self.train_batch_len = 0
        self.val_batch_len = 0
        
    def get_supervision_edge_type(self, heterodata):
        for edge_type in heterodata.edge_types:
            if 'input_id' in heterodata[edge_type].keys():
                return edge_type

    def calculate_metrics(self, split_name, y_hat, y, print_=True):
        y = y.to(torch.int).cpu()
        y_hat = y_hat.cpu()
        acc, prec, rec, f1, aucpr = BinaryAccuracy(threshold=0.5).update(y_hat, y).compute().item(), BinaryPrecision(threshold=0.5).update(y_hat, y).compute().item(), BinaryRecall(threshold=0.5).update(y_hat, y).compute().item(), BinaryF1Score(threshold=0.5).update(y_hat, y).compute().item(), BinaryAUPRC().update(y_hat, y).compute().item()
        self.metrics_history[split_name]['accuracy'].append(acc)
        self.metrics_history[split_name]['precision'].append(prec)
        self.metrics_history[split_name]['recall'].append(rec)
        self.metrics_history[split_name]['f1'].append(f1)
        self.metrics_history[split_name]['aucpr'].append(aucpr) 
        if print_:
            print(f'{split_name}: F1: {f1}, AUC-PR: {aucpr}, (acc: {acc}, prec: {prec}, rec: {rec})')

    def train(self, train_iterator, val_iterator, start_epoch, n_epochs, batch_save_interval, save_path):
        self.free_memory()

        self.model.train()
        for epoch in range(start_epoch, start_epoch+n_epochs):
            self.optimizer.zero_grad()
            
            total_loss = 0
            for batch_idx, edge_batches in tqdm(enumerate(train_iterator()), total=self.train_batch_len):
                minibatch_loss = 0
                
                y_hat, y = [], []
                for i, batch in enumerate(edge_batches):  # each batch here is one edge type, since we want to learn for all edge types
                    batch = batch.to(self.device)
                    hetero_out = model(batch.x_dict, batch.edge_index_dict, batch.edge_weight_dict)  # get model output

                    # evaluate, calculate cosine sim and compute cross-entropy loss
                    supervision_edge_type = self.get_supervision_edge_type(batch)
                    src_type, dst_type = supervision_edge_type[0], supervision_edge_type[2]
                    edge_label = batch[supervision_edge_type].edge_label
                    edge_label_index = batch[supervision_edge_type].edge_label_index
                    src_node_embeddings = hetero_out[src_type][edge_label_index[0]]
                    dst_node_embeddings = hetero_out[dst_type][edge_label_index[1]]
                    logits = F.cosine_similarity(src_node_embeddings, dst_node_embeddings, dim=-1)
                    loss = self.criterion(logits, edge_label)
                    minibatch_loss += loss

                    y_hat.append(torch.sigmoid(logits).cpu().detach())
                    y.append(edge_label.to(torch.int).cpu().detach())

                minibatch_loss.backward()
                self.optimizer.step()
                total_loss += minibatch_loss.item()

                # save loss and metrics
                self.metrics_history['train']['minibatch'].append(epoch+batch_idx)
                self.metrics_history['train']['epoch'].append(epoch+batch_idx)
                self.metrics_history['train']['loss'].append(minibatch_loss.item())

                y_hat = torch.cat(y_hat)
                y = torch.cat(y)
                
                self.calculate_metrics('train', y_hat, y, print_=False)
                if batch_idx % batch_save_interval == 0:
                    self.validate(val_iterator, epoch)
                    self.save_checkpoint(epoch, batch_idx, save_path)
                    
            print(f'ep{epoch-n_epochs}, Loss: {total_loss}')

    def validate(self, val_iterator, epoch):
        self.model.eval()
        with torch.no_grad():
            y_hat, y = [], []
            
            for i,edge_batches in enumerate(val_iterator()):
                print(f'{i}/{self.val_batch_len-1}', end='\r')
                for batch in edge_batches:  # each batch here is one edge type, since we want to learn for all edge types

                    batch = batch.to(self.device)
                    hetero_out = model(batch.x_dict, batch.edge_index_dict, batch.edge_weight_dict)  # get model output

                    # evaluate, calculate cosine sim and compute cross-entropy loss
                    supervision_edge_type = self.get_supervision_edge_type(batch)
                    src_type, dst_type = supervision_edge_type[0], supervision_edge_type[2]
                    edge_label = batch[supervision_edge_type].edge_label
                    edge_label_index = batch[supervision_edge_type].edge_label_index
                    src_node_embeddings = hetero_out[src_type][edge_label_index[0]]
                    dst_node_embeddings = hetero_out[dst_type][edge_label_index[1]]
                    logits = F.cosine_similarity(src_node_embeddings, dst_node_embeddings, dim=-1)
                    y_hat.append(torch.sigmoid(logits).cpu().detach())
                    y.append(edge_label.to(torch.int).cpu().detach())
            print('')

            # save loss and metrics
            self.metrics_history['val']['epoch'].append(epoch)
            y_hat = torch.cat(y_hat)
            y = torch.cat(y)
            self.calculate_metrics('val', y_hat, y)

In [None]:
import os
#os.environ["TOKENIZERS_PARALLELISM"] = "true"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)


optimizer = torch.optim.Adam(model.parameters(), lr=0.00001) #2e-15
criterion = torch.nn.CrossEntropyLoss()
trainer = GNNTrainer(model, criterion, optimizer, device)
#trainer.load_checkpoint('./checkpoints/checkpoint_0_300.pt')


# for tqdm
trainer.train_batch_len = train_batch_len
trainer.val_batch_len = val_batch_len

trainer.train(train_iterator, val_iterator, start_epoch=300, n_epochs=200, batch_save_interval=100, save_path='checkpoints')
# trainer.validate(val_dataloader)
# trainer.plot_losses()
# trainer.load_checkpoint('./checkpoints/checkpoint_100.pt')