In [None]:
# Compact SchNet-style PyTorch script for predicting scene density (per-scene N/volume)
# - Pure PyTorch, no external GNN libraries
# - Generates random scenes, trains small SchNet-like model, plots training/validation loss
# - Saves model to /mnt/data/schnet_density.pth
# Run where PyTorch and matplotlib are available.

import math, random, time, torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

torch.set_num_threads(1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


# ---------------- Data: random scenes ----------------
class RandomSceneDataset(Dataset):
    def __init__(
        self, n_scenes, min_atoms=5, max_atoms=60, min_box=3.0, max_box=8.0, seed=None
    ):
        self.scenes = []
        if seed is not None:
            random.seed(seed)
            torch.manual_seed(seed)
        for _ in range(n_scenes):
            N = random.randint(min_atoms, max_atoms)
            side = random.uniform(min_box, max_box)
            pos = torch.rand(N, 3) * side  # uniform inside cube
            volume = side**3
            density = N / volume
            # use a single dummy atom type (all zeros) — easy to extend to multiple types
            Z = torch.zeros(N, dtype=torch.long)
            self.scenes.append(
                {"pos": pos, "Z": Z, "side": side, "volume": volume, "density": density}
            )

    def __len__(self):
        return len(self.scenes)

    def __getitem__(self, idx):
        s = self.scenes[idx]
        return (
            s["pos"],
            s["Z"],
            s["volume"],
            torch.tensor([s["density"]], dtype=torch.float32),
        )


def collate_fn(batch):
    # batch: list of (pos, Z, volume, density)
    positions = []
    Zs = []
    volumes = []
    densities = []
    batch_idx = []
    for i, (pos, Z, vol, dens) in enumerate(batch):
        positions.append(pos)
        Zs.append(Z)
        volumes.append(vol)
        densities.append(dens)
        batch_idx.append(torch.full((pos.shape[0],), i, dtype=torch.long))
    positions = torch.cat(positions, dim=0)
    Zs = torch.cat(Zs, dim=0)
    batch_idx = torch.cat(batch_idx, dim=0)
    volumes = torch.tensor(volumes, dtype=torch.float32)
    densities = torch.cat(densities, dim=0)
    return positions, Zs, batch_idx, volumes, densities


# ---------------- radial basis (Gaussian RBF) ----------------
class GaussianRBF(nn.Module):
    def __init__(self, centers, gamma):
        super().__init__()
        self.register_buffer("centers", torch.tensor(centers).float())
        self.gamma = float(gamma)

    def forward(self, d):
        # d: (E,) or (...,)
        c = self.centers.to(d.device)
        return torch.exp(-self.gamma * (d.unsqueeze(-1) - c) ** 2)  # (..., K)


# ---------------- radius graph (batched) ----------------
# --- Fixed radius_graph (keeps everything on the same device) ---
def radius_graph(positions, batch, cutoff):
    # positions: (TotalNodes,3) tensor on some device
    device = positions.device
    edge_src = []
    edge_dst = []
    edge_vec = []
    edge_dist = []
    unique = batch.unique(sorted=True)
    for b in unique:
        # b is a scalar tensor on 'device'
        idx = (batch == b).nonzero(as_tuple=True)[0]  # indices on same device
        if idx.numel() <= 1:
            continue
        pos_b = positions[idx]  # (n_b,3) on device
        diff = pos_b.unsqueeze(1) - pos_b.unsqueeze(0)  # (n_b,n_b,3)
        dist2 = (diff**2).sum(dim=-1)  # (n_b,n_b)
        n_b = pos_b.shape[0]
        # create eye on same device
        mask = (~torch.eye(n_b, dtype=torch.bool, device=device)) & (
            dist2 <= cutoff * cutoff
        )
        src_local, dst_local = mask.nonzero(as_tuple=True)
        if src_local.numel() == 0:
            continue
        src_global = idx[src_local]
        dst_global = idx[dst_local]
        rel = positions[dst_global] - positions[src_global]
        dist = torch.sqrt((rel**2).sum(dim=-1) + 1e-8)
        edge_src.append(src_global)
        edge_dst.append(dst_global)
        edge_vec.append(rel)
        edge_dist.append(dist)
    if len(edge_src) == 0:
        # ensure returned tensors are on the same device
        return (
            torch.zeros((2, 0), dtype=torch.long, device=device),
            torch.zeros((0, 3), device=device),
            torch.zeros((0,), device=device),
        )
    edge_src = torch.cat(edge_src)
    edge_dst = torch.cat(edge_dst)
    edge_vec = torch.cat(edge_vec)
    edge_dist = torch.cat(edge_dist)
    edge_index = torch.stack([edge_src, edge_dst], dim=0)
    return edge_index, edge_vec, edge_dist


# ---------------- SchNet-like blocks ----------------
class InteractionBlock(nn.Module):
    def __init__(self, hidden_dim, rbf_dim):
        super().__init__()
        self.filter_net = nn.Sequential(
            nn.Linear(rbf_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim)
        )
        self.dense = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.update = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

    def forward(self, x, edge_index, rbf):
        # x: (N,H), edge_index: (2,E), rbf: (E,K)
        src, dst = edge_index
        filters = self.filter_net(rbf)  # (E,H)
        x_src = x[src]  # (E,H)
        msg = self.dense(x_src) * filters  # (E,H) elementwise
        agg = torch.zeros_like(x)
        agg = agg.index_add(0, dst, msg)  # sum aggregate to dst nodes
        out = self.update(torch.cat([x, agg], dim=-1)) + x
        return out


class SchNetDensity(nn.Module):
    def __init__(
        self,
        n_atom_types=1,
        hidden_dim=64,
        rbf_centers=None,
        gamma=10.0,
        n_interactions=3,
        cutoff=2.0,
    ):
        super().__init__()
        if rbf_centers is None:
            rbf_centers = torch.linspace(0.0, 6.0, 24)
        self.embed = nn.Embedding(n_atom_types, hidden_dim)
        self.rbf = GaussianRBF(rbf_centers, gamma)
        self.interactions = nn.ModuleList(
            [
                InteractionBlock(hidden_dim, len(rbf_centers))
                for _ in range(n_interactions)
            ]
        )
        self.atom_out = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.SiLU(),
            nn.Linear(hidden_dim // 2, 1),
        )
        self.cutoff = cutoff

    def forward(self, positions, Z, batch, volumes):
        # positions: (TotalNodes,3) on some device; Z and batch should be same device
        device = positions.device
        x = self.embed(Z)  # (N,H) on device
        edge_index, edge_vec, edge_dist = radius_graph(positions, batch, self.cutoff)
        if edge_index.numel() > 0:
            rbf = self.rbf(edge_dist)  # (E, K) on same device
            for block in self.interactions:
                x = block(x, edge_index, rbf)
        per_atom = self.atom_out(x).squeeze(-1)  # (N,) on device

        # Sum per-atom contributions per graph (device-safe)
        if batch.numel() == 0:
            return torch.zeros((0,), device=device), per_atom

        batch_size = int(batch.max().item()) + 1
        graph_sum = torch.zeros(batch_size, device=device)  # (B,)
        graph_sum = graph_sum.index_add(0, batch.to(device), per_atom)  # sum per graph
        pred_density = graph_sum / volumes.to(device)  # volumes -> device
        return pred_density, per_atom


# ---------------- training utilities ----------------
def train_epoch(model, loader, opt, criterion):
    model.train()
    running = 0.0
    n = 0
    for positions, Z, batch_idx, volumes, dens in loader:
        positions = positions.to(device)
        Z = Z.to(device)
        batch_idx = batch_idx.to(device)
        volumes = volumes.to(device)
        dens = dens.to(device).squeeze(-1)
        opt.zero_grad()
        pred, _ = model(positions, Z, batch_idx, volumes)
        loss = criterion(pred, dens)
        loss.backward()
        opt.step()
        running += loss.item() * dens.shape[0]
        n += dens.shape[0]
    return running / n


def eval_epoch(model, loader, criterion):
    model.eval()
    running = 0.0
    n = 0
    with torch.no_grad():
        for positions, Z, batch_idx, volumes, dens in loader:
            positions = positions.to(device)
            Z = Z.to(device)
            batch_idx = batch_idx.to(device)
            volumes = volumes.to(device)
            dens = dens.to(device).squeeze(-1)
            pred, _ = model(positions, Z, batch_idx, volumes)
            loss = criterion(pred, dens)
            running += loss.item() * dens.shape[0]
            n += dens.shape[0]
    return running / n


# ---------------- create datasets & loaders ----------------
train_ds = RandomSceneDataset(
    1000, min_atoms=5, max_atoms=60, min_box=3.0, max_box=12.0, seed=0
)
val_ds = RandomSceneDataset(
    1000, min_atoms=5, max_atoms=60, min_box=3.0, max_box=12.0, seed=1
)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, collate_fn=collate_fn)

