In [1]:
import json
import torch
import matplotlib.pyplot as plt
import pandas as pd
from torch.nn import CrossEntropyLoss
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models import GAT
import os
from torch_geometric.data import Data
from torch_geometric.data.data import DataEdgeAttr, DataTensorAttr
from torch_geometric.data.storage import GlobalStorage
import torch.serialization
import optuna
from sklearn.model_selection import KFold
import numpy as np

import mlflow
import mlflow.pytorch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


#### model architecture

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool
from torch_geometric.utils import to_dense_adj, dense_to_sparse
from torch_scatter import scatter_add, scatter_mean
import math

class GlobalAttention(nn.Module):
    """Global Attention mechanism for DIFFormer"""
    
    def __init__(self, hidden_dim, kernel='simple', dropout=0.1):
        super(GlobalAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.kernel = kernel
        self.dropout = nn.Dropout(dropout)
        
        # Projection layers for queries, keys, and values
        self.W_q = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.W_k = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.W_v = nn.Linear(hidden_dim, hidden_dim, bias=False)
        
        # Output projection
        self.W_o = nn.Linear(hidden_dim, hidden_dim)
        
        # Kernel-specific parameters
        if kernel == 'sigmoid':
            self.sigma = nn.Parameter(torch.tensor(1.0))
        
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.W_q.weight)
        nn.init.xavier_uniform_(self.W_k.weight)
        nn.init.xavier_uniform_(self.W_v.weight)
        nn.init.xavier_uniform_(self.W_o.weight)
        if hasattr(self, 'sigma'):
            nn.init.constant_(self.sigma, 1.0)
    
    def forward(self, x, batch=None):
        """
        Args:
            x: Node features [N, hidden_dim]
            batch: Batch vector [N] indicating which graph each node belongs to
        """
        N, D = x.size()
        
        # Compute queries, keys, values
        Q = self.W_q(x)  # [N, hidden_dim]
        K = self.W_k(x)  # [N, hidden_dim]
        V = self.W_v(x)  # [N, hidden_dim]
        
        if batch is not None:
            # Handle batch of graphs
            return self._batched_attention(Q, K, V, batch)
        else:
            # Single graph or set of instances
            return self._single_attention(Q, K, V)
    
    def _single_attention(self, Q, K, V):
        """Compute attention for single graph or set of instances"""
        # Compute attention scores
        if self.kernel == 'simple':
            # Simple kernel: Q * K^T
            attention_scores = torch.mm(Q, K.transpose(0, 1))
        elif self.kernel == 'sigmoid':
            # Sigmoid kernel: more sophisticated attention
            attention_scores = torch.mm(Q, K.transpose(0, 1)) * self.sigma
            attention_scores = torch.sigmoid(attention_scores)
        
        # Apply softmax to get attention weights
        if self.kernel == 'simple':
            attention_weights = F.softmax(attention_scores / math.sqrt(self.hidden_dim), dim=-1)
        else:
            attention_weights = F.softmax(attention_scores, dim=-1)
        
        # Apply dropout
        attention_weights = self.dropout(attention_weights)
        
        # Compute output
        out = torch.mm(attention_weights, V)
        return self.W_o(out)
    
    def _batched_attention(self, Q, K, V, batch):
        """Compute attention for batch of graphs"""
        batch_size = batch.max().item() + 1
        outputs = []
        
        for i in range(batch_size):
            mask = batch == i
            q_i = Q[mask]
            k_i = K[mask]
            v_i = V[mask]
            
            if q_i.size(0) > 0:
                out_i = self._single_attention(q_i, k_i, v_i)
                outputs.append(out_i)
        
        if outputs:
            return torch.cat(outputs, dim=0)
        else:
            return torch.zeros_like(Q)

