In [1]:
import os
import torch
from torch import nn
import pandas as pd
import matplotlib.pyplot as plt
import logging
from tqdm import tqdm
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
from functools import partial
import optuna
import gc
from typing import Literal
import torch.nn.functional as F

# Load utility functions from cloned repository
from src.loadData import GraphDataset
from src.utils import set_seed
from src.models import GNN


# Set the random seed
set_seed()


In [2]:
def add_zeros(data):
    data.x = torch.zeros(data.num_nodes, dtype=torch.long)
    return data

In [3]:
def train(
    data_loader,
    model,
    optimizer,
    criterion,
    device,
    save_checkpoints,
    checkpoint_path,
    current_epoch,
):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for data in data_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += (pred == data.y).sum().item()
        total += data.y.size(0)

    # Save checkpoints if required
    if save_checkpoints:
        checkpoint_file = f"{checkpoint_path}_epoch_{current_epoch + 1}.pth"
        torch.save(model.state_dict(), checkpoint_file)
        print(f"Checkpoint saved at {checkpoint_file}")

    return total_loss / len(data_loader), correct / total

In [4]:
def evaluate(data_loader, model, device, calculate_accuracy=False):
    model.eval()
    correct = 0
    total = 0
    predictions = []
    total_loss = 0
    criterion = torch.nn.CrossEntropyLoss()
    with torch.no_grad():
        for data in data_loader:
            data = data.to(device)
            output = model(data)
            pred = output.argmax(dim=1)

            if calculate_accuracy:
                correct += (pred == data.y).sum().item()
                total += data.y.size(0)
                total_loss += criterion(output, data.y).item()
            else:
                predictions.extend(pred.cpu().numpy())
    if calculate_accuracy:
        accuracy = correct / total
        return total_loss / len(data_loader), accuracy
    return predictions

In [5]:
def save_predictions(predictions, test_path):
    script_dir = os.getcwd()
    submission_folder = os.path.join(script_dir, "submission")
    test_dir_name = os.path.basename(os.path.dirname(test_path))

    os.makedirs(submission_folder, exist_ok=True)

    output_csv_path = os.path.join(submission_folder, f"testset_{test_dir_name}.csv")

    test_graph_ids = list(range(len(predictions)))
    output_df = pd.DataFrame({"id": test_graph_ids, "pred": predictions})

    output_df.to_csv(output_csv_path, index=False)
    print(f"Predictions saved to {output_csv_path}")

In [6]:
def plot_training_progress(train_losses, train_accuracies, output_dir):
    epochs = range(1, len(train_losses) + 1)
    plt.figure(figsize=(12, 6))

    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label="Training Loss", color="blue")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training Loss per Epoch")

    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label="Training Accuracy", color="green")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Training Accuracy per Epoch")

    # Save plots in the current directory
    os.makedirs(output_dir, exist_ok=True)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "training_progress.png"))
    plt.close()

In [7]:
import gzip
import json
import os
from typing import List

def merge_datasets(version_number: int, dataset_parts: List[str]):
    output_filename = f"./datasets/merged/merged_dataset_v{version_number}.json.gz"
    os.makedirs(os.path.dirname(output_filename), exist_ok=True)

    all_graphs = []
    total = 0

    for part_path in dataset_parts:
        print(f"Reading: {part_path}")
        with gzip.open(part_path, "rt", encoding="utf-8") as f:
            graphs = json.load(f)
            all_graphs.extend(graphs)
            total += len(graphs)

    print(f"\nTotal number of graphs aggregated: {total}")
    print(f"Writing merged dataset to: {output_filename}")

    with gzip.open(output_filename, "wt", encoding="utf-8") as f:
        json.dump(all_graphs, f)

    print("🎉 Done.")


In [None]:
merge_datasets(
    version_number=2,
    dataset_parts=[
        "./datasets/A/train.json.gz",
        "./datasets/B/train.json.gz",
        "./datasets/C/train.json.gz",
        "./datasets/D/train.json.gz"
    ],
)


Reading: ./datasets/A/train.json.gz
Reading: ./datasets/B/train.json.gz
Reading: ./datasets/C/train.json.gz
Reading: ./datasets/D/train.json.gz
