# Setup

## Imports

Credit to the original authors before I modified it: https://github.com/ndahlquist/pytorch-fourier-feature-networks

In [None]:
import torch
import rp
from tqdm.notebook import tqdm as tqdm
from IPython.display import clear_output
import icecream
from translator.pytorch_msssim import msssim
import numpy as np

In [None]:
from source.learnable_textures import LearnableImageFourier
from source.learnable_textures import LearnableImageRaster 
from source.learnable_textures import LearnableImageMLP    
from source.scene_reader       import extract_scene_uvs_and_scene_labels

## Other Setup

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
%matplotlib inline
%config InlineBackend.figure_format='retina'

# Fourier Feature Visualization

In [None]:
def visualize_features(number_of_features = 64):
    #Keep number_of_features small to avoid displaying an image that's too big and laggy
    
    fourier_image=LearnableImageFourier(height=64, width=64)
    
    uv_features = fourier_image.uv_grid[0]
    uv_features = rp.as_numpy_array(uv_features)
    
    print("Here are the two UV features used by the normal LearnableImageMLP (U and V on left and right, respectively):")
    icecream.ic(uv_features.shape, uv_features.min(), uv_features.max())
    rp.display_image(rp.tiled_images(uv_features))
    
    features=fourier_image.features
    features=fourier_image.get_features(torch.Tensor([.1,.2,.3,.4]))
    
    images=[]
    for feature in features.squeeze(0):
        image = rp.as_numpy_array(feature)
        image = image + 1
        image = image / 2
        images.append(image)
        
    images = images[:number_of_features]
        
    print("Here's a sample of the %i fourier features used by LearnableImageFourier:" % (2*fourier_image.num_features))
    icecream.ic(features.shape, features.min(), features.max())
    rp.display_image(rp.tiled_images(images))
    
visualize_features()

In [None]:
def visualize_scene_features():
    uvl_map = rp.load_image('assets/mutant_alphadew_uvl_scene.exr')
    uvl_map[:,:,2] = 0 #Get rid of the blue channel for visulization; in this demo it's just distracting
    
    scene_uvs, _ = extract_scene_uvs_and_scene_labels(rp.as_torch_images(uvl_map[None]), [0,127, 255])
    assert len(scene_uvs.shape)==4 and scene_uvs.shape[0]==1 and scene_uvs.shape[1]==2

    fourier_image = LearnableImageFourier()
    feature_extractor = fourier_image.feature_extractor
    scene_features = feature_extractor(scene_uvs)
    scene_features = scene_features[0]
    feature_maps = fourier_image.features[0]

    # Convert range [-1, 1] to [0, 1] so we can display the full range
    scene_features = (scene_features+1)/2 
    feature_maps   = (feature_maps  +1)/2 

    icecream.ic(uvl_map.shape, 
                scene_uvs.shape, 
                scene_features.shape, 
                feature_extractor.num_features, 
                fourier_image.features.shape,
                feature_maps.shape)

    print("A sample UVL scene:")
    rp.display_image(uvl_map)

    print("A random fourier feature of that scene:")

    feature_index = rp.random_index(feature_maps)

    rp.display_image(
        rp.horizontally_concatenated_images(
            rp.as_numpy_array(feature_maps  [feature_index]),
            rp.as_numpy_array(scene_features[feature_index])
        )
    )
    
visualize_scene_features()

