In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import models, transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import glob
device = torch.device('cuda' if torch.cuda.is_available() else 'mps')

In [2]:
#setting
size = (1080,1080)
path = './processed/'
img_transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])
content_weight , style_weight = 1,100
n_epoch = 500
sample_num = 50

In [3]:
def get_image(path, img_transform, size = (300,300)):
    image = Image.open(path)
    image = image.resize(size, Image.LANCZOS)
    image = img_transform(image).unsqueeze(0)
    return image.to(device)

def get_gram(m):
    _, c, h, w = m.size()
    m = m.view(c, h * w)
    m = torch.mm(m, m.t()) 
    return m

def denormalize_img(inp):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    return inp

<h3>VGG</h3>

In [4]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.selected_layers = [3, 8, 15, 22]
        self.vgg = models.vgg16(pretrained=True).features
        
    def forward(self, x):
        layer_features = []
        for layer_number, layer in self.vgg._modules.items():
            x = layer(x)
            if int(layer_number) in self.selected_layers:
                layer_features.append(x)
        return layer_features
encoder = FeatureExtractor().to(device)



<h3>style mean</h3>

In [5]:
path_list = sorted(glob.glob('./processed/*.jpg'))
style_features = torch.tensor([
    0,0,0,0
],requires_grad=False)
n=0

In [6]:
sample = get_image(path_list[0],img_transform)
sample = encoder(sample)
style_features =[]
for i,feature in enumerate(sample):
    style_features.append(torch.zeros_like(get_gram(feature)))

In [7]:
for i in range(sample_num):
    n+=1
    idx = np.random.randint(0,len(path_list))
    style_img = get_image(path_list[idx],img_transform)
    style_feature = encoder(style_img)
    for index, feature in enumerate(style_feature):
        _,c,h,w  = feature.size()
        gram = get_gram(feature)
        style_features[index] = ((style_features[index]*((n-1)/n)) + (1/n)*gram)

In [8]:
content_img = get_image('content.jpg', img_transform)
generated_img = content_img.clone()    # or nn.Parameter(torch.FloatTensor(content_img.size()))
generated_img.requires_grad = True

optimizer = torch.optim.Adam([generated_img], lr=0.003, betas=[0.5, 0.999])
encoder = FeatureExtractor().to(device)

for p in encoder.parameters():
    p.requires_grad = False
    
for epoch in range(100):
    
    content_features = encoder(content_img) 
    generated_features = encoder(generated_img)
    
    content_loss = torch.mean((content_features[-1] - generated_features[-1])**2)  #MSE

    style_loss = 0
    for gf, sf in zip(generated_features, style_features):
        _, c, h, w = gf.size()
        gram_gf = get_gram(gf)
        gram_sf = sf.detach()
        style_loss += torch.mean((gram_gf - gram_sf)**2)  / (c * h * w)

    loss = content_weight * content_loss + style_weight * style_loss 
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


RuntimeError: MPS backend out of memory (MPS allocated: 6.32 GB, other allocations: 2.74 GB, max allowed: 9.07 GB). Tried to allocate 21.97 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [19]:
content_weight = 1
style_weight = 100
content_img = get_image('content.jpg',img_transform)
style_img=get_image('./processed/img0.jpg',img_transform)
generated_img = content_img.clone() 
generated_img.requires_grad = True
optimizer = torch.optim.Adam([generated_img], lr=0.003, betas=[0.5, 0.999])
encoder = FeatureExtractor().to(device)

for epoch in range(500):
    
    content_features = encoder(content_img)
    style_features = encoder(style_img)
    generated_features = encoder(generated_img)
    
    content_loss = torch.mean((content_features[-1] - generated_features[-1])**2)  

    style_loss = 0
    for gf, sf in zip(generated_features, style_features):
        _, c, h, w = gf.size()
        gram_gf = get_gram(gf)
        gram_sf = get_gram(sf)
        style_loss += torch.mean((gram_gf - gram_sf)**2)  / (c * h * w) 

    loss = content_weight * content_loss + style_weight * style_loss 
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()




RuntimeError: MPS backend out of memory (MPS allocated: 508.73 MB, other allocations: 8.55 GB, max allowed: 9.07 GB). Tried to allocate 21.97 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [26]:
from torch.utils.tensorboard import SummaryWriter