In [None]:
import os
import pickle

import pandas as pd
import networkx as nx

from torch.utils.data import Dataset, Subset
from torch_geometric.data import Data, DataLoader
from torch import nn

from sklearn.metrics import r2_score
from sklearn.model_selection import GroupShuffleSplit

import numpy as np


In [None]:
import ray, optuna
from ray import tune, air
from ray.tune.search.optuna import OptunaSearch
from ray.tune.schedulers import ASHAScheduler

In [None]:
from MetaNet import MetaNet, BioDegDataset       # your module
from train import build_loaders   

In [None]:
data = BioDegDataset('df_cleaned.csv','.')
#data = BioDegDatasetCached('biodataset_processed.pt')

In [None]:
import torch
torch.cuda.is_available()

In [None]:
torch.cuda.device_count()

In [None]:
# JUPYTER CELL 3
def train_one(config):
    #gpu_id = int(ray.get_gpu_ids()[0])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1── loaders (with scaling applied inside)
    train_ld, val_ld, test_ld= build_loaders(data,
                                          val_mols=2,
                                          test_mols=2,
                                          seed=42,
                                          bs_train=32,
                                          bs_eval=64)

    # 2── model
    model = MetaNet(hidden=config["hidden"],
                    n_layers=config["layers"]).to(device)
    opt   = torch.optim.AdamW(model.parameters(),
                              lr=config["lr"],
                              weight_decay=5e-4)
    lossf = torch.nn.SmoothL1Loss()

    best_val = -1e9
    for epoch in range(500):
        # ----- train ------------------------------------------------------
        model.train()
        for batch in train_ld:
            batch = batch.to(device)
            opt.zero_grad()
            loss = lossf(model(batch), batch.y.squeeze())
            loss.backward(); opt.step()
       # ----- validation R² --------------------------------------------
        model.eval(); y_val, p_val = [], []
        with torch.no_grad():
            for batch in val_ld:
                batch = batch.to(device)
                y_val.append(batch.y.cpu())
                p_val.append(model(batch).cpu())
        yv, pv = torch.cat(y_val), torch.cat(p_val)
        val_r2 = 1 - ((yv - pv) ** 2).sum() / ((yv - yv.mean()) ** 2).sum()
        best_val = max(best_val, val_r2.item())

        # ----- test R²  (logged, not optimised) --------------------------
        y_test, p_test = [], []
        with torch.no_grad():
            for batch in test_ld:
                batch = batch.to(device)
                y_test.append(batch.y.cpu())
                p_test.append(model(batch).cpu())
        yt, pt = torch.cat(y_test), torch.cat(p_test)
        test_r2 = 1 - ((yt - pt) ** 2).sum() / ((yt - yt.mean()) ** 2).sum()
            
            

        # report to Tune (for ASHA pruning & Optuna sampler)
        tune.report({"val_r2": val_r2.item(),"test_r2": test_r2.item(), "epoch": epoch})

        best = max(best_val, val_r2.item())
    return best


In [None]:
ray.shutdown()               # in case something is running
ray.init(num_gpus=4, ignore_reinit_error=True)


In [None]:
search_alg = OptunaSearch(
    sampler=optuna.samplers.TPESampler(seed=42),
    metric="val_r2", mode="max")

scheduler = ASHAScheduler(metric="val_r2", mode="max",
                          max_t=500, grace_period=30, reduction_factor=2)

In [None]:
trainable = tune.with_resources(train_one, {"gpu": 1})

In [None]:
tuner = tune.Tuner(
    trainable,
    tune_config=tune.TuneConfig(
        num_samples=500,
        search_alg=search_alg,
        scheduler=scheduler,
        max_concurrent_trials=4,     # never exceed GPU count
        reuse_actors=True,
    ),
    run_config=air.RunConfig(
        name="metanet_demo_622",
        storage_path=os.path.abspath("ray_out"),
    ),
    param_space={
        "hidden": tune.choice([32, 64, 128,256]),
        "layers": tune.choice([3, 4, 5,6,8,10]),
        "lr":     tune.loguniform(1e-4, 3e-3),
        #"trial_seed": tune.randint(0, 2**31 - 1),
    },
)

In [None]:
analysis = tuner.fit()
print("Best R²:", analysis.get_best_result("val_r2", "max").metrics["val_r2"])