In [30]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image
import torchvision.datasets as datasets
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

In [62]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_image(path, img_transforms, size = (300,300)):
    image = Image.open(path)
    image = image.resize(size, Image.Resampling.LANCZOS)
    image = img_transforms(image).unsqueeze(0)
    return image.to(device)
def get_gram(m):
    # m is shape (1,C,H,W)
    _, c, h, w = m.size()
    m = m.view(c, h*w)
    m = torch.mm(m, m.t())
    return m
def denormalize_img(inp):
    inp = inp.numpy().transpose((1, 2, 0)) #C,H,W -> H,W,C
    mean = np.array([0.485, 0.456, 0.406]) 
    std = np.array([0.229, 0.224, 0.225])
    inp = std*inp + mean
    inp = np.clip(inp,0,1)
    return inp


In [31]:
class FeaturesExtractor(nn.Module):
    def __init__(self):
        super(FeaturesExtractor, self).__init__()
        self.selected_layers = [3,8,15,22]
        self.vgg = models.vgg16(pretrained = True).features
    def forward(self,x):
        layer_feats = []
        for num, layer in self.vgg._modules.items():
            x = layer(x)
            if int(num) in self.selected_layers:
                layer_feats.append(x)
        return layer_feats

In [19]:
img_transforms = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])


In [40]:
style_image = get_image('style.jpg', img_transforms)
content_image = get_image('content.jpg', img_transforms)
generate_image = content_image.clone()

generate_image.requires_grad = True
optimizer = torch.optim.Adam([generate_image], lr=3e-3,betas = [0.5,0.999])
encoder = FeaturesExtractor().to(device)

In [41]:
for p in encoder.parameters():
    p.requires_grad = False #freezingthe layers. encoder.eval() also works
    

In [67]:
content_weight = 1
style_weight = 100

for epoch in range(1200):
    
    content_features = encoder(content_image)
    style_features = encoder(style_image)
    generate_features = encoder(generate_image)
    
    content_loss = torch.mean((content_features[-1] - generate_features[-1])**2)#content loss uses only the last layer  

    style_loss = 0
    for gf, sf in zip(generate_features, style_features):
        _, c, h, w = gf.size()
        gram_gf = get_gram(gf)
        gram_sf = get_gram(sf)
        style_loss += torch.mean((gram_gf - gram_sf)**2)  / (c * h * w) 

    loss = content_weight * content_loss + style_weight * style_loss 
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print ('Epoch [{}]\tContent Loss: {:.4f}\tStyle Loss: {:.4f}'.format(epoch, content_loss.item(), style_loss.item()))

Epoch [0]	Content Loss: 2.8667	Style Loss: 306.9382
Epoch [10]	Content Loss: 2.8731	Style Loss: 296.0386
Epoch [20]	Content Loss: 2.8790	Style Loss: 285.5115
Epoch [30]	Content Loss: 2.8844	Style Loss: 275.3465
Epoch [40]	Content Loss: 2.8894	Style Loss: 265.5317
Epoch [50]	Content Loss: 2.8945	Style Loss: 256.0580
Epoch [60]	Content Loss: 2.8986	Style Loss: 246.9178
Epoch [70]	Content Loss: 2.9029	Style Loss: 238.1077
Epoch [80]	Content Loss: 2.9075	Style Loss: 229.6163
Epoch [90]	Content Loss: 2.9121	Style Loss: 221.4362
Epoch [100]	Content Loss: 2.9163	Style Loss: 213.5622
Epoch [110]	Content Loss: 2.9202	Style Loss: 205.9820
Epoch [120]	Content Loss: 2.9240	Style Loss: 198.6911
Epoch [130]	Content Loss: 2.9274	Style Loss: 191.6786
Epoch [140]	Content Loss: 2.9303	Style Loss: 184.9445
Epoch [150]	Content Loss: 2.9328	Style Loss: 178.4767
Epoch [160]	Content Loss: 2.9355	Style Loss: 172.2651
Epoch [170]	Content Loss: 2.9377	Style Loss: 166.3086
Epoch [180]	Content Loss: 2.9394	Style 

In [68]:
inputs = generate_image.detach().cpu().squeeze()
print(inputs.shape)
inputs = denormalize_img(inputs)

torch.Size([3, 300, 300])


In [1]:
print(inputs.shape)
im = Image.fromarray(inputs,"RGB")
im.save("your_file.jpeg")

NameError: name 'inputs' is not defined

In [None]:
plt.imshow(inputs)