In [3]:
import ast
import math
import os
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Dict
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl

In [10]:
df = pd.read_csv("processed_data.csv")

In [13]:
df.head()

Unnamed: 0,SMILES,Melting Point(K),atoms,coords,mmff_charge,valenced,total
0,Br/C=C/Br,266.5,"{0: 'Br', 1: 'C', 2: 'C', 3: 'Br', 4: 'H', 5: ...","{0: (-2.2736879646642416, -0.3061462432742523,...","{0: -0.11, 1: -0.039999999999999994, 2: -0.039...","{0: 7, 1: 4, 2: 4, 3: 7, 4: 1, 5: 1}","{0: 35, 1: 6, 2: 6, 3: 35, 4: 1, 5: 1}"
1,Br/C=C/c1ccccc1,280.0,"{0: 'Br', 1: 'C', 2: 'C', 3: 'C', 4: 'C', 5: '...","{0: (4.593717257100764, -0.2625038911597898, -...","{0: -0.11, 1: -0.039999999999999994, 2: -0.178...","{0: 7, 1: 4, 2: 4, 3: 4, 4: 4, 5: 4, 6: 4, 7: ...","{0: 35, 1: 6, 2: 6, 3: 6, 4: 6, 5: 6, 6: 6, 7:..."
2,BrC(=C(Br)c1ccccc1)c1ccccc1,501.0,"{0: 'Br', 1: 'C', 2: 'C', 3: 'Br', 4: 'C', 5: ...","{0: (-1.9648385253421696, -1.6302938318552822,...","{0: -0.11, 1: 0.0816, 2: 0.0816, 3: -0.11, 4: ...","{0: 7, 1: 4, 2: 4, 3: 7, 4: 4, 5: 4, 6: 4, 7: ...","{0: 35, 1: 6, 2: 6, 3: 35, 4: 6, 5: 6, 6: 6, 7..."
3,BrC(=C(\Br)c1ccccc1)/C(Br)=C(\Br)c1ccccc1,509.0,"{0: 'Br', 1: 'C', 2: 'C', 3: 'Br', 4: 'C', 5: ...","{0: (-0.7143307734671797, 1.487459433609242, -...","{0: -0.11, 1: 0.11, 2: 0.0816, 3: -0.11, 4: 0....","{0: 7, 1: 4, 2: 4, 3: 7, 4: 4, 5: 4, 6: 4, 7: ...","{0: 35, 1: 6, 2: 6, 3: 35, 4: 6, 5: 6, 6: 6, 7..."
4,BrC(=C(c1ccccc1)c1ccccc1)c1ccccc1,389.0,"{0: 'Br', 1: 'C', 2: 'C', 3: 'C', 4: 'C', 5: '...","{0: (2.0146539884324786, 2.4494258938086277, -...","{0: -0.11, 1: 0.0816, 2: -0.0568, 3: 0.0284, 4...","{0: 7, 1: 4, 2: 4, 3: 4, 4: 4, 5: 4, 6: 4, 7: ...","{0: 35, 1: 6, 2: 6, 3: 6, 4: 6, 5: 6, 6: 6, 7:..."


In [4]:
!tensorboard --logdir lightning_logs

TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.20.0 at http://localhost:6007/ (Press CTRL+C to quit)
^C


In [9]:
def seed_all(seed: int = 3407):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [10]:
def parse_atoms(s: str) -> List[str]:
    d = ast.literal_eval(s)
    keys = sorted(d.keys(), key=lambda k: int(k))
    return [d[k] for k in keys]

In [11]:
def parse_coords(s) -> np.ndarray:
    """
    Parse the coords column into an (N, 3) float32 array ordered by index.

    Accepts either:
      - a dict-like STRING, e.g. "{0: (0.1, 0.2, 0.3), 1: (0.4, 0.5, 0.6)}"
      - an already-parsed dict, e.g. {0: (x,y,z), "1": [x,y,z], ...}
      - values per index may be tuple/list/ndarray OR a string like "(x,y,z)"
    """
    # If it's a string, parse to a dict first; otherwise assume it's already a dict-like
    d = ast.literal_eval(s) if isinstance(s, str) else s
    if not isinstance(d, dict):
        raise ValueError(f"coords must be a dict-like mapping index -> (x,y,z), got: {type(d)}")

    # sort by integer index (keys may be str or int)
    keys = sorted(d.keys(), key=lambda k: int(k))
    coords = []
    for k in keys:
        v = d[k]
        # If the value itself is a string like "(x,y,z)", parse it
        if isinstance(v, str):
            v = ast.literal_eval(v)
        # Accept tuple/list/ndarray; enforce length 3 and cast to float
        if isinstance(v, (list, tuple, np.ndarray)) and len(v) == 3:
            x, y, z = float(v[0]), float(v[1]), float(v[2])
            coords.append((x, y, z))
        else:
            raise ValueError(f"Invalid coord at index {k}: {v!r} (expected length-3 sequence)")
    arr = np.asarray(coords, dtype=np.float32)
    if arr.ndim != 2 or arr.shape[1] != 3:
        raise ValueError(f"Expected coords of shape (N,3), got {arr.shape}")
    return arr

In [12]:
def parse_charges_optional(df_row) -> Optional[np.ndarray]:
    cand = [c for c in df_row.index if str(c).lower() in
            {"partial_charges", "charges", "atom_charges", "charges_dict"}]
    if not cand:
        return None
    col = cand[0]
    raw = df_row[col]
    try:
        d = ast.literal_eval(raw)
    except Exception:
        return None
    keys = sorted(d.keys(), key=lambda k: int(k))
    vals = []
    for k in keys:
        v = d[k]
        try:
            vals.append(float(v))
        except Exception:
            vals.append(0.0)
    return np.asarray(vals, dtype=np.float32)

In [13]:
AROMATIC_LOWERCASE = {"c", "n", "o", "p", "s", "se", "as", "b"}
ATOM_FALLBACK_ORDER = [
    "H","C","N","O","F","P","S","Si","Cl","Br","I","B","Na","K","Ca",
    "c","n","o","p","s","se","as","b","cl","br",
    "<UNK>"
]

In [14]:
@dataclass
class MolRecord:
    smiles: str
    coords: np.ndarray
    atom_tokens: List[str]
    charges: Optional[np.ndarray]
    y: float

