In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import models, transforms, utils
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import glob
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
import csv


In [None]:
# --- 1. Core Components (Encoder and AdaIN) ---
# pre-trained VGG-19 for the encoder.

class VGGEncoder(nn.Module):
    """
    VGG-19 Encoder Network.
    This class loads a pre-trained VGG-19 model and extracts features from
    intermediate layers (relu1_1, relu2_1, relu3_1, relu4_1) which are
    used to calculate style and content losses.
    """
    def __init__(self):
        super(VGGEncoder, self).__init__()
        # Load pre-trained VGG19 model features
        vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
        
        # Define the layers to be used for feature extraction
        self.slice1 = vgg[:2]    # relu1_1
        self.slice2 = vgg[2:7]   # relu2_1
        self.slice3 = vgg[7:12]  # relu3_1
        self.slice4 = vgg[12:21] # relu4_1

        # Freeze the weights as we are only using it for feature extraction
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        h1 = self.slice1(x)
        h2 = self.slice2(h1)
        h3 = self.slice3(h2)
        h4 = self.slice4(h3)
        return h1, h2, h3, h4

def adaptive_instance_normalization(content_feat, style_feat):
    """
    Adaptive Instance Normalization (AdaIN) layer.
    This function takes content features and style features and aligns the
    mean and standard deviation of the content features to match those of the
    style features.
    """
    assert content_feat.size()[:2] == style_feat.size()[:2]
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)

def calc_mean_std(feat, eps=1e-5):
    """Calculate mean and standard deviation along spatial dimensions."""
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std

# --- 2. Decoder Architectures ---
# Here we define the four different decoder architectures

class ConvReLUUpsample(nn.Module):
    """A standard block for the baseline decoder."""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, upsample=True):
        super().__init__()
        self.upsample = upsample
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = self.conv(x)
        x = self.relu(x)
        return x

class BaselineDecoder(nn.Module):
    """
    The baseline decoder, a mirrored version of the VGG encoder.
    FIXED: This version is corrected to have the right number of upsampling layers
    (3) to match the output size to the input size (e.g., 32x32 -> 256x256).
    """
    def __init__(self):
        super().__init__()
        self.decoder = nn.Sequential(
            ConvReLUUpsample(512, 256, upsample=False), # -> 32x32
            ConvReLUUpsample(256, 256),                # Upsamples to 64x64
            ConvReLUUpsample(256, 128),                # Upsamples to 128x128
            ConvReLUUpsample(128, 64),                 # Upsamples to 256x256
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
        )
    
    def forward(self, x):
        return self.decoder(x)

class ResNetBlock(nn.Module):
    """A Residual Block, the core component of a ResNet-style architecture."""
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x):
        residual = x
        out = self.relu(self.conv1(x))
        out = self.conv2(out)
        out += residual
        return out

class ResNetDecoder(nn.Module):
    """A ResNet-style decoder that uses ResNetBlocks."""
    def __init__(self, num_blocks=2):
        super().__init__()
        self.decoder = nn.Sequential(
            ConvReLUUpsample(512, 256, upsample=False),
            *[ResNetBlock(256) for _ in range(num_blocks)],
            ConvReLUUpsample(256, 256),
            *[ResNetBlock(256) for _ in range(num_blocks)],
            ConvReLUUpsample(256, 128),
            *[ResNetBlock(128) for _ in range(num_blocks)],
            ConvReLUUpsample(128, 64),
            *[ResNetBlock(64) for _ in range(num_blocks)],
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
        )
    
    def forward(self, x):
        return self.decoder(x)

class SelfAttention(nn.Module):
    """A simple self-attention module."""
    def __init__(self, in_channels):
        super().__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size, C, width, height = x.size()
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x
        return out