In [None]:
def visualize_features_sine_cos():
    print("This cell shows how the first half of the features are sines, and the second half are cosines")

    num_features = 128
    fourier_image = LearnableImageFourier(num_features=num_features)
    feature_extractor = fourier_image.feature_extractor
    feature_maps = fourier_image.features[0]

    sines   = feature_maps[:num_features]
    cosines = feature_maps[num_features:]

    magnitudes = (sines**2 + cosines**2) ** .5

    icecream.ic(
        feature_extractor.num_features, 
        fourier_image.features.shape,
        feature_maps.shape,
        sines.shape,
        cosines.shape,
        magnitudes.shape,
        sines.min(), sines.max(),
        cosines.min(), cosines.max(),
        magnitudes.min(), magnitudes.max(),
    )


    feature_index = rp.random_index(magnitudes)

    sine      = (rp.as_numpy_array(sines     [feature_index]) + 1)/2
    cosine    = (rp.as_numpy_array(cosines   [feature_index]) + 1)/2
    magnitude = (rp.as_numpy_array(magnitudes[feature_index]) + 1)/2

    print("Here's a random feature map's sine and cosine respectively:")
    rp.display_image(
        rp.horizontally_concatenated_images(
            rp.cv_text_to_image('sqrt('),
            rp.labeled_image(sine,'(A sine feature)',position='bottom', size=20),
            rp.cv_text_to_image('^2  +  '),
            rp.labeled_image(cosine,'(its respective cosine)',position='bottom', size=20),
            rp.cv_text_to_image('^2)  =  '),
            rp.labeled_image(magnitude,'(magnitude = 1)',position='bottom', size=20),
        )
    )
    
visualize_features_sine_cos()

# Running Tests

## Target Image

In [None]:
target_image_choices={}

target_image_choices['fox'      ] = 'https://live.staticflickr.com/7492/15677707699_d9d67acf9d_b.jpg'
target_image_choices['magikarp' ] = 'https://static.pokemonpets.com/images/monsters-images-300-300/129-Magikarp.webp'
target_image_choices['makeup'   ] = 'https://i.redd.it/vxvs3dgsbxw31.png'
target_image_choices['snowflake'] = 'https://2s7gjr373w3x22jf92z99mgm5w-wpengine.netdna-ssl.com/wp-content/uploads/2020/11/snowflake_shutterstock_kichigin.jpg'
target_image_choices['portal'   ] = 'https://static.wikia.nocookie.net/half-life/images/9/9d/Atlas_P-body_fhp2.jpg/revision/latest?cb=20110519013122&path-prefix=en'
target_image_choices['uv'       ] = 'https://i.imgur.com/w9Mc6qN.png'

def display_target_images():
    target_labels, target_images = zip(*target_image_choices.items())
    target_images = [rp.load_image(image) for image in target_images]
    target_images = [rp.cv_resize_image(image,(128,128)) for image in target_images]
    target_images = rp.labeled_images(target_images, target_labels)
    target_images = rp.tiled_images(target_images)
    target_images = rp.labeled_image(target_images, "Choices", size=30)
    rp.display_image(target_images)

def load_target_image(target_image:str):
    target_image = rp.load_image    (target_image, use_cache=True)
    target_image = rp.as_float_image(target_image)
    target_image = rp.as_rgb_image  (target_image)
    target_image = rp.crop_image    (target_image, target_height, target_width, origin='center')
    target_image = target_image.copy()
    return target_image

display_target_images()

In [None]:
target_height = target_width = 512
target_image = target_image_choices['portal']

In [None]:
target_image = load_target_image(target_image)

assert rp.get_image_height(target_image) == target_height
assert rp.get_image_width (target_image) == target_width 

print("Target Image:")
icecream.ic(target_image.shape, target_image.dtype, type(target_image), target_image.max(), target_image.min())
rp.display_image(target_image)

## Testing Function

