In [None]:
# 2. Model Definition
class GatedFusion(nn.Module):
    def __init__(self, cnn_dim):
        super().__init__()
        self.gate = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 800), 
#             nn.Sigmoid()
            nn.Tanh()
        )
        self.mfrac_boost = nn.Parameter(torch.tensor(3.0))
        self.spatial_groups = 25
        self.channels_per_group = cnn_dim // self.spatial_groups

    def forward(self, cnn_feat, mfrac):
        # Generate gate weights from mfrac
        gate_weights = self.gate(mfrac)
        # Group weights for spatial attention
        grouped_weights = gate_weights.view(-1, self.channels_per_group, self.spatial_groups)
        spatial_weights = grouped_weights.mean(dim=1)
        expanded_weights = spatial_weights.unsqueeze(1).repeat(1, self.channels_per_group, 1)
        gate_weights = expanded_weights.reshape_as(gate_weights)
        # Amplify gate weights
        gate_weights = gate_weights * 2.0
        # Boost mfrac value
        boosted_mfrac = mfrac * self.mfrac_boost
        # Concatenate gated features with boosted mfrac
        return torch.cat([cnn_feat * gate_weights, boosted_mfrac], dim=1), gate_weights

# Model instantiation example
# model = Optimized_CNN_GCN(
#     filters=32,
#     kernel_size=5,
#     dense_units=128,
#     dropout_rate=0.2,
#     gcn_hidden_dim=32,
#     learning_rate=0.0015
# ).to(device)
    
    
class Optimized_CNN_GCN(nn.Module):
    def __init__(self, input_dim=(21, 21), filters=32, kernel_size=5, 
                 dense_units=128, dropout_rate=0.2, gcn_hidden_dim=32,  # Adjusted gcn_hidden_dim from 64 to 32
                 learning_rate=0.0015):
        super(Optimized_CNN_GCN, self).__init__()
        # CNN layers for feature extraction
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        # Dynamically calculate CNN output dimension
        with torch.no_grad():
            dummy_input = torch.randn(1, 1, *input_dim)
            dummy_output = self.cnn(dummy_input)
            self.cnn_output_dim = dummy_output.view(-1).shape[0]
        
        # Gated fusion module
        self.fusion = GatedFusion(self.cnn_output_dim)
        # GCN layers with adjusted hidden dimension
        self.conv1 = GCNConv(self.cnn_output_dim + 1, gcn_hidden_dim)  # Adjusted gcn_hidden_dim from 64 to 32
        self.conv2 = GCNConv(gcn_hidden_dim, gcn_hidden_dim)  # Adjusted gcn_hidden_dim from 64 to 32
        self.dropout = nn.Dropout(dropout_rate)
        # Output layer
        self.linear = nn.Linear(gcn_hidden_dim, 1)  # Adjusted gcn_hidden_dim from 64 to 32
        self.learning_rate = learning_rate
        
        # Add attributes for visualization
        self.gate_weights = None
        self.attn_weights = None
        
#     def forward(self, data):
#         x_all = data.x
#         num_nodes = data.num_nodes
#         x_cnn = x_all[:, :-1].view(num_nodes, 1, 21, 21)
#         x_mfrac = x_all[:, -1].unsqueeze(1)
#         x_cnn = self.cnn(x_cnn)
#         x_cnn = x_cnn.view(num_nodes, -1)
        
#         x, gate_weights = self.fusion(x_cnn, x_mfrac)
#         self.gate_weights = gate_weights
        
#         edge_index = data.edge_index
#         x = F.relu(self.conv1(x, edge_index))
#         x = self.dropout(x)
#         x = F.relu(self.conv2(x, edge_index))
        
#         # Modified for graph-level prediction - average all node features before prediction
#         x = x.mean(dim=0, keepdim=True)  # Average all nodes
#         x = self.linear(x)
#         return x.squeeze()  # Add squeeze() to ensure output is scalar
    
    def forward(self, data):
        x_all = data.x
        num_nodes = data.num_nodes
        
        # Split input into CNN features and mfrac
        x_cnn = x_all[:, :-1].view(num_nodes, 1, 21, 21)
        x_mfrac = x_all[:, -1].unsqueeze(1)

        # CNN processing
        x_cnn = self.cnn(x_cnn)
        x_cnn = x_cnn.view(num_nodes, -1)

        # Fusion processing
        x, gate_weights = self.fusion(x_cnn, x_mfrac)
        self.gate_weights = gate_weights

        edge_index = data.edge_index

        # Graph convolutional layers
        x = F.relu(self.conv1(x, edge_index))
        x = self.dropout(x)
        x = F.relu(self.conv2(x, edge_index))

        # Average all node features for graph-level prediction
        x = x.mean(dim=0, keepdim=False)  # Average across the entire graph, not within batch

        # Regression prediction through linear layer
        x = self.linear(x)  # Should output (batch_size, 1)

        return x.squeeze()  # If batch_size is 1, squeeze removes dimension, ensuring output shape is (batch_size,)
    
# Instant model test
sample_data = graphs[0]  # Test with first sample
model = Optimized_CNN_GCN()
output = model(sample_data)
print("\nModel test output:")
print(output.shape)
print(output)

# Check model structure
print("\nModel structure:")
print(model)