class DIFFormerLayer(nn.Module):
    """Single DIFFormer layer with global attention + GCN + residual"""
    
    def __init__(self, hidden_dim, kernel='simple', dropout=0.1, use_graph=True):
        super(DIFFormerLayer, self).__init__()
        self.hidden_dim = hidden_dim
        self.use_graph = use_graph
        
        # Global attention
        self.global_attention = GlobalAttention(hidden_dim, kernel, dropout)
        
        # GCN convolution (if using graph structure)
        if use_graph:
            self.gcn = GCNConv(hidden_dim, hidden_dim)
        
        # Layer normalization
        self.ln1 = nn.LayerNorm(hidden_dim)
        self.ln2 = nn.LayerNorm(hidden_dim)
        
        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 4, hidden_dim),
            nn.Dropout(dropout)
        )
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, edge_index=None, batch=None):
        """
        Args:
            x: Node features [N, hidden_dim]
            edge_index: Edge connectivity [2, E] (optional)
            batch: Batch vector [N] (optional)
        """
        # Global attention with residual connection
        attn_out = self.global_attention(x, batch)
        x = self.ln1(x + self.dropout(attn_out))
        
        # GCN convolution (if using graph structure)
        if self.use_graph and edge_index is not None:
            gcn_out = self.gcn(x, edge_index)
            x = self.ln2(x + self.dropout(gcn_out))
        
        # Feed-forward network with residual connection
        ffn_out = self.ffn(x)
        x = x + self.dropout(ffn_out)
        
        return x

class DIFFormer_v2(nn.Module):
    """
    DIFFormer v2 implementation supporting batch of graphs
    
    Args:
        in_channels: Input feature dimension
        hidden_channels: Hidden feature dimension
        out_channels: Output feature dimension
        num_layers: Number of DIFFormer layers
        kernel: Attention kernel type ('simple' or 'sigmoid')
        dropout: Dropout rate
        use_graph: Whether to use graph structure
        pooling: Graph-level pooling method ('mean', 'max', 'sum')
    """
    
    def __init__(self, in_channels, hidden_channels, out_channels, 
                 num_layers=4, kernel='simple', dropout=0.1, 
                 use_graph=True, pooling='mean'):
        super(DIFFormer_v2, self).__init__()
        
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.num_layers = num_layers
        self.use_graph = use_graph
        self.pooling = pooling
        
        # Input projection
        self.input_proj = nn.Linear(in_channels, hidden_channels)
        
        # DIFFormer layers
        self.layers = nn.ModuleList([
            DIFFormerLayer(hidden_channels, kernel, dropout, use_graph)
            for _ in range(num_layers)
        ])
        
        # Output projection
        self.output_proj = nn.Linear(hidden_channels, out_channels)
        
        # Layer normalization
        self.ln_final = nn.LayerNorm(hidden_channels)
        
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.input_proj.weight)
        nn.init.xavier_uniform_(self.output_proj.weight)
        nn.init.zeros_(self.input_proj.bias)
        nn.init.zeros_(self.output_proj.bias)
    
    def forward(self, x, edge_index=None, batch=None, n_nodes=None):
        """
        Args:
            x: Node features [N, in_channels]
            edge_index: Edge connectivity [2, E] (optional)
            batch: Batch vector [N] indicating graph membership (optional)
            n_nodes: Number of nodes per graph [batch_size] (optional)
        
        Returns:
            out: Node embeddings [N, out_channels] or graph embeddings [batch_size, out_channels]
        """
        # Input projection
        x = self.input_proj(x)
        
        # Apply DIFFormer layers
        for layer in self.layers:
            x = layer(x, edge_index, batch)
        
        # Final layer normalization
        x = self.ln_final(x)
        
        # Output projection
        out = self.output_proj(x)
        
        # Graph-level pooling if batch is provided
        if batch is not None and self.pooling is not None:
            if self.pooling == 'mean':
                out = global_mean_pool(out, batch)
            elif self.pooling == 'max':
                out = global_max_pool(out, batch)
            elif self.pooling == 'sum':
                out = scatter_add(out, batch, dim=0)
        
        return out
    
    def get_attention_weights(self, x, edge_index=None, batch=None, layer_idx=0):
        """Get attention weights from a specific layer"""
        # Forward pass up to the specified layer
        x = self.input_proj(x)
        
        for i, layer in enumerate(self.layers):
            if i == layer_idx:
                # Get attention weights from this layer
                Q = layer.global_attention.W_q(x)
                K = layer.global_attention.W_k(x)
                
                if layer.global_attention.kernel == 'simple':
                    attention_scores = torch.mm(Q, K.transpose(0, 1))
                    attention_weights = F.softmax(attention_scores / math.sqrt(self.hidden_channels), dim=-1)
                else:
                    attention_scores = torch.mm(Q, K.transpose(0, 1)) * layer.global_attention.sigma
                    attention_scores = torch.sigmoid(attention_scores)
                    attention_weights = F.softmax(attention_scores, dim=-1)
                
                return attention_weights
            else:
                x = layer(x, edge_index, batch)
        
        return None

