In [2]:
!pip install tensorflow-hub
import tensorflow_hub as hub
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np
import cv2

Collecting tensorflow-hub
  Downloading tensorflow_hub-0.16.1-py2.py3-none-any.whl.metadata (1.3 kB)
Collecting tf-keras>=2.14.1 (from tensorflow-hub)
  Downloading tf_keras-2.17.0-py3-none-any.whl.metadata (1.6 kB)
Collecting protobuf>=3.19.6 (from tensorflow-hub)
  Using cached protobuf-4.25.5-cp310-abi3-win_amd64.whl.metadata (541 bytes)
Downloading tensorflow_hub-0.16.1-py2.py3-none-any.whl (30 kB)
Downloading tf_keras-2.17.0-py3-none-any.whl (1.7 MB)
   ---------------------------------------- 0.0/1.7 MB ? eta -:--:--
   ---------------------------------------- 0.0/1.7 MB ? eta -:--:--
    --------------------------------------- 0.0/1.7 MB 667.8 kB/s eta 0:00:03
   - -------------------------------------- 0.1/1.7 MB 656.4 kB/s eta 0:00:03
   - -------------------------------------- 0.1/1.7 MB 656.4 kB/s eta 0:00:03
   -- ------------------------------------- 0.1/1.7 MB 438.1 kB/s eta 0:00:04
   --- ------------------------------------ 0.1/1.7 MB 568.9 kB/s eta 0:00:03
   ---- ----

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from PIL import Image


class StyleTransferNet(nn.Module):
    def __init__(self):
        super(StyleTransferNet, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, stride=1, padding=4),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=9, stride=1, padding=4),
            nn.Tanh()  
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Function to extract features using VGG19
def get_features(x, model, layers):
    # Forward pass through specific layers
    features = []
    for layer in layers:
        x = model[layer](x)
        features.append(x)
    return features

# Define the loss functions (content and style loss)
def content_loss(gen_features, content_features):
    return torch.mean((gen_features - content_features) ** 2)

def gram_matrix(feature_map):
    b, c, h, w = feature_map.size()
    features = feature_map.view(b, c, -1)  
    gram = torch.bmm(features, features.transpose(1, 2)) / (c * h * w)  
    return gram

def style_loss(gen_features, style_features):
    gen_gram = gram_matrix(gen_features)
    style_gram = gram_matrix(style_features)
    return torch.mean((gen_gram - style_gram) ** 2)

# Load dataset
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Use a dataset like ImageNet 
dataset = datasets.ImageFolder(root='C:\\Users\\Sasindu\\Downloads\\style\\content_images', transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)


style_image = Image.open('style_image.jpg')
style_image = transform(style_image).unsqueeze(0)

# Move the style image to the same device as the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
style_image = style_image.to(device)

# Load VGG19 model for feature extraction
vgg = models.vgg19(pretrained=True).features.eval().to(device)

# Specify the layers for content and style loss
content_layer = [21] 
style_layers = [0, 5, 10, 19, 28]  

# Initialize model and optimizer
model = StyleTransferNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 2
for epoch in range(num_epochs):
    for content_images, _ in dataloader:
        content_images = content_images.to(device)

        # Forward pass through the model
        generated_images = model(content_images)

        # Extract features for content and style
        content_features = get_features(content_images, vgg, content_layer)[0]  # Get content features
        style_features = get_features(style_image, vgg, style_layers)[0]  # Get style features

        # Compute content and style losses
        c_loss = content_loss(generated_images, content_features.detach())  # Detach content features
        s_loss = style_loss(get_features(generated_images, vgg, style_layers)[0], style_features.detach())

        # Combine losses
        total_loss = c_loss + 1e5 * s_loss

        # Backpropagation and optimization
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss.item():.4f}")

# Save the trained model
torch.save(model.state_dict(), 'style_transfer_model.pth')


Found 161 images belonging to 1 classes.


  self._warn_if_super_not_called()


Epoch 1/10
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m611s[0m 121s/step - loss: 0.6763
Epoch 2/10
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m587s[0m 95s/step - loss: 0.3029
Epoch 3/10
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m533s[0m 105s/step - loss: 0.1445
Epoch 4/10
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m546s[0m 87s/step - loss: 0.0594
Epoch 5/10
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m571s[0m 91s/step - loss: 0.0433
Epoch 6/10
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m543s[0m 86s/step - loss: 0.0362
Epoch 7/10
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m542s[0m 89s/step - loss: 0.0321
Epoch 8/10
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m545s[0m 88s/step - loss: 0.0309
Epoch 9/10
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m510s[0m 81s/step - loss: 0.0306
Epoch 10/10
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m609s[0m 103s/step - loss: 0.0279




In [4]:
import shutil  # Import shutil
import os
from IPython.display import FileLink
# Move the model to a downloadable location
shutil.copy('style_transfer_model.h5', 'C:\\Users\\Sasindu\\Downloads')  # Change this path as needed

# Create a download link
FileLink('style_transfer_model.h5')