In [1]:
import sklearn
import os
import torch
import pandas as pd
import matplotlib.pyplot as plt
import logging
from tqdm import tqdm
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
from functools import partial
import optuna
import gc
from typing import Literal
import torch.nn.functional as F

# Load utility functions from cloned repository
from src.loadData import GraphDataset
from src.utils import set_seed
from src.models import GNN


# Set the random seed
set_seed()



In [2]:
def add_zeros(data):
    data.x = torch.zeros(data.num_nodes, dtype=torch.long)
    return data

In [None]:
def load_dataloader(dataset_name:Literal["A","B","C","D"],default_batch_size =32):
    train_path = f"./datasets/{dataset_name}/train.json.gz"
    full_dataset = GraphDataset(train_path, transform=add_zeros)
    val_size = int(0.2 * len(full_dataset))
    train_size = len(full_dataset) - val_size
    generator = torch.Generator().manual_seed(12)
    train_dataset, val_dataset = random_split(
        full_dataset, [train_size, val_size], generator=generator
    )
    train_loader = DataLoader(
        train_dataset,  # type:ignore
        batch_size=default_batch_size,
        shuffle=True,
    )
    val_loader = DataLoader(
        val_dataset,  # type:ignore
        batch_size=default_batch_size,
        shuffle=False,
    )
    return train_loader, val_loader

In [None]:
train,_ = load_dataloader("A")

In [9]:
train.dataset[0]

Data(edge_index=[2, 3746], edge_attr=[3746, 7], y=[1], num_nodes=300, x=[300])

In [15]:
labels = dict()
for elem in train.dataset:
    count = labels.get(elem.y.item(),0)
    labels[elem.y.item()] = count + 1
labels

{1: 1531, 2: 2684, 4: 1558, 3: 1585, 5: 595, 0: 1071}

In [21]:
list(labels.items())

[(1, 1531), (2, 2684), (4, 1558), (3, 1585), (5, 595), (0, 1071)]

In [27]:
torch.tensor(sorted(list(labels.items())), dtype = torch.uint32).numpy()[[0,2,5]]

array([[   0, 1071],
       [   2, 2684],
       [   5,  595]], dtype=uint32)