- title: "Neural style transfer 2: style"
- description: "Feature visualization with PyTorch"
- toc: false
- branch: master
- badges: true
- comments: true
- categories: [fastpages, jupyter]
- image: images/some_folder/your_image.png
- hide: false
- search_exclude: true
- metadata_key1: metadata_value1
- metadata_key2: metadata_value2

In [1]:
# Imports
import torch
import torchvision
from torch import nn
import skimage
from skimage import transform
from skimage import io
# from im_func import show_image, timer
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from time import time
import contextlib

@contextlib.contextmanager
def timer(msg='timer'):
    tic = time()
    yield
    return print(f"{msg}: {time() - tic:.2f}")

In [2]:
rgb_mean = torch.tensor([0.485, 0.456, 0.406]) # Fixed values for PyTorch pretrained models
rgb_std = torch.tensor([0.229, 0.224, 0.225])

class Image(nn.Module):
    def __init__(self, img=None, optimizable=True, img_shape=[64,64], jit_max=2, angle_max=2.0):
        super(Image,self).__init__()
        
        self.img_shape = img_shape
        
        if type(img)==type(None):
            self.img = torch.randn([1, 3] + self.img_shape)
        else:
            self.img = img
            self.img = self.preprocess()

        if optimizable == True:
            self.img = nn.Parameter(self.img)
         
        self.jit_i = 0
        self.jit_j = 0
        self.jit_max = jit_max
        self.angle = 0.0
        self.angle_max = angle_max
    def preprocess(self):
        with torch.no_grad():
            transforms = torchvision.transforms.Compose([
                torchvision.transforms.ToPILImage(),
                torchvision.transforms.Resize(self.img_shape),
                torchvision.transforms.ToTensor(),
            ])
            return transforms(self.img).unsqueeze(0)
            

    def postprocess(self):
        with torch.no_grad():
            img = self.img.data[0].to(rgb_std.device).clone()
        return torchvision.transforms.ToPILImage()(img.permute(1, 2, 0).permute(2, 0, 1))
          
    def jittered_image(self):
        with torch.no_grad():
            jit_max = 2
            temp = np.random.standard_normal(2)*2.0
            self.jit_i += temp[0]
            self.jit_j += temp[1]

            self.angle += np.random.standard_normal(1)[0]*1.0
            self.angle = np.clip(self.angle,-self.angle_max,self.angle_max)
            self.jit_i, self.jit_j = np.clip([self.jit_i, self.jit_j],-self.jit_max,self.jit_max)#.astype(int)
            print(self.angle, self.jit_i, self.jit_j, temp)
            return torchvision.transforms.functional.affine(self.img.data, angle=self.angle, translate=(self.jit_i/self.img_shape[1], self.jit_j/self.img_shape[0]), scale=1., shear=[0.0,0.0])#,interpolation=torchvision.transforms.functional.InterpolationMode.BILINEAR)
            
        
    def forward(self, jitter=False):
        if jitter:
            return self.jittered_image()
        else:
            return self.img
            

In [None]:
pretrained_net = torchvision.models.vgg16(pretrained=True)#.features.to(device).eval()
display(pretrained_net.features)
content_layer = [7]

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))

In [None]:
class SmallNet(nn.Module):
    def __init__(self, pretrained_net, last_layer):
        super(SmallNet,self).__init__()
        self.net= nn.Sequential(*([torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)] + 
                     [pretrained_net.features[i]
                            for i in range(last_layer + 1)])).to(device).eval()

    def forward(self, X, extract_layers, gram=False):
        # Passes the image X through the pretrained network and 
        # returns a list containing the feature maps of layers specified in the list extract_layers
        detach = not(X.requires_grad) # We don't want to keep track of the gradients on the content image
        feature_maps = []
        for il in range(len(self.net)):
            X = self.net[il](X)
            if (il-1 in extract_layers): # note: il-1 because I added a normalization layer before the pretrained net in self.net
                if detach:
                    feature_maps.append(X.clone().detach())    
                else:
                    feature_maps.append(X.clone())
                    
        if gram:
            return self._get_gram(feature_maps)
        else:
            return feature_maps
    
    def _get_gram(self, feature_maps):
        gram_matrices = []
        for fm in feature_maps:
            a, b, c, d = fm.size()  # a=batch size(=1)
            features = fm.view(a * b, c * d)  # resise F_XL into \hat F_XL
            G = torch.mm(features, features.t())  # compute the gram product
            gram_matrices.append(G.div(a * b * c * d))
        return gram_matrices

