In [1]:
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torchvision import transforms
import random
import os
from PIL import Image
import numpy as np
import torchvision.transforms.functional as F

# Model

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.norm1 = nn.InstanceNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding)
        self.norm2 = nn.InstanceNorm2d(out_channels)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.norm2(out)
        out = self.relu(out)

        #out += residual  # Add the residual
        return out

In [3]:
class AdvancedStyleTransferModel(nn.Module):
    def __init__(self):
        super(AdvancedStyleTransferModel, self).__init__()

        # Initial convolution input image size=3x64x64
        self.conv1 = nn.Conv2d(3, 32, kernel_size=4, stride=1, padding=2)  # 32x64x64
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv2d(32, 128, kernel_size=2, stride=1, padding=1)  # 128x32x32
        self.relu2 = nn.ReLU()
    

        self.conv_transpose1 = nn.ConvTranspose2d(128, 32, kernel_size=2, stride=1, padding=1)  # 32x64x64
        self.relu3 = nn.ReLU()

        self.conv_transpose2 = nn.ConvTranspose2d(32, 3, kernel_size=4, stride=1, padding=2)
        self.relu4 = nn.ReLU()
        
        # Residual blocks
        self.residual_blocks = ResidualBlock(32, 32)
                                        
        
        

    def forward(self, x):
        resu = x
        out = self.relu1(self.conv1(x))
        #out = self.batch_norm1(out)
        out = self.residual_blocks(out)
        out = self.relu2(self.conv2(out))
        #out = self.batch_norm2(out)
        #out = self.residual_blocks(out)
        out = self.relu3(self.conv_transpose1(out))
        out = self.conv_transpose2(out)
        #out = self.batch_norm2(out)
        # out += resu
        return out

In [4]:
model_url=r'models\model_final_train_loss_0.1172_val_loss_0.1181.pth'
test_folder_path=r'dataset\test'

In [5]:
model_dict = torch.load(model_url, map_location=torch.device('cpu'))
model = AdvancedStyleTransferModel()
model.load_state_dict(model_dict)

<All keys matched successfully>

# cut image

In [6]:
input_image_url=r'C:\Users\prabh\Desktop\pexels-suneo-103573.jpg'
input_image=Image.open(input_image_url).convert('RGB')

In [7]:
patch_size = [256, 256]

In [8]:
def de_norm(img):
  denormalize = transforms.Normalize(mean=[-m/s for m, s in zip(mean, std)], std=[1/s for s in std])
  return denormalize(img)

In [9]:
mean = [0.485, 0.456, 0.406]  # RGB mean values
std = [0.229, 0.224, 0.225]
normalize=transforms.Normalize(
    mean=mean, std=std
)

transform_test=transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    normalize
])

In [14]:
def delete_files_in_folder(folder_path):
    # Get a list of all files in the folder
    file_list = os.listdir(folder_path)

    # Iterate through the files and delete them
    for file_name in file_list:
        file_path = os.path.join(folder_path, file_name)
        try:
            if os.path.isfile(file_path):
                os.remove(file_path)
               
        except Exception as e:
            continue

In [10]:
output_folder = "output_patches"

# Create the output folder if it doesn't exist
os.makedirs(output_folder, exist_ok=True)

# Transformation to convert tensor to PIL Image
to_pil = transforms.ToPILImage()

# Loop through the patches and save each stylized patch to the output folder
for i in range(0, input_image.height, patch_size[1]):
    for j in range(0, input_image.width, patch_size[0]):
        patch = input_image.crop((j, i, j + patch_size[0], i + patch_size[1]))
        patch = transform_test(patch)
        stylized_patch = model(patch)
        stylized_patch = de_norm(stylized_patch)
        
        # Convert the tensor to a PIL Image
        stylized_patch_pil = to_pil(stylized_patch)

        # Specify the filename for each patch based on its position
        patch_filename = f"patch_{i}_{j}.png"
        
        # Save the stylized patch to the output folder
        stylized_patch_pil.save(os.path.join(output_folder, patch_filename))

# Now, you can assemble the cut images from the output_folder if needed
# The code to assemble the images remains the same as in your original code
output_image = Image.new("RGB", input_image.size)
patch_index = 0
for i in range(0, input_image.height, patch_size[1]):
    for j in range(0, input_image.width, patch_size[0]):
        patch_filename = f"patch_{i}_{j}.png"
        patch_path = os.path.join(output_folder, patch_filename)
        patch = Image.open(patch_path)
        output_image.paste(patch, (j, i))
        patch_index += 1

# Save the final assembled image
output_image.save("final_assembled_image.png")
delete_files_in_folder(r'output_patches')

In [13]:
delete_files_in_folder(r'output_patches')

Deleted: output_patches\patch_0_0.png
Deleted: output_patches\patch_0_1024.png
Deleted: output_patches\patch_0_1280.png
Deleted: output_patches\patch_0_1536.png
Deleted: output_patches\patch_0_1792.png
Deleted: output_patches\patch_0_2048.png
Deleted: output_patches\patch_0_2304.png
Deleted: output_patches\patch_0_256.png
Deleted: output_patches\patch_0_2560.png
Deleted: output_patches\patch_0_2816.png
Deleted: output_patches\patch_0_3072.png
Deleted: output_patches\patch_0_3328.png
Deleted: output_patches\patch_0_3584.png
Deleted: output_patches\patch_0_512.png
Deleted: output_patches\patch_0_768.png
Deleted: output_patches\patch_1024_0.png
Deleted: output_patches\patch_1024_1024.png
Deleted: output_patches\patch_1024_1280.png
Deleted: output_patches\patch_1024_1536.png
Deleted: output_patches\patch_1024_1792.png
Deleted: output_patches\patch_1024_2048.png
Deleted: output_patches\patch_1024_2304.png
Deleted: output_patches\patch_1024_256.png
Deleted: output_patches\patch_1024_2560.png