In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import os
import cv2
from PIL import Image
from torchvision import transforms

# Simple UNet-like segmentation model for demonstration purposes
class SimpleUNet(nn.Module):
    def __init__(self):
        super(SimpleUNet, self).__init__()
        # Encoder layers
        self.enc_conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.enc_conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2)
        # Decoder layers
        self.dec_conv1 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.dec_conv2 = nn.Conv2d(64, 1, kernel_size=3, padding=1)
        
    def forward(self, x):
        # Encoding path
        x1 = F.relu(self.enc_conv1(x))  # Encoder feature map 1
        x2 = self.pool(x1)
        x3 = F.relu(self.enc_conv2(x2))  # Encoder feature map 2
        x4 = self.pool(x3)
        
        # Decoding path
        x5 = F.relu(self.dec_conv1(F.interpolate(x4, scale_factor=2, mode='nearest')))  # Decoder feature map 1
        x6 = F.relu(self.dec_conv2(F.interpolate(x5, scale_factor=2, mode='nearest')))  # Decoder feature map 2
        
        return x6

# Hook function to capture and save feature maps
def save_feature_maps(module, input, output, layer_name, folder='layer_outputs', cmap='viridis'):
    os.makedirs(folder, exist_ok=True)  # Create directory if it doesn't exist
    # Assuming the output is a batch of feature maps, we take the first one
    feature_map = output[0].detach().cpu().numpy()
    
    # Normalize each feature map for visibility (you can also use a different normalization)
    feature_map = (feature_map - feature_map.min()) / (feature_map.max() - feature_map.min())
    
    # If the output has multiple channels, save each as a separate image
    for i in range(feature_map.shape[0]):
        plt.imshow(feature_map[i], cmap=cmap)
        plt.axis('off')
        filename = os.path.join(folder, f'{layer_name}_feature_map_{i}.png')
        plt.savefig(filename)
        plt.close()

# Function to register hooks on encoder and decoder layers, and order the saving process
def register_hooks(model):
    hooks = []
    
    # Register hooks for encoder layers first, followed by decoder layers
    encoder_layers = ['enc_conv1', 'enc_conv2']
    decoder_layers = ['dec_conv1', 'dec_conv2']
    
    # Hook encoder layers first
    for name, layer in model.named_modules():
        if name in encoder_layers:
            hook = layer.register_forward_hook(
                lambda module, input, output, layer_name=name: save_feature_maps(module, input, output, layer_name)
            )
            hooks.append(hook)
    
    # Hook decoder layers next
    for name, layer in model.named_modules():
        if name in decoder_layers:
            hook = layer.register_forward_hook(
                lambda module, input, output, layer_name=name: save_feature_maps(module, input, output, layer_name)
            )
            hooks.append(hook)
    
    return hooks

# Function to convert images to a video using OpenCV
def create_video_from_images(image_folder, video_name='feature_maps_video.mp4', fps=2):
    images = [img for img in sorted(os.listdir(image_folder)) if img.endswith(".png")]
    frame = cv2.imread(os.path.join(image_folder, images[0]))
    height, width, layers = frame.shape

    video = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))

    for image in images:
        video.write(cv2.imread(os.path.join(image_folder, image)))

    video.release()

# Example usage
if __name__ == "__main__":
    # Initialize the model
    model = SimpleUNet()
    
    # Register hooks to save feature maps at each layer (first encoder, then decoder)
    hooks = register_hooks(model)
    
    # Load your own image
    transform = transforms.Compose([transforms.Grayscale(), transforms.Resize((256, 256)), transforms.ToTensor()])
    img = Image.open("test.jpg")  # Replace with your image
    input_image = transform(img).unsqueeze(0)  # Add batch dimension

    # Forward pass through the model
    with torch.no_grad():
        output = model(input_image)
    
    # After forward pass, the feature maps will be saved as images in the folder 'layer_outputs'
    
    # Convert the saved images to a video
    create_video_from_images('layer_outputs', video_name='feature_maps_video_1.mp4', fps=2)
    
    # Remove hooks after use
    for hook in hooks:
        hook.remove()
