In [1]:
import os
import random
import itertools
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import torchvision


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision.utils import save_image
from torchvision import transforms

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3),
            nn.InstanceNorm2d(channels),
        )
        
    def forward(self, x):
        return x + self.block(x)

In [3]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=9):
        super().__init__()

        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels, 64, kernel_size=7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]

        # Downsampling
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
            out_features *= 2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
            out_features //= 2

        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, out_channels, kernel_size=7),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


In [4]:
# Assuming you have a Generator class defined
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the pre-trained generator model
G_AB = Generator().to(device)
G_AB.load_state_dict(torch.load("G_BA_monet.pth", map_location=device))
G_AB.eval()  # Set model to evaluation mode

transform = T.Compose([
    T.Resize(256),
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [5]:
input_folder = "datasets/real_imagesv2/photo_jpg"
output_folder = "outputs/monet_stylev2"
os.makedirs(output_folder, exist_ok=True)


In [6]:
for filename in os.listdir(input_folder):
    if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
        img_path = os.path.join(input_folder, filename)
        img = Image.open(img_path).convert('RGB')
        img_tensor = transform(img).unsqueeze(0).to(device)

        with torch.no_grad():
            fake_img = G_AB(img_tensor)

        # Denormalize for saving
        fake_img = 0.5 * (fake_img + 1.0)
        save_path = os.path.join(output_folder, filename)
        save_image(fake_img, save_path)
        print(f"Saved stylized image to: {save_path}")

Saved stylized image to: outputs/monet_stylev2\00068bc07f.jpg
Saved stylized image to: outputs/monet_stylev2\000910d219.jpg
Saved stylized image to: outputs/monet_stylev2\000ded5c41.jpg
Saved stylized image to: outputs/monet_stylev2\00104fd531.jpg
Saved stylized image to: outputs/monet_stylev2\001158d595.jpg
Saved stylized image to: outputs/monet_stylev2\0033c5f971.jpg
Saved stylized image to: outputs/monet_stylev2\0039ebb598.jpg
Saved stylized image to: outputs/monet_stylev2\003aab6fdd.jpg
Saved stylized image to: outputs/monet_stylev2\003c6c30e0.jpg
Saved stylized image to: outputs/monet_stylev2\00479e2a21.jpg
Saved stylized image to: outputs/monet_stylev2\005f987f56.jpg
Saved stylized image to: outputs/monet_stylev2\0080f94ebc.jpg
Saved stylized image to: outputs/monet_stylev2\00882b7e1d.jpg
Saved stylized image to: outputs/monet_stylev2\009d534136.jpg
Saved stylized image to: outputs/monet_stylev2\009ddaed1f.jpg
Saved stylized image to: outputs/monet_stylev2\00aeb60e25.jpg
Saved st

In [10]:
from torch.utils.data import Dataset
from PIL import Image
import os

class FlatFolderDataset(Dataset):
    def __init__(self, folder, transform=None):
        self.files = sorted([
            os.path.join(folder, f) for f in os.listdir(folder)
            if f.lower().endswith(('.jpg', '.jpeg', '.png'))
        ])
        self.transform = transform
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, 0  # dummy label


In [None]:
import torch
from torch.utils.data import DataLoader
from torchmetrics.image.fid import FrechetInceptionDistance
import torchvision.transforms as T

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = T.Compose([
    T.Resize((256, 256)),  # match InceptionNet input
    T.PILToTensor()
])

real_dataset = FlatFolderDataset("datasets/real_imagesv2/photo_jpg", transform)
fake_dataset = FlatFolderDataset("outputs/monet_stylev2", transform)

real_loader = DataLoader(real_dataset, batch_size=16, shuffle=False, num_workers=0)
fake_loader = DataLoader(fake_dataset, batch_size=16, shuffle=False, num_workers=0)

fid = FrechetInceptionDistance(feature=2048).to(device)

# Feed real images
for batch, _ in real_loader:
    fid.update(batch.to(device), real=True)

# Feed fake images
for batch, _ in fake_loader:
    fid.update(batch.to(device), real=False)

# Compute FID score
fid_score = fid.compute().item()
print(f"FID Score: {fid_score:.4f}")


FID Score: 46.7764
