In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINConv, GATv2Conv, BatchNorm, global_add_pool, global_max_pool, global_mean_pool
from torch_geometric.utils import degree
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
import pandas as pd
import numpy as np
import h5py

train_data = torch.load("./data/train.pt")
test_data = torch.load("./data/test.pt")

def enrich_graph(data):
    deg = degree(data.edge_index[0], num_nodes=data.num_nodes).view(-1, 1)
    x = torch.cat([data.x, deg], dim=1)
    num_nodes = x.size(0)
    virtual_node = x.mean(dim=0, keepdim=True)
    x = torch.cat([x, virtual_node], dim=0)
    virtual_edges = torch.tensor([[num_nodes]*num_nodes, list(range(num_nodes))], dtype=torch.long)
    edge_index = torch.cat([data.edge_index, virtual_edges, virtual_edges.flip(0)], dim=1)
    data.x = x
    data.edge_index = edge_index
    return data

train_data = [enrich_graph(data) for data in train_data]
test_data = [enrich_graph(data) for data in test_data]

scaler = StandardScaler()
all_data = train_data + test_data
all_features = torch.cat([data.x for data in all_data], dim=0)
scaler.fit(all_features)
for data in all_data:
    data.x = torch.tensor(scaler.transform(data.x), dtype=torch.float)

class GIN(torch.nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        hidden_dim = 512
        dropout = 0.45388
        nn1 = torch.nn.Sequential(torch.nn.Linear(input_dim, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, hidden_dim))
        nn2 = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, hidden_dim))
        nn3 = torch.nn.Sequential(torch.nn.Linear(hidden_dim, 128), torch.nn.ReLU(), torch.nn.Linear(128, 128))
        self.conv1 = GINConv(nn1)
        self.bn1 = BatchNorm(hidden_dim)
        self.conv2 = GINConv(nn2)
        self.bn2 = BatchNorm(hidden_dim)
        self.conv3 = GINConv(nn3)
        self.bn3 = BatchNorm(128)
        self.lin1 = torch.nn.Linear(128 * 3, 128)
        self.lin2 = torch.nn.Linear(128, 1)
        self.dropout = torch.nn.Dropout(dropout)
        self.act = torch.nn.GELU()

    def forward(self, x, edge_index, batch):
        x = self.act(self.bn1(self.conv1(x, edge_index)))
        x = self.dropout(x)
        res = x
        x = self.act(self.bn2(self.conv2(x, edge_index))) + res
        x = self.dropout(x)
        x = self.act(self.bn3(self.conv3(x, edge_index)))
        x_add = global_add_pool(x, batch)
        x_max = global_max_pool(x, batch)
        x_mean = global_mean_pool(x, batch)
        x = torch.cat([x_add, x_max, x_mean], dim=1)
        x = self.act(self.lin1(x))
        x = self.dropout(x)
        return self.lin2(x).view(-1)

device = torch.device('cuda')
input_dim = train_data[0].x.shape[1]
kf = KFold(n_splits=5, shuffle=True, random_state=42)
all_preds = []
epoch_logs = []

for fold, (train_idx, val_idx) in enumerate(kf.split(train_data)):
    print(f"Fold {fold+1}")
    train_split = [train_data[i] for i in train_idx]
    val_split = [train_data[i] for i in val_idx]
    train_loader = DataLoader(train_split, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_split, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

    model = GIN(input_dim).to('cuda')
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001133, weight_decay=1.66e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5)
    best_val_loss = float('inf')
    patience_counter = 0

    train_loss_history, val_loss_history = [], []

    for epoch in range(1, 101):
        model.train()
        total_loss = 0
        for batch in train_loader:
            batch = batch.to('cuda')
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index, batch.batch)
            loss = F.l1_loss(out, batch.y.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * batch.num_graphs
        train_loss = total_loss / len(train_loader.dataset)
        train_loss_history.append(train_loss)

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                out = model(batch.x, batch.edge_index, batch.batch)
                loss = F.l1_loss(out, batch.y.view(-1))
                val_loss += loss.item() * batch.num_graphs
        val_loss /= len(val_loader.dataset)
        val_loss_history.append(val_loss)
        scheduler.step(val_loss)

        print(f"Epoch {epoch:03d}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), f"best_model_fold{fold}.pt")
        else:
            patience_counter += 1

        if patience_counter >= 10:
            print("Early stopping triggered.")
            break

    epoch_logs.append((train_loss_history, val_loss_history))

    with h5py.File(f"model_fold{fold}.h5", 'w') as hf:
        for name, param in model.named_parameters():
            hf.create_dataset(name, data=param.cpu().detach().numpy())

    model.load_state_dict(torch.load(f"best_model_fold{fold}.pt"))
    model.eval()
    fold_preds = []
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            out = model(batch.x, batch.edge_index, batch.batch)
            fold_preds.append(out.cpu())
    all_preds.append(torch.cat(fold_preds))

for i, (train_hist, val_hist) in enumerate(epoch_logs):
    plt.plot(train_hist, label=f"Fold {i+1} Train")
    plt.plot(val_hist, label=f"Fold {i+1} Val")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Validation Loss per Fold")
plt.legend()
plt.savefig("training_validation_loss.png")
plt.close()

final_preds = torch.stack(all_preds).mean(dim=0).numpy()
sample_submission = pd.read_csv("./data/sample_submission.csv")
sample_submission["labels"] = final_preds
sample_submission.to_csv("submission.csv", index=False)
print("Submission saved to submission.csv")

