In [1]:
import torch
from torchvision import datasets, transforms
import random

# Load dataset

In [63]:
# Path to your dataset
dataset_path = r'C:\Users\dzmit\Downloads\cat_faces\Cat-faces-dataset-master'

In [64]:
transform = transforms.Compose([
    # transforms.Resize((64, 64)),  # Resize all images to 64x64
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


# Create the dataset
# dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)

In [65]:
# DataLoader
batch_size = 128  # Batch size

In [66]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [8]:
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, latent_size, ngf, output_channels_size):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(latent_size, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, output_channels_size, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (output_channels_size) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)


In [9]:
class Discriminator(nn.Module):
    def __init__(self, output_channels_size, ndf):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # input is (output_channels_size) x 64 x 64
            nn.Conv2d(output_channels_size, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)


# Latent vector

In [3]:
def create_noise(size):
    # Generate random noise directly with specified mean and std
    center = 0
    std = 0.2 
    
    return torch.normal(center, std, size=(size, latent_size, 1, 1)).to(device)

In [13]:
import torch
import numpy as np

def sample_spherical(npoints, ndim=3):
    vec = np.random.randn(ndim, npoints)
    vec /= np.linalg.norm(vec, axis=0)
    return vec

def interpolate_along_great_circle(latent_dim, num_steps, device='cpu'):
    """
    Generates a series of latent vectors interpolated along a great circle in the latent space.
    
    Parameters:
        latent_dim (int): Dimension of the latent space.
        num_steps (int): Number of interpolation steps along the great circle.
        device (str): Device to which the latent vectors will be sent ('cpu' or 'cuda').
    
    Returns:
        torch.Tensor: Interpolated latent vectors shaped (num_steps, latent_dim, 1, 1).
    """
    # Generate two points on the unit sphere in latent space
    points = torch.tensor(sample_spherical(2, latent_dim), dtype=torch.float, device=device).t()
    start, end = points[0], points[1]
    
    # Compute the angle between them
    dot = torch.dot(start, end)
    theta = torch.acos(dot)
    
    # Generate the steps
    steps = torch.linspace(0, 1, num_steps, device=device)
    sin_t = torch.sin(theta)
    
    # Perform the interpolation
    latent_vectors = []
    for step in steps:
        alpha = torch.sin((1 - step) * theta) / sin_t
        beta = torch.sin(step * theta) / sin_t
        interpolated_point = alpha * start + beta * end
        latent_vectors.append(interpolated_point.unsqueeze(0).unsqueeze(-1).unsqueeze(-1))
    
    # Concatenate all interpolated points
    return torch.cat(latent_vectors, dim=0)


# Device

In [4]:
# Check if CUDA is available, else use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

Using device: cuda


In [5]:
torch.cuda.is_available()

True

# Load the model

In [6]:
latent_size = 100
ngf = 64
output_channels_size = 3
ndf = 64

In [67]:
generator = Generator(latent_size, ngf, output_channels_size).to(device)
discriminator = Discriminator(output_channels_size, ndf).to(device)

In [11]:
discriminator(torch.randn(1, 3, 64, 64, device=device))

tensor([[[[0.6423]]]], device='cuda:0', grad_fn=<SigmoidBackward0>)

In [28]:
create_noise(3).shape

torch.Size([3, 100, 1, 1])

In [27]:
generator(create_noise(3)).shape

torch.Size([3, 3, 64, 64])

In [29]:
# Example usage:
num_steps = 10  # Number of interpolation steps
latent_vectors = interpolate_along_great_circle(latent_size, 3, device)

In [30]:
latent_vectors.shape

torch.Size([3, 100, 1, 1])

In [31]:
generator(latent_vectors).shape

torch.Size([3, 3, 64, 64])

# Optimizers

In [68]:
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Loss function
criterion = nn.BCELoss()

# Load the checkpoint

In [14]:
# checkpoint = torch.load('GAN_cat_299.pth')
# 
# # Assuming the generator and discriminator are already instantiated as per the saved model architecture
# generator.load_state_dict(checkpoint['generator_state_dict'])
# discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
# 
# # Assuming the optimizers are already instantiated with the parameters of their respective models
# optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
# optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
# 
# # If you saved the epoch number, you can also load this to know where to resume training
# epoch = checkpoint['epoch']


In [15]:
# for param_group in optimizer_G.param_groups:
#     param_group['lr'] *= 0.1
# 
# for param_group in optimizer_D.param_groups:
#     param_group['lr'] *= 0.1

# Plotting functions

In [32]:
import matplotlib.pyplot as plt
import os

def save_images(fake_images, epoch, prefix='gan', folder='generated_images', n_images=25):
    """
    Saves a grid of generated images to a file.

    Parameters:
    - fake_images: Tensor of images generated by the GAN.
    - epoch: Current epoch number, used for naming the output file.
    - prefix: Prefix string for the filename.
    - folder: Output directory for saving the images.
    - n_images: Number of images to save. Default is 25.
    """
    if not os.path.exists(folder):
        os.makedirs(folder)

    # Select the first n_images from the batch
    images_to_save = fake_images[:n_images]
    
    fig, axes = plt.subplots(5, 5, figsize=(10, 10))  # Setting up a 5x5 grid
    axes = axes.flatten()
    
    for idx, img in enumerate(images_to_save):
        img = img.to('cpu').detach().numpy()  # Convert tensor to numpy array
        # if img.shape[0] == 3:  # If there are 3 channels (RGB)
        img = img.transpose(1, 2, 0)  # Change from CxHxW to HxWxC
        # else:
        #     img = img.squeeze(0)  # If grayscale, remove channel dimension
        
        # Normalize image to [0, 1]
        img = (img + 1) / 2
        img = img.clip(0, 1)  # Ensure pixel values are within the [0, 1] range
        
        axes[idx].imshow(img, cmap='gray')
        axes[idx].axis('off')  # Hide axes to enhance visual appeal

    plt.tight_layout()
    # Construct the filename using the prefix and epoch
    filename = f"{prefix}_epoch_{epoch}.png"
    plt.savefig(os.path.join(folder, filename))
    plt.close(fig)  # Close the plot to free memory

# Example usage (assuming `fake_images` is your batch of generated images and `epoch` is your current epoch):
# save_images(fake_images, epoch, prefix='myGAN', folder='my_images')


In [17]:
# generate_and_plot_images(n_images=25, epoch=200)

# Clear cache

In [62]:
# import gc
# 
# torch.cuda.empty_cache()  # Clear cache
# gc.collect()  # Collect garbage
# generator.to('cpu')
# discriminator.to('cpu')
# 
# del generator, discriminator, optimizer_G, optimizer_D

In [69]:
num_epochs = 300

d_losses = []
g_losses = []

In [70]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# Assuming 'generator' and 'discriminator' are your models
# 'optimizer_G' and 'optimizer_D' are the respective optimizers

for epoch in range(num_epochs):
    try:
        for batch_number, (real_images, _) in enumerate(dataloader):
    
            real_images = real_images.to(device)
            
            batch_size = real_images.size(0)
            # Add tiny random noise around 0.9 for the real labels
            noise_epsilon = 0.01  # Standard deviation of noise
            real_labels = torch.full((batch_size,), 0.9, device=device) + torch.randn(batch_size, device=device) * noise_epsilon
            fake_labels = torch.zeros(batch_size, device=device) + torch.randn(batch_size, device=device) * noise_epsilon
    
            # Train Discriminator
            optimizer_D.zero_grad()
            outputs_real = discriminator(real_images)
            loss_real = -torch.mean(torch.log(outputs_real + 1e-8))
            loss_real.backward()
    
            # noise = torch.randn(batch_size, latent_size, 1, 1, device=device)
            latent_vector = interpolate_along_great_circle(latent_size, batch_size, device)

            fake_images = generator(latent_vector)
            
            outputs_fake = discriminator(fake_images.detach())
            loss_fake = -torch.mean(torch.log(1 - outputs_fake + 1e-8))
            loss_fake.backward()
            optimizer_D.step()
    
            # Train Generator with possible label flipping
            optimizer_G.zero_grad()
            # Randomly decide whether to flip labels
            flip = np.random.rand() < 0.1  # 10% chance to flip labels
            if flip:
                # Train generator to produce 'fake' labeled as 'real', but use 'fake' label
                gen_labels = fake_labels
            else:
                # Normal training, train generator to produce 'fake' labeled as 'real'
                gen_labels = real_labels
    
            outputs_fake_for_gen = discriminator(fake_images).squeeze()
            # print(outputs_fake_for_gen.shape, gen_labels.shape)
            loss_G = criterion(outputs_fake_for_gen, gen_labels)
            loss_G.backward()
            optimizer_G.step()

            # Logging and validation here (if applicable)
            if (batch_number + 1) % 20 == 0:
                d_losses.append(loss_real.item() + loss_fake.item())
                g_losses.append(loss_G.item())
                print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{batch_number + 1}/{len(dataloader)}], '
                      f'D Loss: {loss_real.item() + loss_fake.item():.4f}, '
                      f'G Loss: {loss_G.item():.4f}'
                      f'; D Loss Real: {loss_real.item():.4f}, '
                      f'; D Loss Fake: {loss_fake.item():.4f}, '
                      )
    
        # if (epoch + 1) % 1 == 0:
            # check_output(fake_images[0], epoch)
        save_images(fake_images, epoch, prefix='gan', folder='generated_images', n_images=25)
            
        if (epoch + 1) % 20 == 0:
            checkpoint = {
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'epoch': epoch  # Optional, if you want to also save the epoch number
            }
            
            torch.save(checkpoint, f'dcgan_cat_{epoch}.pth')
            
    except OSError:
        print(f"An error occurred while processing the image. Epoch: {epoch}, batch: {batch_number}")
        continue

