In [None]:
import torch
from torchvision import datasets, transforms
import random

In [None]:
# Path to your dataset
dataset_path = r'C:\Users\dzmit\Downloads\cat_images\cats_128x128'

In [None]:
# import os
# import shutil
# 
# # Path to the dataset directory you care about
# dataset_path = r'C:\Users\dzmit\Downloads\Imageandvideodataset\image dataset\train\2'
# 
# # New subfolder path for the single class
# subfolder_path = os.path.join(dataset_path, "class1")
# 
# # Create the subfolder if it does not exist
# if not os.path.exists(subfolder_path):
#     os.makedirs(subfolder_path)
# 
# # Move all files into the subfolder
# for file in os.listdir(dataset_path):
#     file_path = os.path.join(dataset_path, file)
#     if os.path.isfile(file_path):
#         shutil.move(file_path, subfolder_path)
# 
# print("All files moved to:", subfolder_path)

In [None]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((64, 64)),  # Resize all images to 64x64
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


# Create the dataset
# dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)

In [None]:
# DataLoader
batch_size = 24  # Batch size
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
import torch
import torch.nn as nn

class MappingNetwork(nn.Module):
    def __init__(self, input_dim, feature_dim, num_layers=8):
        super().__init__()
        layers = [nn.Sequential(
            nn.Linear(input_dim, feature_dim),
            nn.LeakyReLU(0.2)
        )]
        for _ in range(num_layers - 1):
            layers.append(nn.Sequential(
                nn.Linear(feature_dim, feature_dim),
                nn.LeakyReLU(0.2)
            ))
        self.model = nn.Sequential(*layers)

    def forward(self, z):
        w = self.model(z)
        return w


In [None]:
class SynthesisBlock(nn.Module):
    def __init__(self, in_channels, out_channels, size, w_dim, is_initial=False):
        super().__init__()
        if is_initial:
            self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        else:
            self.conv = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                nn.Conv2d(in_channels, out_channels, 3, padding=1)
            )
        # Transform w to the correct dimension for this block
        self.style_transform = nn.Linear(w_dim, out_channels)
        self.styles = nn.Linear(out_channels, out_channels)
        self.noise = nn.Parameter(torch.randn(1, 1, size, size))
        self.activation = nn.LeakyReLU(0.2)

    def forward(self, x, w):
        x = self.conv(x)
        
        # Transform style vector to match out_channels of this block
        w_transformed = self.style_transform(w)
        s = self.styles(w_transformed).unsqueeze(2).unsqueeze(3)
        
        x = x + self.noise * s
        x = self.activation(x)
        
        return x


In [None]:

class SynthesisNetwork(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        self.device = device
        self.initial_block = SynthesisBlock(512, 512, 4, w_dim=512, is_initial=True).to(device)
        self.blocks = nn.ModuleList([
            SynthesisBlock(512, 256, 8, w_dim=512).to(device),
            SynthesisBlock(256, 128, 16, w_dim=512).to(device),
            SynthesisBlock(128, 64, 32, w_dim=512).to(device),
            SynthesisBlock(64, 3, 64, w_dim=512).to(device),
        ])

    def forward(self, w):
        x = torch.randn(1, 512, 4, 4).to(self.device)  # Initial noise vector
        x = self.initial_block(x, w)
        for block in self.blocks:
            x = block(x, w)
        return x


In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            # input is (nc) x 64 x 64
            # nn.Conv2d(3, 32, 4, 2, 1, bias=False),
            # nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.model(input).view(-1, 1).squeeze(1)


# Debugging

In [None]:
# z = torch.randn(24, 512).to(device)
# w = mapping_network(z)

In [None]:
# fake_images = generator(w)
# print(fake_images.shape)

In [None]:
# output = discriminator(fake_images)
# print(output.shape)

In [None]:
# for i, (real_images, _) in enumerate(dataloader):
#             real_images = real_images.to(device)
#             real_labels = torch.ones(real_images.size(0), device=device)
#             fake_labels = torch.zeros(real_images.size(0), device=device)
#         
#             ### Train Discriminator
#             optimizer_D.zero_grad()
#             real_outputs = discriminator(real_images)
#             break

In [None]:
# init_block = SynthesisBlock(512, 512, 4, w_dim=512, is_initial=True)
# init_block(torch.randn(1, 512, 4, 4), torch.randn(24, 512)).shape

In [None]:
# blocks = nn.ModuleList([
#             SynthesisBlock(512, 256, 8, w_dim=512),
#             SynthesisBlock(256, 128, 16, w_dim=512),
#             SynthesisBlock(128, 64, 32, w_dim=512),
#             SynthesisBlock(64, 3, 64, w_dim=512),
#             # SynthesisBlock(32, 16, 128, w_dim=512)
#         ])

In [None]:
# x = torch.randn(1, 512, 4, 4)  # Initial noise vector
# 
# x = init_block(x, w)
# for block in blocks:
#     x = block(x, w)
#     print(x.shape)

# Device

In [None]:
# Check if CUDA is available, else use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

In [None]:
torch.cuda.is_available()

# Load the model

In [96]:
# Initialize the mapping network, synthesis network, and discriminator
mapping_network = MappingNetwork(input_dim=512, feature_dim=512).to(device)
generator = SynthesisNetwork(device=device).to(device)
discriminator = Discriminator().to(device)

# Optimizers

In [97]:
import torch.optim as optim

# Optimizers
optimizer_G = optim.Adam(
    list(mapping_network.parameters()) + list(generator.parameters()),
    lr=0.001,  # Updated learning rate as per the StyleGAN paper
    betas=(0.0, 0.99)  # Updated beta values as per the StyleGAN paper
)

optimizer_D = optim.Adam(
    discriminator.parameters(),
    lr=0.001,  # Updated learning rate as per the StyleGAN paper
    betas=(0.0, 0.99)  # Updated beta values as per the StyleGAN paper
)

# Loss function
criterion = nn.BCELoss()


# Load the checkpoint

In [None]:
# checkpoint = torch.load('GAN_cat_299.pth')
# 
# # Assuming the generator and discriminator are already instantiated as per the saved model architecture
# generator.load_state_dict(checkpoint['generator_state_dict'])
# discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
# 
# # Assuming the optimizers are already instantiated with the parameters of their respective models
# optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
# optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
# 
# # If you saved the epoch number, you can also load this to know where to resume training
# epoch = checkpoint['epoch']


In [None]:
# for param_group in optimizer_G.param_groups:
#     param_group['lr'] *= 0.1
# 
# for param_group in optimizer_D.param_groups:
#     param_group['lr'] *= 0.1

# Plotting functions

In [104]:
import os
import matplotlib.pyplot as plt

def generate_and_plot_images(n_images=9, epoch=0, plot=True):
    """
    Generates and plots a grid of images using a trained generator model.

    Parameters:
    - generator: The trained generator model for generating images.
    - device: The device (e.g., 'cuda' or 'cpu') the model should run on.
    - n_images: The total number of images to generate and plot. Default is 9.
    """
    fig, axes = plt.subplots(5, 5, figsize=(9, 9))  # Create a 3x3 grid of subplots
    axes = axes.flatten()  # Flatten the 2D array of axes for easier iteration

    for i in range(n_images):
        # Generate random noise
        noise = torch.randn(1, 512).to(device)

        # Generate an image without updating gradients
        with torch.no_grad():
            generated_image = generator(mapping_network(noise))

        # Process the image for visualization
        generated_image = generated_image.to('cpu').clone().detach()
        generated_image = generated_image.numpy().squeeze(0)

        if generated_image.shape[0] == 3:  # Check if the image has 3 channels (RGB)
            generated_image = generated_image.transpose(1, 2, 0)  # Convert from CxHxW to HxWxC
            
        elif generated_image.shape[0] == 1:  # Check if the image has 3 channels (RGB)
            generated_image = generated_image.squeeze(0)  # Convert from CxHxW to HxWxC

        # Normalize the image data to [0, 1]
        generated_image = (generated_image + 1) / 2
        generated_image = generated_image.clip(0, 1)  # Ensure pixel values are within the expected range

        axes[i].imshow(generated_image, cmap='gray')
        axes[i].axis('off')  # Turn off the axis to make the images look cleaner

    plt.tight_layout()
    output_base_path = 'style_net/v1/'
    # Ensure base output directory exists
    os.makedirs(output_base_path, exist_ok=True)

    plt.savefig(output_base_path + f'generated_images_grid_{epoch}.png')

    if plot:
        plt.show()
        
    return None


In [98]:
def check_output(generated_image, epoch=0):
    generated_image = generated_image.to('cpu').clone().detach()  # Move to CPU and detach from the computation graph
    generated_image = generated_image.numpy()  # Convert to numpy array
    # generated_image = generated_image.squeeze(0)  # Remove the batch dimension, resulting in (channels, height, width)
    
    # If the image is in the format (C, H, W), convert it to (H, W, C)
    if generated_image.shape[0] == 3:  # Assuming 3 channels for RGB
        generated_image = generated_image.transpose(1, 2, 0)  # Reorder dimensions to (H, W, C)
    
    # Normalize the image to [0, 1] if it's not already
    generated_image = (generated_image + 1) / 2  # Assuming that the output is in the range [-1, 1]
    generated_image = generated_image.clip(0, 1)  # Ensure the values are within [0, 1]
    
    plt.imshow(generated_image)
    plt.savefig(f'output/check_generated_image_{epoch}.png')


In [None]:
# generator(torch.randn(1, 512).to(device)).shape

In [None]:
# discriminator(generator(torch.randn(1, 512).to(device)).shape)

In [None]:
# generate_and_plot_images(n_images=25, epoch=200)

# Clear cache

In [108]:
import gc

torch.cuda.empty_cache()  # Clear cache
gc.collect()  # Collect garbage
generator.to('cpu')
discriminator.to('cpu')
mapping_network.to('cpu')

del generator, discriminator, optimizer_G, optimizer_D

In [102]:
num_epochs = 300
latent_size = 512

d_losses = []
g_losses = []

flip_prob = 0.0  # Probability of flipping labels

In [105]:
for epoch in range(26, num_epochs):
    try:
        for i, (real_images, _) in enumerate(dataloader):
            real_images = real_images.to(device)
            real_labels = torch.ones(real_images.size(0), device=device)
            fake_labels = torch.zeros(real_images.size(0), device=device)
        
            ### Train Discriminator
            optimizer_D.zero_grad()
            real_outputs = discriminator(real_images)
            d_loss_real = criterion(real_outputs, real_labels)
            d_loss_real.backward()
        
            # Generate fake images
            noise = torch.randn(real_images.size(0), latent_size, device=device)
            w = mapping_network(noise)  # Transform z to w
            fake_images = generator(w)  # Generate images from w
            fake_outputs = discriminator(fake_images.detach())
            d_loss_fake = criterion(fake_outputs, fake_labels)
            d_loss_fake.backward()
            optimizer_D.step()
        
            ### Train Generator
            optimizer_G.zero_grad()
            # Optionally regenerate fake images for freshness
            noise = torch.randn(real_images.size(0), latent_size, device=device)
            w = mapping_network(noise)  # Regenerate w for generator training
            fake_images = generator(w)
            output = discriminator(fake_images)
            
        
            # Optionally flip labels to introduce label smoothing
            if random.random() < flip_prob:
                g_loss = criterion(output, fake_labels)  # Flipped labels
            else:
                g_loss = criterion(output, real_labels)  # Normal training
        
            g_loss.backward()
            optimizer_G.step()

            if (i + 1) % 20 == 0:
                d_losses.append(d_loss_real.item() + d_loss_fake.item())
                g_losses.append(g_loss.item())
                print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(dataloader)}], '
                      f'D Loss: {d_loss_real.item() + d_loss_fake.item():.4f}, G Loss: {g_loss.item():.4f}')

        if (epoch + 1) % 1 == 0:
            generate_and_plot_images(25, epoch=epoch, plot=False)
            
        if (epoch + 1) % 30 == 0:
            checkpoint = {
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'mapping_network_state_dict': mapping_network.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'epoch': epoch
            }
            
            torch.save(checkpoint, f'GAN_cat_{epoch}.pth')

    except OSError:
        print(f"An error occurred while processing the image. Epoch: {epoch}, batch: {i}")
        continue


