# Exploring a latent representation for offline quality assessment for laser melting process
## Additive Manufacturing Use Case - D. Winant


In [1]:
# Import all the necessary modules

import scipy.misc
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import fixed, Layout
import numpy as np
import argparse

In [5]:
# Useful functions for the model

from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torch import from_numpy, split



def get_laserprofile_dataloader(args, path_to_data = r'C:\Users\dwinant\Documents\Projects\Additive Manufacturing\Data\AI_ini_cyls_input.npy'):
    """laserprofile dataloader (100, 120) images"""

    transform = transforms.Compose([transforms.ToTensor()])

    laserprofile_data = laserprofile(path_to_data, transform=transform)
    laserprofile_loader = DataLoader(laserprofile_data, batch_size=args.mb_size,
                                 shuffle=args.shuffle, pin_memory=True, num_workers=args.workers)
    _, c, x, y = next(iter(laserprofile_loader))[0].size()
    #c, x, y = next(iter(tuh_loader))[0].size()
    return laserprofile_loader, c*x*y, c


class laserprofile(Dataset):
    """laserprofile dataloader class"""
    
    lat_names = ('speed','power')
    
    def __init__(self, path_to_data, subsample=1, transform=None):
        """
        Parameters
        ----------
        subsample : int
            Only load every |subsample| number of images.
        """
        
        dataset = np.fromfile(path_to_data,  dtype=np.float32)
        dataset = np.array(dataset)
        # Grey scale images 
        dataset.shape = [10000,100,120]
        labels = np.fromfile(r'C:\Users\dwinant\Documents\Projects\Additive Manufacturing\Data\AI_ini_cyls_output.npy',
                             dtype=np.float32)
        labels = np.array(labels)
        # Respectively normalized speed and normalized power
        labels.shape=[10000,2]
        
        
        self.imgs = 1-dataset[::subsample]
        self.labels = labels[::subsample]
        self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        sample = self.imgs[idx] / 255
        if self.transform:
            sample = self.transform(sample)
        return sample.float(), self.labels[idx]

def convert_to_imshow_format(image):
    # convert from CHW to HWC
    if image.shape[0] == 1:
        return image[0, :, :]
    else:
        if np.any(np.where(image < 0)):
            # first convert back to [0,1] range from [-1,1] range
            image = image / 2 + 0.5
        return image.transpose(1, 2, 0)
    

In [16]:
def create_interface(h_min,h_max):
    '''Create the interface to interact with the latent space.
        
        Args: 
            h_min, h_max define the range of the sliders for exploring
            the hidden units (=principal components)
            
        Returns:
            The two axes for plotting the latent space,
            the position along each axis as given by the slider,
            which labels are visible on the data points
    
    '''
    # Make the figures large enough
    plt.rcParams['figure.figsize'] = [20, 10]   
    # Choose principal components for the axes
    pc_axis_one = widgets.Select(
        # Choose among the first 10 principal components
        options=np.arange(1,11),
        value=1,
        description='Principal Component on x-axis ',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        layout=Layout(width='500px'),
        style = {'description_width': 'initial'}
        )
    pc_axis_two = widgets.Select(
        options=np.arange(1,11),
        value=2,
        description='Principal Component on y-axis',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        layout=Layout(width='500px'),
        style = {'description_width': 'initial'}
        )
    # Sliders to explore the latent space
    x_pos = widgets.FloatSlider(
        value =0,
        min=h_min[0],
        max=h_max[0],
        step=0.05,
        description='x',
        disabled=False,
        continuous_update=True,
        orientation='horizontal',
        layout=Layout(width='600px'),
        readout=True,
        readout_format='.1f',
    )
    y_pos = widgets.FloatSlider(
        value =0,
        min=h_min[1],
        max=h_max[1],
        step=0.05,
        description='y',
        disabled=False,
        continuous_update=True,
        orientation='vertical',
        layout=Layout(height='400px'),
        readout=True,
        readout_format='.1f',
    )
    # Choose which labels to show the data
    labels_select = widgets.Select(
        options=['None', 'Speed', 'Power'],
        value='None',
        # rows=10,
        description='Labels:',
        disabled=False,
        style = {'description_width': 'initial'}
     )
    #return x_pos.value, y_pos.value
    return pc_axis_one, pc_axis_two, x_pos, y_pos, labels_select


def move_point(h_new, pc_axis_one,pc_axis_two,x_pos,y_pos):
    """ Move a point in the latent space
        
        Args: 
            h_new is the point to be moved
            pc_axis_one, pc_ axis_two choose which pc's of h_new to vary
            x_pos, y_pos determine the value to change the pc's to

        Returns:
            h_new is the moved point
    """
    h_new[0,pc_axis_one-1] = x_pos
    h_new[0,pc_axis_two-1] = y_pos
    return h_new