# Example usage and utility functions
def create_difformer_model(task_type='node_classification', **kwargs):
    """
    Factory function to create DIFFormer models for different tasks
    
    Args:
        task_type: 'node_classification', 'graph_classification', 'regression'
        **kwargs: Additional arguments for DIFFormer_v2
    """
    if task_type == 'node_classification':
        return DIFFormer_v2(pooling=None, **kwargs)
    elif task_type == 'graph_classification':
        return DIFFormer_v2(pooling='mean', **kwargs)
    elif task_type == 'regression':
        return DIFFormer_v2(pooling='mean', **kwargs)
    else:
        raise ValueError(f"Unknown task type: {task_type}")

# # Example instantiation
# if __name__ == "__main__":
#     # Example: Node classification
#     model = DIFFormer_v2(
#         in_channels=128,
#         hidden_channels=256,
    #     out_channels=64,
    #     num_layers=4,
    #     kernel='simple',
    #     dropout=0.1,
    #     use_graph=True,
    #     pooling=None  # None for node-level tasks
    # )
    
    # print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")
    
    # # Example input
    # x = torch.randn(100, 128)  # 100 nodes, 128 features
    # edge_index = torch.randint(0, 100, (2, 200))  # 200 edges
    # batch = torch.zeros(100, dtype=torch.long)  # Single graph
    
    # # Forward pass
    # output = model(x, edge_index, batch)
    # print(f"Output shape: {output.shape}")
    
    # # Get attention weights
    # attention_weights = model.get_attention_weights(x, edge_index, batch, layer_idx=0)
    # print(f"Attention weights shape: {attention_weights.shape}")

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [5]:

# Load data 
with torch.serialization.safe_globals([Data]):
    data_list = torch.load("training_data\Datacheckpoint_latest_22", map_location='cuda', weights_only=False)

labels = json.load(open("label_encoding.json"))
batch_size = 1

train_split = int(len(data_list) * 0.8)
train_data = data_list[:train_split]
val_data = data_list[train_split:]

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size)

in_channels = data_list[0].x.size(1)
# in_channels =18
num_classes = len(labels)

model_dir = "C:\\Users\\User\\OneDrive\\Desktop\\GAT-model testing\\GAT-test\\models\\model_diffusion"
results_dir = "C:\\Users\\User\\OneDrive\\Desktop\\GAT-model testing\\GAT-test\\results"
plots_dir = "C:\\Users\\User\\OneDrive\\Desktop\\GAT-model testing\\GAT-test\\plots"
os.makedirs(model_dir, exist_ok=True)
os.makedirs(results_dir, exist_ok=True)
os.makedirs(plots_dir, exist_ok=True)

In [11]:
from sklearn.model_selection import KFold
import mlflow
import mlflow.pytorch