class AttentionDecoder(nn.Module):
    """
    A decoder that incorporates self-attention modules.
    """
    def __init__(self):
        super().__init__()
        self.decoder = nn.Sequential(
            ConvReLUUpsample(512, 256, upsample=False), # -> 32x32
            SelfAttention(256),
            ConvReLUUpsample(256, 256),                # Upsamples to 64x64
            ConvReLUUpsample(256, 128),                # Upsamples to 128x128
            SelfAttention(128),
            ConvReLUUpsample(128, 64),                 # Upsamples to 256x256
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
        )
    def forward(self, x):
        return self.decoder(x)

class UNetDecoder(nn.Module):
    """A U-Net style decoder that uses skip connections from the encoder."""
    def __init__(self):
        super(UNetDecoder, self).__init__()
        # Decoder blocks. Upsampling is handled explicitly in the forward pass.
        self.up_conv1 = ConvReLUUpsample(512, 256, upsample=False)
        self.up_conv2 = ConvReLUUpsample(256 + 256, 128, upsample=False)
        self.up_conv3 = ConvReLUUpsample(128 + 128, 64, upsample=False)
        self.up_conv4 = ConvReLUUpsample(64 + 64, 64, upsample=False)
        self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1)

    def forward(self, stylized_content_feat, encoder_feats):
        e1, e2, e3, _ = encoder_feats

        # Path: AdaIN -> relu3_1
        d1 = self.up_conv1(stylized_content_feat)                             # [B, 256, 32x32]
        d1_upsampled = F.interpolate(d1, scale_factor=2, mode='nearest')      # [B, 256, 64x64]
        d1_cat = torch.cat([d1_upsampled, e3], dim=1)                         # Concat with e3 -> [B, 512, 64x64]

        # Path: -> relu2_1
        d2 = self.up_conv2(d1_cat)                                            # [B, 128, 64x64]
        d2_upsampled = F.interpolate(d2, scale_factor=2, mode='nearest')      # [B, 128, 128x128]
        d2_cat = torch.cat([d2_upsampled, e2], dim=1)                         # Concat with e2 -> [B, 256, 128x128]
        
        # Path: -> relu1_1
        d3 = self.up_conv3(d2_cat)                                            # [B, 64, 128x128]
        d3_upsampled = F.interpolate(d3, scale_factor=2, mode='nearest')      # [B, 64, 256x256]
        d3_cat = torch.cat([d3_upsampled, e1], dim=1)                         # Concat with e1 -> [B, 128, 256x256]

        # Final layers
        d4 = self.up_conv4(d3_cat)                                            # [B, 64, 256x256]
        output = self.final_conv(d4)                                          # [B, 3, 256x256]
        
        return output

# --- 3. Data Loading and Helper Functions ---

class FlatFolderDataset(Dataset):
    """A generic dataset for loading images from a folder."""
    def __init__(self, root, transform):
        super(FlatFolderDataset, self).__init__()
        self.root = root
        self.paths = glob.glob(os.path.join(self.root, '*'))
        self.transform = transform

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path).convert('RGB')
        img = self.transform(img)
        return img

    def __len__(self):
        return len(self.paths)

def create_style_dataset(style_dir, artists, transform):
    """Creates a dataset from a specific list of artist subfolders."""
    all_style_paths = []
    for artist in artists:
        artist_path = os.path.join(style_dir, artist)
        if os.path.isdir(artist_path):
            all_style_paths.extend(glob.glob(os.path.join(artist_path, '*')))
        else:
            print(f"Warning: Artist directory not found at {artist_path}")

    return StyleDataset(all_style_paths, transform)

class StyleDataset(Dataset):
    """A dataset for loading style images from a list of paths."""
    def __init__(self, paths, transform):
        super(StyleDataset, self).__init__()
        self.paths = paths
        self.transform = transform
    
    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path).convert('RGB')
        img = self.transform(img)
        return img

    def __len__(self):
        return len(self.paths)