In [None]:
def run_test(method         = 'fourier',
             criterion      = 'mse'    ,
             num_features   = 128      ,
             hidden_dim     = 256      ,
             scale          = 10       ,
             iter_per_epoch = 100      ,
             learning_rate  = 1e-4     ,
             num_epochs     = 4        ,
             save_path      = None     ,
             load_path      = None     ):
    
    print('Method:',method)
    
    assert method                 in     'fourier mlp raster'.split()
    assert set(criterion.split()) <= set('mse msssim l1'.split())

    target = rp.as_torch_image(target_image).to(device)
    
    print('Target:')
    rp.display_image(rp.as_numpy_image(target))
    
    
    if method=='mlp':
        learnable_image=LearnableImageMLP(target_height,
                                          target_width,
                                          hidden_dim=hidden_dim)
    elif method=='fourier':
        learnable_image=LearnableImageFourier(target_height,
                                              target_width,
                                              num_features=num_features,
                                              hidden_dim=hidden_dim,
                                              scale=scale)
    elif method=='raster':
        learnable_image=LearnableImageRaster(target_height,
                                             target_width)
        
    learnable_image.to(device)
    
    optimizer = torch.optim.Adam(learnable_image.parameters(), lr=learning_rate)
    
    images = []
    losses = []

    def display_current_image():
        image = learnable_image.as_numpy_image()
        images.append(image)
        rp.display_image(image)
        return image
    
    if load_path:
        if rp.file_exists(load_path):
            state = torch.load(load_path)
            learnable_image.load_state_dict(state)

            print("Loaded state:")
            display_current_image()
            
        else:
            print(load_path, 'does not exist and cannot be loaded. Weights will be randomly initialized.')
    
    throw_error=False
    
    try:
        for iteration in tqdm(range(num_epochs*iter_per_epoch)):
            optimizer.zero_grad()

            generated = learnable_image()

            loss = 0
            if 'msssim' in criterion:
                loss += -msssim(target[None],generated[None],normalize=True)
                loss += 1 #Make the min loss 0, so we can do log plots of losses
            if 'mse'    in criterion:
                loss += torch.nn.functional.l1_loss(target, generated)
            if 'l1'     in criterion:
                loss += (generated-target).abs().mean()

            loss.backward()
            optimizer.step()

            losses.append(float(loss))

            if iteration % iter_per_epoch == 0:
                epoch = iteration // iter_per_epoch
                print('Epoch %d, loss = %.03f' % (epoch, float(loss)))
                image=display_current_image()
                
        images.append(image)
    except KeyboardInterrupt:
        throw_error=True
            
    clear_output()
    icecream.ic(method, criterion, num_features, hidden_dim, scale, iter_per_epoch, learning_rate, num_epochs)
    rp.line_graph_via_bokeh(losses,xlabel='Iter',ylabel='Loss',title=method,logy=10)
    rp.display_image_slideshow(images)
    
    if save_path:
        state=learnable_image.state_dict()
        torch.save(state,save_path)
        print("Saving",save_path)
    
    if throw_error:
        raise KeyboardInterrupt

## Tests

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#This cell tests the loading and saving capabilities
%mkdir untracked
run_test('fourier', 'msssim', 
         num_features=128, hidden_dim=256, scale=100, iter_per_epoch=100, num_epochs=3, learning_rate=1e-4,
         load_path='untracked/r.pt', save_path='untracked/r.pt')

In [None]:
run_test('fourier', 'msssim', 
         num_features=128, hidden_dim=20, scale=20, iter_per_epoch=100, num_epochs=30, learning_rate=1e-2)

In [None]:
run_test('fourier', 'msssim', 
         num_features=128, hidden_dim=256, scale=10, iter_per_epoch=100, num_epochs=30, learning_rate=1e-4)

In [None]:
run_test('fourier', 'mse', num_features=128, hidden_dim=256, scale=1, iter_per_epoch=100, num_epochs=30, learning_rate=1e-4)

In [None]:
run_test('fourier', 'mse msssim', num_features=200, hidden_dim=256, scale=10, iter_per_epoch=100, num_epochs=30, learning_rate=1e-2)

In [None]:
run_test('fourier', 'mse msssim', 
         num_features=128, hidden_dim=256, scale=1, iter_per_epoch=100, num_epochs=30, learning_rate=1e-4)

In [None]:
run_test('raster','mse',iter_per_epoch=1000,num_epochs=7)

In [None]:
run_test('raster','msssim mse',iter_per_epoch=1000,num_epochs=7)

In [None]:
run_test('raster','msssim',iter_per_epoch=1000,num_epochs=7)

