In [2]:
import torch
from torch.nn import MSELoss
from torch.optim import Adam
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import SchNet, GCNConv, global_mean_pool
from torch import nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from torch.utils.data import random_split
from torch_geometric.datasets import QM9

In [3]:
from ray import tune
from ray.tune.schedulers import ASHAScheduler

## PREP

In [4]:
#root='data/QM9'
DATA_DIR = 'data/QM9'

In [5]:
class BasePreprocessor:
    """
    Base class for data preprocessing, handling common tasks like loading, 
    subsetting, splitting, and configuration.
    """
    def __init__(self, dataset_cls=QM9, root=DATA_DIR,  #data/QM9
                 transform=None, target=0, val_ratio=0.2,
                 seed=42, subset=None):

        self.dataset_cls = dataset_cls
        self.root = root
        self.transform = transform
        self.target = target
        self.val_ratio = val_ratio
        self.seed = seed
        self.subset = subset
        # Initialize internal storage for the dataset
        self._dataset = None

    # -------------------------
    # Loading dataset (Lazy-loaded)
    # -------------------------
    def _load_dataset(self):
        """Loads the dataset once and caches it."""
        if self._dataset is None:
            # Load the dataset (expensive I/O operation)
            dataset = self.dataset_cls(root=self.root, transform=self.transform)

            if self.subset:
                # Apply subset slicing
                dataset = dataset[:self.subset]
            
            self._dataset = dataset
            
        return self._dataset

    # -------------------------
    # Split train/val
    # -------------------------
    def split(self, processed):
        """Splits the processed dataset into train and validation sets."""
        n_val = int(len(processed) * self.val_ratio)
        # Ensure n_train is the remainder to cover the whole dataset
        n_train = len(processed) - n_val 

        gen = torch.Generator().manual_seed(self.seed)
        return random_split(processed, [n_train, n_val], generator=gen)

    # -------------------------
    # Abstract/Helper
    # -------------------------
    def _format_dataset(self, dataset, is_inference):
        """
        Abstract method for model-specific data formatting and target slicing.
        Must be implemented by subclasses.
          For MLP: extract z, and a column of y
          For GCN: extract z, pos, edge_index, y
          For SchNet: extract all fields
        """
        raise NotImplementedError

    # -------------------------
    # Concrete Workflow Methods (Handle Redundancy)
    # -------------------------
    
    def preprocess(self):
        """Workflow for train/validation: Load -> Format -> Split."""
        dataset = self._load_dataset()
        processed = self._format_dataset(dataset, is_inference=False)
        return self.split(processed)

    def preprocess_test(self):
        """Workflow for test: Load -> Format -> Return."""
        dataset = self._load_dataset()
        return self._format_dataset(dataset, is_inference=False)

    def preprocess_inference(self):
        """Workflow for inference: Load -> Format (no target) -> Return."""
        dataset = self._load_dataset()
        return self._format_dataset(dataset, is_inference=True)
    
    
    
class PreprocessorRegistry:
    _registry = {}

    @classmethod
    def register(cls, name):
        def decorator(prep_cls):
            cls._registry[name] = prep_cls
            return prep_cls
        return decorator

    @classmethod
    def create(cls, name, **kwargs):
        if name not in cls._registry:
            raise ValueError(f"Unknown preprocessor: {name}")
        return cls._registry[name](**kwargs)
    
    
@PreprocessorRegistry.register("mlp")
class MLPPreprocessor(BasePreprocessor):
    def __init__(self, **kwargs):
        # MLP doesn't need graph transforms
        super().__init__(transform=None, **kwargs)


    def _format_dataset(self, dataset, is_inference):
        """Extracts atomic numbers (z) and (optionally) the target (y)."""

        target_col = self.target

        out = []

        for d in dataset:
            d_new = d.clone()

            # remove unwanted fields
            for field in ["x", "edge_attr", "edge_index", "pos","name", "smiles", "idx"]:
                if hasattr(d_new, field):
                    delattr(d_new, field)

            # Explicitly set num_nodes
            d_new.num_nodes = d.z.size(0)

            # Handle inference/no-inference target
            if not is_inference:
                d_new.y = d.y[:, target_col].unsqueeze(1)
            else:
                if hasattr(d_new, "y"):
                    delattr(d_new, "y")

            out.append(d_new)

        return out
    
    
def loaders(batch_size, subset=10):

	model_type = 'mlp'


	prep = PreprocessorRegistry.create(
	    model_type,
	    target=0,
	    root=DATA_DIR,
	    subset=subset,
	)

	train_ds, val_ds = prep.preprocess()
	train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
	val_loader = DataLoader(val_ds, batch_size=batch_size)
	return train_loader, val_loader

## first model

In [None]:
class SimpleMLP(nn.Module):
    def __init__(self, num_atom_types=100, hidden=64):
        super().__init__()
        self.emb = nn.Embedding(num_atom_types, hidden)
        self.fc = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1)
        )
    def forward(self, batch):
        x = self.emb(batch.z)       # [num_nodes, hidden]
        x = global_mean_pool(x, batch.batch)  # [num_graphs, hidden]
        out = self.fc(x)            # [num_graphs, 1]
        return out



def run_epoch(loader, model, criterion, optimizer=None):
    model.train() if optimizer else model.eval()
    total_loss = 0

    for batch in loader:
        batch = batch.to(device)
        out = model(batch)
        pred = out.view(-1)
        target = batch.y.view(-1)
        loss = criterion(pred, target)

        if optimizer:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        total_loss += loss.item() * batch.num_graphs

    return total_loss / len(loader.dataset)



model_mlp = SimpleMLP().to(device)
optimizer = Adam(model_mlp.parameters(), lr=1e-3)
criterion = MSELoss()

# loaders
train_loader, val_loader = loaders(32)


# Train for a few epochs
for epoch in range(3):
    train_loss = run_epoch(train_loader, model_mlp, criterion, optimizer)
    val_loss = run_epoch(val_loader, model_mlp, criterion)
    print(f"[MLP] Epoch {epoch} | Train {train_loss:.4f} | Val {val_loss:.4f}")

