# Neural Style Transfer

*** Neural style transfer implemented in Tensorflow, modified for to suit 3D T1s from https://www.tensorflow.org/tutorials/generative/style_transfer ****

In [None]:
from __future__ import print_function
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 
import IPython.display as display
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import keras
from keras.models import Model
from keras.layers import Activation
from keras.utils.vis_utils import plot_model
from numba import cuda, jit, njit
import nibabel as nib
import numpy as np
import pandas as pd
from sklearn.utils import shuffle
import random
from scipy import ndimage
import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib.ticker import PercentFormatter
from nilearn import plotting

import SimpleITK as sitk
import sys
import scipy.ndimage
import skimage.transform as skTrans

from preprocessingfunctions import * #Custom 
from stylefunctions import * #Custom

*** Random seed number for repeatability ***

In [None]:
import tensorflow
tensorflow.random.set_seed(2022)

## Import pre-trained model

In [None]:
model_custom = keras.models.load_model('model_path')

## Import data

In [None]:
source_directory_GE = 'insert path' 
destination_directory_GE = 'insert path'
dfGE = pd.read_csv(os.path.join(source_directory_GE,'GE_CO.csv'))
dfGE = dfGE.sort_values('filename')
dfGE = dfGE.reset_index()

source_directory_SE = 'insert path'
destination_directory_SE = 'insert path'
dfSIEMENS = pd.read_csv(os.path.join(source_directory_SE,'SE_CO.csv'))
dfSIEMENS = dfSIEMENS.sort_values('filename')
dfSIEMENS = dfSIEMENS.reset_index()

## Create arrays for ease of feeding into model

In [None]:
Content_scans = np.array([process_scan(path) for path in dfGE['FILEPATH']])
Style_scans = np.array([process_scan(path) for path in dfSIEMENS['FILEPATH']])

In [None]:
# View (optional)
print(Style_scans.shape)
def imshow(image, title=None):
    
    if len(image.shape) > 3:
        image = image[:,:,10]
        
    image = ndimage.rotate(image, 90)
    plt.imshow(image,cmap='jet')
    if title:
        plt.title(title)
imshow(Style_scans[0,:,:,75],'a')

print(Style_scans[0,:,:,75].shape)

## Define the content and style losses

In [None]:
def style_content_loss(outputs):
    style_outputs = outputs['style']
    content_outputs = outputs['content']
    
    #Style loss
    style_loss = tf.add_n([tf.reduce_mean((style_outputs[name]-style_targets[name])**2) for name in style_outputs.keys()])
    style_loss *= style_weight / num_style_layers
    
    #Content loss
    content_loss = tf.add_n([tf.reduce_mean((content_outputs[name]-content_targets[name])**2) for name in content_outputs.keys()])
    content_loss *= content_weight / num_content_layers
    
    #Total loss
    loss = style_loss + content_loss
    return style_loss, content_loss, loss

### Optional: Clip image dynamic ranges

In [None]:
def clip_0_1(image):
    tf.clip_by_value(image, clip_value_min=image.min(), clip_value_max=image.max())

## Define gradient function

In [None]:
# @tf.function()
def train_step(Tobe_styled_image,e,m):     
    with tf.GradientTape() as tape:
        outputs = extractor(Tobe_styled_image)
        style_loss, content_loss, loss = style_content_loss(outputs)
        tvl = total_variation_loss(Tobe_styled_image)
        loss += total_variation_weight*tvl  

    grad = tape.gradient(loss, Tobe_styled_image)
    opt.apply_gradients([(grad, Tobe_styled_image)])
    #----
    wandb.log({"Loss":loss.numpy()/total_variation_weight,
               "Epochs":e,
               "Steps per Epoch":m})
    return Tobe_styled_image.assign(Tobe_styled_image)

## Run style transfer model

In [None]:
import wandb
wandb.init(name='NST_vanila',
          project='Neural Style Transfer')

import time
from time import sleep
import tqdm.notebook as tq
from tqdm.auto import tqdm, trange
start = time.time()
from skimage.exposure import match_histograms

# Set some options here...
epochs = 1 
steps_per_epoch = 100
total_variation_weight = 10e3

bar = trange(len(dfGE))
for i in bar:
    print('Running style transfer on GE scan -->',i+1)
    Tobe_styled_image = tf.Variable(Content_scans[i,:,:,:][None,:,:,:])
    #  
    model_custom = keras.models.load_model(model_path)
    prediction_probabilities = model_custom(Tobe_styled_image)
    prediction_probabilities.shape

    print('------------------------------------------')
    #choose content and style layers, refer to output above for the correct layer names
    content_layers = ['conv3d'] #The first layer to preserve structural intergrity of the GE scans
    style_layers = ['conv3d','conv3d_1','conv3d_2','conv3d_3', 'conv3d_4','conv3d_5','conv3d_6','conv3d_7','conv3d_8',
           'conv3d_9','conv3d_10','conv3d_11','conv3d_12','conv3d_13', 'conv3d_14','conv3d_15','conv3d_16','conv3d_17','conv3d_18','conv3d_19'] #Every layer chosen so that maximum style can be extracted
    
    num_content_layers = len(content_layers)
    num_style_layers = len(style_layers)

    style_image = Style_scans[i,:,:,:][None,:,:,:]
    style_extractor = model_layers(style_layers)
    style_outputs = style_extractor(style_image)

    print('Number of content layers', num_content_layers)
    print('Number of style layers', num_style_layers)

    extractor = StyleContentModel(style_layers, content_layers)   
    style_targets = extractor(style_image)['style']
    content_targets = extractor(Tobe_styled_image)['content']
    
    print('------------------------------------------')
    for e in range(epochs):
        # wandb.log({"Epochs":e})
        for m in tq.tqdm(range(steps_per_epoch)):
            train_step(Tobe_styled_image,e,m) #styled image returned here

    ni_styled_image = nib.Nifti1Image(np.squeeze(Tobe_styled_image), nib.load(dfGE['FILEPATH'][i]).affine) #GE Styled converetd back to NIFTI format       
    end = time.time()
    print("Total time: {:.1f}".format(end-start))
    print('--------------------------------------------')
    
    styled_image_path = 'insert path'
    styled_image_name =  dfGE["filename"][i]
    nib.save(ni_styled_image, styled_image_path+styled_image_name)