In [1]:
# === Standard Library ===
import os
import json
import warnings
from pathlib import Path
import psutil

# === Scientific Computing ===
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
from tqdm import tqdm

# === Machine Learning / Preprocessing ===
import joblib
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# === PyTorch & PyTorch Geometric ===
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, Subset
from torch_geometric.data import Dataset as GeoDataset, Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import CGConv, global_mean_pool

# === Pymatgen (Materials Science) ===
from pymatgen.core import Structure
from pymatgen.analysis.graphs import StructureGraph
from pymatgen.analysis.local_env import MinimumDistanceNN, CrystalNN

# === Parallel Processing ===
from joblib import Parallel, delayed

# === Device Configuration ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Dataset Creation

##### Parsing input files (extract nodes and edges features)

In [9]:
def process_file(path, out_dir):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    try:
        with open(path, 'r') as f:
            data = json.load(f)
    except (json.JSONDecodeError, UnicodeDecodeError):
        return None

    if "cif_structure" not in data or "GGA" not in data:
        return None

    try:
        structure = Structure.from_str(data["cif_structure"], fmt="cif")
        if len(structure) < 2:
            structure.make_supercell([2, 2, 2])
        structure.add_oxidation_state_by_guess()
        sg = StructureGraph.from_local_env_strategy(structure, CrystalNN())
    except Exception:
        return None

    gga = data["GGA"]
    results = []

    for carrier in ["n", "p"]:
        s_data = gga.get("seebeck_doping", {}).get(carrier, {})
        for T, dopings in s_data.items():
            for doping, s_entry in dopings.items():
                s_eigs = s_entry.get("eigs", [])
                if not s_eigs:
                    continue

                entry = {
                    "seebeck": float(np.mean(s_eigs)),
                    "global_feats": [float(T), float(doping), 0 if carrier == "n" else 1],
                    "node_feats": [],
                    "edge_index": [],
                    "edge_attr": [],
                }

                for site in structure:
                    Z = site.specie.Z
                    if 1 <= Z <= 118:
                        one_hot = [0.0] * 118
                        one_hot[Z - 1] = 1.0
                        entry["node_feats"].append(one_hot)

                for i, j, attr in sg.graph.edges(data=True):
                    entry["edge_index"].append([i, j])
                    bond_length = attr.get("weight", structure.get_distance(i, j))
                    entry["edge_attr"].append([bond_length])

                if not entry["edge_index"]:
                    continue

                results.append(entry)

    if results:
        out_path = out_dir / (Path(path).stem + ".json")
        with open(out_path, "w") as f_out:
            json.dump(results, f_out, indent=2)

        return str(out_path)
    return None


In [None]:
input_dir = "unzipped_json"
output_dir = "processed_json"
json_files = list(Path(input_dir).rglob("*.json"))[:1000]

saved_paths = Parallel(n_jobs=16)(
    delayed(process_file)(path, output_dir) for path in tqdm(json_files)
)

saved_paths = [p for p in saved_paths if p]
print(f"Saved {len(saved_paths)} JSON files.")

##### Dataset Class

In [2]:
class ThermoGraphDataset(Dataset):
    def __init__(self, directory):
        warnings.filterwarnings("ignore")

        self.directory = Path(directory)
        self.json_files = list(self.directory.rglob("*.json"))
        self.graphs = []

        for i, path in enumerate(tqdm(self.json_files, desc="Loading graph JSONs")):
            with open(path, 'r') as f:
                entries = json.load(f)

            for entry in entries:
                node_feats = torch.tensor(entry["node_feats"], dtype=torch.float32)
                edge_index = torch.tensor(entry["edge_index"], dtype=torch.long).t().contiguous()
                edge_attr = torch.tensor(entry["edge_attr"], dtype=torch.float32)
                y = torch.tensor([entry["seebeck"]], dtype=torch.float32)
                u = torch.tensor(entry["global_feats"], dtype=torch.float32).unsqueeze(0)

                data = Data(
                    x=node_feats,
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    y=y,
                    u=u
                )
                self.graphs.append(data)

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        return self.graphs[idx]