Epoch [1/300], Step [20/234], D Loss: 0.1642, G Loss: 5.8180; D Loss Real: 0.1574, ; D Loss Fake: 0.0068, 
Epoch [1/300], Step [40/234], D Loss: 0.0745, G Loss: 7.7050; D Loss Real: 0.0707, ; D Loss Fake: 0.0038, 
Epoch [1/300], Step [60/234], D Loss: 0.0674, G Loss: 11.1617; D Loss Real: 0.0674, ; D Loss Fake: 0.0000, 
Epoch [1/300], Step [80/234], D Loss: 2.2898, G Loss: -0.0137; D Loss Real: 0.0078, ; D Loss Fake: 2.2820, 
Epoch [1/300], Step [100/234], D Loss: 0.0629, G Loss: 11.0330; D Loss Real: 0.0629, ; D Loss Fake: 0.0000, 
Epoch [1/300], Step [120/234], D Loss: 0.4125, G Loss: 15.0104; D Loss Real: 0.4125, ; D Loss Fake: 0.0000, 
Epoch [1/300], Step [140/234], D Loss: 0.0765, G Loss: 13.3617; D Loss Real: 0.0765, ; D Loss Fake: 0.0000, 
Epoch [1/300], Step [160/234], D Loss: 0.0086, G Loss: 9.9969; D Loss Real: 0.0085, ; D Loss Fake: 0.0000, 
Epoch [1/300], Step [180/234], D Loss: 0.0519, G Loss: 5.0293; D Loss Real: 0.0266, ; D Loss Fake: 0.0253, 
Epoch [1/300], Step [200/23

In [55]:
fake_images.shape

torch.Size([13, 3, 64, 64])

In [56]:
real_images.shape

torch.Size([13, 3, 64, 64])

In [None]:
import matplotlib.pyplot as plt

def generate_and_plot_images(n_images=9, epoch=0, plot=True):
    """
    Generates and plots a grid of images using a trained generator model.

    Parameters:
    - generator: The trained generator model for generating images.
    - device: The device (e.g., 'cuda' or 'cpu') the model should run on.
    - n_images: The total number of images to generate and plot. Default is 9.
    """
    fig, axes = plt.subplots(5, 5, figsize=(9, 9))  # Create a 3x3 grid of subplots
    axes = axes.flatten()  # Flatten the 2D array of axes for easier iteration

    for i in range(n_images):
        # Generate random noise
        noise = torch.randn(1, 256, 1, 1, device=device) / 100

        # Generate an image without updating gradients
        with torch.no_grad():
            generated_image = generator(noise)

        # Process the image for visualization
        generated_image = generated_image.to('cpu').clone().detach()
        generated_image = generated_image.numpy().squeeze(0)

        if generated_image.shape[0] == 3:  # Check if the image has 3 channels (RGB)
            generated_image = generated_image.transpose(1, 2, 0)  # Convert from CxHxW to HxWxC
            
        elif generated_image.shape[0] == 1:  # Check if the image has 3 channels (RGB)
            generated_image = generated_image.squeeze(0)  # Convert from CxHxW to HxWxC

        # Normalize the image data to [0, 1]
        generated_image = (generated_image + 1) / 2
        generated_image = generated_image.clip(0, 1)  # Ensure pixel values are within the expected range

        axes[i].imshow(generated_image, cmap='gray')
        axes[i].axis('off')  # Turn off the axis to make the images look cleaner

    plt.tight_layout()
    plt.savefig(f'output/generated_images_grid_{epoch}.png')

    if plot:
        plt.show()


In [20]:
for epoch in range(num_epochs):
    try:
        for i, (real_images, _) in enumerate(dataloader):
            real_images = real_images.to(device)
            
            # Labels for your batches
            real_labels = torch.full((batch_size,), 0.9, device=device)  # Real labels smoothed to 0.9
            # real_labels = torch.ones(real_images.size(0), device=device)
            fake_labels = torch.zeros(real_images.size(0), device=device)
        
            ### Train Discriminator
            optimizer_D.zero_grad()
            real_outputs = discriminator(real_images)
            d_loss_real = criterion(real_outputs, real_labels)
            d_loss_real.backward()
        
            noise = torch.randn(real_images.size(0), latent_size, 1, 1, device=device)
            fake_images = generator(noise)
            fake_outputs = discriminator(fake_images.detach())
            d_loss_fake = criterion(fake_outputs, fake_labels)
            d_loss_fake.backward()
            optimizer_D.step()
        
            ### Train Generator
            optimizer_G.zero_grad()
            # Optionally regenerate fake images for freshness
            noise = torch.randn(real_images.size(0), latent_size, 1, 1, device=device)
            fake_images = generator(noise)
            output = discriminator(fake_images)
        
            # Randomly decide whether to flip labels
            if random.random() < flip_prob:
                g_loss = criterion(output, fake_labels)  # Flipped labels
            else:
                g_loss = criterion(output, real_labels)  # Normal training
        
            g_loss.backward()
            optimizer_G.step()

    
            if (i + 1) % 20 == 0:
                d_losses.append(d_loss_real.item() + d_loss_fake.item())
                g_losses.append(g_loss.item())
                print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(dataloader)}], '
                      f'D Loss: {d_loss_real.item() + d_loss_fake.item():.4f}, '
                      f'G Loss: {g_loss.item():.4f}'
                      f'; D Loss Real: {d_loss_real.item():.4f}, '
                      f'; D Loss Fake: {d_loss_fake.item():.4f}, '
                      )
                
        if (epoch + 1) % 1 == 0:
            # check_output(fake_images[0], epoch)
            generate_and_plot_images(25, epoch=epoch, plot=False)
            
        if (epoch + 1) % 30 == 0:
            checkpoint = {
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'epoch': epoch  # Optional, if you want to also save the epoch number
            }
            
            torch.save(checkpoint, f'GAN_cat_{epoch}.pth')

    except OSError:
        print(f"An error occurred while processing the image. Epoch: {epoch}, batch: {i}")
        continue


Epoch [1/300], Step [20/313], D Loss: 2.7489, G Loss: 5.7095; D Loss Real: 2.7414, ; D Loss Fake: 0.0075, 



KeyboardInterrupt



In [42]:
# Assuming 'generator' and 'discriminator' are your model instances
# And 'optimizer_G' and 'optimizer_D' are the optimizers for the generator and discriminator respectively

# Define checkpoint dictionary
checkpoint = {
    'generator_state_dict': generator.state_dict(),
    'discriminator_state_dict': discriminator.state_dict(),
    'optimizer_G_state_dict': optimizer_G.state_dict(),
    'optimizer_D_state_dict': optimizer_D.state_dict(),
    'epoch': epoch  # Optional, if you want to also save the epoch number
}

# Save checkpoint
torch.save(checkpoint, f'GAN_checkpoint_main_128_{epoch}.pth')
