In [18]:
from util import *
from gat_models import *

import random
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.optim import Adam
from sklearn.metrics.pairwise import cosine_distances

from src.shared.database_wrapper import DatabaseWrapper
from src.shared.graph_schema import *
from src.shared.graph_sampling import GraphSampling

random.seed(40)
np.random.seed(40)
torch.manual_seed(40)
torch.cuda.manual_seed_all(40)

### Configurations

In [19]:
# Graph sampling configurations
node_spec = [
    NodeType.PUBLICATION,
]

edge_spec = [
    #EdgeType.SIM_VENUE,
    EdgeType.SIM_AUTHOR,
]

node_properties = [
    'id',
    'title',
    'abstract',
    'venue',
    'title_emb',
    'abstract_emb',
    'venue_emb',
    'true_author',
]

database = 'homogeneous-graph-compressed-emb'
gs = GraphSampling(
    node_spec=node_spec,
    edge_spec=edge_spec,
    node_properties=node_properties,
    database=database,
)

# Model configurations
config = {
    'experiment': 'GATv2 encoder (with linear layer + dropout) trained on graph (publication nodes with title and abstract, similarity and co-author edges) using Pairwise Contrastive Loss and dimension reduced embeddings',
    'max_hops': 3,
    'model_node_feature': 'feature_vec',  # Node feature to use for GAT encoder
    'hidden_channels': 64,
    'out_channels': 16,
    'num_heads': 8,
    'margin': 1.0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'weight_decay': 5e-4,
    'num_epochs': 20,
    'batch_size': 32,
}

model_class = HeteroGATEncoderLinearDropout
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


save_file_postfix = "hetero_edges_compressed_emb_linear_layer_dropout"

### Training functions

In [20]:
def train(model, batch1, batch2, labels, optimizer):
    model.train()
    
    optimizer.zero_grad()
    
    batch1 = batch1.to(device)
    batch2 = batch2.to(device)
    labels = labels.to(device)

    embeddings1 = model(batch1)
    embeddings2 = model(batch2)

    embeddings1_central = embeddings1[NodeType.PUBLICATION.value][batch1.central_node_id]
    embeddings2_central = embeddings2[NodeType.PUBLICATION.value][batch2.central_node_id]
    
    loss, _ = contrastive_loss(embeddings1_central, embeddings2_central, labels, config['margin'])
    
    loss.backward()
    optimizer.step()
        
    batch_loss = loss.item()
    #print(f"Batch loss: {batch_loss:.4f}")
    return batch_loss

def test(model, dataloader):
    model.eval()
    total_loss = 0
    distances = []
    labels_list = []
    with torch.no_grad():
        for batch1, batch2, labels in dataloader:
            batch1 = batch1.to(device)
            batch2 = batch2.to(device)
            labels = labels.to(device)
    
            embeddings1 = model(batch1)
            embeddings2 = model(batch2)
            
            embeddings1_central = embeddings1[NodeType.PUBLICATION.value][batch1.central_node_id]
            embeddings2_central = embeddings2[NodeType.PUBLICATION.value][batch2.central_node_id]
            
            loss, dist = contrastive_loss(embeddings1_central, embeddings2_central, labels, config['margin'])
            total_loss += loss.item()
            distances.extend(dist.cpu().numpy())
            labels_list.extend(labels.cpu().numpy())
        
    # Compute accuracy
    distances = np.array(distances)
    labels_list = np.array(labels_list).astype(int)
    predictions = (distances <= config['margin']).astype(int)
    accuracy = accuracy_score(labels_list, predictions)
    
    # Compute average loss    
    avg_loss = total_loss / len(dataloader)
    print(f"Test Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.4f}")
    return avg_loss, accuracy
    

### Training Configuration

In [21]:
db = DatabaseWrapper(database=database)
data_harvester = PairDataHarvester(db=db, gs=gs, edge_spec=edge_spec, config=config, save_file_postfix=save_file_postfix)


# Split the pairs into train and test
random.shuffle(data_harvester.pairs)
train_size = int(0.95 * len(data_harvester.pairs))
train_pairs = data_harvester.pairs[:train_size]
test_pairs = data_harvester.pairs[train_size:]
config['train_size'] = train_size
config['test_size'] = len(data_harvester.pairs) - train_size

# Create the datasets from the pairs (distinct pairs for training and testing)
train_dataset = GraphPairDataset(train_pairs, gs, config=config)
test_dataset = GraphPairDataset(test_pairs, gs, config=config)

# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=custom_pair_collate)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=custom_pair_collate)

# Create model
metadata = (
    [n.value for n in node_spec],
    [edge_pyg_key_vals[r] for r in edge_spec]
)
config['node_spec'] = metadata[0]
config['edge_spec'] = metadata[1]
model = model_class(metadata, config['hidden_channels'], config['out_channels'], num_heads=config['num_heads']).to(device)
optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

2024-10-14 17:24:15,625 - DatabaseWrapper - INFO - Connecting to the database ...
2024-10-14 17:24:15,625 - DatabaseWrapper - INFO - Database ready.


Preparing pairs...
Loading pairs...
Loaded 6006 pairs.


### Training Loop

In [None]:
num_epochs = config['num_epochs']
margin = 1.0  # Margin for contrastive loss
train_losses = []
test_losses = []
test_accuracies = []

for epoch in range(1, num_epochs + 1):
    epoch_marker_pos = list(range(0, len(train_dataloader) * epoch, len(train_dataloader)))
    
    for data1, data2, labels in tqdm(train_dataloader, desc=f"Epoch {epoch}/{num_epochs}"):
        if data1 is None or data2 is None:
            continue
        
        if len(train_losses) % 10 == 0:
            test_loss, test_accuracy = test(model, test_dataloader)
            test_losses.append(test_loss)
            test_accuracies.append(test_accuracy)
            test_epoch_marker_pos = [marker/10 for marker in epoch_marker_pos if marker != 0]
            plot_loss(test_losses, epoch_marker_pos=test_epoch_marker_pos, plot_title='Test Loss', plot_avg=True, plot_file=f'./data/losses/test_loss_{save_file_postfix}.png')
            plot_loss(test_accuracies, epoch_marker_pos=test_epoch_marker_pos, plot_title='Test Accuracy', plot_avg=True, plot_file=f'./data/losses/test_accuracy_{save_file_postfix}.png')
            
        loss = train(model, data1, data2, labels, optimizer)
        train_losses.append(loss)
        
        plot_loss(train_losses, epoch_marker_pos=epoch_marker_pos, plot_title='Training Loss', plot_file=f'./data/losses/train_loss_{save_file_postfix}.png')
    
    # Save config and training results
    save_training_results(train_losses, test_losses, None, config, f'./data/results/training_results_{save_file_postfix}.json')
    
    # Save model if loss has decreased
    if len(test_losses) > 1 and test_losses[-1] < min(test_losses[:-1]):
        print(f"Saving model at epoch {epoch}...")
        torch.save(model.state_dict(), f'./data/models/gat_encoder_{save_file_postfix}.pt')

Epoch 1/20:   0%|          | 0/179 [00:00<?, ?it/s]

Test Loss: 0.3913, Test Accuracy: 0.4917
Test Loss: 0.2324, Test Accuracy: 0.6412
Test Loss: 0.2100, Test Accuracy: 0.5648
Test Loss: 0.2037, Test Accuracy: 0.6213
Test Loss: 0.1995, Test Accuracy: 0.6146
Test Loss: 0.1964, Test Accuracy: 0.6080
Test Loss: 0.1869, Test Accuracy: 0.6080
Test Loss: 0.1797, Test Accuracy: 0.6047