def generate_point(h_new,U,rkm):
    """ Map a point from the latent space back to the input space

    Args: 
        h_new is the hidden point to be moved to the latent space
        U is the interconnection matrix of the Generative RKM
        rkm is the trained Generative RKM model

    Returns:
        x_new is the point in the latent space"""
        
    WH = 100
    HH = 120
    x_new = rkm.decoder(torch.mm(h_new, U.t()).float()).detach().numpy().reshape(-1, WH, HH)
    return x_new

def create_plots(pc_axis_one,pc_axis_two,x_pos,y_pos,labels_select, labels ,h_new, h,U,rkm):
    """ Plots the latent space and the newly generated point

    Args: 
        pc_axis_one, pc_axis_two determine which principal components to show on the axes of the latent space
        x_pos, y_pos determine the value of the newly generated point in the latent space along the pc's
        labels_select chooses which type of labels to show
        labels are the labels of the hidden units
        h_new is the newly generated point
        h are the hidden units to be plotted in the latent space
        U is the interconnection matrix of the Generative RKM
        rkm is the trained Generative RKM model

    """
    
    # Generate new point
    h_new = move_point(h_new, pc_axis_one,pc_axis_two,x_pos,y_pos)
    x_new = generate_point(h_new,U,rkm)
    
    # Plot latent space
    fig = plt.figure()
    latent_ax = fig.add_subplot(121)
    latent_ax.set_title("Latent Space")

    if labels_select == 'Speed':
        scatter_plot = latent_ax.scatter(h[:, pc_axis_one-1].detach().numpy(), h[:, pc_axis_two-1].detach().numpy(), s=10,c=labels[:,0])
    elif labels_select == 'Power':
        scatter_plot = latent_ax.scatter(h[:, pc_axis_one-1].detach().numpy(), h[:, pc_axis_two-1].detach().numpy(), s=10,c=labels[:,1])
    else:
        scatter_plot = latent_ax.scatter(h[:, pc_axis_one-1].detach().numpy(), h[:, pc_axis_two-1].detach().numpy(), s=10)
            
    scatter_plot = latent_ax.scatter(h_new[:,pc_axis_one-1], h_new[:,pc_axis_two-1], s=20, c='red')

    # Plot generated point in the input space
    image_ax = fig.add_subplot(122)
    image_ax.set_title("Generated Point")
    image = image_ax.imshow(convert_to_imshow_format(1-x_new),cmap = 'Greys')
    
    plt.show()


def interactive_latent_space(path_to_model='AM_Trained_rkm_20211105-1058'):
    """ Create tool to interactively explore latent space

    Args:
        path_to_data links to the AM data
        path_to_model links to the trained Generative RKM 
        path_to_labels links to the labels of the AM data
        
    """
    
    # Load rkm model
    dataset_name = 'AM'
    sd_mdl = torch.load('out/{}/{}.tar'.format(dataset_name, path_to_model),
                    map_location=lambda storage, loc: storage)

    rkm = sd_mdl['rkm']
    rkm.load_state_dict(sd_mdl['rkm_state_dict'])
    h = sd_mdl['h']
    U = sd_mdl['U']
    opt = sd_mdl['opt']
    
    # Determine range of the hidden units for plotting the latent space
    h_min = torch.min(h,0).values
    h_max = torch.max(h,0).values
    
    # Load data
    #opt = argparse.Namespace(**vars(opt))#, **vars(opt_gen))
    #xtrain, _, nChannels = get_laserprofile_dataloader(args=opt)
    
    
    # Load labels
    cout = np.fromfile(r'C:\Users\dwinant\Documents\Projects\Additive Manufacturing\Data\AI_ini_cyls_output.npy',  dtype=np.float32)
    cout = np.array(cout)
    cout.shape=[10000,2]
    
    
    # 1 is the normalized power and 0 is normalized speed
    labels = cout

    # Load sliders for interface
    pc_axis_one, pc_axis_two, x_pos, y_pos, labels_select = create_interface(h_min=h_min,h_max=h_max)
    
    # Generate new point
    global h_new
    h_new = torch.zeros((1,64))
    
    # Make interface interactive
    out = widgets.interactive_output(create_plots, {'pc_axis_one': pc_axis_one,'pc_axis_two': pc_axis_two,
                                                     'x_pos' : x_pos, 'y_pos': y_pos, 'labels_select' : labels_select,
                                      'labels': fixed(labels),'h_new' : fixed(h_new), 'h':fixed(h), 'U':fixed(U),
                                                    'rkm' : fixed(rkm)})
    
    
    # Display interface
    explore_ui = widgets.VBox([widgets.HBox([y_pos,out]),x_pos])
    pc_ui = widgets.Box([pc_axis_one, pc_axis_two,labels_select])
    
    display(explore_ui)
    display(pc_ui)
    


In [17]:
interactive_latent_space(path_to_model='AM_Trained_rkm_20211105-1058')

VBox(children=(HBox(children=(FloatSlider(value=0.0, description='y', layout=Layout(height='400px'), max=0.658…

Box(children=(Select(description='Principal Component on x-axis ', layout=Layout(width='500px'), options=(1, 2…