In [None]:
class Losses(nn.Module):
    def __init__(self, img_ref, 
                 content_weight=1.0, tv_weight_ini=0.0, clamp_weight_ini=1.0, int_weight_ini=0.0, 
                 alpha=6, beta=1.5,
                 weight_adjust_bounds=[0.01,0.1],
                 weight_adjust_fac=[1.5, 0.75]):
        super(Losses,self).__init__()
        # img_ref is used to compute a reference total variation and reference intensity
        # tv_weight: weight of the total variation regularizer
        # int_weight: weight of the intensity regularizer
        # alpha: exponent for the intensity regularizer
        # beta: exponent for the total variation regularizer
        self.content_weight = content_weight
        self.tv_weight = tv_weight_ini
        self.int_weight = int_weight_ini
        self.clamp_weight = clamp_weight_ini
        self.content_loss = 0.0
        self.tv_loss = 0.0
        self.int_loss = 0.0
        self.total_loss = 0.0
        
        self.alpha = alpha
        self.beta = beta
        
        self.B, self.V = self.get_regularizer_refs(img_ref)
        
        self.weight_adjust_bounds = weight_adjust_bounds
        self.weight_adjust_fac = weight_adjust_fac
        

    def get_content_loss(self, Y_hat, Y, reduction='mean'):
        # Mean squared error between generated and content image
        loss = 0
#         for i in range(len(feature_map_content)):
#             loss += torch.mean((feature_map_content[i]-feature_map_gen[i])**2)
#         return loss
        for y_hat, y in zip(Y_hat,Y):
            loss += F.mse_loss(y_hat, y.detach(),reduction=reduction)
        return loss
    
    
    def get_style_loss(gram_style_image, gram_generated_image):
        loss = 0
        for A, G in zip(gram_style_image,gram_generated_image):
            loss += F.mse_loss(G, A.detach(),reduction='sum')
        return loss
    
    def get_regularizer_refs(self, img):
        eps = 1e-10
        L2 = torch.sqrt(img[:,0,:,:]**2 + img[:,1,:,:]**2 + img[:,2,:,:]**2 + eps)
        B = L2.mean()

        d_dx = img[:,:,1:,:]-img[:,:,:-1,:]
        d_dy = img[:,:,:,1:]-img[:,:,:,:-1]
        L2 = torch.sqrt(d_dx[:,:,:,1:]**2 + d_dy[:,:,1:,:]**2 + eps)
        V = L2.mean()
        return B, V

    def get_int_loss(self, img):
        # Intensity loss
        H = img.shape[2]
        W = img.shape[3]
        eps = 1e-10
        L2 = torch.sqrt(img[:,0,:,:]**2 + img[:,1,:,:]**2 + img[:,2,:,:]**2 + eps)
        
        loss = 1./H/W/(self.B**self.alpha) * torch.sum(L2**self.alpha)
        
        return loss

    def get_clamp_loss(self, img):
#         loss = torch.sum(img[img>1.0]**2) + torch.sum((1.0-img[img<0.0])**2)
        H = img.shape[2]
        W = img.shape[3]
        loss = 1.0/H/W * (torch.sum(torch.abs(img[img>1.0]-1.0)**2) + torch.sum(torch.abs(img[img<0.0])**2))
        return loss

    def get_TV_loss(self, img):
        # Total variation loss
        H = img.shape[2]
        W = img.shape[3]
        C = img.shape[1]
        eps = 1e-10 # avoids accidentally taking the sqrt of a negative number because of rounding errors

        # # total variation
        d_dx = img[:,:,1:,:]-img[:,:,:-1,:]
        d_dy = img[:,:,:,1:]-img[:,:,:,:-1]
        # I ignore the first row or column of the image when computing the norm, in order to have vectors with matching sizes
        # Thus, d_dx and d_dy are not strictly colocated, but that should be a good enough approximation because neighbouring pixels are correlated
        L2 = torch.sqrt(d_dx[:,:,:,1:]**2 + d_dy[:,:,1:,:]**2 + eps)
        TV = torch.sum(L2**self.beta) # intensity regularizer

        loss = 1./H/W/(self.V**self.beta) * TV
#         loss = 1./H/W * (torch.sum(d_dx**2) + torch.sum(d_dy**2))
        return loss
    
    def _adjust_weight(self,weight,loss):
        lb, ub = self.weight_adjust_bounds
        lfac, ufac = self.weight_adjust_fac
        if weight>1e-10: # if weight_ini=0, the user wants to switch the weight off, we don't want to accidentally activate it because of rounding errors
            if weight*loss<lb*(self.content_weight*self.content_loss):
                weight *= lfac
            if weight*loss>ub*(self.content_weight*self.content_loss):
                weight *= ufac
        return weight
                
    def adjust_weights(self):
        
        self.tv_weight = self._adjust_weight(self.tv_weight,self.tv_loss)
        self.int_weight = self._adjust_weight(self.int_weight,self.int_loss)
          
        if self.clamp_loss>1e-10:
            self.clamp_weight = self._adjust_weight(self.clamp_weight,self.clamp_loss)
    
    def forward(self,img,feature_map, feature_map_target):