## <font color='purple'> Tuning </font>

In [None]:


def train_ray(config):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # small subset for speed
    subset = config["subset"]

    train_loader, val_loader = loaders(
        batch_size=config["batch_size"],
        subset=subset
    )

    model = SimpleMLP(
        num_atom_types=100,
        hidden=config["hidden"]
    ).to(device)

    optimizer = Adam(model.parameters(), lr=config["lr"])
    criterion = MSELoss()

    for epoch in range(config["epochs"]):
        train_loss = run_epoch(train_loader, model, criterion, optimizer)
        val_loss = run_epoch(val_loader, model, criterion)

        # report to Ray Tune
        tune.report({"val_loss":val_loss})


        
        
search_space = {
    "hidden": tune.choice([32, 64, 128]),
    "lr": tune.loguniform(1e-4, 1e-2),
    "batch_size": tune.choice([16, 32]),
    "epochs": 3,
    "subset": 10,        # only 500 samples → very fast tuning
}


scheduler = ASHAScheduler(
    max_t=5,
    grace_period=1,
    reduction_factor=2,
)

tuner = tune.Tuner(
    train_ray,
    tune_config=tune.TuneConfig(
        metric="val_loss",
        mode="min",
        scheduler=scheduler,
        num_samples=2,   # small number of trials
    ),
    param_space=search_space,
)

results = tuner.fit()
best = results.get_best_result("val_loss", "min")
print("Best config:", best.config)


In [None]:
results

## GCN

In [15]:
from torch_geometric.transforms import RadiusGraph

In [20]:
def run_epoch(loader, model, criterion, optimizer=None):
    model.train() if optimizer else model.eval()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)
        out = model(batch)
        pred = out.squeeze(-1)
        target = batch.y.squeeze(-1)
        loss = criterion(pred, target)
        if optimizer:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        total_loss += loss.item() * batch.num_graphs
    return total_loss / len(loader.dataset)

In [21]:
@PreprocessorRegistry.register("gcn")
class GCNPreprocessor(BasePreprocessor):
    def __init__(self, **kwargs):
        # Get user-supplied transform or use the default RadiusGraph
        radius = kwargs.pop("radius", 1.5)   # extracted only for GCN
        transform = RadiusGraph(r=radius)
        super().__init__(transform=transform, **kwargs)



    def _format_dataset(self, dataset, is_inference):
        """Extracts atomic numbers (z) and (optionally) the target (y)."""

        target_col = self.target

        out = []

        for d in dataset:
            d_new = d.clone()

            # remove unwanted fields
            for field in ["x", "edge_attr", "name", "smiles", "idx"]:
                if hasattr(d_new, field):
                    delattr(d_new, field)

            # Handle inference/no-inference target
            if not is_inference:
                d_new.y = d.y[:, target_col].unsqueeze(1)
            else:
                if hasattr(d_new, "y"):
                    delattr(d_new, "y")

            out.append(d_new)

        return out


In [22]:
class SimpleGCN(nn.Module):
    def __init__(self, hidden=64, num_atom_types=100):
        super().__init__()
        self.emb = nn.Embedding(num_atom_types, hidden)
        self.conv1 = GCNConv(hidden, hidden)
        self.conv2 = GCNConv(hidden, hidden)
        self.fc = nn.Linear(hidden, 1)

    def forward(self, batch):
        x = self.emb(batch.z)           # [num_nodes, hidden]
        x = self.conv1(x, batch.edge_index) # using edge_index here
        x = torch.relu(x)
        x = self.conv2(x, batch.edge_index)
        x = torch.relu(x)
        x = global_mean_pool(x, batch.batch)  # [num_graphs, hidden]
        out = self.fc(x)                        # [num_graphs, 1]
        return out


In [23]:
def loadersGCN(batch_size, subset=10):

	model_type = 'gcn'


	prep = PreprocessorRegistry.create(
	    model_type,
	    target=0,
	    root=DATA_DIR,
	    subset=subset,
	)
    

	train_ds, val_ds = prep.preprocess() # -> this result goers to ray.put
    
    
	train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
	val_loader = DataLoader(val_ds, batch_size=batch_size)
	return train_loader, val_loader

In [24]:


def train_ray(config):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # small subset for speed
    subset = config["subset"]

    train_loader, val_loader = loadersGCN(
        batch_size=config["batch_size"],
        subset=subset
    )

    model = SimpleGCN(
        num_atom_types=100,
        hidden=config["hidden"]
    ).to(device)

    optimizer = Adam(model.parameters(), lr=config["lr"])
    criterion = MSELoss()

    for epoch in range(config["epochs"]):
        train_loss = run_epoch(train_loader, model, criterion, optimizer)
        val_loss = run_epoch(val_loader, model, criterion)

        # report to Ray Tune
        tune.report({"val_loss":val_loss})


        
        
search_space = {
    "hidden": tune.choice([32, 64, 128]),
    "lr": tune.loguniform(1e-4, 1e-2),
    "batch_size": tune.choice([16, 32]),
    "epochs": 3,
    "subset": 10,        # only 500 samples → very fast tuning
}


scheduler = ASHAScheduler(
    max_t=5,
    grace_period=1,
    reduction_factor=2,
)

tuner = tune.Tuner(
    train_ray,
    tune_config=tune.TuneConfig(
        metric="val_loss",
        mode="min",
        scheduler=scheduler,
        num_samples=2,   # small number of trials
    ),
    param_space=search_space,
)

results = tuner.fit()
best = results.get_best_result("val_loss", "min")
print("Best config:", best.config)


0,1
Current time:,2025-11-21 16:38:22
Running for:,00:02:43.00
Memory:,4.2/9.6 GiB

Trial name,status,loc,batch_size,hidden,lr,iter,total time (s),val_loss
train_ray_42844_00000,TERMINATED,200.40.20.243:10414,32,64,0.0014255,3,153.757,0.578099
train_ray_42844_00001,TERMINATED,200.40.20.243:10415,32,64,0.00252536,2,159.178,0.782058


