In [2]:
import torch
import torch.nn as nn
from src.networks import MAGANet  # Replace with your actual model import
import matplotlib.pyplot as plt

In [3]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# Load the model
model = MAGANet(in_channels=1, latent_dim=10).to(device)  # Adjust parameters as per your model
model.load_state_dict(torch.load("../../outputs/magan_model.pth", map_location=device))
model.eval()

TypeError: MAGANet.__init__() got an unexpected keyword argument 'in_channels'

In [None]:
# Dictionary to store intermediate outputs
intermediate_outputs = {}

# Hook function to capture outputs
def get_activation(name):
    def hook(module, input, output):
        intermediate_outputs[name] = output
    return hook

# Attach hooks to the desired layers
# Affine layer
model.decoder.affine.register_forward_hook(get_activation('affine'))

# Flow modules and their subcomponents
for i, flow_module in enumerate(model.decoder.flow_modules):
    flow_module.register_forward_hook(get_activation(f'flow_module_{i}'))
    for j, flow_step in enumerate(flow_module.flow_steps):
        flow_step.register_forward_hook(get_activation(f'flow_module_{i}_step_{j}'))
        flow_step.act_norm.register_forward_hook(get_activation(f'flow_module_{i}_step_{j}_actnorm'))
        flow_step.inv_conv.register_forward_hook(get_activation(f'flow_module_{i}_step_{j}_invconv'))
        flow_step.coupling.register_forward_hook(get_activation(f'flow_module_{i}_step_{j}_coupling'))

# Encoder convolutional layers
for i, layer in enumerate(model.encoder.conv):
    layer.register_forward_hook(get_activation(f'encoder_conv_{i}'))

In [None]:
def capture_latents(model, x1, x2):
    with torch.no_grad():
        z, mu1, logvar1, mu2, logvar2 = model.encoder(x1, x2)
        z1 = model.encoder.sample_z(mu1, logvar1)  # Sample z1
        z2 = model.encoder.sample_z(mu2, logvar2)  # Sample z2
        return z, z1, z2

In [None]:
def visualize_image_output(output, name):
    if output.dim() == 4:  # [B, C, H, W]
        plt.figure(figsize=(15, 3))
        num_channels = min(5, output.size(1))  # Show up to 5 channels
        for i in range(num_channels):
            plt.subplot(1, num_channels, i + 1)
            plt.imshow(output[0, i].cpu().detach().numpy(), cmap='gray')
            plt.title(f'{name} Ch{i}')
            plt.axis('off')
        plt.tight_layout()
        plt.show()
    else:
        print(f"Cannot visualize {name}: unexpected shape {output.shape}")

In [None]:
def visualize_latent(z, name):
    plt.figure(figsize=(8, 4))
    z = z.cpu().detach().numpy()
    if z.ndim == 1:  # Single vector
        plt.bar(range(len(z)), z)
    elif z.ndim == 2:  # Batch of vectors
        plt.hist(z.flatten(), bins=30)
    plt.title(name)
    plt.xlabel('Dimension' if z.ndim == 1 else 'Value')
    plt.ylabel('Value' if z.ndim == 1 else 'Frequency')
    plt.show()

In [None]:
def visualize_model_outputs(model, x1, x2):
    # Clear previous outputs
    intermediate_outputs.clear()

    # Run the forward pass
    with torch.no_grad():
        generated_x2 = model(x1, x2)  # Standard forward pass
        z, z1, z2 = capture_latents(model, x1, x2)  # Capture latent vectors

    # Visualize encoder convolutional outputs
    print("Encoder Convolutional Outputs:")
    for key in sorted([k for k in intermediate_outputs.keys() if k.startswith('encoder_conv_')]):
        visualize_image_output(intermediate_outputs[key], key)

    # Visualize latent vectors
    print("Latent Vectors:")
    visualize_latent(z1, 'z1')
    visualize_latent(z2, 'z2')
    visualize_latent(z, 'z (z2 - z1)')

    # Visualize decoder outputs
    print("Decoder Outputs:")
    # Affine layer
    visualize_image_output(intermediate_outputs.get('affine'), 'affine')

    # Flow modules and their subcomponents
    for key in sorted(intermediate_outputs.keys()):
        if key.startswith('flow_module_'):
            if 'step' not in key and 'actnorm' not in key and 'invconv' not in key and 'coupling' not in key:
                visualize_image_output(intermediate_outputs[key], key)  # Flow module output
            elif 'step' in key and 'actnorm' not in key and 'invconv' not in key and 'coupling' not in key:
                visualize_image_output(intermediate_outputs[key], key)  # Flow step output
            elif 'actnorm' in key:
                visualize_image_output(intermediate_outputs[key], key)  # ActNorm output
            elif 'invconv' in key:
                visualize_image_output(intermediate_outputs[key], key)  # Invertible1x1Conv output
            elif 'coupling' in key:
                visualize_image_output(intermediate_outputs[key], key)  # AdditiveCoupling output

    # Optionally visualize the final generated image
    print("Generated Output:")
    visualize_image_output(generated_x2, 'generated_x2')

In [None]:
# Example with dummy data
x1 = torch.randn(1, 1, 64, 64).to(device)  # [B, C, H, W]
x2 = torch.randn(1, 1, 64, 64).to(device)
visualize_model_outputs(model, x1, x2)