In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

# --------- Settings ---------
content_img_path = "E:/CV/content.jpg"
style_img_path = "E:/CV/style.jpg"
output_img_path = "E:/CV/output.jpg"
device = torch.device("cpu")
image_size = 256
num_steps = 1000  
tv_weight = 1e-6

# --------- Preprocessing ---------
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x[:3, :, :]),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

def load_image(path):
    image = Image.open(path).convert('RGB')
    image = transform(image).unsqueeze(0)
    return image.to(device)

def im_convert(tensor):
    image = tensor.clone().detach().squeeze(0)
    image = image * 0.5 + 0.5  # unnormalize
    image = image.clamp(0, 1)
    return transforms.ToPILImage()(image)

# --------- Slightly Modified CNN ---------
class SlightlyModifiedCNN(nn.Module):
    def __init__(self):
        super(SlightlyModifiedCNN, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=9, padding=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            # Residual blocks
            self._residual_block(128),
            self._residual_block(128),
            self._residual_block(128),
            self._residual_block(128),
            self._residual_block(128),
            # Upsampling
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 3, kernel_size=9, padding=4),
            nn.Tanh()
        )

    def _residual_block(self, channels):
        return nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        )

    def forward(self, x):
        return self.model(x) * 150 + 127.5

# --------- Gram Matrix ---------
def gram_matrix(tensor):
    b, c, h, w = tensor.size()
    features = tensor.view(b * c, h * w)
    G = torch.mm(features, features.t())
    return G / (b * c * h * w)

# --------- Load Images ---------
content = load_image(content_img_path)
style = load_image(style_img_path)

# --------- Extract Features ---------
cnn = SlightlyModifiedCNN().to(device).eval()
with torch.no_grad():
    content_features = cnn(content)
    style_features = cnn(style)
    style_gram = gram_matrix(style_features)

# --------- Optimization ---------
target = content.clone().requires_grad_(True)
optimizer = optim.Adam([target], lr=0.03)
content_weight = 1
style_weight = 1e5

for step in range(num_steps):
    optimizer.zero_grad()
    target_features = cnn(target)

    c_loss = torch.mean((target_features - content_features) ** 2)
    t_gram = gram_matrix(target_features)
    s_loss = torch.mean((t_gram - style_gram) ** 2)

    tv_loss = torch.sum(torch.abs(target[:, :, :, :-1] - target[:, :, :, 1:])) + \
              torch.sum(torch.abs(target[:, :, :-1, :] - target[:, :, 1:, :]))

    loss = content_weight * c_loss + style_weight * s_loss + tv_weight * tv_loss
    loss.backward()
    optimizer.step()

    if step % 100 == 0:
        print(f"Step {step}, Loss: {loss.item():.2f}")

# --------- Save Result ---------
output = im_convert(target)
output.save(output_img_path)
output.show()


Step 0, Loss: 0.06
Step 100, Loss: 0.03
Step 200, Loss: 0.04
Step 300, Loss: 0.05
Step 400, Loss: 0.05
Step 500, Loss: 0.03
Step 600, Loss: 0.04
Step 700, Loss: 0.04
Step 800, Loss: 0.06
Step 900, Loss: 0.06
