<a href="https://colab.research.google.com/github/ShaliniAnandaPhD/PIXEL-PIONEERS-TUTORIALS/blob/main/Vanishing_Gradients_Check.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

"""
Gradient Monitoring Tool for GANs

This code provides a PyTorch implementation of a gradient monitoring tool for Generative Adversarial Networks (GANs).
The tool monitors the gradients of the generator and discriminator networks during training and provides recommendations
for addressing vanishing gradients.

Vanishing gradients occur when the gradients become extremely small, making it difficult for the model to learn and update
its weights effectively. This can lead to slow convergence or even failure of the training process.

The `GradientMonitor` class captures the gradients of each layer using backward hooks and analyzes them at the end of each
epoch. If the average gradient of a layer falls below a specified threshold, it is considered a vanishing gradient, and
corresponding recommendations are generated.

The recommendations include suggestions such as adjusting the learning rate, modifying the architecture by adding skip
connections or changing activation functions, applying gradient clipping, or using alternative loss functions.

To use the `GradientMonitor`, create instances for both the generator and discriminator networks, register the backward
hooks, and call the `monitor` method at the end of each training epoch.

Note: This tool provides general recommendations based on common techniques for addressing vanishing gradients. The
effectiveness of the recommendations may vary depending on the specific GAN architecture, dataset, and training setup.
Experimentation and domain knowledge are crucial for successfully mitigating vanishing gradients.
"""

In [None]:


import torch
import torch.nn as nn
import numpy as np

class GradientMonitor:
    def __init__(self, model, max_epochs, threshold=1e-4):
        """
        Initialize the GradientMonitor.

        Args:
            model (nn.Module): The model to monitor gradients for.
            max_epochs (int): The maximum number of epochs for training.
            threshold (float): The threshold for detecting vanishing gradients.
        """
        self.model = model
        self.max_epochs = max_epochs
        self.threshold = threshold
        self.gradients = {}
        self.recommendations = []

    def register_hooks(self):
        """
        Register backward hooks to capture gradients of each layer during training.
        """
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.register_hook(lambda grad, name=name: self.store_gradient(name, grad))

    def store_gradient(self, name, grad):
        """
        Store the gradient of each layer in the gradients dictionary.
        """
        if name not in self.gradients:
            self.gradients[name] = []
        self.gradients[name].append(grad.clone().detach())

    def analyze_gradients(self, epoch):
        """
        Analyze the gradients to identify vanishing gradients and provide recommendations.
        """
        vanishing_gradients = []
        for name, grads in self.gradients.items():
            avg_grad = torch.mean(torch.stack(grads))
            if avg_grad.abs() < self.threshold:
                vanishing_gradients.append(name)

        if vanishing_gradients:
            recommendation = f"Epoch {epoch}: Vanishing gradients detected in layers: {', '.join(vanishing_gradients)}. "
            recommendation += "Consider the following:\n"
            recommendation += "1. Adjust learning rate or use learning rate scheduling.\n"
            recommendation += "2. Modify the architecture by adding skip connections or changing activation functions.\n"
            recommendation += "3. Apply gradient clipping or use alternative loss functions.\n"
            self.recommendations.append(recommendation)

    def print_recommendations(self):
        """
        Print the collected recommendations.
        """
        if self.recommendations:
            print("Recommendations:")
            for rec in self.recommendations:
                print(rec)
        else:
            print("No vanishing gradients detected.")

    def monitor(self, epoch):
        """
        Monitor the gradients at the end of each epoch.
        """
        self.analyze_gradients(epoch)
        if epoch == self.max_epochs - 1:
            self.print_recommendations()

        # Clear the stored gradients for the next epoch
        self.gradients = {}

# Example usage
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

# Create instances of the generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Create gradient monitors for generator and discriminator
generator_monitor = GradientMonitor(generator, max_epochs=100)
discriminator_monitor = GradientMonitor(discriminator, max_epochs=100)

# Register the backward hooks
generator_monitor.register_hooks()
discriminator_monitor.register_hooks()

# Training loop
for epoch in range(100):
    # Train the GAN for one epoch
    # ...

    # Monitor gradients at the end of each epoch
    generator_monitor.monitor(epoch)
    discriminator_monitor.monitor(epoch)