In [None]:
class GatedFusion(nn.Module):
    def __init__(self, cnn_dim):
        super().__init__()
        self.gate = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, cnn_dim),
            nn.Tanh()
        )
        self.mfrac_boost = nn.Parameter(torch.tensor(3.0))

    def forward(self, cnn_feat, mfrac):
        gate_weights = self.gate(mfrac)
        boosted_mfrac = mfrac * self.mfrac_boost
        fused = torch.cat([cnn_feat * gate_weights, boosted_mfrac], dim=1)
        return fused, gate_weights

class MultiTask_CNN_GCN(nn.Module):
    def __init__(self, input_dim=(21, 21), gcn_hidden=GCN_HIDDEN, dropout_rate=DROPOUT):
        super().__init__()
        # CNN for local features
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        with torch.no_grad():
            dummy = torch.randn(1, 1, *input_dim)
            self.cnn_dim = self.cnn(dummy).view(-1).shape[0]

        # Fusion (gate)
        self.fusion = GatedFusion(self.cnn_dim)

        # GCN
        self.conv1 = GCNConv(self.cnn_dim + 1, gcn_hidden)
        self.conv2 = GCNConv(gcn_hidden, gcn_hidden)
        self.dropout = nn.Dropout(dropout_rate)

        # Heads
        self.node_out = nn.Linear(gcn_hidden, 1)
        self.graph_out = nn.Linear(gcn_hidden, 1)

        # For visualization
        self.gate_weights = None

    def forward(self, data):
        x_all = data.x
        num_nodes = x_all.size(0)
        edge_index = data.edge_index
        batch = data.batch

        # CNN feature extraction
        x_cnn = x_all[:, :-1].view(num_nodes, 1, 21, 21)
        x_cnn = self.cnn(x_cnn).view(num_nodes, -1)

        # Gate fusion
        x_mfrac = x_all[:, -1].unsqueeze(1)
        fused, gate_weights = self.fusion(x_cnn, x_mfrac)
        self.gate_weights = gate_weights

        # GCN layers
        x = F.relu(self.conv1(fused, edge_index))
        x = self.dropout(x)
        x = F.relu(self.conv2(x, edge_index))

        # Node output
        node_pred = self.node_out(x)

        # Graph output via pooling
        graph_repr = global_mean_pool(x, batch)
        graph_pred = self.graph_out(graph_repr)

        return node_pred, graph_pred
    
    def get_cnn_features(self, x):
        x_cnn = x[:, :-1].view(x.size(0), 1, 21, 21)
        return self.cnn(x_cnn)
    
    def get_node_embeddings(self, data):
        x_all = data.x
        x_cnn = self.get_cnn_features(x_all)
        x_mfrac = x_all[:, -1].unsqueeze(1)
        fused, _ = self.fusion(x_cnn.view(x_cnn.size(0), -1), x_mfrac)
        
        x = F.relu(self.conv1(fused, data.edge_index))
        x = self.conv2(x, data.edge_index)
        return x

# ========== Model initialization and training ==========
def init_weights(m):
    """Deterministic weight initialization"""
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        torch.manual_seed(SEED)
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

# Create model and apply deterministic initialization
model = MultiTask_CNN_GCN(input_dim=(21,21))
model.apply(init_weights)
model = model.to(device)

# Deterministic optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
mse_loss = nn.MSELoss()

# Early stopping and best model saving
best_state = None
best_val_loss = float('inf')
early_stop_counter = 0

# Record history
train_node_losses, val_node_losses = [], []
train_graph_losses, val_graph_losses = [], []
train_node_r2s, val_node_r2s = [], []
train_graph_r2s, val_graph_r2s = [], []

def calculate_metrics(y_true, y_pred):
    if y_true.size == 0 or y_pred.size == 0:
        return float('nan'), float('nan'), float('nan')
    y_true = y_true.reshape(-1)
    y_pred = y_pred.reshape(-1)
    mse = np.mean((y_true - y_pred)**2)
    rmse = np.sqrt(mse)
    mae = np.mean(np.abs(y_true - y_pred))
    try:
        r2 = r2_score(y_true, y_pred)
    except:
        r2 = float('nan')
    return rmse, mae, r2

print("Starting training...")

