In [17]:
from notebooks.util_classifier import MultiHomogeneousGraphTripletDataset
from training_classifier import *
from util_classifier import *
from gat_models import *

import os
import random
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim import Adam
from torch.nn.modules.loss import TripletMarginLoss
import torch.nn as nn
from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import *
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

In [18]:
config = {
    'experiment': 'GAT Classifier Training With same author edges and no gat dropout',
    'max_hops': 2,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 32,
    'out_channels': 16,
    'num_heads': 8,
    'classifier_in_channels': 4 * 16,
    'classifier_hidden_channels': 128,
    'classifier_out_channels': 16,
    'classifier_dropout': 0.2,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 10,
    'batch_size': 32,
}

In [19]:
# Graph sampling configurations
node_properties = [
    'id',
    'feature_vec',
]

node_spec = [
    NodeType.PUBLICATION
]

edge_spec = [
    EdgeType.SIM_ABSTRACT,
    EdgeType.SIM_TITLE,
    EdgeType.SIM_AUTHOR,
    EdgeType.SAME_AUTHOR,
]

gat_list = {
    EdgeType.SIM_TITLE: './data/results/homogeneous (title) full_emb linear_layer dropout/gat_encoder.pt',
    EdgeType.SIM_ABSTRACT: './data/results/homogeneous (abstract) full_emb linear_layer dropout/gat_encoder.pt',
    EdgeType.SIM_AUTHOR: './data/results/homogeneous (similar co-authors) full_emb linear_layer dropout/gat_encoder.pt',
    EdgeType.SAME_AUTHOR: './data/results/homogeneous (same author) full_emb linear_layer dropout/gat_encoder.pt',
}


database = 'homogeneous-graph'
gs = GraphSampling(
    node_spec=node_spec,
    edge_spec=edge_spec,
    node_properties=node_properties,
    database=database
)

# Model configurations
# Load pre-trained GAT encoders
gat_encoders = {}
for edge_key, gat_path in gat_list.items():
    gat_encoder = HomoGATEncoderLinearDropout(config['hidden_channels'], config['out_channels'], num_heads=config['num_heads'], dropout_p=0).to(device)
    gat_encoder.load_state_dict(torch.load(gat_path))
    gat_list[edge_key] = gat_encoder

