In [1]:
import torch

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# GPU details if available
if device.type == "cuda":
    print(torch.cuda.get_device_name(0))
    print(f"CUDA is available: {torch.cuda.is_available()}")

Using device: cuda
Tesla T4
CUDA is available: True


In [2]:
import glob
import os

# Define the paths to the dataset
GCS_PATH = '/kaggle/input/gan-getting-started'
monet_files_path = os.path.join(GCS_PATH, 'monet_tfrec', '*.tfrec')
photo_files_path = os.path.join(GCS_PATH, 'photo_tfrec', '*.tfrec')

# Load filenames using glob
MONET_FILENAMES = glob.glob(monet_files_path)
PHOTO_FILENAMES = glob.glob(photo_files_path)

print('Monet TFRecord Files:', len(MONET_FILENAMES))
print('Photo TFRecord Files:', len(PHOTO_FILENAMES))

Monet TFRecord Files: 5
Photo TFRecord Files: 20


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import itertools
import os
from PIL import Image
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
import torchvision.utils as vutils
from torchvision.utils import make_grid
import cv2
from torchvision.models import vgg16
import zipfile
import gc

In [4]:
# Constants
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 4
IMAGE_SIZE = 256
LEARNING_RATE = 0.0002
BETA1 = 0.5
BETA2 = 0.999
EPOCHS = 5
LAMBDA_CYCLE = 10.0
LAMBDA_IDENTITY = 5.0

# Visualization class
class VisualizeResults:
    def __init__(self, save_dir='training_progress4_1'):
        self.save_dir = save_dir
        os.makedirs(save_dir, exist_ok=True)
    
    @staticmethod
    def denormalize(tensor):
        """Convert normalized image tensor back to regular image tensor"""
        tensor = tensor.clone()
        tensor = tensor * 0.5 + 0.5
        return tensor.clamp(0, 1)
    
    def save_image_grid(self, real_A, fake_B, real_B, fake_A, epoch, batch_idx):
        """Save a grid of real and generated images"""
        # Denormalize images
        real_A = self.denormalize(real_A)
        fake_B = self.denormalize(fake_B)
        real_B = self.denormalize(real_B)
        fake_A = self.denormalize(fake_A)
        
        # Create image grid
        img_grid = make_grid([
            real_A[0], fake_B[0],  # Photo → Monet
            real_B[0], fake_A[0],  # Monet → Photo
        ], nrow=2)
        
        # Convert to numpy and transpose
        img_grid = img_grid.cpu().numpy().transpose((1, 2, 0))
        
        # Create figure
        plt.figure(figsize=(10, 10))
        plt.imshow(img_grid)
        plt.axis('off')
        
        # Add labels
        plt.text(32, 20, 'Real Photo', color='white', fontsize=10)
        plt.text(IMAGE_SIZE + 32, 20, 'Generated Monet', color='white', fontsize=10)
        plt.text(32, IMAGE_SIZE + 20, 'Real Monet', color='white', fontsize=10)
        plt.text(IMAGE_SIZE + 32, IMAGE_SIZE + 20, 'Generated Photo', color='white', fontsize=10)
        
        # Save figure
        plt.savefig(f'{self.save_dir}/epoch_{epoch}batch{batch_idx}.png', 
                   bbox_inches='tight', pad_inches=0.1)
        plt.close()
    
    def plot_losses(self, g_losses, d_losses):
        """Plot generator and discriminator losses"""
        plt.figure(figsize=(10, 5))
        plt.plot(g_losses, label='Generator Loss')
        plt.plot(d_losses, label='Discriminator Loss')
        plt.xlabel('Iterations')
        plt.ylabel('Loss')
        plt.title('Training Losses')
        plt.legend()
        plt.savefig(f'{self.save_dir}/training_losses.png')
        plt.close()


In [5]:
# Dataset Class
class MonetPhotoDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = glob(os.path.join(root_dir, "*.jpg"))
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        
        # Convert to numpy array
        img_np = np.array(image)
        
        # Apply histogram equalization to each channel
        for i in range(3):
            img_np[:,:,i] = cv2.equalizeHist(img_np[:,:,i])
            
        image = Image.fromarray(img_np)
        
        if self.transform:
            image = self.transform(image)
        
        return image