Epoch [27/300], Step [20/417], D Loss: 0.0000, G Loss: 41.8047
Epoch [27/300], Step [40/417], D Loss: 3.0743, G Loss: 31.6451
Epoch [27/300], Step [60/417], D Loss: 0.0000, G Loss: 17.1550
Epoch [27/300], Step [80/417], D Loss: 0.0002, G Loss: 19.3336
Epoch [27/300], Step [100/417], D Loss: 0.0000, G Loss: 19.0992
Epoch [27/300], Step [120/417], D Loss: 1.5189, G Loss: 35.0321
Epoch [27/300], Step [140/417], D Loss: 0.1947, G Loss: 38.6341
Epoch [27/300], Step [160/417], D Loss: 0.0000, G Loss: 38.5516
Epoch [27/300], Step [180/417], D Loss: 0.0000, G Loss: 37.0060
Epoch [27/300], Step [200/417], D Loss: 0.0000, G Loss: 28.0326
Epoch [27/300], Step [220/417], D Loss: 0.0000, G Loss: 33.7269
Epoch [27/300], Step [240/417], D Loss: 0.0000, G Loss: 31.3070
Epoch [27/300], Step [260/417], D Loss: 0.0000, G Loss: 25.6777
Epoch [27/300], Step [280/417], D Loss: 0.0001, G Loss: 35.2162
Epoch [27/300], Step [300/417], D Loss: 0.0269, G Loss: 6.9806
Epoch [27/300], Step [320/417], D Loss: 0.011

  fig, axes = plt.subplots(5, 5, figsize=(9, 9))  # Create a 3x3 grid of subplots