In [15]:
class StandardScalerTorch(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("mean", torch.zeros(1, 1))
        self.register_buffer("std", torch.ones(1, 1))

    @torch.no_grad()
    def fit(self, y: torch.Tensor):
        m = y.mean(dim=0, keepdim=True)
        s = y.std(dim=0, unbiased=False, keepdim=True).clamp(min=1e-8)
        self.mean.copy_(m)
        self.std.copy_(s)

    def transform(self, y: torch.Tensor) -> torch.Tensor:
        return (y - self.mean.to(dtype=y.dtype)) / self.std.to(dtype=y.dtype)

    def inverse_transform(self, y_scaled: torch.Tensor) -> torch.Tensor:
        return y_scaled * self.std.to(dtype=y_scaled.dtype) + self.mean.to(dtype=y_scaled.dtype)

In [22]:
class MolCoordsDataset(Dataset):
    def __init__(self, csv_path: str, split: str, split_fracs=(0.0, 0.0, 1.0), seed=42,
                 atom_vocab: Optional[List[str]] = None):
        import pandas as pd
        df = pd.read_csv(csv_path)

        required_cols = {"SMILES", "Melting Point(K)", "coords", "atoms"}
        missing = required_cols - set(df.columns)
        if missing:
            raise ValueError(f"CSV missing required columns: {missing}")

        rng = np.random.default_rng(seed)
        indices = np.arange(len(df))
        rng.shuffle(indices)

        n = len(df)
        n_train = int(split_fracs[0] * n)
        n_val = int(split_fracs[1] * n)
        n_test = n - n_train - n_val

        train_idx = indices[:n_train]
        val_idx = indices[n_train:n_train+n_val]
        test_idx = indices[n_train+n_val:]

        if split == "train":
            use_idx = train_idx
        elif split == "val":
            use_idx = val_idx
        elif split == "test":
            use_idx = test_idx
        else:
            raise ValueError("split must be one of: train, val, test")

        # case-sensitive vocab
        if atom_vocab is None:
            train_tokens = set()
            for i in train_idx:
                toks = parse_atoms(df.iloc[i]["atoms"])
                train_tokens.update(toks)
            vocab = [t for t in ATOM_FALLBACK_ORDER if t in train_tokens]
            for t in sorted(train_tokens):
                if t not in vocab:
                    vocab.append(t)
            if "<UNK>" not in vocab:
                vocab.append("<UNK>")
            self.atom_vocab = vocab
        else:
            self.atom_vocab = list(atom_vocab)
            if "<UNK>" not in self.atom_vocab:
                self.atom_vocab.append("<UNK>")

        self.records: List[MolRecord] = []
        for i in use_idx:
            row = df.iloc[i]
            coords = parse_coords(row["coords"])
            toks = parse_atoms(row["atoms"])  # case preserved
            smiles = str(row["SMILES"])
            if len(toks) != len(coords):
                raise ValueError(f"Inconsistent atoms/coords length at row {i}: {len(toks)} vs {len(coords)}")
            charges = parse_charges_optional(row)
            if charges is not None and len(charges) != len(coords):
                charges = None
            y = float(row["Melting Point(K)"])
            if not np.isfinite(y) or not np.isfinite(coords).all():
                continue
            self.records.append(MolRecord(smiles=smiles, coords=coords, atom_tokens=toks, charges=charges, y=y))

        self._targets = np.array([r.y for r in self.records], dtype=np.float32)

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        rec = self.records[idx]
        vocab = self.atom_vocab
        atom_ids = [vocab.index(t) if t in vocab else vocab.index("<UNK>") for t in rec.atom_tokens]
        atom_ids = torch.tensor(atom_ids, dtype=torch.long)
        coords = torch.from_numpy(rec.coords).float()
        y = torch.tensor([rec.y], dtype=torch.float32)

        aromatic = torch.tensor([1.0 if (t.lower() in AROMATIC_LOWERCASE and t == t.lower()) else 0.0
                                 for t in rec.atom_tokens], dtype=torch.float32).unsqueeze(-1)
        if rec.charges is not None:
            charges = torch.from_numpy(rec.charges).view(-1, 1).float()
        else:
            charges = torch.zeros((len(rec.atom_tokens), 1), dtype=torch.float32)

        return {
            "smiles": rec.smiles,
            "atom_ids": atom_ids, "coords": coords, "charges": charges, "aromatic": aromatic,
            "y": y, "num_nodes": torch.tensor(len(rec.atom_tokens), dtype=torch.long)
        }

In [17]:
def collate_graphs(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
    B = len(batch)
    num_nodes = torch.tensor([b["num_nodes"] for b in batch], dtype=torch.long)
    Nmax = int(num_nodes.max().item())

    atom_ids = torch.zeros((B, Nmax), dtype=torch.long)
    coords   = torch.zeros((B, Nmax, 3), dtype=torch.float32)
    charges  = torch.zeros((B, Nmax, 1), dtype=torch.float32)
    aromatic = torch.zeros((B, Nmax, 1), dtype=torch.float32)
    mask     = torch.zeros((B, Nmax), dtype=torch.bool)
    y        = torch.zeros((B, 1), dtype=torch.float32)
    smiles   = []

    for i, b in enumerate(batch):
        n = int(b["num_nodes"])
        atom_ids[i, :n] = b["atom_ids"]
        coords[i, :n]   = b["coords"]
        charges[i, :n]  = b["charges"]
        aromatic[i, :n] = b["aromatic"]
        mask[i, :n]     = True
        y[i]            = b["y"]
        smiles.append(b["smiles"])

    return {"smiles": smiles, "atom_ids": atom_ids, "coords": coords, "charges": charges, "aromatic": aromatic,
            "mask": mask, "y": y}

In [18]:
class MolDataModule(pl.LightningDataModule):
    def __init__(self, csv_path: str, batch_size: int = 16, num_workers: int = 0,
                 split_fracs=(0.8, 0.1, 0.1), seed: int = 42):
        super().__init__()
        self.csv_path = csv_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.split_fracs = split_fracs
        self.seed = seed
        self.atom_vocab: Optional[List[str]] = None
        self.scaler = StandardScalerTorch()

    def setup(self, stage: Optional[str] = None):
        self.train_set = MolCoordsDataset(self.csv_path, split="train",
                                          split_fracs=self.split_fracs, seed=self.seed)
        self.atom_vocab = self.train_set.atom_vocab
        self.val_set = MolCoordsDataset(self.csv_path, split="val",
                                        split_fracs=self.split_fracs, seed=self.seed,
                                        atom_vocab=self.atom_vocab)
        self.test_set = MolCoordsDataset(self.csv_path, split="test",
                                         split_fracs=self.split_fracs, seed=self.seed,
                                         atom_vocab=self.atom_vocab)
        with torch.no_grad():
            y_train = torch.tensor(self.train_set._targets).unsqueeze(1)
            self.scaler.fit(y_train)

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True,
                          num_workers=self.num_workers, collate_fn=collate_graphs, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False,
                          num_workers=self.num_workers, collate_fn=collate_graphs, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False,
                          num_workers=self.num_workers, collate_fn=collate_graphs, pin_memory=True)

In [19]:
def pairwise_squared_dist(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    B, N, _ = x.shape
    xi = x.unsqueeze(2)  # (B,N,1,3)
    xj = x.unsqueeze(1)  # (B,1,N,3)
    diff = xi - xj
    r2 = (diff * diff).sum(dim=-1)  # (B,N,N)

    mask_i = mask.unsqueeze(2)
    mask_j = mask.unsqueeze(1)
    valid = mask_i & mask_j
    r2 = r2.masked_fill(~valid, float('inf'))

    eye = torch.eye(N, device=x.device, dtype=torch.bool).unsqueeze(0)
    r2 = r2.masked_fill(eye, 0.0)   # self
    return r2

In [28]:
class NodeEncoder(nn.Module):
    def __init__(self, vocab_size: int, model_dim: int, atom_embed_dim: int = 192, aux_embed_dim: int = 64):
        super().__init__()
        self.atom_emb = nn.Embedding(vocab_size, atom_embed_dim)
        self.aux_mlp = nn.Sequential(
            nn.Linear(2, aux_embed_dim),  # [charge, aromatic_flag]
            nn.SiLU(),
            nn.Linear(aux_embed_dim, aux_embed_dim)
        )
        self.proj = nn.Linear(atom_embed_dim + aux_embed_dim, model_dim)

    def forward(self, atom_ids: torch.Tensor, charges: torch.Tensor, aromatic: torch.Tensor) -> torch.Tensor:
        a = self.atom_emb(atom_ids)               # (B,N,atom_embed_dim)
        aux = self.aux_mlp(torch.cat([charges, aromatic], dim=-1))  # (B,N,aux_embed_dim)
        return self.proj(torch.cat([a, aux], dim=-1))               # (B,N,model_dim)

In [29]:
class EGCL(nn.Module):
    def __init__(self, feat_dim: int, m_dim: int = 128, hidden_dim: int = 256, update_coords: bool = False):
        super().__init__()
        self.update_coords = update_coords
        self.phi_e = nn.Sequential(
            nn.Linear(feat_dim * 2 + 1, hidden_dim), nn.SiLU(),
            nn.Linear(hidden_dim, m_dim), nn.SiLU()
        )
        self.phi_x = nn.Sequential(
            nn.Linear(m_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, 1)
        )
        self.phi_h = nn.Sequential(
            nn.Linear(feat_dim + m_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, feat_dim)
        )
        self.h_norm = nn.LayerNorm(feat_dim)

    def forward(self, h: torch.Tensor, x: torch.Tensor, mask: torch.Tensor):
        B, N, F = h.shape
        r2 = pairwise_squared_dist(x, mask)
        adj = torch.isfinite(r2) & (~torch.eye(N, device=x.device, dtype=torch.bool).unsqueeze(0))

        r2_valid = torch.where(torch.isfinite(r2), r2, torch.zeros_like(r2))
        hi = h.unsqueeze(2).expand(B, N, N, F)
        hj = h.unsqueeze(1).expand(B, N, N, F)
        e_in = torch.cat([hi, hj, r2_valid.unsqueeze(-1)], dim=-1)
        m_ij = self.phi_e(e_in) * adj.unsqueeze(-1)

        m_i = m_ij.sum(dim=2)

        if self.update_coords:
            s_ij = self.phi_x(m_ij).squeeze(-1)
            s_ij = torch.tanh(s_ij) * 0.1              # tamed step size
            s_ij = s_ij * adj.float()
            diff = x.unsqueeze(2) - x.unsqueeze(1)
            deg = adj.sum(dim=-1).clamp(min=1).unsqueeze(-1)
            delta_x = (diff * s_ij.unsqueeze(-1)).sum(dim=2) / deg
            x = x + delta_x

        dh = self.phi_h(torch.cat([h, m_i], dim=-1))
        h = self.h_norm(h + dh)
        return h, x

In [30]:
class EGNNEmbedding(nn.Module):
    def __init__(self, feat_dim: int, depth: int = 2, m_dim: int = 128, hidden_dim: int = 256, update_coords: bool = False):
        super().__init__()
        self.layers = nn.ModuleList([EGCL(feat_dim=feat_dim, m_dim=m_dim, hidden_dim=hidden_dim, update_coords=update_coords)
                                     for _ in range(depth)])

    def forward(self, h: torch.Tensor, x: torch.Tensor, mask: torch.Tensor):
        for layer in self.layers:
            h, x = layer(h, x, mask)
        return h, x

In [31]:
class GraphTransformerLayer(nn.Module):
    def __init__(self, dim: int, heads: int = 8, attn_dropout: float = 0.1, resid_dropout: float = 0.1,
                 ff_mult: int = 4):
        super().__init__()
        assert dim % heads == 0
        self.dim = dim
        self.heads = heads
        self.head_dim = dim // heads
        self.scale = self.head_dim ** -0.5

        self.to_q = nn.Linear(dim, dim)
        self.to_k = nn.Linear(dim, dim)
        self.to_v = nn.Linear(dim, dim)
        self.to_out = nn.Linear(dim, dim)

        self.attn_drop = nn.Dropout(attn_dropout)
        self.resid_drop = nn.Dropout(resid_dropout)

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        self.ff = nn.Sequential(
            nn.Linear(dim, ff_mult * dim),
            nn.GELU(),
            nn.Dropout(resid_dropout),
            nn.Linear(ff_mult * dim, dim)
        )

    def forward(self, h: torch.Tensor, coords: torch.Tensor, mask: torch.Tensor):
        B, N, C = h.shape
        eps = 1e-8

        # --- inverse-square, row-normalized prior ---
        r2 = pairwise_squared_dist(coords, mask)                         # (B,N,N)
        eye = torch.eye(N, device=h.device, dtype=torch.bool).unsqueeze(0)
        valid = torch.isfinite(r2)

        # per-row mean of finite off-diagonal r^2
        off = valid & (~eye)
        r2_off = r2.masked_fill(~off, 0.0)
        cnt_row = off.sum(dim=-1, keepdim=True).clamp_min(1)             # (B,N,1)
        r2_row_mean = (r2_off.sum(dim=-1, keepdim=True) / cnt_row).clamp_min(1e-4)  # (B,N,1)

        # avoid self singularity using row mean
        r2_eff = torch.where(eye, r2_row_mean.expand_as(r2), r2)
        r2_eff = torch.nan_to_num(r2_eff, posinf=1e8).clamp_(min=1e-6, max=1e8)

        w_raw = torch.where(valid, 1.0 / (r2_eff + eps), torch.zeros_like(r2_eff))   # (B,N,N)
        row_sum = w_raw.sum(dim=-1, keepdim=True).clamp_min(1e-12)                   # (B,N,1)
        w = w_raw / row_sum                                                          # row-normalized
        log_w = torch.log(w.clamp_min(1e-12)).unsqueeze(1)                           # (B,1,N,N)

        # --- attention ---
        x = self.norm1(h)
        q = self.to_q(x).view(B, N, self.heads, self.head_dim).transpose(1, 2)  # (B,H,N,D)
        k = self.to_k(x).view(B, N, self.heads, self.head_dim).transpose(1, 2)  # (B,H,N,D)
        v = self.to_v(x).view(B, N, self.heads, self.head_dim).transpose(1, 2)  # (B,H,N,D)

        attn_logits = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn_logits = attn_logits + log_w.expand(B, self.heads, N, N)

        # Robust masking
        mask_q = mask.unsqueeze(1).unsqueeze(-1).expand(B, self.heads, N, 1)
        mask_k = mask.unsqueeze(1).unsqueeze(2).expand(B, self.heads, 1, N)
        pair_valid = valid.unsqueeze(1) & mask_q & mask_k
        attn_logits = attn_logits.masked_fill(~pair_valid, -1e9)
        attn_logits = attn_logits.masked_fill(~mask_q, 0.0)

        attn = torch.softmax(attn_logits, dim=-1)
        attn = torch.nan_to_num(attn, nan=0.0)     # final guard
        attn = self.attn_drop(attn)

        out = torch.matmul(attn, v)                        # (B,H,N,D)
        out = out * mask_q.float()
        out = out.transpose(1, 2).contiguous().view(B, N, C)
        out = self.to_out(out)
        h = h + self.resid_drop(out)

        y = self.norm2(h)
        y = self.ff(y)
        h = h + self.resid_drop(y)
        return h

In [32]:
class GraphTransformer(nn.Module):
    def __init__(self, vocab_size: int, dim: int = 256, depth: int = 6, heads: int = 8,
                 attn_dropout: float = 0.1, resid_dropout: float = 0.1, ff_mult: int = 4,
                 egnn_depth: int = 2, egnn_update_coords: bool = False):
        super().__init__()
        self.encoder = NodeEncoder(vocab_size=vocab_size, model_dim=dim,
                                   atom_embed_dim=192, aux_embed_dim=64)
        self.egnn = EGNNEmbedding(feat_dim=dim, depth=egnn_depth, m_dim=128, hidden_dim=dim,
                                  update_coords=egnn_update_coords)
        self.layers = nn.ModuleList([
            GraphTransformerLayer(dim=dim, heads=heads, attn_dropout=attn_dropout,
                                  resid_dropout=resid_dropout, ff_mult=ff_mult)
            for _ in range(depth)
        ])
        self.mlp_out = nn.Sequential(
            nn.Linear(dim, dim),
            nn.SiLU(),
            nn.Linear(dim, 1)
        )

    def forward(self, atom_ids, charges, aromatic, coords, mask):
        h = self.encoder(atom_ids, charges, aromatic)
        h, coords = self.egnn(h, coords, mask)   # EGNN embedding
        for layer in self.layers:
            h = layer(h, coords, mask)
        h_sum = (h * mask.unsqueeze(-1)).sum(dim=1)
        denom = mask.sum(dim=1, keepdim=True).clamp(min=1)
        h_pool = h_sum / denom
        return self.mlp_out(h_pool)

In [33]:
class GraphTransformerPredictor(pl.LightningModule):
    def __init__(self,
                 vocab_size: int, dim: int = 256, depth: int = 6, heads: int = 8,
                 attn_dropout: float = 0.1, resid_dropout: float = 0.1, ff_mult: int = 4,
                 lr: float = 3e-4, weight_decay: float = 1e-6,
                 t_max: int = 100,                      # used by cosine if per-epoch
                 scaler: Optional[StandardScalerTorch] = None,
                 egnn_depth: int = 2, egnn_update_coords: bool = False,
                 # ==== NEW: scheduler & loss config ====
                 scheduler: str = "cosine",             # ["cosine", "onecycle"]
                 warmup_ratio: float = 0.1,             # for cosine
                 steps_per_epoch: Optional[int] = None, # required for onecycle/per-step cosine
                 max_epochs: Optional[int] = None,      # required for onecycle/per-step cosine
                 loss_name: str = "focal_huber",              # ["mse","mae","huber","focal_huber","logcosh"]
                 huber_delta: float = 1.0,              # in *scaled* units (≈1 std)
                 focal_gamma: float = 0.8,              # 0.5~1.0 recommended
                 focal_wmax: float = 5.0):              # cap for focal weights
        super().__init__()
        self.model = GraphTransformer(vocab_size=vocab_size, dim=dim, depth=depth, heads=heads,
                                      attn_dropout=attn_dropout, resid_dropout=resid_dropout, ff_mult=ff_mult,
                                      egnn_depth=egnn_depth, egnn_update_coords=egnn_update_coords)
        self.lr = lr
        self.weight_decay = weight_decay
        self.t_max = t_max
        self.scaler = scaler or StandardScalerTorch()
        # store
        self.scheduler_name = scheduler
        self.warmup_ratio = warmup_ratio
        self.steps_per_epoch = steps_per_epoch
        self.max_epochs_cfg = max_epochs
        self.loss_name = loss_name
        self.huber_delta = huber_delta
        self.focal_gamma = focal_gamma
        self.focal_wmax = focal_wmax

    # ---------- losses (computed on *scaled* targets) ----------
    def _huber_per_sample(self, err_abs: torch.Tensor, delta: float) -> torch.Tensor:
        # err_abs: (B,1)
        small = 0.5 * (err_abs ** 2)
        large = delta * (err_abs - 0.5 * delta)
        return torch.where(err_abs <= delta, small, large)

    def _loss(self, y_hat_scaled: torch.Tensor, y_scaled: torch.Tensor) -> torch.Tensor:
        err = y_hat_scaled - y_scaled
        err_abs = err.abs()
        if self.loss_name == "mse":
            return (err ** 2).mean()
        if self.loss_name == "mae":
            return err_abs.mean()
        if self.loss_name == "logcosh":
            return torch.log(torch.cosh(err)).mean()
        if self.loss_name == "huber":
            per = self._huber_per_sample(err_abs, self.huber_delta)
            return per.mean()
        if self.loss_name == "focal_huber":
            per = self._huber_per_sample(err_abs, self.huber_delta)       # bounded gradient
            # batch-median normalized focal weight
            med = err_abs.detach().median().clamp_min(1e-8)
            w = ((err_abs / med) ** self.focal_gamma).clamp_max(self.focal_wmax)
            return (w * per).mean()
        raise ValueError(f"Unknown loss_name: {self.loss_name}")

    def forward(self, batch):
        return self.model(batch["atom_ids"], batch["charges"], batch["aromatic"],
                          batch["coords"], batch["mask"])

    def _compute_loss_and_metrics(self, batch, stage: str):
        y = batch["y"]
        y_scaled = self.scaler.transform(y)
        y_hat_scaled = self.forward(batch)

        loss = self._loss(y_hat_scaled, y_scaled)
        with torch.no_grad():
            y_hat = self.scaler.inverse_transform(y_hat_scaled)
            mae = F.l1_loss(y_hat, y)
            rmse = torch.sqrt(F.mse_loss(y_hat, y))
        self.log(f"{stage}_loss", loss, prog_bar=True, on_epoch=True, batch_size=y.shape[0])
        self.log(f"{stage}_mae_K", mae,  prog_bar=True, on_epoch=True, batch_size=y.shape[0])
        self.log(f"{stage}_rmse_K", rmse, prog_bar=True, on_epoch=True, batch_size=y.shape[0])
        return loss

    def training_step(self, batch, batch_idx):
        return self._compute_loss_and_metrics(batch, "train")

    def validation_step(self, batch, batch_idx):
        self._compute_loss_and_metrics(batch, "val")

    def test_step(self, batch, batch_idx):
        self._compute_loss_and_metrics(batch, "test")

    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        # ---- Scheduler selection ----
        if self.scheduler_name == "cosine":
            # per-step linear warmup then cosine
            if (self.steps_per_epoch is not None) and (self.max_epochs_cfg is not None):
                total_steps = int(self.steps_per_epoch * self.max_epochs_cfg)
                warmup_steps = max(1, int(self.warmup_ratio * total_steps))
                from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
                warm = LinearLR(opt, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
                cos  = CosineAnnealingLR(opt, T_max=max(1, total_steps - warmup_steps))
                sched = SequentialLR(opt, schedulers=[warm, cos], milestones=[warmup_steps])
                return {
                    "optimizer": opt,
                    "lr_scheduler": {"scheduler": sched, "interval": "step"}
                }
            else:
                # epoch-level fallback (still fine with early stopping)
                cos = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=self.t_max)
                return {"optimizer": opt, "lr_scheduler": cos}

        elif self.scheduler_name == "onecycle":
            assert self.steps_per_epoch is not None and self.max_epochs_cfg is not None, \
                "OneCycleLR requires steps_per_epoch and max_epochs."
            total_steps = int(self.steps_per_epoch * self.max_epochs_cfg)
            # set optimizer initial lr lower than max_lr via div_factor
            div_factor = 25.0
            final_div = 1e4
            for g in opt.param_groups:
                g["lr"] = self.lr / div_factor
            one = torch.optim.lr_scheduler.OneCycleLR(
                opt, max_lr=self.lr, total_steps=total_steps,
                pct_start=0.3, anneal_strategy="cos",
                div_factor=div_factor, final_div_factor=final_div
            )
            return {
                "optimizer": opt,
                "lr_scheduler": {"scheduler": one, "interval": "step"}
            }

        else:
            return {"optimizer": opt}

In [34]:
import matplotlib.pyplot as plt
import argparse

In [35]:
from pathlib import Path

In [36]:
out_dir = Path("")

In [25]:
seed_all(3407)

dm = MolDataModule(csv_path="processed_data.csv", batch_size=16, num_workers=2, seed=3407)
dm.setup()
steps_per_epoch = len(dm.train_dataloader())

vocab_size = len(dm.train_set.atom_vocab)

model = GraphTransformerPredictor(
    vocab_size=vocab_size,
    dim=256, depth=6, heads=8,
    lr=1e-4, weight_decay=1e-6, t_max=200,
    scaler=dm.scaler,
    egnn_depth=2, egnn_update_coords=True,
    scheduler="cosine", warmup_ratio=0.1, steps_per_epoch=steps_per_epoch, max_epochs=200,
    loss_name="focal_huber", huber_delta=1.0, focal_gamma=0.8, focal_wmax=4.0
)

ckpt_cb = pl.callbacks.ModelCheckpoint(monitor="val_mae_K", mode="min", save_top_k=1,
                                           dirpath=str(out_dir), filename="best-{epoch:03d}-{val_mae_K:.2f}")
early = pl.callbacks.EarlyStopping(monitor="val_mae_K", mode="min", patience=15, min_delta=0.01)
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="epoch")

trainer = pl.Trainer(
    max_epochs=200,
    accelerator="auto", devices=1,
    callbacks=[ckpt_cb, early, lr_monitor],
    deterministic=True,
    gradient_clip_val=1.0, precision="32-true"
)

trainer.fit(model, datamodule=dm, ckpt_path = "best-epoch=070-val_mae_K=28.03.ckpt")

best_ckpt = ckpt_cb.best_model_path if ckpt_cb.best_model_path else None
if best_ckpt:
    model = GraphTransformerPredictor.load_from_checkpoint(
        best_ckpt, vocab_size=vocab_size, dim=256, depth=6, heads=8,
        lr=1e-4, weight_decay=1e-6, t_max=200, scaler=dm.scaler,
        egnn_depth=2, egnn_update_coords=True,
        scheduler="cosine", warmup_ratio=0.1, steps_per_epoch=steps_per_epoch, max_epochs=200,
        loss_name="focal_huber", huber_delta=1.0, focal_gamma=0.8, focal_wmax=4.0
    )

# Save weights
weights_path = out_dir / "model_weights.pt"
torch.save(model.state_dict(), weights_path)
print(f"[Saved] {weights_path}")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/venv/main/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA H200') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/venv/ma

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[Saved] model_weights.pt


In [47]:
model = GraphTransformerPredictor.load_from_checkpoint(
        "best-epoch=070-val_mae_K=28.03.ckpt", vocab_size=vocab_size, dim=256, depth=6, heads=8,
        lr=1e-4, weight_decay=1e-6, t_max=200, scaler=dm.scaler,
        egnn_depth=2, egnn_update_coords=True,
        scheduler="cosine", warmup_ratio=0.1, steps_per_epoch=steps_per_epoch, max_epochs=200,
        loss_name="focal_huber", huber_delta=1.0, focal_gamma=0.8, focal_wmax=4.0
    )

In [45]:
df = pd.read_csv("dataset.csv")
df["Melting Point(K)"] += 273.15
df.to_csv("dataset.csv", index = False)

In [46]:
dm = MolDataModule(csv_path="dataset.csv", batch_size=16, num_workers=2, seed=3407)
dm.setup()
steps_per_epoch = len(dm.train_dataloader())

vocab_size = len(dm.train_set.atom_vocab)

In [48]:
def to_device(batch, device):
    if torch.is_tensor(batch):
        return batch.to(device, non_blocking=True)
    if isinstance(batch, dict):
        return {k: to_device(v, device) for k, v in batch.items()}
    if isinstance(batch, (list, tuple)):
        return type(batch)(to_device(v, device) for v in batch)
    return batch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Evaluate + print metrics + top-5 worst and plot
preds, targets, smiles_all = [], [], []
model.eval()
with torch.no_grad():
    for batch in dm.test_dataloader():
        batch = to_device(batch, device)
        y_hat_scaled = model(batch)
        y_hat = dm.scaler.inverse_transform(y_hat_scaled)
        preds.append(y_hat.squeeze(-1).cpu().numpy())
        targets.append(batch["y"].squeeze(-1).cpu().numpy())
        smiles_all += batch["smiles"]

In [49]:
import numpy as np
preds = np.concatenate(preds, axis=0) if preds else np.array([])
targets = np.concatenate(targets, axis=0) if targets else np.array([])

In [50]:
if len(preds) > 0:
    err = preds - targets
    mae = float(np.mean(np.abs(err)))
    rmse = float(np.sqrt(np.mean(err**2)))
    y_mean = float(np.mean(targets))
    ss_res = float(np.sum((targets - preds)**2))
    ss_tot = float(np.sum((targets - y_mean)**2)) if len(targets) > 1 else 0.0
    r2 = float(1.0 - ss_res / ss_tot) if ss_tot > 0 else float("nan")
    print(f"[Test] MAE = {mae:.3f} K | RMSE = {rmse:.3f} K | R^2 = {r2:.4f}")

    order = np.argsort(-np.abs(err))[:min(5, len(err))]
    print("\nTop-5 largest absolute errors (SMILES | Actual K | Predicted K | AbsErr):")
    for idx in order:
        print(f"{smiles_all[idx]} | {targets[idx]:.2f} | {preds[idx]:.2f} | {abs(err[idx]):.2f}")
else:
    print("[Test] No predictions to evaluate.")

import matplotlib.pyplot as plt
# (re)compute R^2 here in case it's not in scope
if len(preds) > 1:
    y_mean = float(np.mean(targets))
    ss_res = float(np.sum((targets - preds) ** 2))
    ss_tot = float(np.sum((targets - y_mean) ** 2))
    r2_plot = float(1.0 - ss_res / ss_tot) if ss_tot > 0 else float("nan")
else:
    r2_plot = float("nan")

# Build a concise model title from args
model_title = (
    f"Graph Transformer + EGNN Embedding"
)

fig, ax = plt.subplots(figsize=(6, 6))
ax.scatter(targets, preds, s=5, alpha=0.8, edgecolors="none")  # smaller points for large N

if len(targets) > 0:
    lo = float(np.min([targets.min(), preds.min()]))
    hi = float(np.max([targets.max(), preds.max()]))
    ax.plot([lo, hi], [lo, hi], ls="--", lw=1.0, color="k")

ax.set_xlabel("Actual Melting Point (K)")
ax.set_ylabel("Predicted Melting Point (K)")
ax.set_title(model_title)

# Add R^2 (TeX mathtext; uppercase R)
ax.text(0.05, 0.95, rf"$R^2={r2_plot:.6f}$",
        transform=ax.transAxes, ha="left", va="top")

fig_path = out_dir / "pred_vs_actual.png"
fig.savefig(fig_path, dpi=180, bbox_inches="tight")
plt.close(fig)
print(f"[Saved] {fig_path}")

[Test] MAE = 24.840 K | RMSE = 35.222 K | R^2 = 0.8651

Top-5 largest absolute errors (SMILES | Actual K | Predicted K | AbsErr):
COC(=O)C1CCCN1C(=O)COC(=O)CC2=CC(C=CC=C2)NC3=C(C=CC=CC3Cl)Cl | 722.15 | 385.09 | 337.06
CS(C)(C)I | 490.65 | 235.55 | 255.10
CCOCCCO | 433.65 | 181.06 | 252.59
C=C(C)/C=CC | 348.65 | 128.62 | 220.03
O=C(O)C[C@H](O)C[C@H](O)CC[C@H]2[C@@H](C)C=CC1=C[C@@H](O)C[C@H](OC(=O)[C@@H](C)CC)[C@@H]12 | 599.15 | 393.03 | 206.12
[Saved] pred_vs_actual.png


In [4]:
def load_base_module(base_script: str):
    """
    Try to import graph_transformer_fc_egnn_stable; if not on PYTHONPATH,
    load it directly from the file path provided.
    """
    try:
        import graph_transformer_fc_egnn_stable as base
        return base
    except Exception:
        pass
 
    base_path = Path(base_script)
    if not base_path.exists():
        raise FileNotFoundError(f"Base script not found: {base_path}")
    spec = importlib.util.spec_from_file_location("gt_egnn_base", str(base_path))
    base = importlib.util.module_from_spec(spec)
    assert spec.loader is not None
    spec.loader.exec_module(base)
    return base

In [5]:
# -------------- Fine-tune LightningModule --------------
 
class FineTunePredictor(pl.LightningModule):
    """
    Subclass the base GraphTransformerPredictor to:
    - use robust losses (Huber / Focal-Huber) on scaled targets
    - provide per-step warmup+cosine or OneCycleLR
    - support parameter freezing with custom param groups
    """
    def __init__(self, base_module, *,  # module loaded via load_base_module
                 vocab_size: int,
                 dim: int = 256, depth: int = 6, heads: int = 8,
                 lr: float = 1e-4, weight_decay: float = 1e-6,
                 max_epochs: int = 30, steps_per_epoch: int = 100,
                 scheduler: str = "cosine", warmup_ratio: float = 0.1,
                 loss: str = "huber", huber_delta: float = 1.0,
                 focal_gamma: float = 0.8, focal_wmax: float = 4.0,
                 egnn_depth: int = 2, egnn_update_coords: bool = False,
                 scaler=None,
                 freeze: str = "encoder_egnn", bottomk: int = 0,
                 backbone_lr_mult: float = 1.0):
        super().__init__()
        # compose the base model
        self.model = base_module.GraphTransformer(
            vocab_size=vocab_size, dim=dim, depth=depth, heads=heads,
            attn_dropout=0.1, resid_dropout=0.1, ff_mult=4,
            egnn_depth=egnn_depth, egnn_update_coords=egnn_update_coords
        )
        self.scaler = scaler or base_module.StandardScalerTorch()
        self.lr = lr
        self.weight_decay = weight_decay
        self.max_epochs_cfg = max_epochs
        self.steps_per_epoch = steps_per_epoch
        self.scheduler_name = scheduler
        self.warmup_ratio = warmup_ratio
        self.loss_name = loss
        self.huber_delta = huber_delta
        self.focal_gamma = focal_gamma
        self.focal_wmax = focal_wmax
        self.backbone_lr_mult = backbone_lr_mult
 
        # FREEZE according to policy
        self._apply_freeze_policy(freeze=freeze, bottomk=bottomk)
 
        # build param groups (head vs backbone if needed)
        self._param_groups = self._build_param_groups()
 
        # for nice plot title
        self.title = f"FT: FC GTransformer+EGNN (dim={dim}, depth={depth}, heads={heads}" \
                     f"{', coord↑' if egnn_update_coords else ''})"
 
    # ---- freezing logic ----
    def _apply_freeze_policy(self, freeze: str, bottomk: int):
        """
        freeze:
          - 'none'          : train all
          - 'egnn_only'     : freeze EGNN; train encoder + transformer + head
          - 'encoder_egnn'  : freeze node encoder + EGNN
          - 'bottomk'       : freeze first bottomk transformer layers (plus you can combine by setting bottomk>0)
        """
        # helpers
        enc = self.model.encoder
        egnn = self.model.egnn
        blocks = list(self.model.layers)
 
        def set_requires_grad(module, flag: bool):
            for p in module.parameters():
                p.requires_grad = flag
 
        if freeze == "egnn_only":
            set_requires_grad(egnn, False)
        elif freeze == "encoder_egnn":
            set_requires_grad(enc, False)
            set_requires_grad(egnn, False)
        elif freeze == "none":
            pass
        else:
            # allow 'bottomk' or any string; we'll treat below
            pass
 
        if bottomk and bottomk > 0:
            for i in range(min(bottomk, len(blocks))):
                set_requires_grad(blocks[i], False)
 
    def _build_param_groups(self):
        head_params = list(self.model.mlp_out.parameters())
        head_ids = set(id(p) for p in head_params)
 
        backbone = [p for p in self.model.parameters() if id(p) not in head_ids and p.requires_grad]
        head = [p for p in head_params if p.requires_grad]
 
        # If everything was frozen except head, this still works
        groups = []
        if backbone:
            groups.append({"params": backbone, "lr": self.lr * self.backbone_lr_mult})
        if head:
            groups.append({"params": head, "lr": self.lr})
        return groups if groups else [{"params": [p for p in self.model.parameters() if p.requires_grad], "lr": self.lr}]
 
    # ---- robust losses on scaled targets ----
    def _huber_per_sample(self, err_abs: torch.Tensor, delta: float) -> torch.Tensor:
        small = 0.5 * (err_abs ** 2)
        large = delta * (err_abs - 0.5 * delta)
        return torch.where(err_abs <= delta, small, large)
 
    def _loss(self, y_hat_scaled: torch.Tensor, y_scaled: torch.Tensor) -> torch.Tensor:
        err = y_hat_scaled - y_scaled
        err_abs = err.abs()
        if self.loss_name == "mse":
            return (err ** 2).mean()
        if self.loss_name == "mae":
            return err_abs.mean()
        if self.loss_name == "logcosh":
            return torch.log(torch.cosh(err)).mean()
        if self.loss_name == "huber":
            per = self._huber_per_sample(err_abs, self.huber_delta)
            return per.mean()
        if self.loss_name == "focal_huber":
            per = self._huber_per_sample(err_abs, self.huber_delta)
            med = err_abs.detach().median().clamp_min(1e-8)
            w = ((err_abs / med) ** self.focal_gamma).clamp_max(self.focal_wmax)
            return (w * per).mean()
        raise ValueError(f"Unknown loss: {self.loss_name}")
 
    # ---- Lightning hooks ----
    def forward(self, batch):
        return self.model(batch["atom_ids"], batch["charges"], batch["aromatic"],
                          batch["coords"], batch["mask"])
 
    def _compute_loss_and_metrics(self, batch, stage: str):
        y = batch["y"]
        y_scaled = self.scaler.transform(y)
        y_hat_scaled = self.forward(batch)
 
        loss = self._loss(y_hat_scaled, y_scaled)
        with torch.no_grad():
            y_hat = self.scaler.inverse_transform(y_hat_scaled)
            mae = F.l1_loss(y_hat, y)
            rmse = torch.sqrt(F.mse_loss(y_hat, y))
        self.log(f"{stage}_loss", loss, prog_bar=True, on_epoch=True, batch_size=y.shape[0])
        self.log(f"{stage}_mae_K", mae,  prog_bar=True, on_epoch=True, batch_size=y.shape[0])
        self.log(f"{stage}_rmse_K", rmse, prog_bar=True, on_epoch=True, batch_size=y.shape[0])
        return loss
 
    def training_step(self, batch, batch_idx):
        return self._compute_loss_and_metrics(batch, "train")
 
    def validation_step(self, batch, batch_idx):
        self._compute_loss_and_metrics(batch, "val")
 
    def test_step(self, batch, batch_idx):
        self._compute_loss_and_metrics(batch, "test")
 
    def configure_optimizers(self):
        opt = torch.optim.AdamW(self._param_groups, weight_decay=self.weight_decay)
 
        if self.scheduler_name == "cosine":
            total_steps = int(self.steps_per_epoch * self.max_epochs_cfg)
            warmup_steps = max(1, int(self.warmup_ratio * total_steps))
            from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
            warm = LinearLR(opt, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
            cos  = CosineAnnealingLR(opt, T_max=max(1, total_steps - warmup_steps))
            sched = SequentialLR(opt, schedulers=[warm, cos], milestones=[warmup_steps])
            return {"optimizer": opt, "lr_scheduler": {"scheduler": sched, "interval": "step"}}
 
        elif self.scheduler_name == "onecycle":
            total_steps = int(self.steps_per_epoch * self.max_epochs_cfg)
            div_factor = 25.0; final_div = 1e4
            one = torch.optim.lr_scheduler.OneCycleLR(
                opt, max_lr=self.lr, total_steps=total_steps,
                pct_start=0.3, anneal_strategy="cos",
                div_factor=div_factor, final_div_factor=final_div
            )
            return {"optimizer": opt, "lr_scheduler": {"scheduler": one, "interval": "step"}}
 
        return {"optimizer": opt}

In [34]:
import importlib
base = load_base_module("jolnon.py")

# Data
dm = base.MolDataModule(csv_path="dataset.csv", batch_size=16,
                        num_workers=2, seed=3407)
dm.setup()
vocab_size = len(dm.train_set.atom_vocab)
steps_per_epoch = len(dm.train_dataloader())

# Build fine-tune module
ft = FineTunePredictor(
    base_module=base,
    vocab_size=vocab_size,
    dim=256, depth=6, heads=8,
    lr=1e-4, weight_decay=1e-6,
    max_epochs=200, steps_per_epoch=steps_per_epoch,
    scheduler="3407", warmup_ratio=0.1,
    loss="huber", huber_delta=1.0,
    focal_gamma=0.8, focal_wmax=4,
    egnn_depth=2, egnn_update_coords=True,
    scaler=dm.scaler,
    freeze="encoder_egnn", bottomk=0,
    backbone_lr_mult=1.0
)

# Load pretrained weights
pre = Path("model_weights.pt")
if not pre.exists():
    raise FileNotFoundError(f"Pretrained weights not found: {pre}")
if pre.suffix == ".ckpt":
    # load lightning checkpoint
    ckpt = torch.load(pre, map_location="cpu")
    state = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
    missing, unexpected = ft.load_state_dict(state, strict=False)
else:
    # assume plain state_dict
    state = torch.load(pre, map_location="cpu")
    missing, unexpected = ft.load_state_dict(state, strict=False)
print(f"[load_state_dict] missing={len(missing)}, unexpected={len(unexpected)}")
if missing:
    print("  missing keys (first few):", missing[:10])
if unexpected:
    print("  unexpected keys (first few):", unexpected[:10])

out_dir = Path("finetune_run2"); out_dir.mkdir(parents=True, exist_ok=True)

# Trainer
ckpt_cb = pl.callbacks.ModelCheckpoint(monitor="val_mae_K", mode="min", save_top_k=1,
                                       dirpath=str(out_dir), filename="ft-best-{epoch:03d}-{val_mae_K:.2f}")
early = pl.callbacks.EarlyStopping(monitor="val_mae_K", mode="min", patience=10, min_delta=0.01)
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step")
trainer = pl.Trainer(
    max_epochs=200,
    accelerator="auto", devices=1,
    callbacks=[ckpt_cb, early, lr_monitor],
    log_every_n_steps=5, deterministic=True,
    gradient_clip_val=1.0, precision="32-true"
)

# Fit
trainer.fit(ft, datamodule=dm)

# Load best
best = ckpt_cb.best_model_path if ckpt_cb.best_model_path else None
if best:
    ft = FineTunePredictor.load_from_checkpoint(
        best, base_module=base,
        vocab_size=vocab_size,
        dim=256, depth=6, heads=8,
        lr=3e-4, weight_decay=1e-6,
        max_epochs=200, steps_per_epoch=steps_per_epoch,
        scheduler="3407", warmup_ratio=0.1,
        loss="huber", huber_delta=1.0,
        focal_gamma=0.8, focal_wmax=4,
        egnn_depth=2, egnn_update_coords=True,
        scaler=dm.scaler,
        freeze="encoder_egnn", bottomk=0,
        backbone_lr_mult=1.0
    )

# Save fine-tuned state_dict
weights_path = out_dir / "finetuned_weights.pt"
torch.save(ft.state_dict(), weights_path)
print(f"[Saved] {weights_path}")

# Test + metrics + Top-5 + plot
preds, targets, smiles_all = [], [], []
ft.eval()
with torch.no_grad():
    for batch in dm.test_dataloader():
        batch = to_device(batch, device)
        y_hat_scaled = ft(batch)
        y_hat = dm.scaler.inverse_transform(y_hat_scaled)
        preds.append(y_hat.squeeze(-1).cpu().numpy())
        targets.append(batch["y"].squeeze(-1).cpu().numpy())
        smiles_all += batch["smiles"]
preds = np.concatenate(preds, axis=0) if preds else np.array([])
targets = np.concatenate(targets, axis=0) if targets else np.array([])

if len(preds) > 0:
    err = preds - targets
    mae = float(np.mean(np.abs(err)))
    rmse = float(np.sqrt(np.mean(err**2)))
    y_mean = float(np.mean(targets))
    ss_res = float(np.sum((targets - preds)**2))
    ss_tot = float(np.sum((targets - y_mean)**2)) if len(targets) > 1 else 0.0
    r2 = float(1.0 - ss_res / ss_tot) if ss_tot > 0 else float("nan")
    print(f"[Test] MAE = {mae:.3f} K | RMSE = {rmse:.3f} K | R^2 = {r2:.4f}")

    order = np.argsort(-np.abs(err))[:min(5, len(err))]
    print("\nTop-5 largest absolute errors (SMILES | Actual K | Predicted K | AbsErr):")
    for idx in order:
        print(f"{smiles_all[idx]} | {targets[idx]:.2f} | {preds[idx]:.2f} | {abs(err[idx]):.2f}")
else:
    print("[Test] No predictions to evaluate.")

# Plot with model title + R^2
fig, ax = plt.subplots(figsize=(6, 6))
ax.scatter(targets, preds, s=4, alpha=0.8, edgecolors="none")  # smaller points for large N
if len(targets) > 0:
    lo = float(np.min([targets.min(), preds.min()]))
    hi = float(np.max([targets.max(), preds.max()]))
    ax.plot([lo, hi], [lo, hi], ls="--", lw=1.0, color="k")
ax.set_xlabel("Actual Melting Point (K)")
ax.set_ylabel("Predicted Melting Point (K)")
ax.set_title(ft.title)
# TeX-style R^2
ax.text(0.05, 0.95, rf"$R^2={r2:.6f}$", transform=ax.transAxes, ha="left", va="top")
fig_path = Path("finetune_run") / "pred_vs_actual_finetune.png"
fig.savefig(fig_path, dpi=180, bbox_inches="tight")
plt.close(fig)
print(f"[Saved] {fig_path}")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


[load_state_dict] missing=0, unexpected=0


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type                | Params | Mode 
-------------------------------------------------------
0 | model  | GraphTransformer    | 5.6 M  | train
1 | scaler | StandardScalerTorch | 0      | train
-------------------------------------------------------
4.8 M     Trainable params
797 K     Non-trainable params
5.6 M     Total params
22.409    Total estimated model params size (MB)
130       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[Saved] finetune_run2/finetuned_weights.pt
[Test] MAE = 21.669 K | RMSE = 32.083 K | R^2 = 0.8881

Top-5 largest absolute errors (SMILES | Actual K | Predicted K | AbsErr):
COC(=O)C1CCCN1C(=O)COC(=O)CC2=CC(C=CC=C2)NC3=C(C=CC=CC3Cl)Cl | 722.15 | 403.45 | 318.70
CS(C)(C)I | 490.65 | 233.69 | 256.96
CCOCCCO | 433.65 | 185.72 | 247.93
CC(Cl)(Cl)CCC | 401.65 | 187.80 | 213.85
C=C(C)/C=CC | 348.65 | 148.53 | 200.12
[Saved] finetune_run/pred_vs_actual_finetune.png


In [30]:
def to_device(batch, device):
    if torch.is_tensor(batch):
        return batch.to(device, non_blocking=True)
    if isinstance(batch, dict):
        return {k: to_device(v, device) for k, v in batch.items()}
    if isinstance(batch, (list, tuple)):
        return type(batch)(to_device(v, device) for v in batch)
    return batch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [32]:
import matplotlib.pyplot as plt

In [35]:
# Test + metrics + Top-5 + plot
preds, targets, smiles_all = [], [], []
ft.eval()
with torch.no_grad():
    for batch in dm.test_dataloader():
        batch = to_device(batch, device)
        y_hat_scaled = ft(batch)
        y_hat = dm.scaler.inverse_transform(y_hat_scaled)
        preds.append(y_hat.squeeze(-1).cpu().numpy())
        targets.append(batch["y"].squeeze(-1).cpu().numpy())
        smiles_all += batch["smiles"]
preds = np.concatenate(preds, axis=0) if preds else np.array([])
targets = np.concatenate(targets, axis=0) if targets else np.array([])

if len(preds) > 0:
    err = preds - targets
    mae = float(np.mean(np.abs(err)))
    rmse = float(np.sqrt(np.mean(err**2)))
    y_mean = float(np.mean(targets))
    ss_res = float(np.sum((targets - preds)**2))
    ss_tot = float(np.sum((targets - y_mean)**2)) if len(targets) > 1 else 0.0
    r2 = float(1.0 - ss_res / ss_tot) if ss_tot > 0 else float("nan")
    print(f"[Test] MAE = {mae:.3f} K | RMSE = {rmse:.3f} K | R^2 = {r2:.4f}")

    order = np.argsort(-np.abs(err))[:min(5, len(err))]
    print("\nTop-5 largest absolute errors (SMILES | Actual K | Predicted K | AbsErr):")
    for idx in order:
        print(f"{smiles_all[idx]} | {targets[idx]:.2f} | {preds[idx]:.2f} | {abs(err[idx]):.2f}")
else:
    print("[Test] No predictions to evaluate.")

# Plot with model title + R^2
fig, ax = plt.subplots(figsize=(6, 6))
ax.scatter(targets, preds, s=4, alpha=0.8, edgecolors="none")  # smaller points for large N
if len(targets) > 0:
    lo = float(np.min([targets.min(), preds.min()]))
    hi = float(np.max([targets.max(), preds.max()]))
    ax.plot([lo, hi], [lo, hi], ls="--", lw=1.0, color="k")
ax.set_xlabel("Actual Melting Point (K)")
ax.set_ylabel("Predicted Melting Point (K)")
ax.set_title(ft.title)
# TeX-style R^2
ax.text(0.05, 0.95, rf"$R^2={r2:.6f}$", transform=ax.transAxes, ha="left", va="top")
fig_path = Path("finetune_run") / "pred_vs_actual_finetune.png"
fig.savefig(fig_path, dpi=180, bbox_inches="tight")
plt.close(fig)
print(f"[Saved] {fig_path}")

[Test] MAE = 21.669 K | RMSE = 32.083 K | R^2 = 0.8881

Top-5 largest absolute errors (SMILES | Actual K | Predicted K | AbsErr):
COC(=O)C1CCCN1C(=O)COC(=O)CC2=CC(C=CC=C2)NC3=C(C=CC=CC3Cl)Cl | 722.15 | 403.45 | 318.70
CS(C)(C)I | 490.65 | 233.69 | 256.96
CCOCCCO | 433.65 | 185.72 | 247.93
CC(Cl)(Cl)CCC | 401.65 | 187.80 | 213.85
C=C(C)/C=CC | 348.65 | 148.53 | 200.12
[Saved] finetune_run/pred_vs_actual_finetune.png