def objective(trial):
    config = {
         # ----- Suggest Hyperparameters -----
    'hidden_dim' : trial.suggest_categorical("hidden_channels", [128, 256, 512]),
    'num_layers' : trial.suggest_int("num_layers", 2, 6),
    'dropout' : trial.suggest_float("dropout", 0.1, 0.5),
    'lr' : trial.suggest_float("lr", 1e-4, 1e-2, log=True),
    'use_graph' : trial.suggest_categorical("use_graph", [True, False])
    }

    with mlflow.start_run(run_name=f"trial_{trial.number}", nested=True):
        mlflow.log_params(config)
        mlflow.set_tag("cv_strategy", "KFold")
        mlflow.set_tag("model", "GATv2")

        kf = KFold(n_splits=5, shuffle=True, random_state=42)
        data_indices = list(range(len(data_list)))
        fold_val_acc = []

        for fold, (train_idx, val_idx) in enumerate(kf.split(data_indices)):
            train_data = [data_list[i] for i in train_idx]
            val_data = [data_list[i] for i in val_idx]

            train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
            val_loader = DataLoader(val_data, batch_size=batch_size)

            model = DIFFormer_v2(
            in_channels=in_channels,
            hidden_channels=256,
            out_channels=num_classes,
            num_layers=4,
            kernel='simple',
            dropout=0.1,
            use_graph=True,
            pooling=None  # None for node-level tasks
            ).to(device)

            all_train_labels = torch.cat([data.y for data in train_loader.dataset])
            class_weights = 1.0 / (torch.bincount(all_train_labels, minlength=num_classes).float() + 1e-6)
            class_weights = class_weights / class_weights.sum()
            class_weights = class_weights.to(device)

            criterion = CrossEntropyLoss(weight=class_weights)
            optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], weight_decay=5e-4)
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)

            best_val_acc = 0
            best_val_loss = float('inf')
            wait = 0
            patience = 30
            min_delta = 1e-4

            training_loss, validation_loss, validation_acc = [], [], []

            for epoch in range(501):
                model.train()
                total_loss = 0
                correct_train = 0
                total_train = 0

                for data in train_loader:
                    data = data.to(device)
                    optimizer.zero_grad()
                    out = model(data.x, data.edge_index)
                    loss = criterion(out, data.y)
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()
                    correct_train += (out.argmax(dim=1) == data.y).sum().item()
                    total_train += data.y.size(0)

                avg_train_loss = total_loss / len(train_loader)
                train_acc = correct_train / total_train
                training_loss.append(avg_train_loss)

                model.eval()
                val_loss = 0
                correct, total = 0, 0
                with torch.no_grad():
                    for data in val_loader:
                        data = data.to(device)
                        out = model(data.x, data.edge_index)
                        loss = criterion(out, data.y)
                        val_loss += loss.item()
                        correct += (out.argmax(dim=1) == data.y).sum().item()
                        total += data.y.size(0)

                avg_val_loss = val_loss / len(val_loader)
                val_acc = correct / total
                validation_loss.append(avg_val_loss)
                validation_acc.append(val_acc)

                # Log metrics per epoch
                mlflow.log_metric(f"fold{fold+1}_train_loss", avg_train_loss, step=epoch)
                mlflow.log_metric(f"fold{fold+1}_val_loss", avg_val_loss, step=epoch)
                mlflow.log_metric(f"fold{fold+1}_val_acc", val_acc, step=epoch)
                mlflow.log_metric(f"fold{fold+1}_train_acc", train_acc, step=epoch)

                scheduler.step(avg_val_loss)
                trial.report(val_acc, epoch)

                print(f"Epoch {epoch+1:03d} | Fold {fold+1} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | Train loss: {avg_train_loss:.4f} Val Loss: {avg_val_loss:.4f}")

                if avg_val_loss < best_val_loss - min_delta:
                    best_val_loss = avg_val_loss
                    best_val_acc = val_acc
                    wait = 0
                else:
                    wait += 1
                    if wait >= patience:
                        print(f" Early stopping at epoch {epoch+1} for fold {fold+1}")
                        break

                if trial.should_prune():
                    raise optuna.exceptions.TrialPruned()

            print(f" Fold {fold+1}: Best Val Acc = {best_val_acc:.4f}")
            fold_val_acc.append(float(best_val_acc))

        mean_val_acc = float(np.mean(fold_val_acc))
        mlflow.log_metric("mean_val_acc", mean_val_acc)

        return mean_val_acc


In [10]:
import mlflow

# Set MLflow experiment name — all trials/logs will be grouped under this
mlflow.set_experiment("DIFFormer_v2_Optuna_Experiment_1")
mlflow.set_tracking_uri("http://127.0.0.1:5000") 

# Create and run the Optuna study
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=70)

print("Best trial:")
print("Accuracy:", study.best_trial.value)
print("Params:", study.best_trial.params)


[I 2025-07-16 12:02:16,641] A new study created in memory with name: no-name-7bfb48dd-ac15-4fc4-9e5a-ba730521da7a
[W 2025-07-16 12:02:21,359] Trial 0 failed with parameters: {'hidden_channels': 512, 'num_layers': 3, 'dropout': 0.3488230276100335, 'lr': 0.0001518084130835547, 'use_graph': False} because of the following error: TypeError("DIFFormer_v2.forward() got an unexpected keyword argument 'edge_weight'").
Traceback (most recent call last):
  File "c:\Users\User\anaconda3\envs\ENVGAT\lib\site-packages\optuna\study\_optimize.py", line 201, in _run_trial
    value_or_values = func(trial)
  File "C:\Users\User\AppData\Local\Temp\ipykernel_35248\452603071.py", line 89, in objective
    out = model(data.x, data.edge_index, edge_weight=data.edge_attr)
  File "c:\Users\User\anaconda3\envs\ENVGAT\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "c:\Users\User\anaconda3\envs\ENVGAT\lib\site-packages\torch\nn\m