loss_fn = TripletMarginLoss(margin=config['margin'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# TODO: Adjust result folder name!
result_folder_name = 'classifier full_emb (abstract, title, sim_author, same_author edges) no_gat_dropout'
result_folder_path = f'./data/results/{result_folder_name}'
if not os.path.exists(result_folder_path):
    os.mkdir(result_folder_path)

Using default edge type: SimilarAbstract for homogeneous graph sampling.


## Embedding Network
**This network takes in the stacked GAT node embeddings and outputs a lower-dimensional embedding.**

In [20]:
class EmbeddingNet(nn.Module):
    def __init__(self, input_size, hidden_size = 128, embedding_size = 16, dropout = 0.2):
        super(EmbeddingNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, embedding_size)
        )
    
    def forward(self, x):
        output = self.fc(x)
        return output

## Triplet Network
**This network takes in three inputs: an anchor, a positive example, and a negative example. It outputs the embeddings of the three inputs. It is used to simplify triplet loss training**

In [21]:
class TripletNet(nn.Module):
    def __init__(self, embedding_net: EmbeddingNet, edge_spec: [EdgeType], gat_encoders: dict[EdgeType, nn.Module]):
        super(TripletNet, self).__init__()
        self.edge_spec = edge_spec
        self.gat_encoders = gat_encoders
        self.embedding_net = embedding_net
        
        for gat in self.gat_encoders.values():
            gat.eval()
            for param in gat.parameters():
                param.requires_grad = False

    def forward(self, data_dict: dict):
        anchor = []
        positive = []
        negative = []
        
        for edge_type in self.edge_spec:
            # Anchor node embedding for the edge type 
            anchor_graph = data_dict[edge_type][0]
            anchor_gat_emb = self.gat_encoders[edge_type](anchor_graph)
            anchor.append(anchor_gat_emb[anchor_graph.central_node_id])
            
            # Positive node embedding for the edge type
            positive_graph = data_dict[edge_type][1]
            positive_gat_emb = self.gat_encoders[edge_type](positive_graph)
            positive.append(positive_gat_emb[positive_graph.central_node_id])
            
            # Negative node embedding for the edge type
            negative_graph = data_dict[edge_type][2]
            negative_gat_emb = self.gat_encoders[edge_type](negative_graph)
            negative.append(negative_gat_emb[negative_graph.central_node_id])
            
        anchor = torch.cat(anchor, dim=1)
        positive = torch.cat(positive, dim=1)
        negative = torch.cat(negative, dim=1)
        
        output_anchor = self.embedding_net(anchor)
        output_positive = self.embedding_net(positive)
        output_negative = self.embedding_net(negative)
        
        return output_anchor, output_positive, output_negative

    def get_embedding(self, x):
        return self.embedding_net(x)

## Pair Classifier
**This network will be used for the actual classification task (the AND pipeline).**

In [22]:
class PairClassifier(nn.Module):
    def __init__(self, embedding_net):
        super(PairClassifier, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, embedding_1, embedding_2):
        out_1 = self.embedding_net(embedding_1)
        out_2 = self.embedding_net(embedding_2)
        
        # Compute probability of the pair being similar by computing pairwise distance
        distance = F.pairwise_distance(out_1, out_2)
        similarity_prediction = torch.sigmoid(-distance)
        
        return similarity_prediction

    def get_embedding(self, x):
        return self.embedding_net(x)

## Training

In [23]:
db = DatabaseWrapper(database=database)
data_harvester = ClassifierTripletDataHarvester(db=db, gs=gs, edge_spec=edge_spec, config=config, valid_triplets_save_file='valid_triplets_classifier', transformer_model='sentence-transformers/all-MiniLM-L6-v2')


# Split the pairs into train and test
train_size = int(0.85 * len(data_harvester.triplets))
test_size = int(0.1 * len(data_harvester.triplets))
eval_size = len(data_harvester.triplets) - train_size - test_size

# Harvest the evaluation triplets first, since triplets are ordered by author. This will ensure that the evaluation set has authors not seen in the training set.
eval_triplets = data_harvester.triplets[:eval_size]

train_test_triplets = data_harvester.triplets[eval_size:]
random.shuffle(train_test_triplets)

train_triplets = train_test_triplets[:train_size]
test_triplets = train_test_triplets[train_size:]
config['train_size'] = len(train_triplets)
config['test_size'] = len(test_triplets)
config['eval_size'] = len(eval_triplets)

print(f"Train size: {len(train_triplets)}, Test size: {len(test_triplets)}, Eval size: {len(eval_triplets)}")

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = MultiHomogeneousGraphTripletDataset(train_triplets, gs, edge_spec=edge_spec, config=config)
test_dataset = MultiHomogeneousGraphTripletDataset(test_triplets, gs, edge_spec=edge_spec, config=config)
eval_dataset = MultiHomogeneousGraphTripletDataset(eval_triplets, gs, edge_spec=edge_spec, config=config)

# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_triplet_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)
eval_dataloader = DataLoader(eval_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_triplet_collate)

# Create models
metadata = (
    [n.value for n in node_spec],
    [edge_pyg_key_vals[r] for r in edge_spec]
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]

# Embedding model
embedding_net = EmbeddingNet(
    input_size=config['classifier_in_channels'], 
    hidden_size=config['classifier_hidden_channels'],
    embedding_size=config['classifier_out_channels'],
    dropout=config['classifier_dropout']
).to(device)

# Triplet training classifier model
triplet_net = TripletNet(
    embedding_net=embedding_net,
    edge_spec=edge_spec,
    gat_encoders=gat_list
).to(device)

# Optimizer
optimizer = torch.optim.Adam(triplet_net.embedding_net.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

2024-11-24 14:14:25,298 - DatabaseWrapper - INFO - Connecting to the database ...
2024-11-24 14:14:25,299 - DatabaseWrapper - INFO - Database ready.


Preparing triplets...
Loading triplets...
Loaded 14923 triplets.
Train size: 12684, Test size: 1492, Eval size: 747


In [24]:
num_epochs = config['num_epochs']
train_losses = []

test_losses = []
test_accuracies = []
test_correct_pos = []
test_correct_neg = []

eval_losses = []
eval_accuracies = []
eval_correct_pos = []
eval_correct_neg = []

current_batch = 1

for epoch in range(1, num_epochs + 1):
    print(f"=== Epoch {epoch}/{num_epochs} ======================")
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    current_batch = 1
    for data_dict in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if current_batch == 1 or current_batch == len(train_dataloader) // 2:
            print(f"___ Current Batch {current_batch}/{len(train_dataloader)} _________________________")
            # Model testing
            print("    Test Results:")
            test_loss, test_num_correct, test_correct_pos_val, test_correct_neg_val = test(
                triplet_classifier_model=triplet_net,
                loss_fn=loss_fn,
                dataloader=test_dataloader,
                margin=config['margin']
            )
            test_losses.append(test_loss)
            test_accuracies.append(test_num_correct)
            test_correct_pos.append(test_correct_pos_val)
            test_correct_neg.append(test_correct_neg_val)
    
            plot_loss(test_losses, epoch_len=2, plot_title='Test Loss', plot_avg=False, plot_file=result_folder_path + '/test_loss.png')
            plot_loss(
                test_accuracies,
                epoch_len=2,
                plot_title='Test Accuracy',
                plot_avg=False, 
                x_label='Test Iterations',
                y_label='Accuracy',
                line_label='Accuracy',
                plot_file=result_folder_path + '/test_accuracy.png'
            )
            
            # Model evaluation
            print("    Eval Results:")
            eval_loss, eval_num_correct, eval_correct_pos_val, eval_correct_neg_val = evaluate(
                triplet_classifier_model=triplet_net,
                loss_fn=loss_fn,
                dataloader=eval_dataloader,
                margin=config['margin']
            )
            eval_losses.append(eval_loss)
            eval_accuracies.append(eval_num_correct)
            eval_correct_pos.append(eval_correct_pos_val)
            eval_correct_neg.append(eval_correct_neg_val)
            
            plot_loss(eval_losses, epoch_len=2, plot_title='Evaluation Loss', plot_avg=False, plot_file=result_folder_path + '/eval_loss.png')
            plot_loss(
                eval_accuracies, 
                epoch_len=2, 
                plot_title='Evaluation Accuracy', 
                plot_avg=False, 
                x_label='Eval Iterations',
                y_label='Accuracy',
                line_label='Accuracy',
                plot_file=result_folder_path + '/eval_accuracy.png'
            )
        
        loss = train(
            triplet_classifier_model=triplet_net,
            loss_fn=loss_fn,
            data_dict=data_dict,
            optimizer=optimizer
        )
        train_losses.append(loss)
        
        plot_loss(train_losses, epoch_len=len(train_dataloader), plot_title='Training Loss', plot_avg=True, plot_file=result_folder_path + '/train_loss.png')
        current_batch += 1
        
    # Save config and training results
    eval_results = {
        'eval_losses': eval_losses,
        'eval_accuracies': eval_accuracies,
        'eval_correct_pos': eval_correct_pos,
        'eval_correct_neg': eval_correct_neg
    }
    save_training_results(train_losses, test_losses, eval_results, config, result_folder_path + '/training_data.json')
    
    # Save model if loss has decreased
    if len(test_losses) > 1 and test_losses[-1] < min(test_losses[:-1]):
        print(f"Saving model at epoch {epoch}...")
        torch.save(triplet_net.embedding_net.state_dict(), result_folder_path + '/embedding_net.pt')




Epoch 1/10:   0%|          | 0/397 [00:00<?, ?it/s]

___ Current Batch 1/397 _________________________
    Test Results:
        Correct positive: 1492 (100.00%), Correct negative: 75 (5.03%)
        Total correct: 1567 (52.51%)
        Test Loss: 0.7169, Test Accuracy: 0.5251
    Eval Results:
        Correct positive: 747 (100.00%), Correct negative: 0 (0.00%)
        Total correct: 747 (50.00%)
        Eval Loss: 0.9749, Eval Accuracy: 0.5000
___ Current Batch 198/397 _________________________
    Test Results:
        Correct positive: 913 (61.19%), Correct negative: 1405 (94.17%)
        Total correct: 2318 (77.68%)
        Test Loss: 0.1895, Test Accuracy: 0.7768
    Eval Results:
        Correct positive: 330 (44.18%), Correct negative: 586 (78.45%)
        Total correct: 916 (61.31%)
        Eval Loss: 0.9042, Eval Accuracy: 0.6131
Saving model at epoch 1...


Epoch 2/10:   0%|          | 0/397 [00:00<?, ?it/s]

___ Current Batch 1/397 _________________________
    Test Results:
        Correct positive: 819 (54.89%), Correct negative: 1408 (94.37%)
        Total correct: 2227 (74.63%)
        Test Loss: 0.1814, Test Accuracy: 0.7463
    Eval Results:
        Correct positive: 320 (42.84%), Correct negative: 606 (81.12%)
        Total correct: 926 (61.98%)
        Eval Loss: 0.9948, Eval Accuracy: 0.6198
___ Current Batch 198/397 _________________________
    Test Results:
        Correct positive: 980 (65.68%), Correct negative: 1395 (93.50%)
        Total correct: 2375 (79.59%)
        Test Loss: 0.1659, Test Accuracy: 0.7959
    Eval Results:
        Correct positive: 391 (52.34%), Correct negative: 304 (40.70%)
        Total correct: 695 (46.52%)
        Eval Loss: 1.2191, Eval Accuracy: 0.4652


KeyboardInterrupt: 