In [1]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
from torchvision.utils import save_image

In [2]:
model=models.vgg19(pretrained=True).features
#Assigning the GPU to the variable device
device=torch.device("cuda" if (torch.cuda.is_available()) else 'cpu')



In [4]:
#create a function that should enable the computer or the program to understand and evaluate the input
def image_loader(path):
    image=Image.open(path)
    #defining the image transformation steps to be performed before feeding them to the model
    loader=transforms.Compose([transforms.Resize((512,512)), transforms.ToTensor()])
    #The preprocessing steps involves resizing the image and then converting it to a tensor
    image=loader(image).unsqueeze(0)
    return image.to(device,torch.float)

#Load the original image and the style to be used 
original_image=image_loader('/Users/paulivanespiritu/Downloads/131048452_4207893772561267_374415252907794725_n.jpg')
style_image=image_loader('/Users/paulivanespiritu/Downloads/H0027-L17051937.jpg')

#Image generation from the original image
generated_image=original_image.clone().requires_grad_(True)

In [5]:
#Defining a class that for the model
class VGG(nn.Module):
    def __init__(self):
        super(VGG,self).__init__()
        self.req_features= ['0','5','10','19','28'] 
        #Since we need only the 5 layers in the model so we will be dropping all the rest layers from the features of the model
        self.model=models.vgg19(pretrained=True).features[:29] #model will contain the first 29 layers
    
   
    #x holds the input tensor(image) that will be feeded to each layer
    def forward(self,x):
        #initialize an array that wil hold the activations from the chosen layers
        features=[]
        #Iterate over all the layers of the mode
        for layer_num,layer in enumerate(self.model):
            #activation of the layer will stored in x
            x=layer(x)
            #appending the activation of the selected layers and return the feature array
            if (str(layer_num) in self.req_features):
                features.append(x)
                
        return features

In [6]:
def calc_content_loss(gen_feat,orig_feat):
    #check for the loss in each layer
    content_l=torch.mean((gen_feat-orig_feat)**2)
    return content_l

In [7]:
#Calculating the gram matrix for the style and the generated image
def calc_style_loss(gen,style):
    batch_size,channel,height,width=gen.shape

    G=torch.mm(gen.view(channel,height*width),gen.view(channel,height*width).t())
    A=torch.mm(style.view(channel,height*width),style.view(channel,height*width).t())
        
    #Calcultating the style loss of each layer by calculating the MSE between the gram matrix of the style image and the generated image and adding it to style loss
    style_l=torch.mean((G-A)**2)
    return style_l

In [8]:
def calculate_loss(gen_features, orig_feautes, style_featues):
    style_loss=content_loss=0
    for gen,cont,style in zip(gen_features,orig_feautes,style_featues):
        #defining dimesntions
        content_loss+=calc_content_loss(gen,cont)
        style_loss+=calc_style_loss(gen,style)
    
    #calculating the total loss of e th epoch
    total_loss=alpha*content_loss + beta*style_loss 
    return total_loss

In [9]:
#Load the model to the processor
model=VGG().to(device).eval() 

#defining the parameters to follow for the test
epoch=7000
lr=0.004
alpha=8
beta=70

#using adam optimizer and it will update the generated image not the model parameter 
optimizer=optim.Adam([generated_image],lr=lr)

In [10]:
#training
for e in range (epoch):
    #extracting the features of generated, content and the original required for calculating the loss
    gen_features=model(generated_image)
    orig_feautes=model(original_image)
    style_featues=model(style_image)
    
    #iterating over the activation of each layer and calculate the loss and add it to the content and the style loss
    total_loss=calculate_loss(gen_features, orig_feautes, style_featues)
    #optimize the pixel values of the generated image and backpropagate the loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    #print the image and save it after each 100 epoch
    if(e/100):
        print(total_loss)
        
        save_image(generated_image,"Number2.png")    

