In [None]:
from notebooks.util_classifier import MultiHomogeneousGraphTripletDataset
from training_classifier import *
from util_classifier import *
from gat_models import *

import os
import random
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim import Adam
from torch.nn.modules.loss import TripletMarginLoss
import torch.nn as nn
from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import *
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

In [None]:
config = {
    'experiment': 'GAT Classifier Training',
    'max_hops': 3,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 32,
    'out_channels': 16,
    'num_heads': 8,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 10,
    'batch_size': 32,
}

In [None]:
# Graph sampling configurations
node_properties = [
    'id',
    'feature_vec',
]

node_spec = [
    NodeType.PUBLICATION
]

edge_spec = [
    EdgeType.SIM_ABSTRACT,
    EdgeType.SIM_TITLE,
    EdgeType.SIM_AUTHOR,
    #EdgeType.SAME_AUTHOR,
]

gat_list = [
    './data/results/homogeneous (title) compressed_emb linear_layer dropout/gat_encoder.pt',
    './data/results/homogeneous (abstract) compressed_emb linear_layer dropout/gat_encoder.pt',
    './data/results/homogeneous (sim_author) compressed_emb linear_layer dropout/gat_encoder.pt',
]

database = 'homogeneous-graph'
gs = GraphSampling(
    node_spec=node_spec,
    edge_spec=edge_spec,
    node_properties=node_properties,
    database=database
)

# Model configurations

loss_fn = TripletMarginLoss(margin=config['margin'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# TODO: Adjust result folder name!
result_folder_name = 'classifier full_emb (abstract, title, sim_author edges)'
result_folder_path = f'./data/results/{result_folder_name}'
if not os.path.exists(result_folder_path):
    os.mkdir(result_folder_path)

## Embedding Network
**This network takes in the stacked GAT node embeddings and outputs a lower-dimensional embedding.**

In [None]:
class EmbeddingNet(nn.Module):
    def __init__(self, input_size, hidden_size = 128, embedding_size = 16):
        super(EmbeddingNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, embedding_size)
        )
    
    def forward(self, x):
        output = self.fc(x)
        return output

## Triplet Network
**This network takes in three inputs: an anchor, a positive example, and a negative example. It outputs the embeddings of the three inputs. It is used to simplify triplet loss training**

In [None]:
class TripletNet(nn.Module):
    def __init__(self, embedding_net):
        super(TripletNet, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, anchor, positive, negative):
        output_anchor = self.embedding_net(anchor)
        output_positive = self.embedding_net(positive)
        output_negative = self.embedding_net(negative)
        return output_anchor, output_positive, output_negative

    def get_embedding(self, x):
        return self.embedding_net(x)

## Pair Classifier
**This network will be used for the actual classification task (the AND pipeline).**

In [None]:
class PairClassifier(nn.Module):
    def __init__(self, embedding_net):
        super(PairClassifier, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, embedding_1, embedding_2):
        out_1 = self.embedding_net(embedding_1)
        out_2 = self.embedding_net(embedding_2)
        
        # Compute probability of the pair being similar by computing pairwise distance
        distance = F.pairwise_distance(out_1, out_2)
        similarity_prediction = torch.sigmoid(-distance)
        
        return similarity_prediction

    def get_embedding(self, x):
        return self.embedding_net(x)

## Training

In [None]:
db = DatabaseWrapper(database=database)
data_harvester = ClassifierTripletDataHarvester(db=db, gs=gs, edge_spec=edge_spec, config=config)


# Split the pairs into train and test
train_size = int(0.85 * len(data_harvester.triplets))
test_size = int(0.1 * len(data_harvester.triplets))
eval_size = len(data_harvester.triplets) - train_size - test_size

# Harvest the evaluation triplets first, since triplets are ordered by author. This will ensure that the evaluation set has authors not seen in the training set.
eval_triplets = data_harvester.triplets[:eval_size]

train_test_triplets = data_harvester.triplets[eval_size:]
random.shuffle(train_test_triplets)

train_triplets = train_test_triplets[:train_size]
test_triplets = train_test_triplets[train_size:]
config['train_size'] = len(train_triplets)
config['test_size'] = len(test_triplets)
config['eval_size'] = len(eval_triplets)

print(f"Train size: {len(train_triplets)}, Test size: {len(test_triplets)}, Eval size: {len(eval_triplets)}")

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = MultiHomogeneousGraphTripletDataset(train_triplets, gs, edge_spec=edge_spec, config=config)
test_dataset = MultiHomogeneousGraphTripletDataset(test_triplets, gs, edge_spec=edge_spec, config=config)
eval_dataset = MultiHomogeneousGraphTripletDataset(eval_triplets, gs, edge_spec=edge_spec, config=config)

# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_triplet_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)
eval_dataloader = DataLoader(eval_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)

# Create models
metadata = (
    [n.value for n in node_spec],
    [edge_pyg_key_vals[r] for r in edge_spec]
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]
model = TripletNet(metadata, config['hidden_channels'], config['out_channels'], num_heads=config['num_heads']).to(device)
optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])