# ---------------- instantiate and train model ----------------
model = SchNetDensity(
    n_atom_types=1,
    hidden_dim=48,
    rbf_centers=torch.linspace(0.0, 6.0, 20),
    gamma=8.0,
    n_interactions=2,
    cutoff=2.5,
)
model.to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-6)
criterion = nn.MSELoss()

epochs = 30
train_losses = []
val_losses = []
t0 = time.time()
for ep in range(1, epochs + 1):
    tr = train_epoch(model, train_loader, opt, criterion)
    va = eval_epoch(model, val_loader, criterion)
    train_losses.append(tr)
    val_losses.append(va)
    print(f"Epoch {ep}/{epochs}  train={tr:.6e}  val={va:.6e}")
print("Total training time: {:.1f}s".format(time.time() - t0))

# ---------------- plot losses ----------------
plt.figure(figsize=(6, 3.5))
plt.plot(train_losses, label="train")
plt.plot(val_losses, label="val")
plt.xlabel("Epoch")
plt.ylabel("Relative RMSE")
plt.title("Density training")
plt.legend()
plt.grid(True)
plt.show()

# ---------------- example predictions ----------------
model.eval()
with torch.no_grad():
    positions, Z, batch_idx, volumes, dens = next(iter(val_loader))
    pred, per_atom = model(
        positions.to(device), Z.to(device), batch_idx.to(device), volumes.to(device)
    )
    print("\nExample predictions (first batch):")
    for i in range(min(8, pred.shape[0])):
        print(
            f"pred={pred[i].item():.6f}  true={dens[i].item():.6f}  side={(volumes[i]**(1/3)):.3f}  N_est={ (pred[i].item()*volumes[i]).round() }"
        )

# ---------------- save model ----------------
path = "./schnet_density.pth"
torch.save({"model_state": model.state_dict(), "args": {"hidden_dim": 48}}, path)
print("Saved model to", path)