# Training loop
for epoch in range(1, EPOCHS + 1):
    model.train()
    train_node_loss_total = 0.0
    train_graph_loss_total = 0.0
    train_samples = 0

    train_node_preds_epoch = []
    train_node_trues_epoch = []
    train_graph_preds_epoch = []
    train_graph_trues_epoch = []

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        
        node_pred, graph_pred = model(data)

        y_node_true = data.y_node.view(-1, 1).to(device)
        y_graph_true = data.y_graph.view(-1, 1).to(device)

        node_loss = mse_loss(node_pred, y_node_true)
        graph_loss = mse_loss(graph_pred, y_graph_true)

        loss = NODE_GRAPH_LOSS_WEIGHT * node_loss + GRAPH_LOSS_WEIGHT * graph_loss
        loss.backward()
        
        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        train_node_loss_total += node_loss.item() * data.num_graphs
        train_graph_loss_total += graph_loss.item() * data.num_graphs
        train_samples += data.num_graphs

        train_node_preds_epoch.append(node_pred.detach().cpu().numpy())
        train_node_trues_epoch.append(y_node_true.detach().cpu().numpy())
        train_graph_preds_epoch.append(graph_pred.detach().cpu().numpy())
        train_graph_trues_epoch.append(y_graph_true.detach().cpu().numpy())

    train_node_loss_avg = train_node_loss_total / max(1, train_samples)
    train_graph_loss_avg = train_graph_loss_total / max(1, train_samples)

    # Validation phase
    model.eval()
    val_node_loss_total = 0.0
    val_graph_loss_total = 0.0
    val_samples = 0
    all_node_preds = []
    all_node_trues = []
    all_graph_preds = []
    all_graph_trues = []

    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            node_pred, graph_pred = model(data)
            y_node_true = data.y_node.view(-1, 1).to(device)
            y_graph_true = data.y_graph.view(-1, 1).to(device)

            node_loss = mse_loss(node_pred, y_node_true)
            graph_loss = mse_loss(graph_pred, y_graph_true)

            val_node_loss_total += node_loss.item() * data.num_graphs
            val_graph_loss_total += graph_loss.item() * data.num_graphs
            val_samples += data.num_graphs

            all_node_preds.append(node_pred.cpu().numpy())
            all_node_trues.append(y_node_true.cpu().numpy())
            all_graph_preds.append(graph_pred.cpu().numpy())
            all_graph_trues.append(y_graph_true.cpu().numpy())

    val_node_loss_avg = val_node_loss_total / max(1, val_samples)
    val_graph_loss_avg = val_graph_loss_total / max(1, val_samples)
    val_combined_loss = NODE_GRAPH_LOSS_WEIGHT * val_node_loss_avg + GRAPH_LOSS_WEIGHT * val_graph_loss_avg

    all_node_preds = np.concatenate(all_node_preds, axis=0) if len(all_node_preds) > 0 else np.array([])
    all_node_trues = np.concatenate(all_node_trues, axis=0) if len(all_node_trues) > 0 else np.array([])
    all_graph_preds = np.concatenate(all_graph_preds, axis=0) if len(all_graph_preds) > 0 else np.array([])
    all_graph_trues = np.concatenate(all_graph_trues, axis=0) if len(all_graph_trues) > 0 else np.array([])

    node_rmse, node_mae, node_r2 = calculate_metrics(all_node_trues, all_node_preds)
    graph_rmse, graph_mae, graph_r2 = calculate_metrics(all_graph_trues, all_graph_preds)

    # Record metrics
    train_node_losses.append(train_node_loss_avg)
    val_node_losses.append(val_node_loss_avg)
    train_graph_losses.append(train_graph_loss_avg)
    val_graph_losses.append(val_graph_loss_avg)

    try:
        train_node_preds_epoch = np.concatenate(train_node_preds_epoch, axis=0)
        train_node_trues_epoch = np.concatenate(train_node_trues_epoch, axis=0)
        tn_rmse, tn_mae, tn_r2 = calculate_metrics(train_node_trues_epoch, train_node_preds_epoch)
        train_node_r2s.append(tn_r2)
    except:
        train_node_r2s.append(float('nan'))
    val_node_r2s.append(node_r2)

    try:
        train_graph_preds_epoch = np.concatenate(train_graph_preds_epoch, axis=0)
        train_graph_trues_epoch = np.concatenate(train_graph_trues_epoch, axis=0)
        tg_rmse, tg_mae, tg_r2 = calculate_metrics(train_graph_trues_epoch, train_graph_preds_epoch)
        train_graph_r2s.append(tg_r2)
    except:
        train_graph_r2s.append(float('nan'))
    val_graph_r2s.append(graph_r2)

    # Print progress
    if epoch % 5 == 0 or epoch == 1:
        print(f"Epoch {epoch}/{EPOCHS} | "
              f"Train Node Loss: {train_node_loss_avg:.6f}, Train Graph Loss: {train_graph_loss_avg:.6f} | "
              f"Val Node Loss: {val_node_loss_avg:.6f}, Val Graph Loss: {val_graph_loss_avg:.6f} | "
              f"Val Combined Loss: {val_combined_loss:.6f}")
        print(f"    Val Node -> RMSE: {node_rmse:.6f}, MAE: {node_mae:.6f}, R2: {node_r2:.6f}")
        print(f"    Val Graph-> RMSE: {graph_rmse:.6f}, MAE: {graph_mae:.6f}, R2: {graph_r2:.6f}")

    # Early stopping
    if val_combined_loss < best_val_loss - 1e-8:
        best_val_loss = val_combined_loss
        best_state = copy.deepcopy(model.state_dict())
        early_stop_counter = 0
    else:
        early_stop_counter += 1

    if early_stop_counter >= PATIENCE:
        print(f"Early stopping at epoch {epoch}. Best combined val loss: {best_val_loss:.6f}")
        break

# Save best model
if best_state is not None:
    torch.save({
        'model_state_dict': best_state,
        'seed': SEED,
        'hyperparameters': {
            'batch_size': BATCH_SIZE,
            'learning_rate': LR,
            'gcn_hidden': GCN_HIDDEN,
            'dropout': DROPOUT
        }
    }, 'best_model.pth')
    print("Best model saved to best_model.pth")
else:
    print("No best state found, model not saved.")