In [None]:
run_test('raster','l1',iter_per_epoch=1000,num_epochs=7)

In [None]:
run_test('raster','l1 mse',iter_per_epoch=1000,num_epochs=7)

In [None]:
run_test('raster','mse',iter_per_epoch=1000,num_epochs=3, learning_rate=1e-1)

In [None]:
run_test('mlp','msssim',hidden_dim=1024,iter_per_epoch=100,num_epochs=4)

In [None]:
run_test('mlp',hidden_dim=1024,iter_per_epoch=100,num_epochs=4)

In [None]:
run_test('mlp','msssim mse',hidden_dim=1024,iter_per_epoch=100,num_epochs=4)

In [None]:
run_test('mlp',hidden_dim=256,iter_per_epoch=100,num_epochs=4)

In [None]:
run_test('mlp',num_features=128,hidden_dim=256,scale=1,iter_per_epoch=100,num_epochs=4)

In [None]:
run_test('fourier',num_features=128,hidden_dim=256,scale=1,iter_per_epoch=100,num_epochs=7)

In [None]:
run_test('fourier',num_features=2,hidden_dim=128,scale=1,iter_per_epoch=100,num_epochs=4)

In [None]:
run_test('fourier',num_features=256,hidden_dim=5,scale=1,iter_per_epoch=2000,num_epochs=5)

In [None]:
run_test('fourier',num_features=256,hidden_dim=256,scale=1,iter_per_epoch=100,num_epochs=5)

# Multi-Image Tests

In [None]:
def run_multi_image_test(method         = 'fourier',
                         criterion      = 'mse'    ,
                         image_names    = 'portal makeup fox magikarp uv snowflake',
                         num_features   = 128      ,
                         hidden_dim     = 256      ,
                         scale          = 10       ,
                         iter_per_epoch = 100      ,
                         learning_rate  = 1e-4     ,
                         num_epochs     = 4        ,
                         save_path      = None     ,
                         load_path      = None     ):
    
    print('Method:',method)
    
    assert method in 'fourier mlp raster'.split()
    assert set(criterion.split()) <= set('mse msssim l1'.split())

    image_names = image_names.split()
    num_images = len(image_names)
    images = [load_target_image(target_image_choices[image_name]) for image_name in image_names]
    target_image = np.concatenate(images, axis=2) #Concatenate all images into channels
    
    target = rp.as_torch_image(target_image).to(device)
    icecream.ic(num_images, target.shape)
    
    if method=='mlp':
        learnable_image=LearnableImageMLP(target_height,
                                          target_width,
                                          hidden_dim=hidden_dim,
                                          num_channels=num_images*3)
    elif method=='fourier':
        learnable_image=LearnableImageFourier(target_height,
                                              target_width,
                                              num_features=num_features,
                                              hidden_dim=hidden_dim,
                                              scale=scale,
                                              num_channels=num_images*3)
    elif method=='raster':
        learnable_image=LearnableImageRaster(target_height,
                                             target_width,
                                             num_channels=num_images*3)
        
    learnable_image.to(device)
    
    optimizer = torch.optim.Adam(learnable_image.parameters(), lr=learning_rate)
    
    output_images = []
    losses = []
            
    def get_current_multi_image():
        #Our learnable_image has 3*num_images channels
        #This function takes learnable_image() and returns an image as defined by rp.is_image,
        #that displays all of those images at once.
            
        multi_image = rp.as_numpy_image(learnable_image())
            
        tiles = []
        for i in range(num_images):
            tiles.append(multi_image[:,:,i*3:i*3+3])
        
        output = rp.tiled_images(tiles)
            
        assert rp.is_image(output)
        return output
            
    def display_current_image():
        image = get_current_multi_image()
        output_images.append(image)
        rp.display_image(image)
        return image
    
    if load_path:
        if rp.file_exists(load_path):
            state = torch.load(load_path)
            learnable_image.load_state_dict(state)

            print("Loaded state:")
            display_current_image()
            
        else:
            print(load_path, 'does not exist and cannot be loaded. Weights will be randomly initialized.')
    
    throw_error = False
    
    try:
        for iteration in tqdm(range(num_epochs*iter_per_epoch)):
            optimizer.zero_grad()

            generated = learnable_image()

            loss = 0
            if 'msssim' in criterion:
                loss += -msssim(target[None],generated[None],normalize=True)
                loss += 1 #Make the min loss 0, so we can do log plots of losses
            if 'mse'    in criterion:
                loss += torch.nn.functional.l1_loss(target, generated)
            if 'l1'     in criterion:
                loss += (generated-target).abs().mean()

            loss.backward()
            optimizer.step()

            losses.append(float(loss))

            if iteration % iter_per_epoch == 0:
                epoch = iteration // iter_per_epoch
                print('Epoch %d, loss = %.03f' % (epoch, float(loss)))
                image=display_current_image()
                
        output_images.append(image)
    except KeyboardInterrupt:
        throw_error = True #Don't continue to other cells after this one
        pass
            
    clear_output()
    icecream.ic(method, criterion, num_features, hidden_dim, scale, iter_per_epoch, learning_rate, num_epochs)
    rp.line_graph_via_bokeh(losses,xlabel='Iter',ylabel='Loss',logy=10,
                  title='multi '+method+' '+criterion+'\n'+' '.join(image_names))
    rp.display_image_slideshow(output_images)
    
    if save_path:
        state=learnable_image.state_dict()
        torch.save(state,save_path)
        print("Saving",save_path)
        
    if throw_error:
        raise KeyboardInterrupt

