In [None]:
# Import Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, ToTensor, RandomHorizontalFlip
import matplotlib.pyplot as plt
import itertools

# Set Device
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif torch.has_mps:
    device = 'mps'
torch.manual_seed(237237)

In [None]:
# Get the Data
transform = Compose([Resize((256, 256)), ToTensor()])
painting_transform = Compose([Resize((256, 256)), ToTensor(), RandomHorizontalFlip()])
photo_dataset = ImageFolder(r'../data/photo_jpg', transform=transform)
monet_dataset = ImageFolder(r'../data/monet_jpg', transform=painting_transform)

train_size = int(0.8 * len(photo_dataset))
val_size = len(photo_dataset) - train_size
photo_train, photo_val = random_split(photo_dataset, [train_size, val_size])

batch_size = 16
photo_loader_train = DataLoader(photo_train, batch_size=batch_size, shuffle=True)
photo_loader_val = DataLoader(photo_val, batch_size=batch_size, shuffle=False)
monet_loader = DataLoader(monet_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Display Monet Images and Input Photos
def show_images(loader, title):
    plt.figure(figsize=(15, 3))
    for i, (images, _) in enumerate(loader):
        images = images.numpy().transpose((0, 2, 3, 1))
        for j in range(5):
            plt.subplot(1, 5, j+1)
            plt.imshow(images[j])
            plt.axis('off')
        plt.suptitle(title)
        break
    plt.show()

# Display images from each dataset
show_images(photo_loader_val, 'Input Images')
show_images(monet_loader, 'Monet Paintings')

In [None]:
# Function to denormalize image for display
def denormalize(image):
    image = image * 0.5 + 0.5  # Assuming images were normalized in range [-1, 1]
    return image.clamp(0, 1)

In [None]:
# Function to display images
def show_transformed_images(generator, photo_loader, device, denormalize_func):
    generator.eval()  # Set the generator to evaluation mode
    with torch.no_grad():
        for i, (photos, _) in enumerate(photo_loader):
            if i >= 5:  # Display first 5 images
                break

            original_photo = photos.to(device)
            transformed_image = generator(original_photo).cpu()

            plt.figure(figsize=(10, 4))
            plt.subplot(1, 2, 1)
            plt.axis("off")
            plt.title("Original Image")
            plt.imshow(denormalize_func(original_photo[0].cpu()).permute(1, 2, 0).numpy())

            plt.subplot(1, 2, 2)
            plt.axis("off")
            plt.title("Transformed Image")
            plt.imshow(denormalize_func(transformed_image[0]).permute(1, 2, 0).numpy())

            plt.show()

In [None]:
# Define the Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            # Input size: (3 x 256 x 256)
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.ReLU(inplace=True),
            # State size: (64 x 128 x 128)
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            # State size: (128 x 64 x 64)
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            # State size: (256 x 32 x 32)
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            # State size: (512 x 16 x 16)
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # State size: (256 x 32 x 32)
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # State size: (128 x 64 x 64)
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # State size: (64 x 128 x 128)
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
            # Output state size: (3 x 256 x 256)
        )

    def forward(self, x):
        return self.model(x)

In [None]:
# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
        # Input size: (3 x 256 x 256)
        nn.Conv2d(3, 64, 4, 2, 1, bias=False),
        nn.LeakyReLU(0.2, inplace=True),
        nn.MaxPool2d(2, 2),  # Max pooling
        # State size: (64 x 64 x 64)
        nn.Conv2d(64, 128, 4, 2, 1, bias=False),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2, inplace=True),
        nn.MaxPool2d(2, 2),  # Max pooling
        # State size: (128 x 16 x 16)
        nn.Conv2d(128, 256, 4, 2, 1, bias=False),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2, inplace=True),
        # State size: (256 x 8 x 8)
        nn.Conv2d(256, 512, 4, 2, 1, bias=False),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2, inplace=True),
        # State size: (512 x 4 x 4)
        nn.Conv2d(512, 1024, 4, 2, 1, bias=False),
        nn.BatchNorm2d(1024),
        nn.LeakyReLU(0.2, inplace=True),
        # State size: (1024 x 2 x 2)
        nn.Conv2d(1024, 1, 2, 1, 0, bias=False),  # Adjusted final convolution
        nn.Sigmoid()
        # Output state size: (1 x 1 x 1)
    )
    
    def forward(self, x):
        return self.model(x)

In [None]:
# Initialize the models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [None]:
# Define Loss Functions and Optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
# Training Loop
num_epochs = 100
for epoch in range(num_epochs):
    monet_loader_iter = itertools.cycle(monet_loader)
    
    for i, photo_data in enumerate(photo_loader_train):
        monet_data = next(monet_loader_iter)

        current_batch_size = photo_data[0].size(0)

        photos = photo_data[0].to(device)
        monets = monet_data[0].to(device)

        real_label = torch.ones(current_batch_size, 1, device=device)
        fake_label = torch.zeros(current_batch_size, 1, device=device)

        # === Discriminator Training ===
        optimizer_D.zero_grad()

        output = discriminator(photos).view(-1)
        lossD_real = criterion(output, real_label.view(-1))
        lossD_real.backward()

        fake_monets = generator(photos)
        output = discriminator(fake_monets.detach()).view(-1)
        lossD_fake = criterion(output, fake_label.view(-1))
        lossD_fake.backward()

        lossD = lossD_real + lossD_fake
        optimizer_D.step()

        # === Generator Training ===
        optimizer_G.zero_grad()
        output = discriminator(fake_monets).view(-1)
        lossG = criterion(output, real_label.view(-1))
        lossG.backward()
        optimizer_G.step()

        if (i + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(photo_loader_train)}], D Loss: {lossD.item():.4f}, G Loss: {lossG.item():.4f}')

    # Validation
    for photo_data in photo_loader_val:
        monet_data = next(monet_loader_iter)
        with torch.no_grad():
            photos = photo_data[0].to(device)
            monets = monet_data[0].to(device)
            fake_monets = generator(photos)

            current_batch_size = photo_data[0].size(0)
            real_labels = torch.ones(monets.size(0), 1).to(device)
            fake_labels = torch.zeros(fake_monets.size(0), 1).to(device)
            combined_images = torch.cat([monets, fake_monets], dim=0)
            combined_labels = torch.cat([real_labels, fake_labels], dim=0)

            output = discriminator(combined_images).view(-1)
            loss_val = criterion(output, combined_labels.view(-1))

    print(f"Epoch {epoch+1}/{num_epochs} - Validation Error: {loss_val.item():.4f}")

    if (epoch + 1) % 10 == 0 or epoch < 10:
        show_transformed_images(generator, photo_loader_val, device, denormalize)

In [None]:
# Display Original and Transformed Images
generator.eval()

# Display first 5 images
for i, (photo_data, _) in enumerate(photo_loader_val):
    if i >= 5:  # Only display first 5 images
        break
    
    # Original image
    original_image = photo_data[0]

    # Generate transformed image
    with torch.no_grad():
        transformed_image = generator(original_image.to(device)).cpu()

    # Plotting
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.axis("off")
    plt.title("Original Image")
    plt.imshow(denormalize(original_image[0]).permute(1, 2, 0))  # Convert from CxHxW to HxWxC

    plt.subplot(1, 2, 2)
    plt.axis("off")
    plt.title("Transformed Image")
    plt.imshow(denormalize(transformed_image[0]).permute(1, 2, 0))  # Convert from CxHxW to HxWxC

    plt.show()

In [None]:
# Save the model
torch.save(generator.state_dict(), 'generator.ckpt')
torch.save(discriminator.state_dict(), 'discriminator.ckpt')