🏃 View run trial_0 at: http://127.0.0.1:5000/#/experiments/128168505728161317/runs/1b85a257d37146f39c819d0cbda5a0f9
🧪 View experiment at: http://127.0.0.1:5000/#/experiments/128168505728161317


TypeError: DIFFormer_v2.forward() got an unexpected keyword argument 'edge_weight'

### final_model 

In [None]:
import os
import json
import torch
import numpy as np
import matplotlib.pyplot as plt
import mlflow
import mlflow.pytorch
from torch.nn import CrossEntropyLoss
from torch_geometric.loader import DataLoader

# ----------------- SET MLFLOW CONFIG ----------------------
mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.set_experiment("GATv2_Final_Training")

# ----------------- LOAD BEST HYPERPARAMETERS ----------------------
best_params_path = os.path.join(model_dir, "best_params.json")
best_params = study.best_trial.params
with open(best_params_path, "w") as f:
    json.dump(best_params, f, indent=4)

# ----------------- SETUP FINAL TRAINING ----------------------
final_model = GAT(
    in_channels=in_channels,
    hidden_channels=best_params['hidden_channels'],
    num_layers=best_params['num_layers'],
    out_channels=num_classes,
    dropout=best_params['dropout'],
    heads=best_params['heads'],
    v2=True,
    edge_dim=1,
    jk='lstm'
).to(device)

final_model = DIFFormer_v2(
            in_channels=in_channels,
            hidden_channels=best_params['hidden_channels'],
            out_channels=num_classes,
            num_layers=best_params['num_layers'],
            kernel='simple',
            dropout=0.1,
            use_graph=True,
            pooling=None  # None for node-level tasks
            ).to(device)

# Class-weighted loss
all_labels = torch.cat([data.y for data in train_loader.dataset])
class_counts = torch.bincount(all_labels, minlength=num_classes)
class_weights = 1.0 / (class_counts.float() + 1e-6)
class_weights = class_weights / class_weights.sum()
class_weights = class_weights.to(device)

criterion = CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.Adam(final_model.parameters(), lr=0.0005, weight_decay=5e-4)

train_losses = []
val_losses = []
best_val_loss = float('inf')
best_model_path = os.path.join(model_dir, "GAT_full_model_best_001.pt")

# ----------------- TRAINING LOOP ----------------------
for epoch in range(500):
    final_model.train()
    train_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = final_model(data.x, data.edge_index, edge_weight=data.edge_attr)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    avg_train_loss = train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Validation
    final_model.eval()
    val_loss = 0
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            out = final_model(data.x, data.edge_index, edge_weight=data.edge_attr)
            loss = criterion(out, data.y)
            val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)
    val_losses.append(avg_val_loss)

    # Save best weights
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(final_model.state_dict(), best_model_path)

    print(f"[FINAL TRAIN] Epoch {epoch+1:03d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

# Save full model (.pt)
full_model_path = os.path.join(model_dir, "GAT_final_model.pt")
torch.save(final_model, full_model_path)
print(f"Full model saved to {full_model_path}")
print(f" Best model (lowest val loss) saved to {best_model_path}")

# ----------------- PLOT LOSS CURVE ----------------------
plt.figure(figsize=(8,5))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Final Model Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plot_path = os.path.join(plots_dir, "final_train_val_loss.png")
plt.savefig(plot_path)
plt.show()

# ----------------- MLFLOW LOGGING ----------------------
# Load best weights before logging
final_model.load_state_dict(torch.load(best_model_path))

with mlflow.start_run(run_name="final_retrain"):
    mlflow.log_params(best_params)
    mlflow.log_artifact(best_params_path)
    mlflow.log_artifact(plot_path)

    for epoch, (train_l, val_l) in enumerate(zip(train_losses, val_losses)):
        mlflow.log_metric("train_loss", train_l, step=epoch)
        mlflow.log_metric("val_loss", val_l, step=epoch)

    # ✅ Log model
    mlflow.pytorch.log_model(final_model, artifact_path="final_model")
    print(" Final model logged to MLflow.")
