In [6]:
# Import Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, ToTensor
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from PIL import Image

# Set Device
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif torch.has_mps:
    device = 'mps'
torch.manual_seed(237237)

<torch._C.Generator at 0x1b7b80a5f10>

In [7]:
# Get the Data
transform = Compose([Resize((256, 256)), ToTensor()])
photo_dataset = ImageFolder(r'../data', transform=transform)
monet_dataset = ImageFolder(r'../data', transform=transform)

batch_size = 32
photo_loader = DataLoader(photo_dataset, batch_size=batch_size, shuffle=True)
monet_loader = DataLoader(monet_dataset, batch_size=batch_size, shuffle=True)

In [8]:
# Define the Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            # Input is a latent vector Z
            nn.ConvTranspose2d(100, 1024, 4, 1, 0, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            # State size: (1024 x 4 x 4)
            nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # State size: (512 x 8 x 8)
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # State size: (256 x 16 x 16)
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # State size: (128 x 32 x 32)
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # State size: (64 x 64 x 64)
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
            # Output state size: (3 x 128 x 128)
        )
    
    def forward(self, x):
        # Implement the forward pass
        return self.model(x)

In [28]:
# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            # Input size: (3 x 256 x 256)
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State size: (64 x 128 x 128)
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # State size: (128 x 64 x 64)
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # State size: (256 x 32 x 32)
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # State size: (512 x 16 x 16)
            nn.Conv2d(512, 1024, 4, 2, 1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            # State size: (1024 x 8 x 8)
            nn.Conv2d(1024, 1, 4, 1, 0, bias=False),
            nn.Flatten(),
            nn.Sigmoid()
            # Output state size: (1 x 1 x 1)
        )
    
    def forward(self, x):
        return self.model(x)

In [29]:
# Initialize the models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [30]:
# Define Loss Functions and Optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [32]:
# Training Loop
num_epochs = 100
for epoch in range(num_epochs):
    for i, (photo_data, monet_data) in enumerate(zip(photo_loader, monet_loader)):
        # get size of current batch
        current_batch_size = photo_data[0].size(0)
        
        # Get images and move them to device
        photos = photo_data[0].to(device)
        monets = monet_data[0].to(device)
        
        # Correctly sized labels
        real_label = torch.ones(current_batch_size, 1, device=device)
        fake_label = torch.zeros(current_batch_size, 1, device=device)
        
        # === Discriminator Training ===
        optimizer_D.zero_grad()
        
        # Train with real images
        output = discriminator(photos).view(-1)
        print("Discriminator Output Shape:", output.shape)
        print("label shape", real_label.shape)
        lossD_real = criterion(output, real_label.view(-1))
        lossD_real.backward()
        
        # Train with fake images
        fake_monets = generator(photos)
        output = discriminator(fake_monets.detach()).view(-1)
        lossD_fake = criterion(output, fake_label.view(-1))
        lossD_fake.backward()
        
        # Update Discriminator
        optimizer_D.step()
        
        # === Train Generator ===
        optimizer_G.zero_grad()
        output = discriminator(fake_monets).view(-1)
        lossG = criterion(output, real_label)
        lossG.backward()
        
        # Update Generator
        optimizer_G.step()
        
        # Display Progress
        if (i + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(photo_loader)}]')

Discriminator Output Shape: torch.Size([800])
label shape torch.Size([32, 1])


ValueError: Using a target size (torch.Size([32])) that is different to the input size (torch.Size([800])) is deprecated. Please ensure they have the same size.

In [None]:
# Function to denormalize image for display
def denormalize(image):
    image = image * 0.5 + 0.5  # Assuming images were normalized in range [-1, 1]
    return image.clamp(0, 1)

In [None]:
# Display Original and Transformed Images
generator.eval()

# Display first 5 images
for i, (photo_data, _) in enumerate(photo_loader):
    if i >= 5:  # Only display first 5 images
        break
    
    # Original image
    original_image = photo_data[0]

    # Generate transformed image
    with torch.no_grad():
        transformed_image = generator(original_image.to(device)).cpu()

    # Plotting
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.axis("off")
    plt.title("Original Image")
    plt.imshow(denormalize(original_image[0]).permute(1, 2, 0))  # Convert from CxHxW to HxWxC

    plt.subplot(1, 2, 2)
    plt.axis("off")
    plt.title("Transformed Image")
    plt.imshow(denormalize(transformed_image[0]).permute(1, 2, 0))  # Convert from CxHxW to HxWxC

    plt.show()

In [None]:
# Save the model
torch.save(generator.state_dict(), 'generator.ckpt')
torch.save(discriminator.state_dict(), 'discriminator.ckpt')