In [12]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import itertools
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import zipfile
import glob


In [13]:
class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3),
            nn.InstanceNorm2d(dim),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3),
            nn.InstanceNorm2d(dim),
        )

    def forward(self, x):
        return x + self.block(x)

# Generator: ResNet style
class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super().__init__()
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, output_nc, 7),
            nn.Tanh()
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# Discriminator: PatchGAN
class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super().__init__()
        model = [
            nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        in_channels = 64
        for i in range(3):
            out_channels = in_channels * 2
            model += [
                nn.Conv2d(in_channels, out_channels, 4, stride=2 if i < 2 else 1, padding=1),
                nn.InstanceNorm2d(out_channels),
                nn.LeakyReLU(0.2, inplace=True)
            ]
            in_channels = out_channels
        model += [nn.Conv2d(in_channels, 1, 4, padding=1)]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [14]:
class MonetDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_paths = sorted(glob.glob(os.path.join(img_dir, '*.jpg')))
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image


In [15]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

photo_path = "/Users/navin/Documents/Anik/GitHub/kaggle-playground/notebooks/I’m Something of a Painter Myself/datasets/photo_jpg/"
monet_path = "/Users/navin/Documents/Anik/GitHub/kaggle-playground/notebooks/I’m Something of a Painter Myself/datasets/monet_jpg/"

photo_dataset = MonetDataset(photo_path, transform=transform)
monet_dataset = MonetDataset(monet_path, transform=transform)

photo_loader = DataLoader(photo_dataset, batch_size=1, shuffle=True)
monet_loader = DataLoader(monet_dataset, batch_size=1, shuffle=True)

In [16]:
if torch.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [17]:
device = torch.device(device)
print("Using device:", device)

G = Generator(3, 3).to(device)
F = Generator(3, 3).to(device)
D_X = Discriminator(3).to(device)
D_Y = Discriminator(3).to(device)

optimizer_G = optim.Adam(itertools.chain(G.parameters(), F.parameters()), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_X = optim.Adam(D_X.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_Y = optim.Adam(D_Y.parameters(), lr=0.0002, betas=(0.5, 0.999))

criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()

lambda_cycle = 10.0

for epoch in range(1, 101):
    for i, (real_X, real_Y) in enumerate(zip(photo_loader, monet_loader)):
        real_X = real_X.to(device)
        real_Y = real_Y.to(device)
        valid = torch.ones(real_X.size(0), 1, 30, 30).to(device)
        fake = torch.zeros(real_X.size(0), 1, 30, 30).to(device)

        # Generators
        fake_Y = G(real_X)
        rec_X = F(fake_Y)
        fake_X = F(real_Y)
        rec_Y = G(fake_X)

        loss_G = criterion_GAN(D_Y(fake_Y), valid) + criterion_GAN(D_X(fake_X), valid)
        loss_cycle = criterion_cycle(rec_X, real_X) + criterion_cycle(rec_Y, real_Y)
        loss_total_G = loss_G + lambda_cycle * loss_cycle

        optimizer_G.zero_grad()
        loss_total_G.backward()
        optimizer_G.step()

        # Discriminators
        optimizer_D_X.zero_grad()
        loss_D_X = criterion_GAN(D_X(real_X), valid) + criterion_GAN(D_X(fake_X.detach()), fake)
        loss_D_X.backward()
        optimizer_D_X.step()

        optimizer_D_Y.zero_grad()
        loss_D_Y = criterion_GAN(D_Y(real_Y), valid) + criterion_GAN(D_Y(fake_Y.detach()), fake)
        loss_D_Y.backward()
        optimizer_D_Y.step()

        if i % 100 == 0:
            print(f"Epoch {epoch}, Batch {i}: G_loss: {loss_total_G.item():.4f}, D_X: {loss_D_X.item():.4f}, D_Y: {loss_D_Y.item():.4f}")

Using device: mps
Epoch 1, Batch 0: G_loss: 11.9555, D_X: 1.3480, D_Y: 1.5696
Epoch 1, Batch 100: G_loss: 6.8848, D_X: 0.2377, D_Y: 0.2568
Epoch 1, Batch 200: G_loss: 7.0345, D_X: 0.3929, D_Y: 0.4980
Epoch 2, Batch 0: G_loss: 6.5519, D_X: 0.4686, D_Y: 0.3899
Epoch 2, Batch 100: G_loss: 6.3739, D_X: 0.3846, D_Y: 0.7823
Epoch 2, Batch 200: G_loss: 7.5516, D_X: 0.2855, D_Y: 0.4029
Epoch 3, Batch 0: G_loss: 5.4740, D_X: 0.4405, D_Y: 0.3090
Epoch 3, Batch 100: G_loss: 5.6143, D_X: 0.2801, D_Y: 0.5408
Epoch 3, Batch 200: G_loss: 5.9411, D_X: 0.4042, D_Y: 0.4016
Epoch 4, Batch 0: G_loss: 5.4786, D_X: 0.1599, D_Y: 0.3949
Epoch 4, Batch 100: G_loss: 7.7105, D_X: 0.5595, D_Y: 0.4573
Epoch 4, Batch 200: G_loss: 5.6536, D_X: 0.1406, D_Y: 0.7757
Epoch 5, Batch 0: G_loss: 5.7694, D_X: 0.1026, D_Y: 0.2331
Epoch 5, Batch 100: G_loss: 6.3901, D_X: 0.2870, D_Y: 1.0039
Epoch 5, Batch 200: G_loss: 6.2900, D_X: 0.1968, D_Y: 0.4952
Epoch 6, Batch 0: G_loss: 6.9517, D_X: 0.4389, D_Y: 0.4500
Epoch 6, Batch 10

In [18]:
import io
with zipfile.ZipFile("images.zip", mode="w") as zipf:
    G.eval()
    for i, img in enumerate(photo_loader):
        if i >= 7000:
            break
        img = img.to(device)
        with torch.no_grad():
            fake = G(img)
        fake = (fake.squeeze(0).cpu() + 1) / 2

        buffer = io.BytesIO()
        save_image(fake, buffer, format='JPEG')
        buffer.seek(0)

        zipf.writestr(f"{i}.jpg", buffer.read())