In [3]:
def standardize_graph_dataset(dataset):
    y_all = []
    u_all = []

    for data in dataset:
        y_all.append(data.y.numpy())
        u_all.append(data.u.numpy())

    y_all = np.vstack(y_all)
    u_all = np.vstack(u_all)

    scaler_y = StandardScaler()
    scaler_u = StandardScaler()

    y_scaled = scaler_y.fit_transform(y_all)
    u_scaled = scaler_u.fit_transform(u_all)

    joblib.dump(scaler_y, "scaler_y.pkl")
    joblib.dump(scaler_u, "scaler_u.pkl")

    # Update graphs
    new_graphs = []
    for i, data in enumerate(dataset):
        new_data = data.clone()
        new_data.y = torch.tensor(y_scaled[i], dtype=torch.float32)
        new_data.u = torch.tensor(u_scaled[i], dtype=torch.float32).unsqueeze(0)
        new_graphs.append(new_data)

    return new_graphs

In [None]:
if os.path.exists("ThermoGraphDataset.json"):
    dataset = torch.load("ThermoGraphDataset.json")
else:
    dataset_raw = ThermoGraphDataset("processed_json")
    dataset = standardize_graph_dataset(dataset_raw)
    torch.save(dataset, "ThermoGraphDataset.json")

indices = list(range(len(dataset)))
train_idx, test_idx = train_test_split(indices, test_size=0.2, random_state=42)

train_dataset = [dataset[i] for i in train_idx]
test_dataset = [dataset[i] for i in test_idx]

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Define Neural Network Architecture

In [5]:
class CGCNNModel(nn.Module):
    def __init__(self,
                 node_feat_dim=118,
                 edge_feat_dim=1,
                 global_feat_dim=3,
                 hidden_dim=64,
                 num_layers=3,
                 output_dim=1):
        super(CGCNNModel, self).__init__()

        self.embedding = nn.Linear(node_feat_dim, hidden_dim)

        self.convs = nn.ModuleList([
            CGConv(channels=hidden_dim, dim=edge_feat_dim) for _ in range(num_layers)
        ])

        self.pool = global_mean_pool  # aggregates node features per graph

        self.global_fc = nn.Sequential(
            nn.Linear(global_feat_dim, hidden_dim),
            nn.ReLU()
        )

        self.regressor = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, data):
        x = self.embedding(data.x)
        for conv in self.convs:
            x = conv(x, data.edge_index, data.edge_attr)

        x = self.pool(x, data.batch)  # pooled node features per graph
        g = self.global_fc(data.u)  # embedded global features

        combined = torch.cat([x, g], dim=1)
        out = self.regressor(combined)
        return out

In [6]:
model = CGCNNModel().to(device)

In [None]:
debug_loader = DataLoader(dataset[:5], batch_size=2, shuffle=False)
model.eval()

with torch.no_grad():
    for batch in debug_loader:
        print("🔍 DEBUGGING BATCH")
        print("Node features shape:", batch.x.shape)
        print("Edge index shape:", batch.edge_index.shape)
        print("Edge attr shape:", batch.edge_attr.shape)
        print("Batch vector shape:", batch.batch.shape)
        print("Target shape:", batch.y.shape)

        # Move to device
        batch = batch.to(device)
        out = model(batch)

        print("Output predictions shape:", out.shape)
        print("Predictions:", out)
        print("Targets:", batch.y)
        print("✅ Forward pass successful!\n")
        break  # Only check one batch


# Training and testing

