In [6]:
import os
os.environ["TORCHDYNAMO_DISABLE"] = "1"

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from PIL import Image
from PIL import ImageFilter
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import sys

In [7]:
# TransformerNEt Module
class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super().__init__()
        padding = kernel_size // 2
        self.layer = nn.Sequential(
            nn.ReflectionPad2d(padding),
            nn.Conv2d(in_channels, out_channels, kernel_size, stride)
        )

    def forward(self, x):
        return self.layer(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvLayer(channels, channels, 3, 1),
            nn.InstanceNorm2d(channels, affine=True),
            nn.ReLU(inplace=True),
            ConvLayer(channels, channels, 3, 1),
            nn.InstanceNorm2d(channels, affine=True),
        )

    def forward(self, x):
        return x + self.block(x)

class TransformerNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            ConvLayer(3, 32, 9, 1),
            nn.InstanceNorm2d(32),
            nn.ReLU(inplace=True),
            ConvLayer(32, 64, 3, 2),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            ConvLayer(64, 128, 3, 2),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(128) for _ in range(5)]
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(32),
            nn.ReLU(inplace=True),
            ConvLayer(32, 3, 9, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.res_blocks(x)
        x = self.decoder(x)
        return x

In [9]:
# VGG Feature Extractor
class VGGFeatures(nn.Module):
    def __init__(self, content_layers, style_layers):
        super().__init__()
        vgg = models.vgg16(pretrained=True).features[:23]
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg.eval()
        self.content_layers = content_layers
        self.style_layers = style_layers

    def forward(self, x):
        features = {}
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.content_layers + self.style_layers:
                features[name] = x
        return features

def gram_matrix(feat):
    B, C, H, W = feat.size()
    F = feat.view(B, C, H * W)
    G = torch.bmm(F, F.transpose(1, 2))
    return G / (C * H * W)

def total_variation_loss(img):
    return torch.sum(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:])) + \
           torch.sum(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]))

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_size = 256

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),  # Converts to [0, 1]
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# os.mkdir("/kaggle/working/output")
# os.mkdir("/kaggle/working/output/model4")
content_dataset = datasets.ImageFolder("/kaggle/input/mcoco-1500/mini_coco_sample", transform=transform)
content_loader = DataLoader(content_dataset, batch_size=4)

# Load style image

style_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

style_img_path = "/kaggle/input/starry-night-vangogh/1024px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg"#"/kaggle/input/oil-boat-monet/boat.jpg"
style_image = Image.open(style_img_path).convert("L").convert("RGB").crop((300, 300, 700, 700))#.filter(ImageFilter.GaussianBlur(radius=2))
style_image = style_transform(style_image).unsqueeze(0).to(device)


# Extract style features
vgg = VGGFeatures(content_layers=["21"], style_layers=["0", "5", "10", "19", "21"]).to(device)
style_features = vgg(style_image)
# style_grams = {k: gram_matrix(v) for k, v in style_features.items() if k in vgg.style_layers}
style_grams = {k: gram_matrix(v) for k, v in style_features.items()}

# Initialize model
model = TransformerNet().to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)

# Loss weights
style_weight = 1e5
content_weight = 1e0
tv_weight = 1e-6
identity_weight = 1

# Training loop
for epoch in range(50): 
    for i, (content_batch, _) in enumerate(content_loader):
        content_batch = content_batch.to(device)
        output = model(content_batch)

        output_features = vgg(output)
        content_features = vgg(content_batch)

        content_loss = torch.nn.functional.mse_loss(output_features["21"], content_features["21"])
        style_loss = 0
        for layer in vgg.style_layers:
            output_gram = gram_matrix(output_features[layer])
            style_gram = style_grams[layer].expand_as(output_gram)
            style_loss += torch.nn.functional.mse_loss(output_gram, style_gram)

        tv_loss = total_variation_loss(output)
        
        identity = model(content_batch)
        identity_loss = torch.nn.functional.mse_loss(identity, content_batch)

        loss = content_weight * content_loss + style_weight * style_loss + tv_weight * tv_loss 
        # if epoch > 2:
            # loss += identity_weight * identity_loss
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 50 == 0:
            print(f"Epoch {epoch}, Batch {i}, Loss: {loss.item():.2f}")
            # save_image(output[0].cpu().div(255), f"/kaggle/working/output/model2/output_{epoch}_{i}.jpg")
            # show_tensor_image(output[0].unsqueeze(0))

# Save final model
torch.save(model.state_dict(), "output/model4/oil_painting_model14.pth")

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 227MB/s]  


Epoch 0, Batch 0, Loss: 33.12
Epoch 0, Batch 50, Loss: 18.16
Epoch 0, Batch 100, Loss: 16.24
Epoch 0, Batch 150, Loss: 12.61
Epoch 0, Batch 200, Loss: 12.39
Epoch 0, Batch 250, Loss: 8.10
Epoch 0, Batch 300, Loss: 7.04
Epoch 0, Batch 350, Loss: 6.92
Epoch 1, Batch 0, Loss: 5.52
Epoch 1, Batch 50, Loss: 7.03
Epoch 1, Batch 100, Loss: 7.06
Epoch 1, Batch 150, Loss: 4.85
Epoch 1, Batch 200, Loss: 6.01
Epoch 1, Batch 250, Loss: 4.52
Epoch 1, Batch 300, Loss: 4.34
Epoch 1, Batch 350, Loss: 5.38
Epoch 2, Batch 0, Loss: 4.29
Epoch 2, Batch 50, Loss: 6.19
Epoch 2, Batch 100, Loss: 6.34
Epoch 2, Batch 150, Loss: 3.98
Epoch 2, Batch 200, Loss: 5.47
Epoch 2, Batch 250, Loss: 4.15
Epoch 2, Batch 300, Loss: 4.01
Epoch 2, Batch 350, Loss: 5.04
Epoch 3, Batch 0, Loss: 3.97
Epoch 3, Batch 50, Loss: 5.90
Epoch 3, Batch 100, Loss: 6.05
Epoch 3, Batch 150, Loss: 3.76
Epoch 3, Batch 200, Loss: 5.23
Epoch 3, Batch 250, Loss: 3.97
Epoch 3, Batch 300, Loss: 3.84
Epoch 3, Batch 350, Loss: 4.86
Epoch 4, Batch 

In [22]:
# Stylize the Image.

def load_image(image_path):
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)  # Add batch dimension
    return image

def save_output(tensor, path):
    image = tensor.clone().detach().cpu().squeeze(0)
    image = (image + 1) / 2  # From [-1, 1] to [0, 1]
    image = image.clamp(0, 1)
    image = transforms.ToPILImage()(image)
    image.save(path)


MODEL_PATH = "/kaggle/working/output/model4/oil_painting_model11.pth"    # Path to trained model
INPUT_IMAGE = "/kaggle/input/dog-test-img/Golden-retriever-dog-1362597631o6g.jpg" # Content Image path
OUTPUT_IMAGE = "/kaggle/working/output/model4/stylized_img36.jpg"            # Output image path
IMAGE_SIZE = 256 

# Load image and preprocess
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# Load model.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerNet().to(device)
model.load_state_dict(torch.load(MODEL_PATH, weights_only=True)) #, weights_only=True
model.eval()

# Stylize Output.
content_image = load_image(INPUT_IMAGE).to(device)
with torch.no_grad():
    output = model(content_image)

save_output(output, OUTPUT_IMAGE)
print(f"Stylized image saved to: {OUTPUT_IMAGE}")
# plt.imshow(output)

Stylized image saved to: /kaggle/working/output/model4/stylized_img35.jpg


In [None]:
# Display Output Image
# Load the JPG image
img = mpimg.imread('/kaggle/working/output/model4/stylized_img27_dog.jpg')

# Display the image
plt.imshow(img)
plt.axis('off') 
plt.show()