Epoch [48/300], Step [20/417], D Loss: 0.0016, G Loss: 23.2631
Epoch [48/300], Step [40/417], D Loss: 0.0013, G Loss: 9.4849
Epoch [48/300], Step [60/417], D Loss: 0.0104, G Loss: 10.6804
Epoch [48/300], Step [80/417], D Loss: 0.0000, G Loss: 28.6086
Epoch [48/300], Step [100/417], D Loss: 0.0000, G Loss: 17.8450
Epoch [48/300], Step [120/417], D Loss: 0.0011, G Loss: 11.5185
Epoch [48/300], Step [140/417], D Loss: 0.0124, G Loss: 6.7696
Epoch [48/300], Step [160/417], D Loss: 0.2541, G Loss: 24.1531
Epoch [48/300], Step [180/417], D Loss: 0.0044, G Loss: 9.9759
Epoch [48/300], Step [200/417], D Loss: 0.0376, G Loss: 13.8527
Epoch [48/300], Step [220/417], D Loss: 0.0636, G Loss: 19.3898
Epoch [48/300], Step [240/417], D Loss: 0.0004, G Loss: 7.0103
Epoch [48/300], Step [260/417], D Loss: 0.7238, G Loss: 12.9389
Epoch [48/300], Step [280/417], D Loss: 0.0597, G Loss: 12.0384
Epoch [48/300], Step [300/417], D Loss: 0.0004, G Loss: 21.6528
Epoch [48/300], Step [320/417], D Loss: 0.0034, 

