In [None]:
import math
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch.utils.data import Subset
from sklearn.model_selection import KFold
from sklearn.metrics import r2_score
import numpy as np
import matplotlib.pyplot as plt
from gnn_dataset_generation import PolymerDataset
from torch_geometric.nn import GINConv, global_add_pool
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU

dataset = PolymerDataset(root='.')
kf = KFold(n_splits=5, shuffle=True, random_state=42)
fold_rmse = []
fold_r2 = []
all_preds = []
all_targets = []

class GIN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        def make_nn(in_f):
            return Sequential(
                Linear(in_f, hidden_dim), BatchNorm1d(hidden_dim), ReLU(),
                Linear(hidden_dim, hidden_dim), ReLU()
            )
        self.conv1 = GINConv(make_nn(in_dim))
        self.conv2 = GINConv(make_nn(hidden_dim))
        self.conv3 = GINConv(make_nn(hidden_dim))
        self.lin1 = Linear(hidden_dim, hidden_dim)
        self.lin2 = Linear(hidden_dim, 1)
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        h = self.conv1(x, edge_index).relu()
        h = self.conv2(h, edge_index).relu()
        h = self.conv3(h, edge_index)
        h = global_add_pool(h, batch)
        h = self.lin1(h).relu()
        h = F.dropout(h, p=0.5, training=self.training)
        return self.lin2(h).view(-1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for fold, (train_idx, val_idx) in enumerate(kf.split(dataset), 1):
    train_subset = Subset(dataset, train_idx)
    val_subset = Subset(dataset, val_idx)

    targets = torch.tensor(
        [train_subset[i].y.view(-1)[0].item() for i in range(len(train_subset))],
        dtype=torch.float32
    )
    target_mean = targets.mean()
    target_std = targets.std() if targets.std() > 1e-6 else torch.tensor(1.0)

    train_loader = DataLoader(train_subset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=16)

    model = GIN(in_dim=3, hidden_dim=128).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, weight_decay=5e-4)
    target_mean = target_mean.to(device)
    target_std = target_std.to(device)

    def train_epoch():
        model.train()
        total_loss = 0.0
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            out = model(data)
            y = data.y.view(data.num_graphs, -1)[:, 0]
            y_norm = (y - target_mean) / target_std
            loss = F.mse_loss(out, y_norm, reduction='mean')
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        return total_loss / len(train_loader)

    @torch.no_grad()
    def validate():
        model.eval()
        sum_sq = 0.0
        total = 0
        fold_preds = []
        fold_targets = []
        for data in val_loader:
            data = data.to(device)
            out_norm = model(data)
            out_denorm = out_norm * target_std + target_mean
            y = data.y.view(data.num_graphs, -1)[:, 0]
            sum_sq += ((out_denorm - y) ** 2).sum().item()
            total += data.num_graphs
            fold_preds.append(out_denorm.cpu().numpy())
            fold_targets.append(y.cpu().numpy())
        rmse_val = math.sqrt(sum_sq / total)
        preds_np = np.concatenate(fold_preds)
        targets_np = np.concatenate(fold_targets)
        r2_val = r2_score(targets_np, preds_np)
        return rmse_val, r2_val, preds_np, targets_np

    for epoch in range(1500):
        train_epoch()

    rmse, r2, preds_fold, targets_fold = validate()
    fold_rmse.append(rmse)
    fold_r2.append(r2)
    all_preds.append(preds_fold)
    all_targets.append(targets_fold)
    print(f'Fold {fold} RMSE: {rmse:.4f}, R2: {r2:.4f}')

print(f'CV RMSE: {np.mean(fold_rmse):.4f} ± {np.std(fold_rmse):.4f}')
print(f'CV R2: {np.mean(fold_r2):.4f} ± {np.std(fold_r2):.4f}')

all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)

plt.figure(figsize=(6, 6))
plt.scatter(all_targets, all_preds, alpha=0.5)
lims = [min(all_targets.min(), all_preds.min()), max(all_targets.max(), all_preds.max())]
plt.plot(lims, lims, linestyle='--')
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.title('Actual vs. Predicted Area')
plt.tight_layout()
plt.savefig('area.png', dpi=300)
plt.show()

Fold 1 RMSE: 2301.3926, R2: -0.0005
Fold 2 RMSE: 2433.1876, R2: -0.0224


KeyboardInterrupt: 