In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.optim.lr_scheduler import _LRScheduler
import torch.utils.data as data

import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

from sklearn import decomposition
from sklearn import manifold
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import numpy as np

import copy
from collections import namedtuple
import os
import random
import shutil
import time

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt

In [3]:
!mkdir data

In [5]:
!kaggle competitions download -c gan-getting-started

ROOT = 'data'

gan-getting-started.zip: Skipping, found more recently modified local copy (use --force to force download)


In [7]:
!mkdir "data/gan-getting-started"

In [None]:
import zipfile
zip_path = 'gan-getting-started.zip'
extract_path = 'data/gan-getting-started'  # Optional: Specify a target directory
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)  # Extracts to the specified path

In [None]:
# TRAIN_RATIO = 0.8

# data_dir = os.path.join(ROOT, 'CUB_200_2011')
# images_dir = os.path.join(data_dir, 'images')
# train_dir = os.path.join(data_dir, 'train')
# test_dir = os.path.join(data_dir, 'test')


In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt

In [34]:
#device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [58]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm

class Config:
    content_dir = "data/gan-getting-started/photo_jpg"
    style_dir = "data/gan-getting-started/monet_jpg"
    image_size = 256
    batch_size = 4
    epochs = 1000
    save_every = 100
    device = torch.device("mps" if torch.backends.mps.is_available() else 
                         "cuda" if torch.cuda.is_available() else "cpu")

class StyleTransferDataset(Dataset):
    def __init__(self, content_dir, style_dir, transform=None, image_size=256):
        self.content_dir = content_dir
        self.style_dir = style_dir
        self.transform = transform or self.default_transform(image_size)
        self.content_images = [f for f in os.listdir(content_dir) if f.endswith(('.jpg', '.png'))]
        self.style_images = [f for f in os.listdir(style_dir) if f.endswith(('.jpg', '.png'))]

    @staticmethod
    def default_transform(image_size):
        return transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return max(len(self.content_images), len(self.style_images))

    def __getitem__(self, idx):
        content_idx = idx % len(self.content_images)
        style_idx = idx % len(self.style_images)
        
        content_img = Image.open(os.path.join(self.content_dir, self.content_images[content_idx])).convert('RGB')
        style_img = Image.open(os.path.join(self.style_dir, self.style_images[style_idx])).convert('RGB')
        
        return self.transform(content_img), self.transform(style_img)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super().__init__()
        padding = kernel_size // 2
        self.layers = nn.Sequential(
            nn.ReflectionPad2d(padding),
            nn.Conv2d(in_channels, out_channels, kernel_size, stride),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.layers(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, stride=1),
            ConvBlock(channels, channels, kernel_size=3, stride=1)
        )
    
    def forward(self, x):
        return x + self.block(x)

class TransformerNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Initial layers
        self.conv1 = ConvBlock(3, 32, kernel_size=9, stride=1)
        self.conv2 = ConvBlock(32, 64, kernel_size=3, stride=2)
        self.conv3 = ConvBlock(64, 128, kernel_size=3, stride=2)
        
        # Residual blocks
        self.res_blocks = nn.Sequential(*[ResidualBlock(128) for _ in range(5)])
        
        # Upsampling
        self.deconv1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            ConvBlock(128, 64, kernel_size=3, stride=1)
        )
        self.deconv2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            ConvBlock(64, 32, kernel_size=3, stride=1)
        )
        
        # Output layer
        self.out = nn.Sequential(
            nn.ReflectionPad2d(4),
            nn.Conv2d(32, 3, kernel_size=9),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.res_blocks(x)
        x = self.deconv1(x)
        x = self.deconv2(x)
        return self.out(x)

class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        resnet = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*list(resnet.children())[:7]).eval()
        for param in self.features.parameters():
            param.requires_grad = False
    
    def forward(self, x):
        return self.features(x)

class StyleTransferLoss(nn.Module):
    def __init__(self, feature_extractor):
        super().__init__()
        self.feature_extractor = feature_extractor
        self.mse_loss = nn.MSELoss()
    
    @staticmethod
    def gram_matrix(x):
        b, c, h, w = x.size()
        features = x.view(b * c, h * w)
        gram = torch.mm(features, features.t())
        return gram.div(b * c * h * w)
    
    def forward(self, generated, content, style):
        # Ensure all inputs are 4D
        if generated.dim() == 3:
            generated = generated.unsqueeze(0)
        if content.dim() == 3:
            content = content.unsqueeze(0)
        if style.dim() == 3:
            style = style.unsqueeze(0)
        
        # Extract features
        gen_features = self.feature_extractor(generated)
        content_features = self.feature_extractor(content)
        style_features = self.feature_extractor(style)
        
        # Content loss
        content_loss = self.mse_loss(gen_features, content_features)
        
        # Style loss
        style_loss = 0
        for gen_f, style_f in zip([gen_features], [style_features]):  # Using single layer for simplicity
            style_loss += self.mse_loss(self.gram_matrix(gen_f), self.gram_matrix(style_f))
        
        return content_loss + 1e5 * style_loss

def train():
    config = Config()
    print(f"Using device: {config.device}")
    
    # Dataset and DataLoader
    dataset = StyleTransferDataset(config.content_dir, config.style_dir, image_size=config.image_size)
    dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
    
    # Model and optimizer
    transformer = TransformerNet().to(config.device)
    optimizer = optim.Adam(transformer.parameters(), lr=1e-3)
    
    # Loss function
    feature_extractor = FeatureExtractor().to(config.device)
    criterion = StyleTransferLoss(feature_extractor)
    
    # Training loop
    for epoch in range(1, config.epochs + 1):
        pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{config.epochs}")
        for content_imgs, style_imgs in pbar:
            # Move data to device
            content_imgs = content_imgs.to(config.device)
            style_imgs = style_imgs.to(config.device)
            
            # Forward pass
            optimizer.zero_grad()
            generated = transformer(content_imgs)
            loss = criterion(generated, content_imgs, style_imgs)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            pbar.set_postfix(loss=f"{loss.item():.2f}")
        
        # Save checkpoint
        if epoch % config.save_every == 0:
            torch.save(transformer.state_dict(), f"checkpoint_epoch_{epoch}.pth")
    
    torch.save(transformer.state_dict(), "final_model.pth")


In [60]:
Config.device

device(type='mps')

In [None]:
train()

Using device: mps


0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
Epoch 1/1000:  29%|██▉       | 509/1760 [00:35<01:23, 15.07it/s, loss=0.01]

NameError: name 'model' is not defined