In [8]:
def plot_errors(y_train, y_train_pred, y_test, y_test_pred, scaler_path="scaler_y.pkl"):
    label = "Seebeck Coefficient (μV/K)"

    '''# Загрузка скейлера
    scaler_y = joblib.load(scaler_path)

    # Делай .reshape(-1, 1), т.к. у нас только один целевой признак
    y_train_unscaled = scaler_y.inverse_transform(y_train.reshape(-1, 1))
    y_train_pred_unscaled = scaler_y.inverse_transform(y_train_pred.reshape(-1, 1))
    y_test_unscaled = scaler_y.inverse_transform(y_test.reshape(-1, 1))
    y_test_pred_unscaled = scaler_y.inverse_transform(y_test_pred.reshape(-1, 1))

    # Вычисляем ошибки
    error_train = y_train_pred_unscaled - y_train_unscaled
    error_test = y_test_pred_unscaled - y_test_unscaled'''
    error_train = y_train_pred - y_train
    error_test = y_test_pred - y_test

    # Визуализация
    plt.figure(figsize=(8, 4))
    plt.title(f"{label} Error")
    plt.xlabel(r"$y_{pred} - y_{true}$")
    plt.ylabel("Density")

    # Гистограммы
    plt.hist(error_train, bins=40, alpha=0.4, density=True, color="#138A07", label="Train")
    plt.hist(error_test, bins=40, alpha=0.4, density=True, color="#bc4749", label="Test")

    # KDE
    kde_train = gaussian_kde(error_train.ravel())
    kde_train.covariance_factor = lambda: .25
    kde_train._compute_covariance()

    kde_test = gaussian_kde(error_test.ravel())
    kde_test.covariance_factor = lambda: .25
    kde_test._compute_covariance()

    xs = np.linspace(min(error_train.min(), error_test.min()) - 0.1,
                     max(error_train.max(), error_test.max()) + 0.1, 200)

    plt.plot(xs, kde_train(xs), color="#138A07", linewidth=3)
    plt.plot(xs, kde_test(xs), color="#bc4749", linewidth=3)

    plt.axvline(0, color='black', linestyle='--', linewidth=2)
    plt.legend()
    plt.tight_layout()
    plt.xlim(-2,2)
    plt.show()

In [9]:
def train(model, loader, optimizer, criterion, epoch, device):
    model.train()
    total_loss = 0
    all_preds = []
    all_targets = []

    pbar = tqdm(loader, desc=f"Training Epoch {epoch}")
    for batch in pbar:
        batch = batch.to(device)
        optimizer.zero_grad()

        predictions = model(batch)
        loss = criterion(predictions.view(-1), batch.y.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pbar.set_postfix({"Batch Loss": loss.item()})

        all_preds.append(predictions.view(-1).detach().cpu().numpy())
        all_targets.append(batch.y.view(-1).cpu().numpy())

    avg_loss = total_loss / len(loader)
    preds = np.concatenate(all_preds)
    targets = np.concatenate(all_targets)
    rmse = np.sqrt(mean_squared_error(targets, preds))
    r2 = r2_score(targets, preds)

    print(f"Train Loss: {avg_loss:.4f} | RMSE: {rmse:.4f} | R²: {r2:.4f}")
    return preds, targets


In [10]:
def test(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            predictions = model(batch)
            loss = criterion(predictions.view(-1), batch.y.view(-1))
            total_loss += loss.item()

            all_preds.append(predictions.view(-1).cpu().numpy())
            all_targets.append(batch.y.view(-1).cpu().numpy())

    avg_loss = total_loss / len(loader)
    preds = np.concatenate(all_preds)
    targets = np.concatenate(all_targets)
    rmse = np.sqrt(mean_squared_error(targets, preds))
    r2 = r2_score(targets, preds)

    print(f"Test Loss: {avg_loss:.4f} | RMSE: {rmse:.4f} | R²: {r2:.4f}")
    return preds, targets

In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

epochs = 20
for epoch in range(1, epochs + 1):
    preds_train, targets_train = train(model, train_loader, optimizer, criterion, epoch, device)
    preds_test, targets_test = test(model, test_loader, criterion, device)
    plot_errors(targets_train, preds_train, targets_test, preds_test)