In [6]:
# Generator Architecture
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, 3, groups=max(1, channels//4)),  # Grouped convolution for more texture
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),  # Slight dropout for variation
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, 3),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, input_channels=3, num_residual_blocks=9):
        super(Generator, self).__init__()
        
        # Initial convolution
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_channels, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]
        
        # Downsampling
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2
        
        # Residual blocks
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(in_features)]
        
        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2
        
        # Output layer
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, input_channels, 7),
            nn.Tanh()
        ]
        
        self.model = nn.Sequential(*model)
        self.instance_norm = nn.InstanceNorm2d(input_channels, affine=True)
    
    def forward(self, x):
        x = self.instance_norm(x)
        return self.model(x)

# Discriminator Architecture
class Discriminator(nn.Module):
    def __init__(self, input_channels=3):
        super(Discriminator, self).__init__()
        
        def discriminator_block(in_filters, out_filters, normalize=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *discriminator_block(input_channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )
    
    def forward(self, x):
        return self.model(x)


In [7]:
class StyleLoss:
    def __init__(self):
        vgg = vgg16(pretrained=True).features[:16].to(DEVICE)
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg
    
    def compute_style_loss(self, generated, style):
        """Compute style loss using Gram matrix"""
        def gram_matrix(x):
            b, c, h, w = x.size()
            features = x.view(b, c, h * w)
            G = torch.bmm(features, features.transpose(1, 2))
            return G.div(c * h * w)
        
        # Extract features
        gen_features = self.vgg(generated)
        style_features = self.vgg(style)
        
        # Compute Gram matrix style loss
        style_gram = gram_matrix(style_features)
        gen_gram = gram_matrix(gen_features)
        
        return nn.functional.mse_loss(gen_gram, style_gram)        

# Loss functions
class CycleGANLoss:
    def __init__(self):
        self.mae_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
        self.style_loss = StyleLoss()
        self.style_weight = 500.0  # Adjust this to control style transfer intensity


    def get_style_transfer_loss(self, generated, style_reference):
        """Compute additional style transfer loss"""
        return self.style_loss.compute_style_loss(generated, style_reference) * self.style_weight    

    def get_color_consistency_loss(self, real, fake):
        real_mean = torch.mean(real, dim=[2, 3])
        fake_mean = torch.mean(fake, dim=[2, 3])
        return self.mae_loss(real_mean, fake_mean)    
    
    def get_gan_loss(self, pred, target_is_real):
        target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
        return self.mse_loss(pred, target)
    
    def get_cycle_loss(self, real, cycled):
        return self.mae_loss(real, cycled) * LAMBDA_CYCLE
    
    def get_identity_loss(self, real, same):
        return self.mae_loss(real, same) * LAMBDA_IDENTITY

# Training class
class CycleGAN:
    def __init__(self):
        # Initialize generators and discriminators
        self.G_AB = Generator().to(DEVICE)
        self.G_BA = Generator().to(DEVICE)
        self.D_A = Discriminator().to(DEVICE)
        self.D_B = Discriminator().to(DEVICE)
        
        # Initialize optimizers
        self.optimizer_G = optim.Adam(
            itertools.chain(self.G_AB.parameters(), self.G_BA.parameters()),
            lr=LEARNING_RATE, betas=(BETA1, BETA2)
        )
        self.optimizer_D = optim.Adam(
            itertools.chain(self.D_A.parameters(), self.D_B.parameters()),
            lr=LEARNING_RATE, betas=(BETA1, BETA2)
        )
        
        # Initialize loss functions
        self.criterion = CycleGANLoss()
        
        # Initialize learning rate schedulers
        self.scheduler_G = optim.lr_scheduler.LinearLR(self.optimizer_G, 
                                                      start_factor=1.0, 
                                                      end_factor=0.0,
                                                      total_iters=EPOCHS)
        self.scheduler_D = optim.lr_scheduler.LinearLR(self.optimizer_D,
                                                      start_factor=1.0,
                                                      end_factor=0.0,
                                                      total_iters=EPOCHS)
    
    def train_step(self, real_A, real_B):
        # Generate fake images
        fake_B = self.G_AB(real_A)
        fake_A = self.G_BA(real_B)
        
        # Reconstruct images
        cycle_A = self.G_BA(fake_B)
        cycle_B = self.G_AB(fake_A)
        
        # Identity mapping
        same_A = self.G_BA(real_A)
        same_B = self.G_AB(real_B)
        
        # Train Generators
        self.optimizer_G.zero_grad()

         # Compute additional style transfer loss
        style_loss_AB = self.criterion.get_style_transfer_loss(fake_B, real_B)
        style_loss_BA = self.criterion.get_style_transfer_loss(fake_A, real_A)

        # Color consistency loss
        loss_color = (self.criterion.get_color_consistency_loss(real_A, fake_A) + 
                     self.criterion.get_color_consistency_loss(real_B, fake_B)) * 5.0
        loss_G = loss_color

        # Identity loss
        loss_identity_A = self.criterion.get_identity_loss(real_A, same_A)
        loss_identity_B = self.criterion.get_identity_loss(real_B, same_B)

        # GAN loss
        loss_GAN_AB = self.criterion.get_gan_loss(self.D_B(fake_B), True)
        loss_GAN_BA = self.criterion.get_gan_loss(self.D_A(fake_A), True)

        # Cycle loss
        loss_cycle_A = self.criterion.get_cycle_loss(real_A, cycle_A)
        loss_cycle_B = self.criterion.get_cycle_loss(real_B, cycle_B)

         # Combined generator loss - add style transfer loss
        loss_G += (loss_identity_A + loss_identity_B + 
                   loss_GAN_AB + loss_GAN_BA + 
                   loss_cycle_A + loss_cycle_B + 
                   style_loss_AB + style_loss_BA)

        loss_G.backward()
        torch.nn.utils.clip_grad_norm_(
            itertools.chain(self.G_AB.parameters(), self.G_BA.parameters()), 
            max_norm=1.0
        )
        self.optimizer_G.step()

        # Train Discriminators
        self.optimizer_D.zero_grad()

        # Discriminator A loss
        loss_D_A_real = self.criterion.get_gan_loss(self.D_A(real_A), True)
        loss_D_A_fake = self.criterion.get_gan_loss(self.D_A(fake_A.detach()), False)
        loss_D_A = (loss_D_A_real + loss_D_A_fake) * 0.5

        # Discriminator B loss
        loss_D_B_real = self.criterion.get_gan_loss(self.D_B(real_B), True)
        loss_D_B_fake = self.criterion.get_gan_loss(self.D_B(fake_B.detach()), False)
        loss_D_B = (loss_D_B_real + loss_D_B_fake) * 0.5

        # Combined discriminator loss
        loss_D = loss_D_A + loss_D_B

        loss_D.backward()
        torch.nn.utils.clip_grad_norm_(
            itertools.chain(self.G_AB.parameters(), self.G_BA.parameters()), 
            max_norm=1.0
        )
        self.optimizer_D.step()

        return {
            'loss_G': loss_G.item(),
            'loss_D': loss_D.item(),
            'fake_A': fake_A,
            'fake_B': fake_B
        }
    
    def save_models(self, epoch):
        torch.save(self.G_AB.state_dict(), f'generator_AB_epoch_{epoch}.pth')
        torch.save(self.G_BA.state_dict(), f'generator_BA_epoch_{epoch}.pth')
        torch.save(self.D_A.state_dict(), f'discriminator_A_epoch_{epoch}.pth')
        torch.save(self.D_B.state_dict(), f'discriminator_B_epoch_{epoch}.pth')


In [8]:
def train():
    # Data preprocessing
    transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
    ])
    
    # Load datasets
    monet_dataset = MonetPhotoDataset("/kaggle/input/gan-getting-started/monet_jpg", transform=transform)
    photo_dataset = MonetPhotoDataset("/kaggle/input/gan-getting-started/photo_jpg", transform=transform)
    
    monet_loader = DataLoader(monet_dataset, batch_size=BATCH_SIZE, shuffle=True)
    photo_loader = DataLoader(photo_dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    # Initialize model and visualizer
    model = CycleGAN()
    visualizer = VisualizeResults()
    
    # Track losses for plotting
    g_losses = []
    d_losses = []
    
    # Training loop
    for epoch in range(EPOCHS):
        for i, (monet_imgs, photo_imgs) in enumerate(zip(monet_loader, photo_loader)):
            monet_imgs = monet_imgs.to(DEVICE)
            photo_imgs = photo_imgs.to(DEVICE)
            
            # Train step
            results = model.train_step(photo_imgs, monet_imgs)
            
            # Record losses
            g_losses.append(results['loss_G'])
            d_losses.append(results['loss_D'])
            
            # Print progress
            if i % 100 == 0:
                print(f"Epoch [{epoch}/{EPOCHS}] Batch [{i}/{len(monet_loader)}] "
                      f"Loss G: {results['loss_G']:.4f}, Loss D: {results['loss_D']:.4f}")
            
            # Save images for visualization every 100th batch
            if i % 100 == 0:
                visualizer.save_image_grid(photo_imgs, results['fake_B'], 
                                         monet_imgs, results['fake_A'], 
                                         epoch, i)
        
        # Save model checkpoints at regular intervals
        if (epoch + 1) % 5 == 0:
            model.save_models(epoch + 1)
        
        # Plot and save losses
        visualizer.plot_losses(g_losses, d_losses)
    
    print("Training completed.")

def generate_images(num_images=7100):
    # Load trained generator
    generator = Generator().to(DEVICE)
    generator.load_state_dict(torch.load('generator_AB_epoch_5.pth'))  # Load the last saved model
    generator.eval()
    
    # Create a zip file to store the images
    zip_filename = '/kaggle/working/images.zip'
    with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
    
        # Data preprocessing
        transform = transforms.Compose([
            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        # Load photo dataset
        photo_dataset = MonetPhotoDataset("/kaggle/input/gan-getting-started/photo_jpg", transform=transform)
        photo_loader = DataLoader(photo_dataset, batch_size=1, shuffle=False)
        
        # Generate images
        with torch.no_grad():
            for i, photo in enumerate(itertools.islice(photo_loader, num_images)):
                photo = photo.to(DEVICE)
                fake_monet = generator(photo)
                
                # Save individual generated image to the zip file
                fake_monet_img = VisualizeResults.denormalize(fake_monet)  # Use as a static method
                fake_monet_img = fake_monet_img.cpu().squeeze(0).permute(1, 2, 0).numpy()
                generated_image = Image.fromarray((fake_monet_img * 255).astype(np.uint8))
                with zipf.open(f'generated_{i}.jpg', 'w') as imgf:
                    generated_image.save(imgf, format='JPEG')

                # Explicitly clear tensors and collect garbage to free memory
                del photo, fake_monet, fake_monet_img
                gc.collect()
                torch.cuda.empty_cache()

                if i % 100 == 0:
                    print(f"Generated {i} images")
                
    print(f"All images have been saved to {zip_filename}")


In [9]:
if __name__ == "__main__":
    print(f"Using device: {DEVICE}")
    
    # Train the model
    print("Starting training...")
    train()
    print("Training completed!")
    
    # Generate images
    print("Generating Monet-style images...")
    generate_images()
    print("Image generation completed!")

Using device: cuda
Starting training...


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 213MB/s]


Epoch [0/5] Batch [0/75] Loss G: 21.9818, Loss D: 1.2837
Epoch [1/5] Batch [0/75] Loss G: 6.2680, Loss D: 0.4404
Epoch [2/5] Batch [0/75] Loss G: 8.2299, Loss D: 0.5152
Epoch [3/5] Batch [0/75] Loss G: 5.4436, Loss D: 0.4353
Epoch [4/5] Batch [0/75] Loss G: 6.2761, Loss D: 0.3710
Training completed.
Training completed!
Generating Monet-style images...


  generator.load_state_dict(torch.load('generator_AB_epoch_5.pth'))  # Load the last saved model


Generated 0 images
Generated 100 images
Generated 200 images
Generated 300 images
Generated 400 images
Generated 500 images
Generated 600 images
Generated 700 images
Generated 800 images
Generated 900 images
Generated 1000 images
Generated 1100 images
Generated 1200 images
Generated 1300 images
Generated 1400 images
Generated 1500 images
Generated 1600 images
Generated 1700 images
Generated 1800 images
Generated 1900 images
Generated 2000 images
Generated 2100 images
Generated 2200 images
Generated 2300 images
Generated 2400 images
Generated 2500 images
Generated 2600 images
Generated 2700 images
Generated 2800 images
Generated 2900 images
Generated 3000 images
Generated 3100 images
Generated 3200 images
Generated 3300 images
Generated 3400 images
Generated 3500 images
Generated 3600 images
Generated 3700 images
Generated 3800 images
Generated 3900 images
Generated 4000 images
Generated 4100 images
Generated 4200 images
Generated 4300 images
Generated 4400 images
Generated 4500 images
