In [263]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torchvision.datasets as datasets
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch
from torch.utils.data import Subset
from PIL import Image

mse = nn.MSELoss()
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
device

device(type='cuda')

In [264]:
def load_and_prepare_img(content_img_path, style_img_path, target_size=(512, 512), return_tensor = True):
    """
    Загружает и подготавливает изображения
    
    Параметры:
        content_img_path: путь к контентному изображению
        style_img_path: путь к изображению стиля
        target_size: желаемый размер (ширина, высота)
        return_tensor: если True - возвращает тензоры, если False - PIL Images
    
    Возвращает:
        кортеж (content_image, style_image)
    """
    
    common_transforms = [
        transforms.Resize(target_size)
    ]
    
    if return_tensor:
        tensor_transforms = [
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ]
        transform = transforms.Compose(common_transforms + tensor_transforms)
    else:
        transform = transforms.Compose(common_transforms)

    def load_img(img_path):
        img = Image.open(img_path).convert('RGB')
        transformed = transform(img)
        return transformed.unsqueeze(0) if return_tensor else transformed
        
    content_image = load_img(content_img_path)
    style_image = load_img(style_img_path)
    
    return content_image, style_image

In [265]:
cont, style = load_and_prepare_img(
    "C:\\Users\\Admin\\Desktop\\AIM\\aim_2_hw\\DL_summer\\feedforward_data\\cat.jpg",
    "C:\\Users\\Admin\\Desktop\\AIM\\aim_2_hw\\DL_summer\\feedforward_data\\fff.jpg",
    (800,680),
    True
)

In [266]:
cont = cont.to(device)
style = style.to(device)

In [267]:
from torchvision.models import vgg19
model = vgg19(pretrained=True).to(device)

for param in model.parameters():
    param.requires_grad = False

In [268]:
class FeatureExtractor(nn.Module):
    def __init__(self, features, layers_list):
        super().__init__()

        self.features = features
        self.layers = layers_list
        self.layer_map = {
            1: 'relu1_1',
            6: 'relu2_1',
            11: 'relu3_1',
            20: 'relu4_1',
            22: 'relu4_2',
            29: 'relu5_1'
        }

    def forward(self, x):
        features_dict = {}
        for i, layer in enumerate(self.features):
            x = layer(x)
            if i in self.layers:
                name = self.layer_map[i]
                features_dict[name] = x

        return features_dict

In [269]:
extractor = FeatureExtractor(model.features, [1,6,11,20,22,29])
content_features = extractor(cont)
style_features = extractor(style)

In [270]:
def gram_matrix(tensor):
    B, C, H, W = tensor.size()
    features = tensor.view(B, C, H*W)
    G = torch.bmm(features, features.transpose(1, 2))
    return G / (C*H*W)

style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}

In [271]:
target = cont.clone().requires_grad_(True).to(device)
optimizer = torch.optim.LBFGS([target])
style_layers = ['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1']

content_weight = 1
style_weight = 1e4

for step in range(50):
    print(f"Step {step}")
    def closure():
        optimizer.zero_grad()
        target_features = extractor(target)

        content_loss = mse(target_features['relu4_2'], content_features['relu4_2'])

        style_loss = 0
        for layer in style_layers:
            G_t = gram_matrix(target_features[layer])
            G_s = style_grams[layer]
            style_loss += mse(G_t, G_s)

        total_loss = content_weight * content_loss + style_weight * style_loss
        total_loss.backward()
        return total_loss

    optimizer.step(closure)

Step 0
Step 1
Step 2
Step 3
Step 4
Step 5
Step 6
Step 7
Step 8
Step 9
Step 10
Step 11
Step 12
Step 13
Step 14
Step 15
Step 16
Step 17
Step 18
Step 19
Step 20
Step 21
Step 22
Step 23
Step 24
Step 25
Step 26
Step 27
Step 28
Step 29
Step 30
Step 31
Step 32
Step 33
Step 34
Step 35
Step 36
Step 37
Step 38
Step 39
Step 40
Step 41
Step 42
Step 43
Step 44
Step 45
Step 46
Step 47
Step 48
Step 49


In [272]:
from torchvision.utils import save_image

# Обратная нормализация
def denormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).to(device).view(1,3,1,1)
    std = torch.tensor([0.229, 0.224, 0.225]).to(device).view(1,3,1,1)
    return tensor * std + mean

result = denormalize(target.clone().detach())
save_image(result, "cat_fff.jpg")