In [None]:
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange
import torch.nn.functional as F
from torch.nn import Sequential, Linear, BatchNorm1d, ReLU
import numpy as np
from torch_geometric.nn import GCNConv, global_mean_pool, GINConv, global_add_pool
from torch_geometric.nn import MLP as pyg_MLP
import argparse
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from torch.optim import Adam
from torch_geometric.loader import DataLoader
from torch.utils.data import Dataset
import os
from torch_geometric.nn import GCNConv, GINConv, GATConv, global_mean_pool, global_add_pool
from torch.nn import Linear
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from torch.utils.data import random_split
from tqdm import tqdm

# Check if running in a non-interactive environment (e.g., using nohup)
is_non_interactive = not os.isatty(1)

In [None]:
class GCN(nn.Module):
    def __init__(self, in_feats, hidden_size, out_feats):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_feats, hidden_size)
        self.conv2 = GCNConv(hidden_size, hidden_size)
        self.lin1 = Linear(hidden_size, in_feats)
        self.lin2 = Linear(in_feats, out_feats)

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv2(x.float(), edge_index, edge_weight)
        x = F.relu(x)
        x = global_mean_pool(x, data.batch) 
        x = F.dropout(x, p=0.3, training=self.training)
        x_fea = self.lin1(x)
        x = F.relu(x_fea)
        x = self.lin2(x)
        return x, x_fea

In [None]:
def load_data(path='../data/ppmi_curv_pyg.pth', batchsize=16):
    dataset = torch.load(path)
    train_set, val_set, test_set = random_split(dataset, [0.8, 0.1, 0.1])
    print(f"Train set: {len(train_set)}, Val set: {len(val_set)}, Test set: {len(test_set)}")
    train_loader = DataLoader(train_set, batch_size=batchsize, shuffle=True, num_workers=8)
    val_loader = DataLoader(val_set, batch_size=batchsize, shuffle=True, num_workers=8)
    test_loader = DataLoader(test_set, batch_size=batchsize, shuffle=True, num_workers=8)
    return train_loader, val_loader, test_loader

In [None]:
def train(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for data in tqdm(loader, desc="Training Batches", leave=False, dynamic_ncols=True, disable=is_non_interactive):
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def test(model, loader, device):
    model.eval()
    correct = 0
    preds = []
    gts = []
    with torch.no_grad():
        for data in tqdm(loader, desc="Testing Batches", leave=False, dynamic_ncols=True, disable=is_non_interactive):
            data = data.to(device)
            out = model(data)
            pred = out.argmax(dim=-1)
            correct += int((pred == data.y).sum())
            preds.append(pred.cpu().numpy())
            gts.append(data.y.cpu().numpy())
    preds = np.concatenate(preds, axis=0)
    gts = np.concatenate(gts, axis=0)
    accuracy = accuracy_score(gts, preds)
    precision = precision_score(gts, preds, average='weighted', zero_division=0)
    recall = recall_score(gts, preds, average='weighted', zero_division=0)
    f1 = f1_score(gts, preds, average='weighted', zero_division=0)
    return accuracy, precision, recall, f1