input_size = 512  # dimension of stacked embeddings
embedding_size = 64  # size of the classifier embedding space
embedding_net = EmbeddingNet(input_size, embedding_size)
model = TripletNet(embedding_net)

# Loss
triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
num_epochs = config['num_epochs']
train_losses = []

test_losses = []
test_accuracies = []
test_correct_pos = []
test_correct_neg = []

eval_losses = []
eval_accuracies = []
eval_correct_pos = []
eval_correct_neg = []

current_batch = 1

for epoch in range(1, num_epochs + 1):
    print(f"=== Epoch {epoch}/{num_epochs} ======================")
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    current_batch = 1
    for batch_anchor, batch_pos, batch_neg in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if batch_anchor is None or batch_pos is None or batch_neg is None:
            continue
        
        if current_batch == 1 or current_batch == len(train_dataloader) // 2:
            print(f"___ Current Batch {current_batch}/{len(train_dataloader)} _________________________")
            # Model testing
            print("    Test Results:")
            test_loss, test_num_correct, test_correct_pos_val, test_correct_neg_val = test(
                model=model,
                loss_fn=loss_fn,
                dataloader=test_dataloader,
                margin=config['margin']
            )
            test_losses.append(test_loss)
            test_accuracies.append(test_num_correct)
            test_correct_pos.append(test_correct_pos_val)
            test_correct_neg.append(test_correct_neg_val)
    
            plot_loss(test_losses, epoch_len=2, plot_title='Test Loss', plot_avg=False, plot_file=result_folder_path + '/test_loss.png')
            plot_loss(
                test_accuracies,
                epoch_len=2,
                plot_title='Test Accuracy',
                plot_avg=False, 
                x_label='Test Iterations',
                y_label='Accuracy',
                line_label='Accuracy',
                plot_file=result_folder_path + '/test_accuracy.png'
            )
            
            # Model evaluation
            print("    Eval Results:")
            eval_loss, eval_num_correct, eval_correct_pos_val, eval_correct_neg_val = evaluate(
                model=model,
                loss_fn=loss_fn,
                dataloader=eval_dataloader,
                margin=config['margin']
            )
            eval_losses.append(eval_loss)
            eval_accuracies.append(eval_num_correct)
            eval_correct_pos.append(eval_correct_pos_val)
            eval_correct_neg.append(eval_correct_neg_val)
            
            plot_loss(eval_losses, epoch_len=2, plot_title='Evaluation Loss', plot_avg=False, plot_file=result_folder_path + '/eval_loss.png')
            plot_loss(
                eval_accuracies, 
                epoch_len=2, 
                plot_title='Evaluation Accuracy', 
                plot_avg=False, 
                x_label='Eval Iterations',
                y_label='Accuracy',
                line_label='Accuracy',
                plot_file=result_folder_path + '/eval_accuracy.png'
            )
            
        loss = train(
            model=model,
            loss_fn=loss_fn,
            batch_anchor=batch_anchor,
            batch_pos=batch_pos,
            batch_neg=batch_neg,
            optimizer=optimizer
        )
        train_losses.append(loss)
        
        plot_loss(train_losses, epoch_len=len(train_dataloader), plot_title='Training Loss', plot_avg=True, plot_file=result_folder_path + '/train_loss.png')
        current_batch += 1
        
    # Save config and training results
    eval_results = {
        'eval_losses': eval_losses,
        'eval_accuracies': eval_accuracies,
        'eval_correct_pos': eval_correct_pos,
        'eval_correct_neg': eval_correct_neg
    }
    save_training_results(train_losses, test_losses, eval_results, config, result_folder_path + '/training_data.json')
    
    # Save model if loss has decreased
    if len(test_losses) > 1 and test_losses[-1] < min(test_losses[:-1]):
        print(f"Saving model at epoch {epoch}...")
        torch.save(model.state_dict(), result_folder_path + '/gat_encoder.pt')
