In [None]:
import os
import torch
from torch import nn
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

class Generator(nn.Module):
    def __init__(self, input_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.input_dim = input_dim
        self.gen = nn.Sequential(
            self.make_gen_block(input_dim, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(),
            )

    def forward(self, noise):
        x = noise.view(len(noise), self.input_dim, 1, 1)
        return self.gen(x)

def combine_vectors(x, y):
    return torch.cat((x.float(), y.float()), 1)

def generate_single_image(generator, class_label, z_dim, n_classes):
    noise = torch.randn(1, z_dim)
    one_hot_label = torch.zeros(1, n_classes)
    one_hot_label[0][class_label] = 1
    noise_and_label = combine_vectors(noise, one_hot_label)

    with torch.no_grad():
        fake_image = generator(noise_and_label)

    # Convert the generated image from a tensor to a numpy array for display
    img = (fake_image.squeeze().numpy() + 1) / 2  # Normalize to [0, 1]
    return img

def display_image(img):
    plt.imshow(img, cmap='gray')
    plt.axis('off')  # Turn off axis
    plt.show()

if __name__ == "__main__":
    z_dim = 64
    n_classes = 10
    epoch = input("Enter the epoch number you want to load (e.g., 1, 11, 21...): ")
    class_label = int(input(f"Enter the number you want to generate (0 to {n_classes - 1}): "))

    # Paths (modify these according to your folder structure)
    checkpoint_dir = '/path_to_checkpoints_folder'  # Folder where the checkpoints are stored

    checkpoint_path = os.path.join(checkpoint_dir, f'generator_epoch_{epoch}.pth')

    if os.path.exists(checkpoint_path):
        # Load the generator model for the specified epoch
        generator = Generator(input_dim=z_dim + n_classes)
        generator.load_state_dict(torch.load(checkpoint_path))
        generator.eval()

        # Generate the image for the specified class label
        generated_image = generate_single_image(generator, class_label, z_dim, n_classes)

        # Display the generated image in the notebook
        display_image(generated_image)
    else:
        print(f"Checkpoint {checkpoint_path} not found.")