In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from data import CIFData, collate_pool, get_train_val_test_loader
from model import CrystalGraphConvNet 

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(123)
np.random.seed(123)

In [3]:
dataset = CIFData(root_dir='F:\College\Research Paper Work\Superconductor\models\CGCNN\data', max_num_nbr=20, radius=15,
                  dmin=0, step=0.2, random_seed=123)


In [12]:
def collate_batch(batch):
    batch_atom_fea = []
    batch_nbr_fea = []
    batch_nbr_fea_idx = []
    batch_crystal_atom_idx = []
    batch_target = []
    batch_cif_ids = []

    base_idx = 0
    for (atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx), target, cif_id in batch:
        n_i = atom_fea.shape[0]

        batch_atom_fea.append(atom_fea)
        batch_nbr_fea.append(nbr_fea)
        batch_nbr_fea_idx.append(nbr_fea_idx)
        batch_crystal_atom_idx.extend([i + base_idx for i in crystal_atom_idx])
        batch_target.append(target.view(1))  # ensure shape [1]
        batch_cif_ids.append(cif_id)

        base_idx += n_i

    return (
        torch.cat(batch_atom_fea, dim=0),              # [total_atoms, atom_fea_len]
        torch.cat(batch_nbr_fea, dim=0),               # [total_atoms, max_num_nbr, nbr_fea_len]
        torch.cat(batch_nbr_fea_idx, dim=0),           # [total_atoms, max_num_nbr]
        torch.LongTensor(batch_crystal_atom_idx),      # [total_neighbors]
    ),  torch.tensor(batch_target, dtype=torch.float), batch_cif_ids   # targets as [batch_size]


In [14]:
train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1
batch_size = 64
num_workers = 4
pin_memory = True  # or False depending on your CUDA usage

# Use ratio-based split (can also specify fixed sizes instead)
train_loader, val_loader, test_loader = get_train_val_test_loader(
    dataset=dataset,
    collate_fn=collate_batch, 
    batch_size=1,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    test_ratio=test_ratio,
    num_workers=0,
    pin_memory=pin_memory,
    train_size=None,  # or set to exact int like 4618 (0.8 * 5773)
    val_size=None,
    test_size=None,
    return_test=True
)

In [15]:
from model import CrystalGraphConvNet

model = CrystalGraphConvNet(
    orig_atom_fea_len=4,
    nbr_fea_len=3,
    atom_fea_len=64,
    n_conv=3,
    h_fea_len=128,
    n_h=1,
    classification=False  # change to True if you’re doing classification
)
model.to(device)


CrystalGraphConvNet(
  (embedding): Linear(in_features=4, out_features=64, bias=True)
  (convs): ModuleList(
    (0-2): 3 x ConvLayer(
      (fc_full): Linear(in_features=131, out_features=128, bias=True)
      (sigmoid): Sigmoid()
      (softplus1): Softplus(beta=1.0, threshold=20.0)
      (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (softplus2): Softplus(beta=1.0, threshold=20.0)
    )
  )
  (conv_to_fc): Linear(in_features=64, out_features=128, bias=True)
  (conv_to_fc_softplus): Softplus(beta=1.0, threshold=20.0)
  (fc_out): Linear(in_features=128, out_features=1, bias=True)
)

In [16]:
criterion = nn.MSELoss()  # for regression tasks
optimizer = Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

In [17]:
def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0

    for (atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx), target, cif_id in train_loader:
        # Move tensors to device
        atom_fea = atom_fea.to(device)
        nbr_fea = nbr_fea.to(device)
        nbr_fea_idx = nbr_fea_idx.to(device)
        crystal_atom_idx = crystal_atom_idx.to(device)
        target = target.to(device)

        # Print tensor shapes for debugging
        print("atom_fea:", atom_fea.shape)
        print("nbr_fea:", nbr_fea.shape)
        print("nbr_fea_idx:", nbr_fea_idx.shape)
        print("crystal_atom_idx:", crystal_atom_idx.shape)
        print("target:", target.shape)

        # Flatten target to shape [batch_size]
        target = target.view(-1)

        optimizer.zero_grad()
        output = model(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)
        print("Model output shape:",output.shape)

        # Flatten output to match target shape
        output = output.view(-1)

        # Compute loss
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)




def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for (atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx), target, _ in val_loader:
            atom_fea, nbr_fea, nbr_fea_idx = atom_fea.to(device), nbr_fea.to(device), nbr_fea_idx.to(device)
            target = target.to(device).float()
            output = model(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)
            loss = criterion(output, target)
            total_loss += loss.item() * target.size(0)
    return total_loss / len(val_loader.dataset)


In [18]:
from sklearn.metrics import r2_score, mean_squared_error
import numpy as np

num_epochs = 2
train_losses, val_losses = [], []

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, all_preds, all_targets = validate(model, val_loader, criterion, device, return_preds=True)
    
    r2 = r2_score(all_targets, all_preds)
    mse = mean_squared_error(all_targets, all_preds)
    rmse = np.sqrt(mse)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(f"R² Score: {r2:.4f} | MSE: {mse:.4f} | RMSE: {rmse:.4f}")



Epoch 1/2
atom_fea: torch.Size([13, 4])
nbr_fea: torch.Size([13, 20, 3])
nbr_fea_idx: torch.Size([13, 20])
crystal_atom_idx: torch.Size([260])
target: torch.Size([1])


  struct = parser.parse_structures(primitive=primitive)[0]


TypeError: len() of a 0-d tensor

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Training Loss', marker='o')
plt.plot(val_losses, label='Validation Loss', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training & Validation Loss over Epochs')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
model.eval()
all_preds, all_targets = [], []

with torch.no_grad():
    for (atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx), target, _ in test_loader:
        atom_fea = atom_fea.to(device)
        nbr_fea = nbr_fea.to(device)
        nbr_fea_idx = nbr_fea_idx.to(device)
        target = target.to(device)

        output = model(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)
        all_preds.append(output.cpu().numpy())
        all_targets.append(target.cpu().numpy())

all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)


In [None]:
plt.figure(figsize=(8, 8))
plt.scatter(all_targets, all_preds, alpha=0.6, edgecolors='k')
plt.plot([all_targets.min(), all_targets.max()],
         [all_targets.min(), all_targets.max()], 'r--', lw=2)
plt.xlabel('True Tc')
plt.ylabel('Predicted Tc')
plt.title('True vs Predicted Critical Temperature (Tc)')
plt.grid(True)
plt.tight_layout()
plt.show()
