In [1]:
# user-friendly print
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.instance().ast_node_interactivity = "all"

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning import LightningModule

torch.set_float32_matmul_precision("medium")  # 推荐选项


def weighted_mse_loss(pred, target, weight_ratio=0.5):
    with torch.no_grad():
        max_val = target.max(dim=1, keepdim=True).values.clamp(min=1e-6)
        weights = 1.0 + weight_ratio * (target / max_val)
    return (weights * (pred - target) ** 2).mean()


# --- Mixture Decoder ---
class MixtureDecoder(nn.Module):
    def __init__(self, d_in, n_components=2):
        super().__init__()
        self.n_components = n_components
        self.linear = nn.Linear(d_in, 3 * n_components)

    def forward(self, hidden, t):
        B, L, _ = t.shape
        params = self.linear(hidden)
        alpha, mu, log_sigma = params.chunk(3, dim=-1)
        sigma = torch.exp(log_sigma).clamp(min=1e-2, max=30.0)
        alpha = torch.softmax(alpha, dim=-1)
        t_exp = t.expand(-1, -1, self.n_components)
        gauss = torch.exp(-0.5 * ((t_exp - mu) / sigma) ** 2)
        return (alpha * gauss).sum(dim=-1, keepdim=True)


# --- Main Model ---
class LightningCurveRNN(LightningModule):
    def __init__(
        self,
        d_descriptor=290,
        hidden_size=64,
        n_layers=2,
        bidirectional=True,
        use_layer_norm=True,
        lr=1e-3,
        weighted_loss=False,
        weight_ratio=0.5,
        inject_descriptor=False,
        use_mixture_decoder=False,
        n_mixture_components=10,
        use_residual=False,
    ):
        super().__init__()
        self.save_hyperparameters(
            ignore=[
                "weighted_loss",
                "weight_ratio",
                "inject_descriptor",
                "use_mixture_decoder",
                "n_mixture_components",
                "use_residual",
            ]
        )
        self.lr = lr
        self.weighted_loss = weighted_loss
        self.weight_ratio = weight_ratio
        self.inject_descriptor = inject_descriptor
        self.use_mixture_decoder = use_mixture_decoder
        self.use_residual = use_residual

        self.d_proj = nn.Linear(d_descriptor, hidden_size)
        self.t_proj = nn.Linear(1, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, batch_first=True, bidirectional=bidirectional)
        d_out = hidden_size * (2 if bidirectional else 1)
        self.layer_norm = nn.LayerNorm(d_out) if use_layer_norm else None
        self.decoder = (
            MixtureDecoder(d_out, n_components=n_mixture_components) if use_mixture_decoder else nn.Linear(d_out, 1)
        )

    def forward(self, t_seq, D):
        B, L, _ = t_seq.shape
        base = torch.tanh(self.t_proj(t_seq))  # (B, L, H)
        d_feat = torch.tanh(self.d_proj(D))  # (B, H)

        t_feat = base
        if self.inject_descriptor:
            d_feat_exp = d_feat.unsqueeze(1).expand(-1, L, -1)
            t_feat = t_feat + d_feat_exp

        num_dir = 2 if self.gru.bidirectional else 1
        h0 = d_feat.unsqueeze(0).repeat(self.gru.num_layers * num_dir, 1, 1)
        out, _ = self.gru(t_feat, h0)

        if self.layer_norm:
            out = self.layer_norm(out)

        if self.use_residual:
            out = out + base  # ✅ residual from t-projected input

        return self.decoder(out, t_seq) if self.use_mixture_decoder else self.decoder(out)

    def _compute_loss(self, pred, target):
        if self.weighted_loss:
            w = 1 + self.weight_ratio * (target / target.max(dim=1, keepdim=True).values.clamp(min=1e-6))
            return (w * (pred - target) ** 2).mean()
        return F.mse_loss(pred, target)

    def training_step(self, batch, batch_idx):
        t, D, v = batch
        pred = self(t, D)
        loss = self._compute_loss(pred, v)
        self.log("train/loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        t, D, v = batch
        pred = self(t, D)
        loss = F.mse_loss(pred, v)
        self.log("val/loss", loss, prog_bar=True, on_epoch=True)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        t, D, v_true = batch
        pred = self(t, D)
        return t, v_true, pred

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

In [3]:
from typing import Tuple

import pandas as pd
import torch
from torch.utils.data import Dataset


class DOSDataset(Dataset):
    def __init__(self, desc: pd.DataFrame, dos_energy: pd.Series, dos: pd.Series):
        super().__init__()
        self.desc = desc
        self.dos_energy = dos_energy
        self.dos = dos
        # 只保留索引交集的样本
        common_idx = desc.index.intersection(dos.index).intersection(dos_energy.index)
        self.indices: list = list(common_idx)
        self.d_descriptor: int = desc.shape[1]

    def __len__(self) -> int:
        return len(self.indices)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        key = self.indices[idx]
        descriptor = torch.tensor(self.desc.loc[key].values, dtype=torch.float32)
        energy_grid = torch.tensor(self.dos_energy.loc[key], dtype=torch.float32).unsqueeze(-1)  # (L, 1)
        dos_curve = torch.tensor(self.dos.loc[key], dtype=torch.float32).unsqueeze(-1)  # (L, 1)
        return energy_grid, descriptor, dos_curve

In [4]:
from typing import Optional

import lightning as pl
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader


class DOSDataModule(pl.LightningDataModule):
    def __init__(
        self,
        desc: pd.DataFrame,
        dos_energy: pd.Series,
        dos: pd.Series,
        serial: Optional[pd.Series] = None,  # index->"train"/"val"/"test"
        batch_size: int = 32,
        random_seed: int = 42,
    ):
        super().__init__()
        self.desc = desc
        self.dos_energy = dos_energy
        self.dos = dos
        self.serial = serial
        self.batch_size = batch_size
        self.random_seed = random_seed

    def setup(self, stage=None):
        if self.serial is not None:
            train_idx = self.serial[self.serial == "train"].index
            val_idx = self.serial[self.serial == "val"].index
            test_idx = self.serial[self.serial == "test"].index
        else:
            # 自动划分
            all_idx = np.array(list(self.desc.index.intersection(self.dos.index).intersection(self.dos_energy.index)))
            rng = np.random.RandomState(self.random_seed)
            perm = rng.permutation(len(all_idx))
            n = len(all_idx)
            n_train = int(n * 0.7)
            n_val = int(n * 0.1)
            # n_test = n - n_train - n_val  # 未被使用，可删除
            train_idx = all_idx[perm[:n_train]]
            val_idx = all_idx[perm[n_train : n_train + n_val]]
            test_idx = all_idx[perm[n_train + n_val :]]
        self.train_dataset = DOSDataset(
            self.desc.loc[train_idx], self.dos_energy.loc[train_idx], self.dos.loc[train_idx]
        )
        self.val_dataset = DOSDataset(self.desc.loc[val_idx], self.dos_energy.loc[val_idx], self.dos.loc[val_idx])
        self.test_dataset = DOSDataset(self.desc.loc[test_idx], self.dos_energy.loc[test_idx], self.dos.loc[test_idx])

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

In [5]:
import pandas as pd
from pymatgen.core import Composition

qc_ac_te_mp_dos_data = pd.read_parquet("/data/foundation_model/data/qc_ac_te_mp_dos_reformat_20250529.pd.parquet")
qc_ac_te_mp_dos_data.composition = qc_ac_te_mp_dos_data.composition.apply(
    lambda x: Composition({k: v for k, v in x.items() if v is not None and v > 0})
)
desc_trans = pd.read_parquet("/data/foundation_model/data/qc_ac_te_mp_dos_composition_desc_trans_20250529.pd.parquet")

In [6]:
dos = qc_ac_te_mp_dos_data["DOS density"].dropna()
dos_norm = qc_ac_te_mp_dos_data["DOS density (normalized)"].loc[dos.index]
dos_energy = qc_ac_te_mp_dos_data["DOS energy"].loc[dos.index]
desc = desc_trans.loc[dos.index]
split = qc_ac_te_mp_dos_data["split"].loc[dos.index]

dos.head(3)
dos_norm.head(3)
dos_energy.head(3)
desc.head(3)
split.head(3)

id
mp-1184879    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
mp-1188721    [-3.9227656025467192, -1.8316655857718702, 0.4...
mp-81         [2.2133526301309794, 2.105725706060867, 2.0071...
Name: DOS density, dtype: object

id
mp-1184879    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
mp-1188721    [-0.09143959740688823, -0.04269609268478773, 0...
mp-81         [0.6810085317013201, 0.6478936757426565, 0.617...
Name: DOS density (normalized), dtype: object

id
mp-1184879    [-6.0, -5.946488294314381, -5.892976588628763,...
mp-1188721    [-6.0, -5.946488294314381, -5.892976588628763,...
mp-81         [-6.0, -5.946488294314381, -5.892976588628763,...
Name: DOS energy, dtype: object

Unnamed: 0_level_0,ave:atomic_number,ave:atomic_radius,ave:atomic_radius_rahm,ave:atomic_volume,ave:atomic_weight,ave:boiling_point,ave:bulk_modulus,ave:c6_gb,ave:covalent_radius_cordero,ave:covalent_radius_pyykko,...,min:num_s_valence,min:period,min:specific_heat,min:thermal_conductivity,min:vdw_radius,min:vdw_radius_alvarez,min:vdw_radius_mm3,min:vdw_radius_uff,min:sound_velocity,min:Polarizability
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
mp-1184879,0.651165,0.332235,-0.093343,1.599262,0.632693,-1.820146,-2.570454,-1.394964,-0.091408,-0.024667,...,0.965208,0.962933,2.219223,-0.812983,0.32616,0.001953,0.419739,2.107274,0.781697,-0.014997
mp-1188721,-1.102084,-2.989988,-1.429175,0.77488,-1.058183,-2.174474,-0.777649,-1.993339,-1.584199,-1.489223,...,0.965208,-0.877215,2.073631,-0.83405,-0.844011,-0.559261,-0.689554,1.401375,-1.132118,-0.946938
mp-81,2.264521,-0.294966,-0.062036,-1.40694,2.242643,1.15755,2.100947,-0.801386,0.31572,0.14753,...,-0.907212,2.16415,-1.296625,2.110449,1.745153,1.189419,1.35422,0.726264,0.263055,0.93407


id
mp-1184879      val
mp-1188721    train
mp-81         train
Name: split, dtype: object

In [7]:
import random

import matplotlib.pyplot as plt
import torch


def plot_prediction_pairs(samples, *, n=9, seed=42, title_prefix="Sample"):
    """
    输入 samples: List of (t, v_true) 或 (t, v_true, v_pred)
    自动判断有无 pred，进行单曲线或对比绘图
    """

    random.seed(seed)  # 固定随机种子以确保可重复性

    indices = random.sample(range(len(samples)), n)
    n_cols = int(np.sqrt(n))
    n_rows = int(np.ceil(n / n_cols))

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(3.5 * n_cols, 3 * n_rows))
    axes = axes.flatten()

    for i, idx in enumerate(indices):
        entry = samples[idx]
        if len(entry) == 2:
            t, v_true = entry
            v_pred = None
        elif len(entry) == 3:
            t, v_true, v_pred = entry
        else:
            raise ValueError("Each sample must be (t, v_true) or (t, v_true, v_pred)")

        t = t.squeeze(-1).cpu().numpy()
        v_true = v_true.squeeze(-1).cpu().numpy()

        ax = axes[i]
        ax.plot(t, v_true, label="True", linewidth=1.5)
        if v_pred is not None:
            v_pred = v_pred.squeeze(-1).cpu().numpy()
            ax.plot(t, v_pred, label="Pred", linewidth=1.5)
        ax.set_title(f"{title_prefix} #{idx}")
        ax.set_xlabel("Energy (eV)")
        ax.set_ylabel("DOS")
        ax.legend()

    for j in range(n, len(axes)):
        axes[j].axis("off")

    plt.tight_layout()
    plt.show()

In [122]:
from lightning import Trainer
from lightning.pytorch.callbacks import EarlyStopping

# 1. 初始化数据模块
dm = DOSDataModule(
    desc=desc,
    dos_energy=dos_energy,
    dos=dos,
    serial=split,
    batch_size=64,
    random_seed=42,
)
dm.setup()

# 2. 初始化模型
model = LightningCurveRNN(
    d_descriptor=desc.shape[1],
    hidden_size=128,
    n_layers=6,
    bidirectional=True,
    use_layer_norm=True,
    lr=1e-3,
    weighted_loss=False,
    weight_ratio=0.5,
    inject_descriptor=True,
    use_mixture_decoder=True,
    n_mixture_components=300,
    use_residual=False,
)

# 3. 初始化训练器
trainer = Trainer(
    max_epochs=50,
    accelerator="auto",
    callbacks=[EarlyStopping(monitor="val/loss", patience=5, mode="min")],
    log_every_n_steps=100,
)

# 4. 开始训练
trainer.fit(model, datamodule=dm)

# 5. predict
preds = trainer.predict(model, dataloaders=dm.test_dataloader())

# 6. 可视化预测结果
# flatten and pass to plot
samples = [sample for batch in preds for sample in zip(*batch)]
plot_prediction_pairs(samples, n=9)

Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name       | Type           | Params | Mode 
------------------------------------------------------
0 | d_proj     | Linear         | 37.2 K | train
1 | t_proj     | Linear         | 256    | train
2 | gru        | GRU            | 1.7 M  | train
3 | layer_norm | LayerNorm      | 512    | train
4 | decoder    | MixtureDecoder | 231 K  | train
------------------------------------------------------
1.9 M     Trainable params
0         Non-trainable params
1.9 M     Total params
7.799     Total estimated model params size (MB)
6         Modules in train mode
0         Modules in eval mode


                                                                            

/data/foundation_model/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.
/data/foundation_model/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


Epoch 5:  26%|██▌       | 29/113 [00:00<00:02, 30.25it/s, v_num=26, train/loss_step=56.90, val/loss=81.00, train/loss_epoch=71.90] 


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined