<a href="https://colab.research.google.com/github/Shrinjita/Style-transfer-for-rooms/blob/main/photorealistic_style_transfer_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Step 1: Install Dependencies

In [1]:
!pip install tensorflow-datasets torch torchvision pillow matplotlib



Step 2: Import Libraries

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader
from PIL import Image
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Step 3: Load Pre-trained VGG-19 for Feature Extraction

In [3]:
class VGGEncoder(nn.Module):
    def __init__(self):
        super(VGGEncoder, self).__init__()
        vgg = models.vgg19(pretrained=True).features
        self.encoder = nn.Sequential(*list(vgg.children())[:21])  # Up to relu_4_1

    def forward(self, x):
        return self.encoder(x)

Step 4: Define Decoder Network (Symmetric to VGG-19 Encoder)

In [4]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        return self.decoder(x)

Step 5: Style Transfer Model with Autoencoder

In [5]:
class StyleTransferModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(StyleTransferModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, content_img, style_img):
        content_features = self.encoder(content_img)
        style_features = self.encoder(style_img)

        # Feature Aggregation at the Bottleneck
        styled_features = self.whitening_coloring_transform(content_features, style_features)

        return self.decoder(styled_features)

    def whitening_coloring_transform(self, content_feat, style_feat):
        # Whitening and Coloring Transform (WCT) step
        c_mean, c_std = content_feat.mean([2, 3]), content_feat.std([2, 3])
        s_mean, s_std = style_feat.mean([2, 3]), style_feat.std([2, 3])

        normalized_content = (content_feat - c_mean[None, :, None, None]) / c_std[None, :, None, None]
        stylized_content = normalized_content * s_std[None, :, None, None] + s_mean[None, :, None, None]

        return stylized_content


Step 6: Image Preprocessing

In [6]:
def process_image(img_pil, max_size=512):
    # Convert image to a Tensor and normalize it
    transform = transforms.Compose([
        transforms.Resize((max_size, max_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img_tensor = transform(img_pil).unsqueeze(0)
    return img_tensor.to(device)

def im_convert(tensor):
    image = tensor.clone().detach().cpu().numpy().squeeze(0)
    image = image.transpose(1, 2, 0)
    image = image * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]
    image = image.clip(0, 1)
    return image


Step 7: Load MSCOCO Dataset Test Images

In [None]:
def get_coco_images(dataset, num_images=2):
    # This function loads a few images from the MSCOCO dataset
    images = []
    for sample in dataset.take(num_images):  # Take 'num_images' samples
        img = sample['image'].numpy()
        img_pil = Image.fromarray(img)
        images.append(img_pil)
    return images

# Load MSCOCO dataset
dataset, info = tfds.load('coco/2017', split='test', with_info=True)

# Get two images: one for content, one for style
images = get_coco_images(dataset, num_images=2)
content_img_pil, style_img_pil = images[0], images[1]  # Use the first as content, the second as style

# Preprocess the images
content_img = process_image(content_img_pil).to(device)
style_img = process_image(style_img_pil).to(device)

# Display the content and style images
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(content_img_pil)
plt.title("Content Image")

plt.subplot(1, 2, 2)
plt.imshow(style_img_pil)
plt.title("Style Image")

plt.show()




Downloading and preparing dataset 25.20 GiB (download: 25.20 GiB, generated: Unknown size, total: 25.20 GiB) to /root/tensorflow_datasets/coco/2017/1.1.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Step 8: Train the Model Using MSCOCO Images

In [None]:
def train_model(model, dataset, transform, content_img, style_img, num_epochs=10):
    optimizer = optim.Adam(model.decoder.parameters(), lr=1e-4)

    for epoch in range(num_epochs):
        for batch in dataset.take(16):  # Limit dataset size for training
            image = batch['image']
            image_pil = Image.fromarray(image.numpy())
            images = transform(image_pil).unsqueeze(0).to(device)

            # Forward pass for content and style transfer
            styled_image = model(content_img, style_img)

            loss = ((styled_image - content_img)**2).mean()  # Reconstruction loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

Step 9: Visualize Results

In [None]:
def visualize_results(model, content_img, style_img):
    model.eval()
    with torch.no_grad():
        output = model(content_img, style_img).detach()

    plt.figure(figsize=(10, 10))
    plt.subplot(1, 3, 1)
    plt.imshow(im_convert(content_img))
    plt.title("Content Image")

    plt.subplot(1, 3, 2)
    plt.imshow(im_convert(style_img))
    plt.title("Style Image")

    plt.subplot(1, 3, 3)
    plt.imshow(im_convert(output))
    plt.title("Styled Image")

    plt.show()


Step 10: Train the Model with MSCOCO Images

In [None]:
encoder = VGGEncoder().to(device).eval()
decoder = Decoder().to(device)
model = StyleTransferModel(encoder, decoder).to(device)

# Train the model
train_model(model, dataset, process_image, content_img, style_img, num_epochs=5)

Step 11: Visualize the Final Styled Image

In [None]:
# Visualize the results after training
visualize_results(model, content_img, style_img)