tensor(8.6265e+09, grad_fn=<AddBackward0>)
tensor(7.3327e+09, grad_fn=<AddBackward0>)
tensor(6.3574e+09, grad_fn=<AddBackward0>)
tensor(5.6187e+09, grad_fn=<AddBackward0>)
tensor(5.0296e+09, grad_fn=<AddBackward0>)
tensor(4.5764e+09, grad_fn=<AddBackward0>)
tensor(4.2244e+09, grad_fn=<AddBackward0>)
tensor(3.9287e+09, grad_fn=<AddBackward0>)
tensor(3.6713e+09, grad_fn=<AddBackward0>)
tensor(3.4459e+09, grad_fn=<AddBackward0>)
tensor(3.2411e+09, grad_fn=<AddBackward0>)
tensor(3.0491e+09, grad_fn=<AddBackward0>)
tensor(2.8668e+09, grad_fn=<AddBackward0>)
tensor(2.6933e+09, grad_fn=<AddBackward0>)
tensor(2.5277e+09, grad_fn=<AddBackward0>)
tensor(2.3701e+09, grad_fn=<AddBackward0>)
tensor(2.2209e+09, grad_fn=<AddBackward0>)
tensor(2.0798e+09, grad_fn=<AddBackward0>)
tensor(1.9462e+09, grad_fn=<AddBackward0>)
tensor(1.8195e+09, grad_fn=<AddBackward0>)
tensor(1.6994e+09, grad_fn=<AddBackward0>)
tensor(1.5854e+09, grad_fn=<AddBackward0>)
tensor(1.4771e+09, grad_fn=<AddBackward0>)
tensor(1.37

tensor(80823664., grad_fn=<AddBackward0>)
tensor(80588160., grad_fn=<AddBackward0>)
tensor(80354528., grad_fn=<AddBackward0>)
tensor(80123456., grad_fn=<AddBackward0>)
tensor(79893512., grad_fn=<AddBackward0>)
tensor(79666200., grad_fn=<AddBackward0>)
tensor(79440840., grad_fn=<AddBackward0>)
tensor(79216560., grad_fn=<AddBackward0>)
tensor(78994704., grad_fn=<AddBackward0>)
tensor(78774512., grad_fn=<AddBackward0>)
tensor(78556464., grad_fn=<AddBackward0>)
tensor(78340152., grad_fn=<AddBackward0>)
tensor(78125192., grad_fn=<AddBackward0>)
tensor(77912112., grad_fn=<AddBackward0>)
tensor(77700920., grad_fn=<AddBackward0>)
tensor(77491352., grad_fn=<AddBackward0>)
tensor(77283296., grad_fn=<AddBackward0>)
tensor(77077048., grad_fn=<AddBackward0>)
tensor(76871832., grad_fn=<AddBackward0>)
tensor(76668248., grad_fn=<AddBackward0>)
tensor(76466512., grad_fn=<AddBackward0>)
tensor(76266416., grad_fn=<AddBackward0>)
tensor(76068016., grad_fn=<AddBackward0>)
tensor(75871240., grad_fn=<AddBack

tensor(54795736., grad_fn=<AddBackward0>)
tensor(54717048., grad_fn=<AddBackward0>)
tensor(54638812., grad_fn=<AddBackward0>)
tensor(54560668., grad_fn=<AddBackward0>)
tensor(54482716., grad_fn=<AddBackward0>)
tensor(54405396., grad_fn=<AddBackward0>)
tensor(54328152., grad_fn=<AddBackward0>)
tensor(54251348., grad_fn=<AddBackward0>)
tensor(54174616., grad_fn=<AddBackward0>)
tensor(54098460., grad_fn=<AddBackward0>)
tensor(54022420., grad_fn=<AddBackward0>)
tensor(53946604., grad_fn=<AddBackward0>)
tensor(53870980., grad_fn=<AddBackward0>)
tensor(53795724., grad_fn=<AddBackward0>)
tensor(53720824., grad_fn=<AddBackward0>)
tensor(53646144., grad_fn=<AddBackward0>)
tensor(53571584., grad_fn=<AddBackward0>)
tensor(53497180., grad_fn=<AddBackward0>)
tensor(53423272., grad_fn=<AddBackward0>)
tensor(53349340., grad_fn=<AddBackward0>)
tensor(53275864., grad_fn=<AddBackward0>)
tensor(53202908., grad_fn=<AddBackward0>)
tensor(53129884., grad_fn=<AddBackward0>)
tensor(53057328., grad_fn=<AddBack

tensor(43219284., grad_fn=<AddBackward0>)
tensor(43178968., grad_fn=<AddBackward0>)
tensor(43141788., grad_fn=<AddBackward0>)
tensor(43106400., grad_fn=<AddBackward0>)
tensor(43075300., grad_fn=<AddBackward0>)
tensor(43047348., grad_fn=<AddBackward0>)
tensor(43025852., grad_fn=<AddBackward0>)
tensor(43000196., grad_fn=<AddBackward0>)
tensor(42969612., grad_fn=<AddBackward0>)
tensor(42909192., grad_fn=<AddBackward0>)
tensor(42827308., grad_fn=<AddBackward0>)
tensor(42738348., grad_fn=<AddBackward0>)
tensor(42679988., grad_fn=<AddBackward0>)
tensor(42656380., grad_fn=<AddBackward0>)
tensor(42637544., grad_fn=<AddBackward0>)
tensor(42599680., grad_fn=<AddBackward0>)
tensor(42536468., grad_fn=<AddBackward0>)
tensor(42467596., grad_fn=<AddBackward0>)
tensor(42413604., grad_fn=<AddBackward0>)
tensor(42379700., grad_fn=<AddBackward0>)
tensor(42352136., grad_fn=<AddBackward0>)
tensor(42316364., grad_fn=<AddBackward0>)
tensor(42270032., grad_fn=<AddBackward0>)
tensor(42214040., grad_fn=<AddBack

tensor(35633752., grad_fn=<AddBackward0>)
tensor(35637692., grad_fn=<AddBackward0>)
tensor(35623964., grad_fn=<AddBackward0>)
tensor(35574160., grad_fn=<AddBackward0>)
tensor(35505280., grad_fn=<AddBackward0>)
tensor(35458052., grad_fn=<AddBackward0>)
tensor(35440280., grad_fn=<AddBackward0>)
tensor(35429116., grad_fn=<AddBackward0>)
tensor(35406216., grad_fn=<AddBackward0>)
tensor(35364280., grad_fn=<AddBackward0>)
tensor(35312452., grad_fn=<AddBackward0>)
tensor(35260948., grad_fn=<AddBackward0>)
tensor(35221060., grad_fn=<AddBackward0>)
tensor(35193724., grad_fn=<AddBackward0>)
tensor(35173308., grad_fn=<AddBackward0>)
tensor(35156224., grad_fn=<AddBackward0>)
tensor(35140500., grad_fn=<AddBackward0>)
tensor(35132320., grad_fn=<AddBackward0>)
tensor(35124968., grad_fn=<AddBackward0>)
tensor(35122812., grad_fn=<AddBackward0>)
tensor(35099620., grad_fn=<AddBackward0>)
tensor(35055996., grad_fn=<AddBackward0>)
tensor(34968256., grad_fn=<AddBackward0>)
tensor(34878392., grad_fn=<AddBack

tensor(29789242., grad_fn=<AddBackward0>)
tensor(29822064., grad_fn=<AddBackward0>)
tensor(29926432., grad_fn=<AddBackward0>)
tensor(30115730., grad_fn=<AddBackward0>)
tensor(30507350., grad_fn=<AddBackward0>)
tensor(30837856., grad_fn=<AddBackward0>)
tensor(31142450., grad_fn=<AddBackward0>)
tensor(30416590., grad_fn=<AddBackward0>)
tensor(29618442., grad_fn=<AddBackward0>)
tensor(29792998., grad_fn=<AddBackward0>)
tensor(30203558., grad_fn=<AddBackward0>)
tensor(29894964., grad_fn=<AddBackward0>)
tensor(29489742., grad_fn=<AddBackward0>)
tensor(29787424., grad_fn=<AddBackward0>)
tensor(29818904., grad_fn=<AddBackward0>)
tensor(29439358., grad_fn=<AddBackward0>)
tensor(29566116., grad_fn=<AddBackward0>)
tensor(29657970., grad_fn=<AddBackward0>)
tensor(29364074., grad_fn=<AddBackward0>)
tensor(29430906., grad_fn=<AddBackward0>)
tensor(29507164., grad_fn=<AddBackward0>)
tensor(29272602., grad_fn=<AddBackward0>)
tensor(29318252., grad_fn=<AddBackward0>)
tensor(29369710., grad_fn=<AddBack

tensor(24964364., grad_fn=<AddBackward0>)
tensor(24916956., grad_fn=<AddBackward0>)
tensor(24933348., grad_fn=<AddBackward0>)
tensor(24920700., grad_fn=<AddBackward0>)
tensor(24863072., grad_fn=<AddBackward0>)
tensor(24816332., grad_fn=<AddBackward0>)
tensor(24810318., grad_fn=<AddBackward0>)
tensor(24810060., grad_fn=<AddBackward0>)
tensor(24781764., grad_fn=<AddBackward0>)
tensor(24734998., grad_fn=<AddBackward0>)
tensor(24694958., grad_fn=<AddBackward0>)
tensor(24676834., grad_fn=<AddBackward0>)
tensor(24669384., grad_fn=<AddBackward0>)
tensor(24659266., grad_fn=<AddBackward0>)
tensor(24643684., grad_fn=<AddBackward0>)
tensor(24617134., grad_fn=<AddBackward0>)
tensor(24585732., grad_fn=<AddBackward0>)
tensor(24548942., grad_fn=<AddBackward0>)
tensor(24513274., grad_fn=<AddBackward0>)
tensor(24483066., grad_fn=<AddBackward0>)
tensor(24460874., grad_fn=<AddBackward0>)
tensor(24443936., grad_fn=<AddBackward0>)
tensor(24431456., grad_fn=<AddBackward0>)
tensor(24427796., grad_fn=<AddBack

tensor(21140994., grad_fn=<AddBackward0>)
tensor(21178352., grad_fn=<AddBackward0>)
tensor(21028556., grad_fn=<AddBackward0>)
tensor(21097968., grad_fn=<AddBackward0>)
tensor(21015838., grad_fn=<AddBackward0>)
tensor(20978566., grad_fn=<AddBackward0>)
tensor(20978872., grad_fn=<AddBackward0>)
tensor(20906028., grad_fn=<AddBackward0>)
tensor(20922114., grad_fn=<AddBackward0>)
tensor(20849696., grad_fn=<AddBackward0>)
tensor(20838934., grad_fn=<AddBackward0>)
tensor(20827184., grad_fn=<AddBackward0>)
tensor(20765758., grad_fn=<AddBackward0>)
tensor(20769720., grad_fn=<AddBackward0>)
tensor(20718810., grad_fn=<AddBackward0>)
tensor(20706660., grad_fn=<AddBackward0>)
tensor(20685800., grad_fn=<AddBackward0>)
tensor(20640528., grad_fn=<AddBackward0>)
tensor(20632694., grad_fn=<AddBackward0>)
tensor(20596008., grad_fn=<AddBackward0>)
tensor(20576592., grad_fn=<AddBackward0>)
tensor(20560142., grad_fn=<AddBackward0>)
tensor(20523752., grad_fn=<AddBackward0>)
tensor(20509872., grad_fn=<AddBack

tensor(17554168., grad_fn=<AddBackward0>)
tensor(17558298., grad_fn=<AddBackward0>)
tensor(17484586., grad_fn=<AddBackward0>)
tensor(17509438., grad_fn=<AddBackward0>)
tensor(17461726., grad_fn=<AddBackward0>)
tensor(17439390., grad_fn=<AddBackward0>)
tensor(17445102., grad_fn=<AddBackward0>)
tensor(17397950., grad_fn=<AddBackward0>)
tensor(17415006., grad_fn=<AddBackward0>)
tensor(17415542., grad_fn=<AddBackward0>)
tensor(17445794., grad_fn=<AddBackward0>)
tensor(17567890., grad_fn=<AddBackward0>)
tensor(17774612., grad_fn=<AddBackward0>)
tensor(18016256., grad_fn=<AddBackward0>)
tensor(18334998., grad_fn=<AddBackward0>)
tensor(18048216., grad_fn=<AddBackward0>)
tensor(17509240., grad_fn=<AddBackward0>)
tensor(17292204., grad_fn=<AddBackward0>)
tensor(17533108., grad_fn=<AddBackward0>)
tensor(17672762., grad_fn=<AddBackward0>)
tensor(17422384., grad_fn=<AddBackward0>)
tensor(17216918., grad_fn=<AddBackward0>)
tensor(17371500., grad_fn=<AddBackward0>)
tensor(17405394., grad_fn=<AddBack

tensor(14632270., grad_fn=<AddBackward0>)
tensor(14615334., grad_fn=<AddBackward0>)
tensor(14601529., grad_fn=<AddBackward0>)
tensor(14588502., grad_fn=<AddBackward0>)
tensor(14575321., grad_fn=<AddBackward0>)
tensor(14563312., grad_fn=<AddBackward0>)
tensor(14553811., grad_fn=<AddBackward0>)
tensor(14545531., grad_fn=<AddBackward0>)
tensor(14540426., grad_fn=<AddBackward0>)
tensor(14538684., grad_fn=<AddBackward0>)
tensor(14544557., grad_fn=<AddBackward0>)
tensor(14550002., grad_fn=<AddBackward0>)
tensor(14555378., grad_fn=<AddBackward0>)
tensor(14540688., grad_fn=<AddBackward0>)
tensor(14507530., grad_fn=<AddBackward0>)
tensor(14454166., grad_fn=<AddBackward0>)
tensor(14409040., grad_fn=<AddBackward0>)
tensor(14392169., grad_fn=<AddBackward0>)
tensor(14398910., grad_fn=<AddBackward0>)
tensor(14406521., grad_fn=<AddBackward0>)
tensor(14395862., grad_fn=<AddBackward0>)
tensor(14368132., grad_fn=<AddBackward0>)
tensor(14331948., grad_fn=<AddBackward0>)
tensor(14307263., grad_fn=<AddBack

KeyboardInterrupt: 