In [1]:
import torch as th
import torchvision
from utils import *
from loss import *
from model import *
import numpy as np
import matplotlib.pyplot as plt
from torchsummary import summary
from PIL import Image
from dataset import COCODataset
from torchvision import transforms
from vgg import Vgg16

In [2]:
model = Vgg16(requires_grad = False).cuda()
autoencoder = TransformerNet().cuda()

In [3]:
optim = torch.optim.Adam(autoencoder.parameters(),lr = 1e-3)
criterion = th.nn.MSELoss()
epochs = 2
bs = 4
imsize = 256
content_weight = 1e5
style_weight = 1e10

In [4]:
style_transform = transforms.Compose([
        transforms.Resize((imsize,imsize)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])

train_transform = transforms.Compose([
        transforms.Resize((imsize,imsize)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    
    ])

test_transform = transforms.Compose([
        transforms.Resize((imsize,imsize)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    
    ])

style_img = Image.open('./images/skull.jpg')
style_img = style_transform(style_img)
style_img = style_img.unsqueeze(0).cuda()

test_img = Image.open('./images/amber.jpg')
test_img = test_transform(test_img)
test_img = test_img.unsqueeze(0).cuda()
trans = transforms.ToPILImage()

train_ds = COCODataset(root='/home/harsh/Downloads/val2017/',transform = train_transform)
train_dl = th.utils.data.DataLoader(train_ds,batch_size = bs)

In [5]:
model.eval()
autoencoder.train()

style_img = style_img.repeat(bs, 1, 1, 1).cuda()
style_features = model(normalize_batch(style_img))
gram_style = [gram_matrix(y) for y in style_features]


for epoch in range(epochs):
    
    for batch,input_img in enumerate(train_dl):
        input_img = input_img.cuda()
        content_loss = 0
        style_loss = 0
        
        optim.zero_grad()
        generated = autoencoder(input_img)
    
        gen_features = model(normalize_batch(generated))
        input_features = model(normalize_batch(input_img))
        
        content_loss = criterion(gen_features.relu2_2,input_features.relu2_2)
        for i in range(4):
            style_loss += criterion(gram_style[i],gram_matrix(gen_features[i]))
        content_loss *= content_weight
        style_loss *= style_weight
            
        loss = content_loss + style_loss      
        loss.backward()
        
        optim.step()
        if (batch+1)%250 == 0:
            autoencoder.eval()
            output = autoencoder(test_img)
            print('Saving Image test_{}_{}.jpg'.format(epoch,batch))
            output = trans(output[0].clamp(0,255).permute(1,2,0).cpu().detach().numpy().astype("uint8"))
            output.save('./images/generated/test_{}_{}.jpg'.format(epoch,batch))
            autoencoder.train()

In [6]:
torch.save(autoencoder.state_dict(), './auto_skull.pth')

In [7]:
autoencoder.load_state_dict(torch.load('auto_skull.pth'))

<All keys matched successfully>

In [8]:
def stylize(path,autoencoder):
    test_transform = transforms.Compose([
            transforms.Resize((1080,1080)),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))  
        ])
    test_img = Image.open(path)
    test_img = test_transform(test_img)
    test_img = test_img.unsqueeze(0).cuda()

    output = autoencoder(test_img)
    output = trans(output[0].clamp(0,255).permute(1,2,0).cpu().detach().numpy().astype("uint8"))
    output.save('./images/generated/gen.jpg')
