In [1]:
import numpy as np
from torchvision.models import vgg19
from torch import nn
from torchvision.utils import save_image
import torch
import cv2
from tqdm import tqdm
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
#配置
content_path='content/content.jpg'
style_path='style/style2.jpg'
vgg_path="vgg_para/vgg19.pth"

In [3]:
# 读取图片
def load_image(path):
    image = cv2.imread(path)  # 打开图片
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 转换通道，因为opencv默认读取格式为BGR，转换为RGB格式
    image = torch.from_numpy(image).float() / 255  # 数值归一化操作
    image = image.permute(2, 0, 1).unsqueeze(0)  # 换轴，（H,W,C）转换为（C,H,W），并做升维处理。
    return image


In [4]:
# 定义损失函数
def get_gram_matrix(features_map): #计算gram矩阵
    n, c, h, w = features_map.shape
    if n == 1:
        features_map = features_map.reshape(c, -1)
        gram_matrix = features_map@features_map.T
        return gram_matrix
    else:
        raise ValueError('Can not process more than one picture')

def style_loss(feature_bank_x,feature_bank_style):
    E=0
    n_layer=len(feature_bank_style)
    w=1/n_layer
    for i, feature in enumerate(feature_bank_style):
        shape=feature_bank_x[i].shape
        C = int(shape [1])
        H = int(shape[2])
        W = int(shape[3])
        G_x=get_gram_matrix(feature_bank_x[i])
        G_s = get_gram_matrix(feature)
        loss_func=nn.MSELoss().to(device)
        E += w * loss_func(G_x,G_s)/ (4 * C**2 * H**2 * W**2)*100**3
    return E

def content_loss(out_x, out_content):
    loss_func=nn.MSELoss().to(device)
    C=int(out_x.shape[1])
    return loss_func(out_x, out_content)/(2*C**2)*100**3



In [5]:
#建立模型
class VGG19(nn.Module): #vgg_19 model
    def __init__(self):
        super(VGG19, self).__init__()
        self.indexes=[-1,]
        vgg_model=vgg19()
        pre_trained=torch.load(vgg_path)
        vgg_model.load_state_dict(pre_trained)
        self.features=vgg_model.features
        for i,layer in enumerate(self.features):
            if isinstance(layer,nn.ReLU):
                self.indexes.append(i)
        selected_layer=[0,1,3,5,9,10,13] # 选择用来计算损失函数的ReLU层
        self.indexes=np.array(self.indexes)[selected_layer]

    def forward(self,input):
        features_bank=[]
        out=input
        n=len(self.indexes)
        for i in range(1,n):
            out=self.features[self.indexes[i-1]+1:self.indexes[i]+1](out) # 计算ReLU后的结果
            features_bank.append(out)
        out=features_bank[-2]
        return features_bank, out


class GNet(nn.Module): # 要训练的model
    def __init__(self, image):
        super(GNet, self).__init__()
        self.image_g = nn.Parameter(image.detach().clone())# 从白噪声开始

    def forward(self):
        return self.image_g



In [6]:
#训练模型
content_img=load_image(content_path).to(device)
style_img=load_image(style_path).to(device)
g_net=GNet(content_img).to(device)
vgg_net= VGG19().to(device)
with torch.no_grad():
    features_bank_style, out_style=vgg_net(style_img)
    features_bank_content, out_content=vgg_net(content_img)

In [7]:
def train_loop(epoches, alpha,beta ,learning_rate):
    optimizer = torch.optim.Adam(g_net.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.95)
    for t in tqdm(range(epoches)):
        image_x = g_net()
        features_bank_x, out_x= vgg_net(image_x)

        # 计算损失
        loss_s=style_loss(features_bank_x,features_bank_style)
        loss_c=content_loss(out_x, out_content)
        loss_total=alpha*loss_c+beta*loss_s

        # 优化器
        optimizer.zero_grad()
        loss_total.backward()
        optimizer.step()
        scheduler.step()

        #输出
        if t % 100 == 0:
            print(t, loss_total.item(), loss_c.item(), loss_s.item())
            save_image(image_x, f'{t/100}.jpg', padding=0, normalize=True, value_range=(0, 1))

In [None]:
train_loop(2000, 1e-3,1 ,0.01)