In [None]:
run_multi_image_test('fourier','mse msssim',
                    num_features=256,hidden_dim=256,scale=10,iter_per_epoch=100,num_epochs=20,learning_rate=1e-2)

In [None]:
run_multi_image_test('fourier','mse msssim',
                    num_features=128,hidden_dim=128,scale=10,iter_per_epoch=100,num_epochs=16,learning_rate=1e-2)

In [None]:
run_multi_image_test('fourier','mse msssim',
                    num_features=128,hidden_dim=128,scale=10,iter_per_epoch=1000,num_epochs=16,learning_rate=1e-2)

In [None]:
run_multi_image_test('fourier','mse msssim',
                    num_features=128,hidden_dim=128,scale=10,iter_per_epoch=1000,num_epochs=16,learning_rate=1e-2)

In [None]:
run_multi_image_test('fourier','mse msssim',
                    num_features=256,hidden_dim=256,scale=20,iter_per_epoch=1000,num_epochs=16,learning_rate=1e-2)

In [None]:
run_multi_image_test('fourier','mse msssim',
                    num_features=64,hidden_dim=128,scale=10,iter_per_epoch=100,num_epochs=3,learning_rate=1e-2)

# Conditional-Image Tests

In [None]:
def run_condi_image_test(method         = 'fourier',
                         criterion      = 'mse'    ,
                         image_names    = 'portal makeup fox magikarp uv snowflake',
                         num_features   = 128      ,
                         hidden_dim     = 256      ,
                         scale          = 10       ,
                         iter_per_epoch = 100      ,
                         learning_rate  = 1e-4     ,
                         num_epochs     = 4        ,
                         save_path      = None     ,
                         load_path      = None     ):
    
    print('Method:',method)
    
    assert method in 'fourier mlp'.split()
    assert set(criterion.split()) <= set('mse msssim l1'.split())

    image_names = image_names.split()
    num_images = len(image_names)
    images = [load_target_image(target_image_choices[image_name]) for image_name in image_names]
    target_image = np.concatenate(images, axis=2) #Concatenate all images into channels
    
    target = rp.as_torch_image(target_image).to(device)
    icecream.ic(num_images, target.shape)
    
    if method=='mlp':
        assert False, 'Not yet implemented in learnable_textures.py'
        learnable_image=LearnableImageMLP(target_height,
                                          target_width,
                                          hidden_dim=hidden_dim,
                                          num_channels=3)
    elif method=='fourier':
        learnable_image=LearnableImageFourier(target_height,
                                              target_width,
                                              num_features=num_features,
                                              hidden_dim=hidden_dim,
                                              scale=scale,
                                              num_channels=3)
        
    learnable_image.to(device)
    
    optimizer = torch.optim.Adam(learnable_image.parameters(), lr=learning_rate)
    
    output_images = []
    losses = []
    
    def get_learnable_images():
        images = []
        for i in range(num_images):
            condition = torch.zeros(num_images).to(device)
            condition[i]=1
            image = learnable_image(condition=condition)
            images.append(image)
        return torch.cat(images)
            
    def get_current_multi_image():
        #Our learnable_image has 3*num_images channels
        #This function takes learnable_image() and returns an image as defined by rp.is_image,
        #that displays all of those images at once.
            
        multi_image = rp.as_numpy_image(get_learnable_images())
            
        tiles = []
        for i in range(num_images):
            tiles.append(multi_image[:,:,i*3:i*3+3])
        
        output = rp.tiled_images(tiles)
            
        assert rp.is_image(output)
        return output
            
    def display_current_image():
        image = get_current_multi_image()
        output_images.append(image)
        rp.display_image(image)
        return image
    
    if load_path:
        if rp.file_exists(load_path):
            state = torch.load(load_path)
            learnable_image.load_state_dict(state)

            print("Loaded state:")
            display_current_image()
            
        else:
            print(load_path, 'does not exist and cannot be loaded. Weights will be randomly initialized.')
    
    throw_error = False
    
    try:
        for iteration in tqdm(range(num_epochs*iter_per_epoch)):
            optimizer.zero_grad()

            generated = get_learnable_images()

            loss = 0
            if 'msssim' in criterion:
                loss += -msssim(target[None],generated[None],normalize=True)
                loss += 1 #Make the min loss 0, so we can do log plots of losses
            if 'mse'    in criterion:
                loss += torch.nn.functional.l1_loss(target, generated)
            if 'l1'     in criterion:
                loss += (generated-target).abs().mean()

            loss.backward()
            optimizer.step()

            losses.append(float(loss))

            if iteration % iter_per_epoch == 0:
                epoch = iteration // iter_per_epoch
                print('Epoch %d, loss = %.03f' % (epoch, float(loss)))
                image=display_current_image()
                
        output_images.append(image)
    except KeyboardInterrupt:
        throw_error = True #Don't continue to other cells after this one
        pass
            
    clear_output()
    icecream.ic(method, criterion, num_features, hidden_dim, scale, iter_per_epoch, learning_rate, num_epochs)
    rp.line_graph_via_bokeh(losses,xlabel='Iter',ylabel='Loss',logy=10,
                  title='cond '+method+' '+criterion+'\n'+' '.join(image_names))
    rp.display_image_slideshow(output_images)
    
    if save_path:
        state=learnable_image.state_dict()
        torch.save(state,save_path)
        print("Saving",save_path)
        
    if throw_error:
        raise KeyboardInterrupt

In [None]:
run_condi_image_test('fourier','mse msssim',
                    num_features=256,hidden_dim=256,scale=10,iter_per_epoch=100,num_epochs=16,learning_rate=1e-2)

In [None]:
run_condi_image_test('fourier','mse msssim',
                    num_features=128,hidden_dim=128,scale=10,iter_per_epoch=1000,num_epochs=16,learning_rate=1e-2)

In [None]:
run_condi_image_test('fourier','mse msssim',
                    num_features=128,hidden_dim=128,scale=10,iter_per_epoch=1000,num_epochs=16,learning_rate=1e-2)

In [None]:
run_condi_image_test('fourier','mse msssim',
                    num_features=128,hidden_dim=128,scale=10,iter_per_epoch=100,num_epochs=16,learning_rate=1e-2)