KeyboardInterrupt: 

Error in callback <function flush_figures at 0x0000023391F81360> (for post_execute), with arguments args (),kwargs {}:



KeyboardInterrupt



In [None]:
for epoch in range(0, num_epochs):
    try:
        for i, (real_images, _) in enumerate(dataloader):
            real_images = real_images.to(device)
            real_labels = torch.ones(real_images.size(0), device=device)
            fake_labels = torch.zeros(real_images.size(0), device=device)
        
            ### Train Discriminator
            optimizer_D.zero_grad()
            real_outputs = discriminator(real_images)
            d_loss_real = criterion(real_outputs, real_labels)
            d_loss_real.backward()
        
            noise = torch.randn(real_images.size(0), latent_size, 1, 1, device=device)
            fake_images = generator(noise)
            fake_outputs = discriminator(fake_images.detach())
            d_loss_fake = criterion(fake_outputs, fake_labels)
            d_loss_fake.backward()
            optimizer_D.step()
        
            ### Train Generator
            optimizer_G.zero_grad()
            # Optionally regenerate fake images for freshness
            noise = torch.randn(real_images.size(0), latent_size, 1, 1, device=device)
            fake_images = generator(noise)
            output = discriminator(fake_images)
        
            # Randomly decide whether to flip labels
            if random.random() < flip_prob:
                g_loss = criterion(output, fake_labels)  # Flipped labels
            else:
                g_loss = criterion(output, real_labels)  # Normal training
        
            g_loss.backward()
            optimizer_G.step()

    
            if (i + 1) % 20 == 0:
                d_losses.append(d_loss_real.item() + d_loss_fake.item())
                g_losses.append(g_loss.item())
                print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(dataloader)}], '
                      f'D Loss: {d_loss_real.item() + d_loss_fake.item():.4f}, '
                      f'G Loss: {g_loss.item():.4f}'
                      f'; D Loss Real: {d_loss_real.item():.4f}, '
                      f'; D Loss Fake: {d_loss_fake.item():.4f}, '
                      )
                
        if (epoch + 1) % 1 == 0:
            # check_output(fake_images[0], epoch)
            generate_and_plot_images(25, epoch=epoch, plot=False)
            
        if (epoch + 1) % 30 == 0:
            checkpoint = {
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'epoch': epoch  # Optional, if you want to also save the epoch number
            }
            
            torch.save(checkpoint, f'GAN_cat_{epoch}.pth')

    except OSError:
        print(f"An error occurred while processing the image. Epoch: {epoch}, batch: {i}")
        continue


In [None]:
# Assuming 'generator' and 'discriminator' are your model instances
# And 'optimizer_G' and 'optimizer_D' are the optimizers for the generator and discriminator respectively

# Define checkpoint dictionary
checkpoint = {
    'generator_state_dict': generator.state_dict(),
    'discriminator_state_dict': discriminator.state_dict(),
    'optimizer_G_state_dict': optimizer_G.state_dict(),
    'optimizer_D_state_dict': optimizer_D.state_dict(),
    'epoch': epoch  # Optional, if you want to also save the epoch number
}

# Save checkpoint
torch.save(checkpoint, f'GAN_checkpoint_main_128_{epoch}.pth')


In [None]:
generate_and_plot_images(25, epoch=1000, plot=True)

In [None]:
# check_output(real_images[-1])

In [None]:
for i in range(1301, 1320):
    generate_and_plot_images(25, epoch=1301)

In [None]:
# Generate random noise
noise = torch.randn(1, 100, 1, 1, device=device)  # Batch size of 1, latent vector size of 100

# Generate an image
with torch.no_grad():  # Temporarily set all the requires_grad flags to false
    generated_image = generator(noise)

In [None]:

# generated_image = fake_images[23]