[36m(train_ray pid=10415)[0m Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/molnet_publish/qm9.zip
[36m(train_ray pid=10414)[0m Extracting data/QM9/raw/qm9.zip
[36m(train_ray pid=10414)[0m Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/molnet_publish/qm9.zip
[36m(train_ray pid=10414)[0m Downloading https://ndownloader.figshare.com/files/3195404
[36m(train_ray pid=10414)[0m Processing...
  0%|          | 0/133885 [00:00<?, ?it/s]
  0%|          | 141/133885 [00:00<01:35, 1406.02it/s]
  0%|          | 287/133885 [00:00<01:33, 1433.44it/s]
  0%|          | 431/133885 [00:00<01:38, 1358.22it/s]
  0%|          | 568/133885 [00:00<01:38, 1349.77it/s]
[36m(train_ray pid=10415)[0m Extracting data/QM9/raw/qm9.zip
  1%|          | 704/133885 [00:00<01:42, 1302.98it/s]
  1%|          | 845/133885 [00:00<01:39, 1335.63it/s]
  1%|          | 989/133885 [00:00<01:37, 1366.86it/s]
  1%|          | 1127/133885 [00:00<01:37, 1357.97it/s]
  1%|    

 94%|█████████▍| 126416/133885 [02:02<00:06, 1162.27it/s]
 95%|█████████▍| 126533/133885 [02:03<00:06, 1146.34it/s]
 95%|█████████▍| 126648/133885 [02:03<00:06, 1142.50it/s]
 95%|█████████▍| 126764/133885 [02:03<00:06, 1145.03it/s]
 95%|█████████▍| 126886/133885 [02:03<00:06, 1164.64it/s]
 95%|█████████▍| 127003/133885 [02:03<00:06, 1146.97it/s]
 95%|█████████▍| 127118/133885 [02:03<00:05, 1146.74it/s]
 95%|█████████▌| 127242/133885 [02:03<00:05, 1173.64it/s]
 95%|█████████▌| 127367/133885 [02:03<00:05, 1194.18it/s]
 95%|█████████▌| 127487/133885 [02:03<00:05, 1185.13it/s]
 95%|█████████▌| 127619/133885 [02:04<00:05, 1223.85it/s]
 95%|█████████▌| 127743/133885 [02:04<00:05, 1226.64it/s]
 96%|█████████▌| 127871/133885 [02:04<00:04, 1240.77it/s]
 96%|█████████▌| 128004/133885 [02:04<00:04, 1266.31it/s]
 96%|█████████▌| 128131/133885 [02:04<00:04, 1238.03it/s]
 96%|█████████▌| 128260/133885 [02:04<00:04, 1252.99it/s]
 91%|█████████▏| 122239/133885 [01:59<00:11, 976.14it/s]
 96%|█████████▌

 99%|█████████▉| 132348/133885 [02:09<00:01, 1390.11it/s]
 99%|█████████▉| 132489/133885 [02:09<00:01, 1327.77it/s]
 99%|█████████▉| 132624/133885 [02:09<00:00, 1272.33it/s]
 99%|█████████▉| 132757/133885 [02:09<00:00, 1285.85it/s]
 99%|█████████▉| 132902/133885 [02:10<00:00, 1331.22it/s]
 99%|█████████▉| 133061/133885 [02:10<00:00, 1404.72it/s]
 99%|█████████▉| 133203/133885 [02:10<00:00, 1375.99it/s]
100%|█████████▉| 133358/133885 [02:10<00:00, 1423.47it/s]
100%|█████████▉| 133505/133885 [02:10<00:00, 1434.11it/s]
100%|█████████▉| 133649/133885 [02:10<00:00, 1366.12it/s]
100%|█████████▉| 133787/133885 [02:10<00:00, 1298.31it/s]
100%|██████████| 133885/133885 [02:10<00:00, 1024.00it/s]
[36m(train_ray pid=10415)[0m Done!
2025-11-21 16:38:22,289	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/marcos/ray_results/train_ray_2025-11-21_16-35-39' in 0.0040s.
2025-11-21 16:38:22,295	INFO tune.py:1041 -- Total run time: 163.03 seconds (162.99 

Best config: {'hidden': 64, 'lr': 0.001425497839559118, 'batch_size': 32, 'epochs': 3, 'subset': 10}


In [25]:
results

ResultGrid<[
  Result(
    metrics={'val_loss': 0.5780990719795227},
    path='/home/marcos/ray_results/train_ray_2025-11-21_16-35-39/train_ray_42844_00000_0_batch_size=32,hidden=64,lr=0.0014_2025-11-21_16-35-39',
    filesystem='local',
    checkpoint=None
  ),
  Result(
    metrics={'val_loss': 0.7820579409599304},
    path='/home/marcos/ray_results/train_ray_2025-11-21_16-35-39/train_ray_42844_00001_1_batch_size=32,hidden=64,lr=0.0025_2025-11-21_16-35-39',
    filesystem='local',
    checkpoint=None
  )
]>

### Schnet

In [26]:
@PreprocessorRegistry.register("schnet")
class SchNetPreprocessor(BasePreprocessor):
    def __init__(self, **kwargs):
        # SchNet often handles connectivity internally (e.g., using radius graph
        # and pos), so we might not need an explicit transform here, depending on 
        # how the model is implemented. Sticking with None for consistency.
        cutoff = kwargs.pop("cutoff", 10.0)
        super().__init__(transform=None, **kwargs)
        
    
    def _format_dataset(self, dataset, is_inference):
        target_col = self.target

        processed = []
        for d in dataset:
            d_new = d.clone()  # <-- keeps all fields: pos, z, edge_index, etc

            if is_inference:
                # remove y safely if exists
                if hasattr(d_new, "y"):
                    del d_new.y
            else:
                d_new.y = d.y[:, target_col].unsqueeze(1)

            processed.append(d_new)

        return processed

In [38]:
def loaders_schnet(batch_size, subset=10):

	model_type = 'schnet'


	prep = PreprocessorRegistry.create(
	    model_type,
	    target=0,
	    root=DATA_DIR,
	    subset=subset,
	)
    

	train_ds, val_ds = prep.preprocess() # -> this result goers to ray.put
    
    
	train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
	val_loader = DataLoader(val_ds, batch_size=batch_size)
	return train_loader, val_loader

In [34]:
import torch
import torch.nn as nn
from torch_geometric.nn import SchNet

class TunableSchNet(nn.Module):
    def __init__(self, hidden_channels, num_filters, num_interactions,
                 num_gaussians, cutoff):
        super().__init__()

        self.model = SchNet(
            hidden_channels=hidden_channels,
            num_filters=num_filters,
            num_interactions=num_interactions,
            num_gaussians=num_gaussians,
            cutoff=cutoff,
            readout='add'
        )

        # IMPORTANT: output dim = hidden_channels
        self.regressor = nn.Linear(hidden_channels, 1)

    def forward(self, z, pos, batch):
        x = self.model(z, pos, batch)   # graph embeddings
        return self.regressor(x).squeeze(-1)

    
    
def run_epoch(loader, model, criterion, optimizer=None, train=True):
    if train:
        model.train()
    else:
        model.eval()

    total = 0
    for batch in loader:
        batch = batch.to(model.regressor.weight.device)

        pred = model(batch.z, batch.pos, batch.batch)
        loss = criterion(pred, batch.y)

        if train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        total += float(loss)

    return total / len(loader)



def train_ray(config):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # small subset for speed
    subset = config["subset"]

    train_loader, val_loader = loadersSchNet(
        batch_size=config["batch_size"],
        subset=subset
    )

    model = TunableSchNet(
        hidden_channels=config["hidden_channels"],
        num_filters=config["num_filters"],
        num_interactions=config["num_interactions"],
        num_gaussians=config["num_gaussians"],
        cutoff=config["cutoff"]
    ).to(device)
    
    

    
    optimizer = Adam(list(model.parameters()), lr=config["lr"])
    criterion = MSELoss()



    for epoch in range(1, 11):
        train_loss = run_epoch(train_loader, model, criterion, optimizer, train=True)
        val_loss = run_epoch(val_loader, model, criterion, optimizer, train=False)
        print(f"Epoch {epoch:02d} | Train: {train_loss:.4f} | Val: {val_loss:.4f}")

        # report to Ray Tune
        tune.report({"val_loss":val_loss})


        
        
search_space = {
    "hidden_channels": tune.choice([64, 128, 256]),
    "num_filters": tune.choice([64, 128, 256]),
    "num_interactions": tune.choice([3, 4, 6]),
    "num_gaussians": tune.choice([25, 50, 75]),
    "cutoff": tune.choice([5.0, 6.0, 8.0, 10.0]),

    # training hyperparams
    "lr": tune.loguniform(1e-4, 3e-3),
    "batch_size": tune.choice([16, 32]),
    "epochs": 3,
    "subset": 10,
}


scheduler = ASHAScheduler(
    max_t=5,
    grace_period=1,
    reduction_factor=2,
)

tuner = tune.Tuner(
    train_ray,
    tune_config=tune.TuneConfig(
        metric="val_loss",
        mode="min",
        scheduler=scheduler,
        num_samples=2,   # small number of trials
    ),
    param_space=search_space,
)

results = tuner.fit()
best = results.get_best_result("val_loss", "min")
print("Best config:", best.config)



0,1
Current time:,2025-11-21 17:05:51
Running for:,00:02:56.78
Memory:,4.3/9.6 GiB

Trial name,# failures,error file
train_ray_118ca_00000,1,"/tmp/ray/session_2025-11-21_16-29-26_941264_9190/artifacts/2025-11-21_17-02-55/train_ray_2025-11-21_17-02-55/driver_artifacts/train_ray_118ca_00000_0_batch_size=32,cutoff=6.0000,hidden_channels=256,lr=0.0007,num_filters=128,num_gaussians=75,num_interaction_2025-11-21_17-02-55/error.txt"
train_ray_118ca_00001,1,"/tmp/ray/session_2025-11-21_16-29-26_941264_9190/artifacts/2025-11-21_17-02-55/train_ray_2025-11-21_17-02-55/driver_artifacts/train_ray_118ca_00001_1_batch_size=32,cutoff=5.0000,hidden_channels=256,lr=0.0009,num_filters=256,num_gaussians=75,num_interaction_2025-11-21_17-02-55/error.txt"

Trial name,status,loc,batch_size,cutoff,hidden_channels,lr,num_filters,num_gaussians,num_interactions
train_ray_118ca_00000,ERROR,200.40.20.243:12306,32,6,256,0.000684638,128,75,3
train_ray_118ca_00001,ERROR,200.40.20.243:12307,32,5,256,0.000935475,256,75,6


[36m(train_ray pid=12307)[0m Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/molnet_publish/qm9.zip
[36m(train_ray pid=12307)[0m Extracting data/QM9/raw/qm9.zip
[36m(train_ray pid=12306)[0m Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/molnet_publish/qm9.zip
[36m(train_ray pid=12307)[0m Downloading https://ndownloader.figshare.com/files/3195404
[36m(train_ray pid=12307)[0m Processing...
  0%|          | 0/133885 [00:00<?, ?it/s]
  0%|          | 144/133885 [00:00<01:33, 1431.03it/s]
  0%|          | 291/133885 [00:00<01:31, 1453.68it/s]
  0%|          | 437/133885 [00:00<01:35, 1395.11it/s]
[36m(train_ray pid=12306)[0m Extracting data/QM9/raw/qm9.zip
  0%|          | 577/133885 [00:00<01:42, 1300.85it/s]
  1%|          | 708/133885 [00:00<01:49, 1217.00it/s]
  1%|          | 836/133885 [00:00<01:47, 1236.55it/s]
  1%|          | 986/133885 [00:00<01:40, 1316.72it/s]
  1%|          | 1120/133885 [00:00<01:40, 1322.58it/s]
  1%|    

 95%|█████████▌| 127832/133885 [02:09<00:05, 1181.15it/s]
 92%|█████████▏| 122589/133885 [02:07<00:12, 914.67it/s]
 96%|█████████▌| 127953/133885 [02:09<00:04, 1187.57it/s]
 92%|█████████▏| 122681/133885 [02:07<00:12, 892.81it/s]
 96%|█████████▌| 128072/133885 [02:09<00:04, 1183.71it/s]
 92%|█████████▏| 122788/133885 [02:07<00:11, 943.23it/s]
 96%|█████████▌| 128191/133885 [02:09<00:04, 1181.70it/s]
 92%|█████████▏| 122887/133885 [02:07<00:11, 955.52it/s]
 96%|█████████▌| 128314/133885 [02:09<00:04, 1195.56it/s]
 92%|█████████▏| 122983/133885 [02:07<00:11, 954.37it/s]
 96%|█████████▌| 128434/133885 [02:09<00:04, 1192.49it/s]
 92%|█████████▏| 123079/133885 [02:07<00:11, 934.78it/s]
 96%|█████████▌| 128554/133885 [02:09<00:04, 1163.62it/s]
 92%|█████████▏| 123173/133885 [02:07<00:11, 929.46it/s]
 96%|█████████▌| 128671/133885 [02:09<00:04, 1085.41it/s]
 92%|█████████▏| 123267/133885 [02:07<00:12, 864.83it/s]
 96%|█████████▌| 128781/133885 [02:10<00:04, 1068.53it/s]
 92%|█████████▏| 12337

 99%|█████████▉| 132692/133885 [02:17<00:01, 1154.75it/s]
 99%|█████████▉| 132809/133885 [02:17<00:00, 1151.93it/s]
 99%|█████████▉| 132960/133885 [02:17<00:00, 1255.78it/s]
 99%|█████████▉| 133099/133885 [02:17<00:00, 1293.11it/s]
100%|█████████▉| 133229/133885 [02:17<00:00, 1293.14it/s]
100%|█████████▉| 133382/133885 [02:17<00:00, 1360.72it/s]
100%|█████████▉| 133521/133885 [02:17<00:00, 1366.64it/s]
100%|█████████▉| 133658/133885 [02:17<00:00, 1298.38it/s]
100%|█████████▉| 133789/133885 [02:18<00:00, 1244.73it/s]
100%|██████████| 133885/133885 [02:18<00:00, 969.57it/s] 
2025-11-21 17:05:51,891	ERROR tune_controller.py:1331 -- Trial task failed for trial train_ray_118ca_00000
Traceback (most recent call last):
  File "/home/marcos/.local/lib/python3.10/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
  File "/home/marcos/.local/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapp

RuntimeError: No best trial found for the given metric: val_loss. This means that no trial has reported this metric, or all values reported for this metric are NaN. To not ignore NaN values, you can set the `filter_nan_and_inf` arg to False.

In [45]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import MSELoss
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import SchNet

from ray import tune
from ray.tune.schedulers import ASHAScheduler

# ============================================================
#                   Tunable SchNet Model
# ============================================================

class TunableSchNet(nn.Module):
    def __init__(self, hidden_channels, num_filters,
                 num_interactions, num_gaussians, cutoff):
        super().__init__()

        self.model = SchNet(
            hidden_channels=hidden_channels,
            num_filters=num_filters,
            num_interactions=num_interactions,
            num_gaussians=num_gaussians,
            cutoff=cutoff,
            readout='add'
        )

        # keep regressor but use conditionally
        self.regressor = nn.Linear(hidden_channels, 1)

    def forward(self, batch):
        # call SchNet with explicit z, pos, batch
        x = self.model(batch.z, batch.pos, batch.batch)

        # possible outputs:
        # - shape [B, hidden_channels]  -> apply regressor
        # - shape [B, 1] or [B]         -> already scalar per graph
        if x.dim() == 1:                # [B]
            return x
        if x.dim() == 2 and x.size(1) == 1:  # [B,1]
            return x.view(-1)
        # otherwise assume [B, hidden_channels]
        return self.regressor(x).squeeze(-1)

        

# ============================================================
#                   Dataset Loader
# ============================================================



# ============================================================
#                   Training Helpers
# ============================================================

def run_epoch(loader, model, criterion, optimizer=None, train=True):
    model.train() if train else model.eval()

    total = 0
    for batch in loader:
        batch = batch.to(next(model.parameters()).device)

        pred = model(batch)             
        #loss = criterion(pred, batch.y)
        target = batch.y.view(-1)   # [B]
        loss = criterion(pred, target)


        if train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        total += float(loss)

    return total / len(loader)


# ============================================================
#                   Ray Tune Trainable
# ============================================================

def train_ray(config):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # loader size is part of config
    train_loader, val_loader = loaders_schnet(
        batch_size=config["batch_size"],
        subset=config["subset"]
    )

    model = TunableSchNet(
        hidden_channels=config["hidden_channels"],
        num_filters=config["num_filters"],
        num_interactions=config["num_interactions"],
        num_gaussians=config["num_gaussians"],
        cutoff=config["cutoff"]
    ).to(device)

    optimizer = Adam(model.parameters(), lr=config["lr"])
    criterion = MSELoss()

    # train loop
    for epoch in range(config["epochs"]):
        train_loss = run_epoch(train_loader, model, criterion, optimizer, train=True)
        val_loss = run_epoch(val_loader, model, criterion, train=False)

        # must report for Ray
        tune.report({"val_loss":val_loss})


# ============================================================
#                   Search Space + Tuner
# ============================================================

search_space = {
    # SchNet architecture
    "hidden_channels": tune.choice([64, 128, 256]),
    "num_filters": tune.choice([64, 128, 256]),
    "num_interactions": tune.choice([3, 4, 6]),
    "num_gaussians": tune.choice([25, 50, 75]),
    "cutoff": tune.choice([5.0, 6.0, 8.0, 10.0]),

    # Training hyperparams
    "lr": tune.loguniform(1e-4, 3e-3),
    "batch_size": tune.choice([16, 32]),
    "epochs": 2,
    "subset": 10,
}

scheduler = ASHAScheduler(
    max_t=5,
    grace_period=1,
    reduction_factor=2,
)

tuner = tune.Tuner(
    train_ray,
    tune_config=tune.TuneConfig(
        metric="val_loss",
        mode="min",
        scheduler=scheduler,
        num_samples=1,
    ),
    param_space=search_space,
)

results = tuner.fit()

best = results.get_best_result(metric="val_loss", mode="min")
print("\nBest config:")
print(best.config)


0,1
Current time:,2025-11-21 17:37:44
Running for:,00:02:42.36
Memory:,3.9/9.6 GiB

Trial name,status,loc,batch_size,cutoff,hidden_channels,lr,num_filters,num_gaussians,num_interactions,iter,total time (s),val_loss
train_ray_8e635_00000,TERMINATED,200.40.20.243:15421,16,5,256,0.000284235,64,50,4,2,158.123,1.76922


[36m(train_ray pid=15421)[0m Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/molnet_publish/qm9.zip
[36m(train_ray pid=15421)[0m Extracting data/QM9/raw/qm9.zip
[36m(train_ray pid=15421)[0m Downloading https://ndownloader.figshare.com/files/3195404
[36m(train_ray pid=15421)[0m Processing...
  0%|          | 0/133885 [00:00<?, ?it/s]
  0%|          | 144/133885 [00:00<01:33, 1436.59it/s]
  0%|          | 291/133885 [00:00<01:32, 1449.45it/s]
  0%|          | 436/133885 [00:00<01:37, 1366.82it/s]
  0%|          | 574/133885 [00:00<01:39, 1345.57it/s]
  1%|          | 709/133885 [00:00<01:42, 1293.46it/s]
  1%|          | 849/133885 [00:00<01:40, 1326.09it/s]
  1%|          | 989/133885 [00:00<01:38, 1348.65it/s]
  1%|          | 1125/133885 [00:00<01:39, 1335.61it/s]
  1%|          | 1259/133885 [00:00<01:41, 1311.69it/s]
  1%|          | 1391/133885 [00:01<01:41, 1304.47it/s]
  1%|          | 1522/133885 [00:01<01:41, 1298.15it/s]
  1%|          | 1659/13388

 13%|█▎        | 16919/133885 [00:15<01:49, 1071.73it/s]
 13%|█▎        | 17027/133885 [00:15<01:53, 1025.44it/s]
 13%|█▎        | 17130/133885 [00:15<01:54, 1020.59it/s]
 13%|█▎        | 17235/133885 [00:15<01:53, 1026.99it/s]
 13%|█▎        | 17338/133885 [00:15<01:54, 1022.06it/s]
 13%|█▎        | 17443/133885 [00:15<01:53, 1029.98it/s]
 13%|█▎        | 17661/133885 [00:15<01:49, 1060.70it/s]
 13%|█▎        | 17768/133885 [00:16<03:05, 624.54it/s] 
 13%|█▎        | 17867/133885 [00:16<02:46, 697.17it/s]
 13%|█▎        | 17970/133885 [00:16<02:30, 769.77it/s]
 14%|█▎        | 18076/133885 [00:16<02:18, 838.77it/s]
 14%|█▎        | 18179/133885 [00:16<02:10, 886.01it/s]
 14%|█▎        | 18284/133885 [00:16<02:04, 928.04it/s]
 14%|█▎        | 18387/133885 [00:16<02:00, 955.68it/s]
 14%|█▍        | 18496/133885 [00:16<01:56, 991.95it/s]
 14%|█▍        | 18602/133885 [00:17<01:54, 1011.02it/s]
 14%|█▍        | 18710/133885 [00:17<01:51, 1028.56it/s]
 14%|█▍        | 18815/133885 [00:17<0

 26%|██▌       | 34214/133885 [00:30<01:37, 1024.00it/s]
 26%|██▌       | 34317/133885 [00:30<01:37, 1018.78it/s]
 26%|██▌       | 34421/133885 [00:30<01:37, 1024.24it/s]
 26%|██▌       | 34526/133885 [00:31<01:36, 1029.38it/s]
 26%|██▌       | 34633/133885 [00:31<01:35, 1040.47it/s]
 26%|██▌       | 34738/133885 [00:31<01:35, 1036.64it/s]
 26%|██▌       | 34842/133885 [00:31<01:35, 1033.34it/s]
 26%|██▌       | 34946/133885 [00:31<01:36, 1021.59it/s]
 26%|██▌       | 35049/133885 [00:31<01:36, 1020.08it/s]
 26%|██▋       | 35155/133885 [00:31<01:35, 1031.06it/s]
 26%|██▋       | 35261/133885 [00:31<01:35, 1038.15it/s]
 26%|██▋       | 35366/133885 [00:31<01:34, 1040.47it/s]
 26%|██▋       | 35471/133885 [00:31<01:34, 1036.99it/s]
 27%|██▋       | 35575/133885 [00:32<01:35, 1033.46it/s]
 27%|██▋       | 35684/133885 [00:32<01:33, 1047.73it/s]
 27%|██▋       | 35790/133885 [00:32<01:33, 1049.04it/s]
 27%|██▋       | 35896/133885 [00:32<01:33, 1049.95it/s]
 27%|██▋       | 36002/133885 [

 37%|███▋      | 49384/133885 [00:46<01:18, 1080.78it/s]
 37%|███▋      | 49493/133885 [00:46<01:19, 1061.65it/s]
 37%|███▋      | 49600/133885 [00:46<01:21, 1039.38it/s]
 37%|███▋      | 49707/133885 [00:46<01:20, 1046.55it/s]
 37%|███▋      | 49812/133885 [00:46<01:20, 1043.66it/s]
 37%|███▋      | 49925/133885 [00:46<01:18, 1068.74it/s]
 37%|███▋      | 50041/133885 [00:46<01:16, 1093.53it/s]
 37%|███▋      | 50151/133885 [00:46<01:17, 1087.39it/s]
 38%|███▊      | 50260/133885 [00:46<01:18, 1063.44it/s]
 38%|███▊      | 50367/133885 [00:46<01:19, 1050.03it/s]
 38%|███▊      | 50474/133885 [00:47<01:19, 1054.78it/s]
 38%|███▊      | 50581/133885 [00:47<01:18, 1059.14it/s]
 38%|███▊      | 50687/133885 [00:47<01:20, 1033.12it/s]
 38%|███▊      | 50791/133885 [00:47<01:21, 1021.70it/s]
 38%|███▊      | 50908/133885 [00:47<01:17, 1064.02it/s]
 38%|███▊      | 51017/133885 [00:47<01:17, 1070.80it/s]
 38%|███▊      | 51125/133885 [00:47<01:18, 1049.44it/s]
 38%|███▊      | 51231/133885 [

 49%|████▊     | 64958/133885 [01:01<01:06, 1042.64it/s]
 49%|████▊     | 65064/133885 [01:01<01:05, 1044.91it/s]
 49%|████▊     | 65169/133885 [01:01<01:07, 1021.56it/s]
 49%|████▉     | 65272/133885 [01:01<01:07, 1010.17it/s]
 49%|████▉     | 65375/133885 [01:01<01:07, 1015.13it/s]
 49%|████▉     | 65482/133885 [01:01<01:06, 1028.88it/s]
 49%|████▉     | 65585/133885 [01:01<01:07, 1006.99it/s]
 49%|████▉     | 65686/133885 [01:01<01:07, 1006.28it/s]
 49%|████▉     | 65793/133885 [01:01<01:06, 1024.03it/s]
 49%|████▉     | 65905/133885 [01:02<01:04, 1050.93it/s]
 49%|████▉     | 66011/133885 [01:02<01:05, 1044.02it/s]
 49%|████▉     | 66120/133885 [01:02<01:04, 1056.59it/s]
 49%|████▉     | 66226/133885 [01:02<01:04, 1044.42it/s]
 50%|████▉     | 66331/133885 [01:02<01:04, 1044.71it/s]
 50%|████▉     | 66437/133885 [01:02<01:04, 1047.97it/s]
 50%|████▉     | 66542/133885 [01:02<01:04, 1041.19it/s]
 50%|████▉     | 66647/133885 [01:02<01:05, 1022.04it/s]
 50%|████▉     | 66750/133885 [

 60%|█████▉    | 80077/133885 [01:16<00:56, 946.79it/s]
 60%|█████▉    | 80178/133885 [01:16<00:55, 962.73it/s]
 60%|█████▉    | 80284/133885 [01:17<00:54, 989.85it/s]
 60%|██████    | 80388/133885 [01:17<00:53, 1001.48it/s]
 60%|██████    | 80489/133885 [01:17<00:54, 984.10it/s] 
 60%|██████    | 80589/133885 [01:17<00:53, 987.05it/s]
 60%|██████    | 80688/133885 [01:17<00:54, 984.93it/s]
 60%|██████    | 80791/133885 [01:17<00:53, 997.73it/s]
 60%|██████    | 80891/133885 [01:17<00:53, 997.10it/s]
 60%|██████    | 80991/133885 [01:17<00:53, 979.83it/s]
 61%|██████    | 81090/133885 [01:17<00:54, 974.73it/s]
 61%|██████    | 81190/133885 [01:17<00:53, 979.37it/s]
 61%|██████    | 81288/133885 [01:18<00:53, 978.09it/s]
 61%|██████    | 81386/133885 [01:18<00:54, 965.37it/s]
 61%|██████    | 81483/133885 [01:18<00:54, 954.04it/s]
 61%|██████    | 81582/133885 [01:18<00:54, 962.30it/s]
 61%|██████    | 81679/133885 [01:18<00:55, 944.48it/s]
 61%|██████    | 81774/133885 [01:18<00:55, 94

 71%|███████   | 94801/133885 [01:31<00:43, 905.54it/s]
 71%|███████   | 94892/133885 [01:32<00:43, 895.78it/s]
 71%|███████   | 94987/133885 [01:32<00:42, 911.10it/s]
 71%|███████   | 95079/133885 [01:32<01:45, 369.57it/s]
 71%|███████   | 95177/133885 [01:32<01:24, 457.85it/s]
 71%|███████   | 95273/133885 [01:32<01:11, 543.33it/s]
 71%|███████   | 95365/133885 [01:33<01:02, 616.85it/s]
 71%|███████▏  | 95466/133885 [01:33<00:54, 702.58it/s]
 71%|███████▏  | 95559/133885 [01:33<00:50, 756.28it/s]
 71%|███████▏  | 95662/133885 [01:33<00:46, 824.76it/s]
 72%|███████▏  | 95762/133885 [01:33<00:44, 865.12it/s]
 72%|███████▏  | 95858/133885 [01:33<00:42, 889.72it/s]
 72%|███████▏  | 95956/133885 [01:33<00:41, 914.28it/s]
 72%|███████▏  | 96062/133885 [01:33<00:39, 954.03it/s]
 72%|███████▏  | 96161/133885 [01:33<00:40, 939.16it/s]
 72%|███████▏  | 96259/133885 [01:34<00:39, 949.29it/s]
 72%|███████▏  | 96356/133885 [01:34<00:39, 949.26it/s]
 72%|███████▏  | 96454/133885 [01:34<00:39, 956.

 82%|████████▏ | 109568/133885 [01:47<00:25, 963.89it/s]
 82%|████████▏ | 109665/133885 [01:47<00:25, 956.37it/s]
 82%|████████▏ | 109767/133885 [01:47<00:24, 973.13it/s]
 82%|████████▏ | 109865/133885 [01:47<00:25, 951.04it/s]
 82%|████████▏ | 109961/133885 [01:47<00:25, 939.26it/s]
 82%|████████▏ | 110058/133885 [01:47<00:25, 945.68it/s]
 82%|████████▏ | 110157/133885 [01:48<00:24, 957.83it/s]
 82%|████████▏ | 110256/133885 [01:48<00:24, 966.78it/s]
 82%|████████▏ | 110353/133885 [01:48<00:24, 962.13it/s]
 82%|████████▏ | 110450/133885 [01:48<00:25, 902.73it/s]
 83%|████████▎ | 110542/133885 [01:48<00:26, 893.25it/s]
 83%|████████▎ | 110632/133885 [01:48<00:25, 894.47it/s]
 83%|████████▎ | 110728/133885 [01:48<00:25, 912.56it/s]
 83%|████████▎ | 110827/133885 [01:48<00:24, 934.56it/s]
 83%|████████▎ | 110923/133885 [01:48<00:24, 941.02it/s]
 83%|████████▎ | 111018/133885 [01:48<00:24, 929.32it/s]
 83%|████████▎ | 111112/133885 [01:49<00:24, 925.97it/s]
 83%|████████▎ | 111205/133885 

 93%|█████████▎| 124124/133885 [02:02<00:08, 1134.10it/s]
 93%|█████████▎| 124243/133885 [02:02<00:08, 1149.86it/s]
 93%|█████████▎| 124371/133885 [02:02<00:08, 1186.34it/s]
 93%|█████████▎| 124490/133885 [02:02<00:08, 1174.24it/s]
 93%|█████████▎| 124608/133885 [02:02<00:08, 1149.24it/s]
 93%|█████████▎| 124724/133885 [02:02<00:07, 1149.40it/s]
 93%|█████████▎| 124844/133885 [02:02<00:07, 1163.35it/s]
 93%|█████████▎| 124961/133885 [02:03<00:07, 1147.99it/s]
 93%|█████████▎| 125084/133885 [02:03<00:07, 1170.84it/s]
 94%|█████████▎| 125206/133885 [02:03<00:07, 1182.48it/s]
 94%|█████████▎| 125325/133885 [02:03<00:07, 1179.44it/s]
 94%|█████████▎| 125449/133885 [02:03<00:07, 1196.49it/s]
 94%|█████████▍| 125569/133885 [02:03<00:06, 1189.20it/s]
 94%|█████████▍| 125692/133885 [02:03<00:06, 1200.56it/s]
 94%|█████████▍| 125817/133885 [02:03<00:06, 1212.38it/s]
 94%|█████████▍| 125939/133885 [02:03<00:06, 1180.09it/s]
 94%|█████████▍| 126058/133885 [02:04<00:06, 1154.82it/s]
 94%|█████████


Best config:
{'hidden_channels': 256, 'num_filters': 64, 'num_interactions': 4, 'num_gaussians': 50, 'cutoff': 5.0, 'lr': 0.0002842349613009927, 'batch_size': 16, 'epochs': 2, 'subset': 10}


In [46]:
results

ResultGrid<[
  Result(
    metrics={'val_loss': 1.7692179679870605},
    path='/home/marcos/ray_results/train_ray_2025-11-21_17-35-02/train_ray_8e635_00000_0_batch_size=16,cutoff=5.0000,hidden_channels=256,lr=0.0003,num_filters=64,num_gaussians=50,num_interactions_2025-11-21_17-35-02',
    filesystem='local',
    checkpoint=None
  )
]>

In [None]:
"""
from torch import nn
from torch_geometric.nn.models.schnet import SchNet

class TunableSchNet(nn.Module):
    def __init__(
        self,
        hidden_channels=64,
        num_filters=64,
        num_interactions=3,
        num_gaussians=50,
        cutoff=10.0,
        readout="add"
    ):
        super().__init__()

        self.model = SchNet(
            hidden_channels=hidden_channels,
            num_filters=num_filters,
            num_interactions=num_interactions,
            num_gaussians=num_gaussians,
            cutoff=cutoff,
            readout=readout
        )

        # SchNet outputs graph embeddings of size hidden_channels
        self.regressor = nn.Linear(hidden_channels, 1)

    def forward(self, batch):
        x = self.model(batch)          # [num_graphs, hidden]
        out = self.regressor(x)        # [num_graphs, 1]
        return out




search_space = {
    "hidden_channels": tune.choice([64, 128, 256]),
    "num_filters": tune.choice([64, 128, 256]),
    "num_interactions": tune.choice([3, 4, 6]),
    "num_gaussians": tune.choice([25, 50, 75]),
    "cutoff": tune.choice([5.0, 6.0, 8.0, 10.0]),

    # training hyperparams
    "lr": tune.loguniform(1e-4, 3e-3),
    "batch_size": tune.choice([16, 32]),
    "epochs": 3,
    "subset": 10,
}


def train_ray(config):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    train_loader, val_loader = loadersGCN(
        batch_size=config["batch_size"],
        subset=config["subset"]
    )

    model = TunableSchNet(
        hidden_channels=config["hidden_channels"],
        num_filters=config["num_filters"],
        num_interactions=config["num_interactions"],
        num_gaussians=config["num_gaussians"],
        cutoff=config["cutoff"]
    ).to(device)

    optimizer = Adam(model.parameters(), lr=config["lr"])
    criterion = MSELoss()

    for epoch in range(config["epochs"]):
        train_loss = run_epoch(train_loader, model, criterion, optimizer)
        val_loss = run_epoch(val_loader, model, criterion)

        tune.report({"val_loss": val_loss})


"""

now, withgout downlading data everytime

In [None]:

# ---- GLOBAL PRELOAD (runs once) ----
"""
prep = PreprocessorRegistry.create("mlp", root="data/QM9", subset=50)
dataset = prep.preprocess()          # cached on disk after first run

def loadersv2(batch_size=32, subset=None):
    global dataset                 # ← reuse the preloaded dataset

    ds = dataset
    if subset is not None:
        ds = ds[:subset]

    train_ds, val_ds = split(ds)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size)
    return train_loader, val_loader

"""