def calculate_loss(gen_features, content_features_target, style_features_target):
    """Calculates the total loss."""
    content_loss = F.mse_loss(gen_features[-1], content_features_target[-1])
    style_loss = 0
    for gen_f, style_f in zip(gen_features, style_features_target):
        gen_mean, gen_std = calc_mean_std(gen_f)
        style_mean, style_std = calc_mean_std(style_f)
        style_loss += F.mse_loss(gen_mean, style_mean) + F.mse_loss(gen_std, style_std)
    return content_loss, style_loss

def save_image_samples(decoder_name, step, content, style, output):
    """Saves a grid of image samples during training."""
    output_dir = f"output/{decoder_name}/images"
    os.makedirs(output_dir, exist_ok=True)
    
    # Clamp output to be in the valid [0, 1] range for visualization
    output = torch.clamp(output, 0, 1)

    # We only save the first image of the batch
    content = content[0].cpu()
    style = style[0].cpu()
    output = output[0].cpu()

    image_grid = utils.make_grid([content, style, output], nrow=3, padding=5)
    
    file_path = os.path.join(output_dir, f'sample_step_{step:06d}.png')
    utils.save_image(image_grid, file_path)

In [3]:
# --- 4. Main Training Script ---
def main(batch_size=1, num_epochs=150, style_weight=20.0):
    """
    Main function to run the style transfer experiment with local datasets.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- Configuration ---
    content_train_dir = "data/DIV2K/DIV2K_train_HR"
    style_dir = "data/images"
    artists_to_use = [
        "Vasiliy_Kandinskiy", "Claude_Monet", "Salvador_Dali", "Vincent_van_Gogh",
        "Gustav_Klimt", "Hieronymus_Bosch", "Pablo_Picasso", "Henri_Matisse",
        "Edvard_Munch", "Georges_Seurat", "Piet_Mondrian", "Jackson_Pollock",
        "Frida_Kahlo", "William_Turner", "Albrecht_Dürer"
    ]
    
    image_size = 256
    lr = 1e-4
    max_iter = 300 * num_epochs  # Set a max iteration count

    # --- Data Loading ---
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
    ])
    
    print("Loading datasets...")
    content_dataset = FlatFolderDataset(content_train_dir, transform)
    style_dataset = create_style_dataset(style_dir, artists_to_use, transform)
    
    content_loader = DataLoader(content_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    style_loader = DataLoader(style_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    
    print(f"Found {len(content_dataset)} content images and {len(style_dataset)} style images.")

    decoders = [
        BaselineDecoder(),
        ResNetDecoder(num_blocks=3),
        AttentionDecoder(),
        UNetDecoder()
    ]
    for i in range(len(decoders)):
        # --- Model Initialization ---
        encoder = VGGEncoder().to(device)

        decoder = decoders[i].to(device)

        decoder_name = decoder.__class__.__name__
        print(f"Selected Decoder: {decoder_name}")

        optimizer = optim.Adam(decoder.parameters(), lr=lr)

        # --- Setup Loss Logging ---
        output_dir = f"output/{decoder_name}"
        weights_dir = os.path.join(output_dir, 'weights')
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(weights_dir, exist_ok=True)

        log_file_path = os.path.join(output_dir, 'loss_log.csv')

        with open(log_file_path, 'w', newline='') as log_file:
            log_writer = csv.writer(log_file)
            log_writer.writerow(['step', 'total_loss', 'content_loss', 'style_loss'])
        
            # --- Training Loop ---
            print("Starting training loop...")
            style_iter = iter(style_loader)
            total_steps = 0

            for epoch in range(num_epochs):
                print(f"--- Epoch {epoch+1}/{num_epochs} ---")
                for i, content_batch in enumerate(content_loader):
                    if total_steps > max_iter:
                        break

                    # Get a batch of style images
                    try:
                        style_batch = next(style_iter)
                    except StopIteration:
                        style_iter = iter(style_loader)
                        style_batch = next(style_iter)

                    content_batch = content_batch.to(device)
                    style_batch = style_batch.to(device)

                    optimizer.zero_grad()
                    
                    # Get features
                    content_feats = encoder(content_batch)
                    style_feats = encoder(style_batch)

                    # Apply AdaIN
                    stylized_feat = adaptive_instance_normalization(content_feats[-1], style_feats[-1])

                    # Generate image
                    if isinstance(decoder, UNetDecoder):
                        output_img = decoder(stylized_feat, content_feats)
                    else:
                        output_img = decoder(stylized_feat)
                    
                    # Get features of the generated image
                    output_feats = encoder(output_img)
                    
                    # Calculate loss
                    content_loss, style_loss = calculate_loss(output_feats, content_feats, style_feats)
                    total_loss = content_loss + style_weight * style_loss

                    # Backpropagation
                    total_loss.backward()
                    optimizer.step()

                    if total_steps % 100 == 0:
                        loss_str = (f'Step {total_steps:05d} | Epoch {epoch+1} | '
                                    f'Total Loss: {total_loss.item():.4f} | '
                                    f'Content Loss: {content_loss.item():.4f} | '
                                    f'Style Loss: {style_loss.item():.4f}')
                        print(loss_str)
                        log_writer.writerow([total_steps, total_loss.item(), content_loss.item(), style_loss.item()])

                    if total_steps != 0 and total_steps % 10000 == 0:
                        save_image_samples(decoder_name, total_steps, content_batch, style_batch, output_img)
                        # Save model checkpoint
                        torch.save(decoder.state_dict(), f'{weights_dir}/decoder_step_{total_steps:06d}.pth')
                        
                    total_steps += 1

            print("Training finished.")
            # Save final model
            save_image_samples(decoder_name, total_steps, content_batch, style_batch, output_img)
            torch.save(decoder.state_dict(), f'{weights_dir}/decoder_final.pth')



In [8]:
main()

Using device: cuda
Loading datasets...
Found 300 content images and 2788 style images.
Selected Decoder: BaselineDecoder
Starting training loop...
--- Epoch 1/150 ---
Step 00000 | Epoch 1 | Total Loss: 57.0218 | Content Loss: 4.8446 | Style Loss: 2.6089
Step 00100 | Epoch 1 | Total Loss: 32.2666 | Content Loss: 4.7727 | Style Loss: 1.3747
Step 00200 | Epoch 1 | Total Loss: 49.6182 | Content Loss: 6.9741 | Style Loss: 2.1322
--- Epoch 2/150 ---
Step 00300 | Epoch 2 | Total Loss: 13.4558 | Content Loss: 4.1299 | Style Loss: 0.4663
Step 00400 | Epoch 2 | Total Loss: 23.3579 | Content Loss: 6.3620 | Style Loss: 0.8498
Step 00500 | Epoch 2 | Total Loss: 19.1970 | Content Loss: 6.5180 | Style Loss: 0.6339
--- Epoch 3/150 ---
Step 00600 | Epoch 3 | Total Loss: 13.1288 | Content Loss: 6.1848 | Style Loss: 0.3472
Step 00700 | Epoch 3 | Total Loss: 17.6154 | Content Loss: 5.4803 | Style Loss: 0.6068
Step 00800 | Epoch 3 | Total Loss: 13.5145 | Content Loss: 4.7240 | Style Loss: 0.4395
--- Epoch 

In [5]:
main(batch_size=1, num_epochs=150, style_weight=10.0)

Using device: cuda
Loading datasets...
Found 300 content images and 2788 style images.


1.1%

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /tmp/xdg-cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


100.0%


Selected Decoder: BaselineDecoder
Starting training loop...
--- Epoch 1/150 ---
Step 00000 | Epoch 1 | Total Loss: 44.7927 | Content Loss: 4.3052 | Style Loss: 4.0487
Step 00100 | Epoch 1 | Total Loss: 8.3183 | Content Loss: 4.2346 | Style Loss: 0.4084
Step 00200 | Epoch 1 | Total Loss: 15.5836 | Content Loss: 8.4504 | Style Loss: 0.7133
--- Epoch 2/150 ---
Step 00300 | Epoch 2 | Total Loss: 22.9608 | Content Loss: 8.9089 | Style Loss: 1.4052
Step 00400 | Epoch 2 | Total Loss: 7.2045 | Content Loss: 2.6980 | Style Loss: 0.4507
Step 00500 | Epoch 2 | Total Loss: 16.0074 | Content Loss: 10.7198 | Style Loss: 0.5288
--- Epoch 3/150 ---
Step 00600 | Epoch 3 | Total Loss: 11.6692 | Content Loss: 4.6802 | Style Loss: 0.6989
Step 00700 | Epoch 3 | Total Loss: 13.1690 | Content Loss: 5.4499 | Style Loss: 0.7719
Step 00800 | Epoch 3 | Total Loss: 11.3644 | Content Loss: 7.0482 | Style Loss: 0.4316
--- Epoch 4/150 ---
Step 00900 | Epoch 4 | Total Loss: 11.3198 | Content Loss: 5.7627 | Style Loss

In [4]:
main(batch_size=1, num_epochs=150, style_weight=7.0)

Using device: cuda
Loading datasets...
Found 300 content images and 2788 style images.
Selected Decoder: BaselineDecoder
Starting training loop...
--- Epoch 1/150 ---
Step 00000 | Epoch 1 | Total Loss: 25.4561 | Content Loss: 2.9137 | Style Loss: 3.2203
Step 00100 | Epoch 1 | Total Loss: 16.9043 | Content Loss: 7.3476 | Style Loss: 1.3652
Step 00200 | Epoch 1 | Total Loss: 27.7351 | Content Loss: 6.1372 | Style Loss: 3.0854
--- Epoch 2/150 ---
Step 00300 | Epoch 2 | Total Loss: 14.7977 | Content Loss: 7.9169 | Style Loss: 0.9830
Step 00400 | Epoch 2 | Total Loss: 9.8139 | Content Loss: 3.9469 | Style Loss: 0.8381
Step 00500 | Epoch 2 | Total Loss: 8.9039 | Content Loss: 4.4673 | Style Loss: 0.6338
--- Epoch 3/150 ---
Step 00600 | Epoch 3 | Total Loss: 8.9257 | Content Loss: 4.3592 | Style Loss: 0.6524
Step 00700 | Epoch 3 | Total Loss: 6.6924 | Content Loss: 3.5702 | Style Loss: 0.4460
Step 00800 | Epoch 3 | Total Loss: 4.6626 | Content Loss: 2.8864 | Style Loss: 0.2537
--- Epoch 4/150

In [5]:
main(batch_size=1, num_epochs=150, style_weight=3.0)

Using device: cuda
Loading datasets...
Found 300 content images and 2788 style images.
Selected Decoder: BaselineDecoder
Starting training loop...
--- Epoch 1/150 ---
Step 00000 | Epoch 1 | Total Loss: 18.3804 | Content Loss: 5.7802 | Style Loss: 4.2001
Step 00100 | Epoch 1 | Total Loss: 9.0266 | Content Loss: 6.4463 | Style Loss: 0.8601
Step 00200 | Epoch 1 | Total Loss: 6.5137 | Content Loss: 5.3701 | Style Loss: 0.3812
--- Epoch 2/150 ---
Step 00300 | Epoch 2 | Total Loss: 5.1376 | Content Loss: 3.4558 | Style Loss: 0.5606
Step 00400 | Epoch 2 | Total Loss: 8.3564 | Content Loss: 5.2278 | Style Loss: 1.0428
Step 00500 | Epoch 2 | Total Loss: 2.9704 | Content Loss: 2.1017 | Style Loss: 0.2895
--- Epoch 3/150 ---
Step 00600 | Epoch 3 | Total Loss: 8.8061 | Content Loss: 4.7199 | Style Loss: 1.3621
Step 00700 | Epoch 3 | Total Loss: 2.5771 | Content Loss: 1.5387 | Style Loss: 0.3461
Step 00800 | Epoch 3 | Total Loss: 3.6776 | Content Loss: 2.5558 | Style Loss: 0.3739
--- Epoch 4/150 --

In [4]:
main(batch_size=1, num_epochs=150, style_weight=30.0)

Using device: cuda
Loading datasets...
Found 300 content images and 2788 style images.
Selected Decoder: BaselineDecoder
Starting training loop...
--- Epoch 1/150 ---
Step 00000 | Epoch 1 | Total Loss: 183.4737 | Content Loss: 2.9367 | Style Loss: 6.0179
Step 00100 | Epoch 1 | Total Loss: 42.3519 | Content Loss: 8.6045 | Style Loss: 1.1249
Step 00200 | Epoch 1 | Total Loss: 16.6842 | Content Loss: 6.1675 | Style Loss: 0.3506
--- Epoch 2/150 ---
Step 00300 | Epoch 2 | Total Loss: 57.3456 | Content Loss: 7.9699 | Style Loss: 1.6459
Step 00400 | Epoch 2 | Total Loss: 67.2252 | Content Loss: 6.5571 | Style Loss: 2.0223
Step 00500 | Epoch 2 | Total Loss: 23.0252 | Content Loss: 10.0043 | Style Loss: 0.4340
--- Epoch 3/150 ---
Step 00600 | Epoch 3 | Total Loss: 27.2864 | Content Loss: 4.2097 | Style Loss: 0.7692
Step 00700 | Epoch 3 | Total Loss: 35.0258 | Content Loss: 8.8324 | Style Loss: 0.8731
Step 00800 | Epoch 3 | Total Loss: 23.5563 | Content Loss: 10.9523 | Style Loss: 0.4201
--- Epo

In [5]:
def unet_train(batch_size=1, num_epochs=150, style_weight=10.0):
    """
    function to run the style transfer experiment using only unet architecture with local datasets.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- Configuration ---
    content_train_dir = "data/DIV2K/DIV2K_train_HR"
    style_dir = "data/images"
    artists_to_use = [
        "Vasiliy_Kandinskiy", "Claude_Monet", "Salvador_Dali", "Vincent_van_Gogh",
        "Gustav_Klimt", "Hieronymus_Bosch", "Pablo_Picasso", "Henri_Matisse",
        "Edvard_Munch", "Georges_Seurat", "Piet_Mondrian", "Jackson_Pollock",
        "Frida_Kahlo", "William_Turner", "Albrecht_Dürer"
    ]
    
    image_size = 256
    lr = 1e-4
    max_iter = 300/batch_size * num_epochs  # Set a max iteration count

    # --- Data Loading ---
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
    ])
    
    print("Loading datasets...")
    content_dataset = FlatFolderDataset(content_train_dir, transform)
    style_dataset = create_style_dataset(style_dir, artists_to_use, transform)
    
    content_loader = DataLoader(content_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    style_loader = DataLoader(style_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    
    print(f"Found {len(content_dataset)} content images and {len(style_dataset)} style images.")

    decoders = [
        UNetDecoder()
    ]
    for i in range(len(decoders)):
        # --- Model Initialization ---
        encoder = VGGEncoder().to(device)

        decoder = decoders[i].to(device)

        decoder_name = decoder.__class__.__name__
        print(f"Selected Decoder: {decoder_name}")

        optimizer = optim.Adam(decoder.parameters(), lr=lr)

        # --- Setup Loss Logging ---
        output_dir = f"output/{decoder_name}"
        weights_dir = os.path.join(output_dir, 'weights')
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(weights_dir, exist_ok=True)

        log_file_path = os.path.join(output_dir, 'loss_log.csv')

        with open(log_file_path, 'w', newline='') as log_file:
            log_writer = csv.writer(log_file)
            log_writer.writerow(['step', 'total_loss', 'content_loss', 'style_loss'])
        
            # --- Training Loop ---
            print("Starting training loop...")
            style_iter = iter(style_loader)
            total_steps = 0

            for epoch in range(num_epochs):
                print(f"--- Epoch {epoch+1}/{num_epochs} ---")
                for i, content_batch in enumerate(content_loader):
                    if total_steps > max_iter:
                        break

                    # Get a batch of style images
                    try:
                        style_batch = next(style_iter)
                    except StopIteration:
                        style_iter = iter(style_loader)
                        style_batch = next(style_iter)

                    content_batch = content_batch.to(device)
                    style_batch = style_batch.to(device)

                    optimizer.zero_grad()
                    
                    # Get features
                    content_feats = encoder(content_batch)
                    style_feats = encoder(style_batch)

                    # Apply AdaIN
                    stylized_feat = adaptive_instance_normalization(content_feats[-1], style_feats[-1])

                    # Generate image
                    if isinstance(decoder, UNetDecoder):
                        output_img = decoder(stylized_feat, content_feats)
                    else:
                        output_img = decoder(stylized_feat)
                    
                    # Get features of the generated image
                    output_feats = encoder(output_img)
                    
                    # Calculate loss
                    content_loss, style_loss = calculate_loss(output_feats, content_feats, style_feats)
                    total_loss = content_loss + style_weight * style_loss

                    # Backpropagation
                    total_loss.backward()
                    optimizer.step()

                    if total_steps % 100 == 0:
                        loss_str = (f'Step {total_steps:05d} | Epoch {epoch+1} | '
                                    f'Total Loss: {total_loss.item():.4f} | '
                                    f'Content Loss: {content_loss.item():.4f} | '
                                    f'Style Loss: {style_loss.item():.4f}')
                        print(loss_str)
                        log_writer.writerow([total_steps, total_loss.item(), content_loss.item(), style_loss.item()])

                    if total_steps != 0 and total_steps % 10000 == 0:
                        save_image_samples(decoder_name, total_steps, content_batch, style_batch, output_img)
                        # Save model checkpoint
                        # torch.save(decoder.state_dict(), f'{weights_dir}/decoder_step_{total_steps:06d}.pth')
                        
                    total_steps += 1

            print("Training finished.")
            # Save final model
            save_image_samples(decoder_name, total_steps, content_batch, style_batch, output_img)
            torch.save(decoder.state_dict(), f'{weights_dir}/decoder_final.pth')
unet_train(batch_size=2, num_epochs=1000, style_weight=10.0)

Using device: cuda
Loading datasets...
Found 300 content images and 2788 style images.
Selected Decoder: UNetDecoder
Starting training loop...
--- Epoch 1/1000 ---
Step 00000 | Epoch 1 | Total Loss: 33.5273 | Content Loss: 3.7336 | Style Loss: 2.9794
Step 00100 | Epoch 1 | Total Loss: 22.8424 | Content Loss: 3.4431 | Style Loss: 1.9399
--- Epoch 2/1000 ---
Step 00200 | Epoch 2 | Total Loss: 15.2015 | Content Loss: 2.1734 | Style Loss: 1.3028
--- Epoch 3/1000 ---
Step 00300 | Epoch 3 | Total Loss: 8.9552 | Content Loss: 4.2022 | Style Loss: 0.4753
Step 00400 | Epoch 3 | Total Loss: 12.8629 | Content Loss: 3.3028 | Style Loss: 0.9560
--- Epoch 4/1000 ---
Step 00500 | Epoch 4 | Total Loss: 9.6054 | Content Loss: 3.5257 | Style Loss: 0.6080
--- Epoch 5/1000 ---
Step 00600 | Epoch 5 | Total Loss: 10.3259 | Content Loss: 3.4713 | Style Loss: 0.6855
Step 00700 | Epoch 5 | Total Loss: 10.1878 | Content Loss: 3.2080 | Style Loss: 0.6980
--- Epoch 6/1000 ---
Step 00800 | Epoch 6 | Total Loss: 12

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f0b9b5fd510>
Traceback (most recent call last):
  File "/home/apillai/.conda/envs/styleTransfer/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1663, in __del__
    self._shutdown_workers()
  File "/home/apillai/.conda/envs/styleTransfer/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1646, in _shutdown_workers
    if w.is_alive():
  File "/home/apillai/.conda/envs/styleTransfer/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


--- Epoch 471/1000 ---
Step 70500 | Epoch 471 | Total Loss: 9.6699 | Content Loss: 4.3949 | Style Loss: 0.5275
Step 70600 | Epoch 471 | Total Loss: 5.8993 | Content Loss: 2.8442 | Style Loss: 0.3055
--- Epoch 472/1000 ---
Step 70700 | Epoch 472 | Total Loss: 4.7440 | Content Loss: 2.8630 | Style Loss: 0.1881


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f0b9b5fd510>
Traceback (most recent call last):
  File "/home/apillai/.conda/envs/styleTransfer/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1663, in __del__
    self._shutdown_workers()
  File "/home/apillai/.conda/envs/styleTransfer/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1646, in _shutdown_workers
    if w.is_alive():
  File "/home/apillai/.conda/envs/styleTransfer/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


--- Epoch 473/1000 ---
Step 70800 | Epoch 473 | Total Loss: 7.3105 | Content Loss: 4.2121 | Style Loss: 0.3098
Step 70900 | Epoch 473 | Total Loss: 9.9619 | Content Loss: 5.0189 | Style Loss: 0.4943
--- Epoch 474/1000 ---
Step 71000 | Epoch 474 | Total Loss: 3.9856 | Content Loss: 1.8770 | Style Loss: 0.2109


Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f0b9b5fd510><function _MultiProcessingDataLoaderIter.__del__ at 0x7f0b9b5fd510>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/apillai/.conda/envs/styleTransfer/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1663, in __del__
  File "/home/apillai/.conda/envs/styleTransfer/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1663, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/apillai/.conda/envs/styleTransfer/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1646, in _shutdown_workers
  File "/home/apillai/.conda/envs/styleTransfer/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1646, in _shutdown_workers
    if w.is_alive():    
if w.is_alive():  File "/home/apillai/.conda/envs/styleTransfer/lib/python3.10/multiprocessing/process.py", line 160, i

--- Epoch 475/1000 ---
Step 71100 | Epoch 475 | Total Loss: 3.7730 | Content Loss: 1.8962 | Style Loss: 0.1877
Step 71200 | Epoch 475 | Total Loss: 9.9044 | Content Loss: 4.1807 | Style Loss: 0.5724
--- Epoch 476/1000 ---
Step 71300 | Epoch 476 | Total Loss: 5.4254 | Content Loss: 3.0180 | Style Loss: 0.2407
--- Epoch 477/1000 ---
Step 71400 | Epoch 477 | Total Loss: 4.2087 | Content Loss: 3.0259 | Style Loss: 0.1183
Step 71500 | Epoch 477 | Total Loss: 6.6132 | Content Loss: 4.1304 | Style Loss: 0.2483
--- Epoch 478/1000 ---
Step 71600 | Epoch 478 | Total Loss: 7.0818 | Content Loss: 2.9199 | Style Loss: 0.4162
--- Epoch 479/1000 ---
Step 71700 | Epoch 479 | Total Loss: 7.5594 | Content Loss: 4.9789 | Style Loss: 0.2580
Step 71800 | Epoch 479 | Total Loss: 6.5399 | Content Loss: 3.6347 | Style Loss: 0.2905
--- Epoch 480/1000 ---
Step 71900 | Epoch 480 | Total Loss: 5.0872 | Content Loss: 2.6004 | Style Loss: 0.2487
--- Epoch 481/1000 ---
Step 72000 | Epoch 481 | Total Loss: 5.9983 | C