generated_image = generated_image.to('cpu').clone().detach()  # Move to CPU and detach from the computation graph
generated_image = generated_image.numpy()  # Convert to numpy array
generated_image = generated_image.squeeze(0)  # Remove the batch dimension, resulting in (channels, height, width)

# If the image is in the format (C, H, W), convert it to (H, W, C)
if generated_image.shape[0] == 3:  # Assuming 3 channels for RGB
    generated_image = generated_image.transpose(1, 2, 0)  # Reorder dimensions to (H, W, C)

# Normalize the image to [0, 1] if it's not already
generated_image = (generated_image + 1) / 2  # Assuming that the output is in the range [-1, 1]
generated_image = generated_image.clip(0, 1)  # Ensure the values are within [0, 1]


In [None]:
import matplotlib.pyplot as plt

plt.imshow(generated_image)
plt.axis('off')  # Turn off axis numbers and ticks
plt.show()


In [None]:
for num, generated_image in enumerate(fake_images):
    # Generate random noise
    noise = torch.randn(1, 100, 1, 1, device=device) / 100  # Batch size of 1, latent vector size of 100
    
    # Generate an image
    with torch.no_grad():  # Temporarily set all the requires_grad flags to false
        generated_image = generator(noise)

    generated_image = generated_image.to('cpu').clone().detach()  # Move to CPU and detach from the computation graph
    generated_image = generated_image.numpy()  # Convert to numpy array
    generated_image = generated_image.squeeze(0)  # Remove the batch dimension, resulting in (channels, height, width)
    
    # If the image is in the format (C, H, W), convert it to (H, W, C)
    if generated_image.shape[0] == 3:  # Assuming 3 channels for RGB
        generated_image = generated_image.transpose(1, 2, 0)  # Reorder dimensions to (H, W, C)
    
    # Normalize the image to [0, 1] if it's not already
    generated_image = (generated_image + 1) / 2  # Assuming that the output is in the range [-1, 1]
    generated_image = generated_image.clip(0, 1)  # Ensure the values are within [0, 1]

    
    plt.imshow(generated_image)
    plt.savefig(f'output/generated_imag_{num}.png')


In [None]:
for num, generated_image in enumerate(fake_images):

    generated_image = generated_image.to('cpu').clone().detach()  # Move to CPU and detach from the computation graph
    generated_image = generated_image.numpy()  # Convert to numpy array
    # generated_image = generated_image.squeeze(0)  # Remove the batch dimension, resulting in (channels, height, width)
    
    # If the image is in the format (C, H, W), convert it to (H, W, C)
    if generated_image.shape[0] == 3:  # Assuming 3 channels for RGB
        generated_image = generated_image.transpose(1, 2, 0)  # Reorder dimensions to (H, W, C)
    
    # Normalize the image to [0, 1] if it's not already
    generated_image = (generated_image + 1) / 2  # Assuming that the output is in the range [-1, 1]
    generated_image = generated_image.clip(0, 1)  # Ensure the values are within [0, 1]
    
    plt.imshow(generated_image)
    plt.savefig(f'output/fake_imag_{num}.png')


In [None]:
for num, real_image in enumerate(real_images):

    real_image = real_image.to('cpu').clone().detach()  # Move to CPU and detach from the computation graph
    real_image = real_image.numpy()  # Convert to numpy array
    # generated_image = generated_image.squeeze(0)  # Remove the batch dimension, resulting in (channels, height, width)
    
    # If the image is in the format (C, H, W), convert it to (H, W, C)
    if real_image.shape[0] == 3:  # Assuming 3 channels for RGB
        real_image = real_image.transpose(1, 2, 0)  # Reorder dimensions to (H, W, C)
    
    # Normalize the image to [0, 1] if it's not already
    real_image = (real_image + 1) / 2  # Assuming that the output is in the range [-1, 1]
    real_image = real_image.clip(0, 1)  # Ensure the values are within [0, 1]
    
    plt.imshow(real_image)
    break
