In [1]:
! pip install torch torchvision numpy matplotlib pillow




[notice] A new release of pip available: 22.3 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

content_img_path = 'Pictures/Pic2.jpg'
style_img_path = 'Pictures/Pic1.jpg'

Image Processing Functions

In [4]:
def load_image(img_path, max_size=512):
    if not os.path.isfile(img_path):
        raise FileNotFoundError(f"Cannot find image file: {img_path}")
        
    image = Image.open(img_path).convert('RGB')
    
    if max(image.size) > max_size:
        size = max_size
        if image.width > image.height:
            size = (max_size, int(image.height * max_size / image.width))
        else:
            size = (int(image.width * max_size / image.height), max_size)
        image = image.resize(size, Image.LANCZOS)
    
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    
    img_tensor = transform(image).unsqueeze(0).to(device)
    return img_tensor, image.size

In [5]:
def tensor_to_image(tensor):
    image = tensor.cpu().clone().detach().numpy()
    image = image.squeeze(0)
    image = image.transpose(1, 2, 0)
    image = np.clip(image, 0, 1)
    return image

In [6]:
def save_image(tensor, filename):
    image = tensor_to_image(tensor)
    image = (image * 255).astype(np.uint8)
    Image.fromarray(image).save(filename)

IEST Model Definition

In [7]:
class InstanceNorm(nn.Module):
    def __init__(self, epsilon=1e-8):
        super(InstanceNorm, self).__init__()
        self.epsilon = epsilon

    def forward(self, x):
        x_mean = torch.mean(x, dim=(2, 3), keepdim=True)
        x_var = torch.var(x, dim=(2, 3), keepdim=True) + self.epsilon
        x_normalized = (x - x_mean) / torch.sqrt(x_var)
        return x_normalized

In [8]:
class IESTTransferModule(nn.Module):
    def __init__(self):
        super(IESTTransferModule, self).__init__()
        self.norm = InstanceNorm()
        
    def forward(self, content_feat, style_feat):
        c_mean = torch.mean(content_feat, dim=(2, 3), keepdim=True)
        c_std = torch.std(content_feat, dim=(2, 3), keepdim=True) + 1e-8
        s_mean = torch.mean(style_feat, dim=(2, 3), keepdim=True)
        s_std = torch.std(style_feat, dim=(2, 3), keepdim=True) + 1e-8
        
        c_normalized = (content_feat - c_mean) / c_std
        
        channels = content_feat.size(1)
        batch_size = content_feat.size(0)
        
        content_flat = content_feat.view(batch_size, channels, -1)
        style_flat = style_feat.view(batch_size, channels, -1)
        
        content_flat_norm = F.normalize(content_flat, dim=2)
        style_flat_norm = F.normalize(style_flat, dim=2)
        
        correlation = torch.bmm(content_flat_norm, style_flat_norm.transpose(1, 2))
        correlation = F.softmax(correlation, dim=2)
        
        enhanced_style = torch.bmm(correlation, style_flat)
        enhanced_style = enhanced_style.view_as(content_feat)
        
        instance_enhanced = c_normalized * s_std + s_mean
        channel_enhanced = enhanced_style
        
        alpha = 0.6
        result = alpha * instance_enhanced + (1 - alpha) * channel_enhanced
        
        return result

In [9]:
class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0):
        super(ConvLayer, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        return self.relu(self.conv(x))

In [None]:
IEST Model Definition