#         self.content_loss = self.get_content_loss(feature_map, feature_map_target)
        self.content_loss = self.get_content_loss(feature_map, feature_map_target,reduction='sum')
        self.int_loss = self.get_int_loss(img)
        self.tv_loss = self.get_int_loss(img)
        self.clamp_loss = self.get_clamp_loss(img)
        
        self.total_loss = ( self.content_weight*self.content_loss 
                    + self.int_weight*self.int_loss 
                    + self.tv_weight*self.tv_loss 
                    + self.clamp_weight*self.clamp_loss )
        
        return self.total_loss

In [None]:
device = 'cpu'

# Images
# content_im = skimage.io.imread("https://github.com/scijs/baboon-image/blob/master/baboon.png?raw=true")
content_im = skimage.io.imread("https://github.com/abauville/Neural_style_transfer/blob/main/Data/matisse_cat.jpeg?raw=true")
fig, ax = plt.subplots(1,1,figsize=[5,5])
_ = plt.imshow(content_im); plt.title("content"); _ = plt.axis("off")

img_content = Image(img=content_im, optimizable=False).to(device)
img_gen = Image(None, optimizable=True).to(device)

# SmallNet
net = SmallNet(pretrained_net, content_layer[-1])

# Losses
loss_fn = Losses(img_content(), 
                 tv_weight_ini=1e3, 
                 int_weight_ini=0.0,
                 clamp_weight_ini=1e1,
                 weight_adjust_bounds=[0.01, 0.1],
                 weight_adjust_fac=[1.5, 0.75])

# Optimizer
optimizer = torch.optim.LBFGS(img_gen.parameters(),lr=1.0)
abs_loss_limit = 1e-3
rel_loss_limit = 1e-7

# Jitter the input image?
jitter_nsteps = 30

In [None]:
fig, ax = plt.subplots(1,1,figsize=[10,10])

# for sanity
if jitter_nsteps<0:
    jitter_nsteps = 0

def closure():
    optimizer.zero_grad()
    fm_gen = net(img_gen(), content_layer, gram=True)
    loss = loss_fn(img_gen(), fm_gen, fm_content)
    loss.backward()
    return loss

last_loss = 1e10
frame = 0
for i in range(1000):
    if i<jitter_nsteps:
        fm_content = net(img_content(jitter=True), content_layer, gram=True)
    elif i==jitter_nsteps:
        fm_content = net(img_content(jitter=False), content_layer, gram=True)
    optimizer.step(closure)
    loss_fn.adjust_weights()

    if i%1==0:
        with torch.no_grad():
            plt.clf()
            plt.imshow(img_gen.postprocess())
            
            plt.title(f"epoch {i:02}, content, tv, intensity, clamp losses:" + 
                      f"{loss_fn.content_weight*loss_fn.content_loss:.2e}, " + 
                      f"{loss_fn.tv_weight*loss_fn.tv_loss:.2e}, " +
                      f"{loss_fn.int_weight*loss_fn.int_loss:.2e}, " + 
                      f"{loss_fn.clamp_weight*loss_fn.clamp_loss:.2e}, "+
                      f"total:{loss_fn.total_loss:.2e}, abs: {torch.abs(last_loss-loss_fn.total_loss):.2e}")
            
            clear_output(wait = True)
            display(fig)
            
            # plt.savefig(f"./Output/Frame{frame:05d}")
            # frame += 1

            if loss_fn.total_loss<abs_loss_limit and loss_fn.clamp_loss<1e-10:
                clear_output(wait = True)
                print(f'success: absolute loss limit ({abs_loss_limit:.1e}) reached')
                break
            if torch.abs(last_loss-loss_fn.total_loss)<rel_loss_limit and loss_fn.clamp_loss<1e-10:
                clear_output(wait = True)
                print(f'stopped because relative loss limit ({rel_loss_limit:.1e})  was reached')
                break

            if loss_fn.total_loss.isnan():
                print(f'stopped because loss is NaN')
                break
                
    last_loss = loss_fn.total_loss

    with torch.no_grad():
        img_gen.img.data.clamp_(0,1)