In [1]:
# cnn_gnn_graph_to_image_lpips.ipynb

import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import lpips  # Make sure to install lpips via: pip install lpips

# Enable anomaly detection for better debugging (this may slow down training)
torch.autograd.set_detect_anomaly(True)

# Import our unified CNN-GNN model and necessary utilities
from models.cnn_gnn_model import CNN_GNN_Model
from core.graph import Graph, GraphBatch
from core.dataloader import GraphDataset, GraphDataLoader
from core.utils import get_device

#############################################
#   LPIPS Loss Setup
#############################################
print("Setting up [LPIPS] perceptual loss...")
lpips_criterion = lpips.LPIPS(net='alex')
device = get_device()  # obtain device early
lpips_criterion = lpips_criterion.to(device)  # move LPIPS model to device

def lpips_loss_fn(pred, target):
    # Normalize images to [-1, 1] using tanh (assuming inputs are roughly centered)
    pred_norm = torch.tanh(pred)
    target_norm = torch.tanh(target)
    return lpips_criterion(pred_norm, target_norm).mean()

#############################################
#   Custom Collate Function for (Graph, label) pairs
#############################################
def graph_label_collate_fn(batch):
    """Collate a list of (Graph, label) tuples into a batched graph and label tensor."""
    graphs, labels = zip(*batch)
    return GraphBatch(graphs), torch.stack(labels)

#############################################
#   Synthetic Dataset Generation
#############################################
# Parameters
num_graphs = 50                # total number of graphs
num_nodes = 16                 # each graph has 16 nodes (patches)
num_edges = 40                 # arbitrary number of edges per graph
raw_node_shape = (3, 32, 32)     # each node is a 32x32 RGB patch (raw image)
cnn_out_dim = 64              # embedding dimension output by the CNN encoder
output_image_shape = (3, 128, 128)  # desired output image: 128x128 RGB
num_classes = output_image_shape[0] * output_image_shape[1] * output_image_shape[2]
num_epochs = 1000
lr = 0.0001  # Lower learning rate to help stabilize training

# Additional parameters for dropout, batch norm, and residual connections:
cnn_dropout = 0.3    # Dropout probability for CNN layers
gnn_dropout = 0.2    # Dropout probability for GNN layers
use_batchnorm = True
use_residual = True

graphs = []
labels_list = []
for i in range(num_graphs):
    node_features = torch.randn(num_nodes, *raw_node_shape)
    edge_index = torch.randint(0, num_nodes, (2, num_edges))
    graph = Graph(node_features, edge_index)
    graphs.append(graph)
    target_image = torch.randn(*output_image_shape)
    labels_list.append(target_image)

# Move graphs and labels to device
for graph in graphs:
    graph.to(device)
labels_list = [lbl.to(device) for lbl in labels_list]

# Create dataset and dataloader using our custom collate function
dataset = GraphDataset(list(zip(graphs, labels_list)))
dataloader = GraphDataLoader(dataset, batch_size=5, shuffle=True, collate_fn=graph_label_collate_fn)

#############################################
#   Model, Loss, and Optimizer Setup
#############################################
# Define CNN encoder parameters (with dropout and batch norm)
cnn_params = {
    'in_channels': raw_node_shape[0],
    'out_features': cnn_out_dim,
    'num_layers': 4,
    'hidden_channels': 32,
    'input_spatial_size': raw_node_shape[1],
    'dropout_prob': cnn_dropout,      # e.g., 0.3
    'use_batchnorm': use_batchnorm
}

# Instantiate the unified CNN-GNN model with dropout & residuals for the GNN part.
# (Assuming your CNN_GNN_Model is updated to accept 'gnn_dropout' and 'residual' parameters)
model = CNN_GNN_Model(
    cnn_params=cnn_params,
    gnn_in_dim=cnn_out_dim,
    gnn_hidden_dim=(64,),
    num_classes=num_classes,
    num_gnn_layers=3,
    gnn_dropout=gnn_dropout,  # e.g., 0.2
    residual=use_residual     # e.g., True
)

model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

#############################################
#   Training Loop with Enhanced Debugging
#############################################
loss_values = []

print("\nTraining CNN-GNN model (4-depth CNN encoder with dropout, 3 GNN layers with dropout/residual) using LPIPS loss...")
for epoch in range(1, num_epochs + 1):
    model.train()
    epoch_loss = 0.0
    for batch, batch_labels in dataloader:
        optimizer.zero_grad()
        logits = model(batch.node_features, batch.edge_index, batch.edge_features, batch.batch)

        # Debug: check model output statistics before reshaping
        with torch.no_grad():
            out_mean = logits.mean().item()
            out_std = logits.std().item()
            out_min = logits.min().item()
            out_max = logits.max().item()
            # Uncomment below to print these stats for every batch
            # print(f"Pre-reshape: mean={out_mean:.4f}, std={out_std:.4f}, min={out_min:.4f}, max={out_max:.4f}")

        logits = logits.view(-1, *output_image_shape)

        # Debug: check image logits stats
        with torch.no_grad():
            img_mean = logits.mean().item()
            img_std = logits.std().item()
            img_min = logits.min().item()
            img_max = logits.max().item()
            # Uncomment below to print these stats for every batch
            # print(f"Post-reshape: mean={img_mean:.4f}, std={img_std:.4f}, min={img_min:.4f}, max={img_max:.4f}")

        loss = lpips_loss_fn(logits, batch_labels)

        if torch.isnan(loss):
            print("Detected NaN loss!")
            print(f"Logits stats: mean={img_mean:.4f}, std={img_std:.4f}, min={img_min}, max={img_max}")
            break

        loss.backward()
        # Apply gradient clipping to avoid exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.9)
        optimizer.step()
        epoch_loss += loss.item() * batch_labels.size(0)
    epoch_loss /= num_graphs
    loss_values.append(epoch_loss)
    if epoch % 50 == 0:
        print(f"Epoch {epoch}: Loss = {epoch_loss:.4f}")
    if math.isnan(epoch_loss):
        print("NaN detected in epoch loss. Stopping training for debugging.")
        break

#############################################
#   Plotting Learning Curve (Loss)
#############################################
plt.figure(figsize=(8, 6))
plt.plot(range(1, len(loss_values) + 1), loss_values, linewidth=2)
plt.title("LPIPS Loss over Epochs")
plt.xlabel("Epoch")
plt.ylabel("LPIPS Loss")
plt.grid(True)
plt.show()

print("Training complete!")


Setting up [LPIPS] perceptual loss...
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: C:\Users\arash\anaconda3\envs\tgraphx\lib\site-packages\lpips\weights\v0.1\alex.pth


TypeError: __init__() got an unexpected keyword argument 'input_spatial_size'