In [3]:
import torch
import lightning as L
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
import os

class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = 64 // 4
        self.latent_dim = 512

        self.l1 = torch.nn.Sequential(torch.nn.Linear(self.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = torch.nn.Sequential(
            torch.nn.BatchNorm2d(128),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
            torch.nn.BatchNorm2d(128, 0.8),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(128, 64, 3, stride=1, padding=1),
            torch.nn.BatchNorm2d(64, 0.8),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(64, 3, 3, stride=1, padding=1),
            torch.nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [torch.nn.Conv2d(in_filters, out_filters, 3, 2, 1), 
                    torch.nn.LeakyReLU(0.2, inplace=True), 
                    torch.nn.Dropout2d(0.25)]
            if bn:
                block.append(torch.nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = torch.nn.Sequential(
            *discriminator_block(3, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        ds_size = 64 // 2 ** 4
        self.adv_layer = torch.nn.Sequential(
            torch.nn.Linear(128 * ds_size ** 2, 1), 
            torch.nn.Sigmoid()
        )

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        return validity

class GAN(L.LightningModule):
    def __init__(
            self,
            img_size: int = 64,
            latent_dim: int = 512,
            lr: float = 1e-5,
            b1: float = 0.5,
            b2: float = 0.999,
            n_critic: int = 5
            ):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False
        
        self.latent_dim = latent_dim
        self.generator = Generator()
        self.discriminator = Discriminator()
        
        self.criterion = torch.nn.BCEWithLogitsLoss()

    def forward(self, z):
        return self.generator(z)

def generate_images(
    num_images=16, 
    output_dir="generated_images", 
    checkpoint_path="Weights/WGANGP_MNIST_final/epoch=570-step=2498696.ckpt",
    grid_size=(4, 4)  # Default 4x4 grid
):
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Load the model
    model = GAN.load_from_checkpoint(checkpoint_path)
    model.eval()
    
    # Move model to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # Generate images
    with torch.no_grad():
        # Create random noise vectors
        z = torch.randn(num_images, model.latent_dim).to(device)
        
        # Generate images
        generated_images = model(z)
        
        # Denormalize the images using the same values as in training
        denorm = transforms.Normalize(
            mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
            std=[1/0.229, 1/0.224, 1/0.225]
        )
        
        # Denormalize all images
        denormalized_images = torch.stack([denorm(img) for img in generated_images])
        
        # Create and save individual images
        for i, img_tensor in enumerate(denormalized_images):
            # Convert to PIL image and save
            img_array = (img_tensor * 255).clamp(0, 255).permute(1, 2, 0).cpu().numpy().astype('uint8')
            img = Image.fromarray(img_array)
            # img.save(os.path.join(output_dir, f'generated_image_{i}.png'))
        
        # Create and save grid
        grid = vutils.make_grid(
            denormalized_images, 
            nrow=grid_size[0],
            padding=2, 
            normalize=False
        )
        
        # Convert grid to image and save
        grid_array = (grid.permute(1, 2, 0) * 255).clamp(0, 255).cpu().numpy().astype('uint8')
        grid_image = Image.fromarray(grid_array)
        grid_image.save(os.path.join(output_dir, 'generated_grid.png'))
            
    print(f"Generated {num_images} images and grid in {output_dir}")

if __name__ == "__main__":
    # Generate 16 images in a 4x4 grid
    generate_images(
        num_images=16,
        output_dir="generated_images",
        checkpoint_path="Weights/epoch=570-step=2498696 (1).ckpt",
        grid_size=(4, 4)
    )

Generated 16 images and grid in generated_images
