# Models used for thesis visualizations

In [1]:
import torch
from torch.nn import functional as F

import torch.nn as nn
from torch import optim as optim

# wandb
import pprint

# misc
import numpy as np
from os.path import join
import time

# visualizations
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# modules
from src.main_utils import Configuration
from src.evaluation_utils import eval_epoch
from src.thirdHand_data_loader import get_min_max_from_dataset

ModuleNotFoundError: No module named 'torch'

---

# Model Basic works

In [3]:
model_config = Configuration()

Data loaded from 11 filse, stored in a dataframe with shape (248486, 10)
Dataframe headers are: ['px', 'py', 'pz', 'v1x', 'v1y', 'v1z', 'v2x', 'v2y', 'v2z', 'hand']
Data loaded from 2 filse, stored in a dataframe with shape (54402, 10)
Dataframe headers are: ['px', 'py', 'pz', 'v1x', 'v1y', 'v1z', 'v2x', 'v2y', 'v2z', 'hand']


In [4]:
print(model_config.data_item)

X_centered_scaled


### C-VAE Model

In [5]:
class Encoder(nn.Module):
    def __init__(self, device, first_filter_size, kernel_size, depth, dropout, latent_dim):
        super(Encoder, self).__init__()
        self.device= device
        self.first_filter_size= first_filter_size
        self.kernel_size= kernel_size
        self.encoder_padding = kernel_size//2 -1
        self.depth = depth 
        self.latent_dim = 2**latent_dim
        self.filter_number = [2**(i) for i in range(first_filter_size+1)]
        self.filter_number.reverse()
        
        self.filter_number = self.filter_number[:self.depth]
        self.last_encoder_filter_size = None
        
        self.dropout = dropout
        self.encoder_layers = self.make_encoder() 
        
        self.last_filter_size = self.filter_number[0]
        self.last_feature_size= (10-(depth*2+1))
        self.last_dim =  self.last_filter_size*self.last_feature_size

        self.flatten_layer = nn.Flatten().to(device)
        self.convert_to_latent = nn.Linear(self.last_dim, 2*self.latent_dim).to(device)
        
    def make_encoder(self):
        encoder_cnn_blocks = []
        
        for i in range(len(self.filter_number)):
            if i ==0:
                in_dim = 20
                out_dim = self.filter_number[i]   
            else:
                in_dim = self.filter_number[i-1]
                out_dim = self.filter_number[i]
                
            cnn_block_layers=[
                            nn.Conv1d(in_channels= in_dim, 
                                    out_channels= out_dim, 
                                    kernel_size= self.kernel_size, 
                                    padding= self.encoder_padding),
                            nn.BatchNorm1d(out_dim),
                            nn.ReLU(),
                            nn.Dropout(self.dropout),
                            ]
            
            cnn_block = nn.Sequential(*cnn_block_layers).to(self.device)
            
            encoder_cnn_blocks.append(cnn_block)
            self.last_encoder_filter_size = out_dim
            
        self.filter_number.reverse()
        
        return nn.ModuleList(encoder_cnn_blocks)
    
    
    def reparametrization(self, mean, log_var):
        """
        Samples from a normal distribution with a given set of
        means and log_vars
        """
        # epsilon is a vector of size (1, latent_dim)
        # it is samples from a Standard Normal distribution
        # mean = 0. and std = 1.
        epsilon = torch.normal(mean= 0, std= 1, size = log_var.shape).to(self.device) 

        # we need to convert log(var) into var:
        var = torch.exp(log_var*0.5)
        # epsilon = torch.randn_like(var)
        # now, we change the standard normal distributions to
        # a set of non standard normal distributions
        z = mean + epsilon*var
        return z
    
    def forward(self, x, y):
        for block in self.encoder_layers:
            x = block(x) 

        latent_ready = self.flatten_layer(x) 
        latent = self.convert_to_latent(latent_ready)

        mean = latent[:, : self.latent_dim]
        log_var = latent[:,self.latent_dim:]

        z = self.reparametrization(mean, log_var)

        return z, mean, log_var

In [6]:
class Decoder(nn.Module):
    def __init__(self, device, first_filter_size, kernel_size, depth, latent_dim, last_filter_size, last_feature_size):
        super(Decoder, self).__init__()
        
        self.device= device
        self.first_filter_size= first_filter_size
        self.kernel_size= kernel_size
        self.encoder_padding = kernel_size//2 -1
        self.depth = depth 
        self.latent_dim = 2**latent_dim
        self.last_filter_size= last_filter_size
        self.last_feature_size= last_feature_size
        
        
        
        self.filter_number = [2**(i) for i in range(first_filter_size+1)]        
        self.filter_number = self.filter_number[-self.depth:]
        self.filter_number.reverse()
        
        self.decoder_layers = self.make_decoder()
        self.z_to_decoder = nn.Linear(self.latent_dim,self.last_filter_size*self.last_feature_size).to(device)
       
    def make_decoder(self):
        decoder_cnn_blocks = []
        
        for i in range(len(self.filter_number)+1):
            self.decoder_padding = self.encoder_padding
            if i == 0:
                in_dim = self.last_filter_size
                out_dim = self.filter_number[i]
                
            elif i == len(self.filter_number):
                in_dim = self.filter_number[i-1]
                out_dim = 20  
                self.decoder_padding += 1
                
            else:
                in_dim = self.filter_number[i-1]
                out_dim = self.filter_number[i]
                        
            cnn_block_layers = [
                                nn.ConvTranspose1d(in_channels= in_dim, 
                                                    out_channels= out_dim, 
                                                    kernel_size= self.kernel_size, 
                                                    padding= self.decoder_padding,
                                                    ),
                                ]
            
            
            cnn_block = nn.Sequential(*cnn_block_layers).to(self.device)
            decoder_cnn_blocks.append(cnn_block)
            
        return nn.ModuleList(decoder_cnn_blocks)
    
    def forward(self, z, y):
        decoded = self.z_to_decoder(z).view(-1, self.last_filter_size, self.last_feature_size)
        
        for block in self.decoder_layers:
            decoded = block(decoded)
        
        return decoded 

In [7]:
class VAE_CNN(nn.Module):
    def __init__(self, device, first_filter_size, kernel_size, depth, dropout, latent_dim, rec_loss, reduction, kld_weight):
        super(VAE_CNN, self).__init__()

        self.encoder = Encoder(
                                device, 
                                first_filter_size, 
                                kernel_size, 
                                depth, 
                                dropout, 
                                latent_dim,
                                )

        self.decoder = Decoder(
                                device, 
                                first_filter_size, 
                                kernel_size, 
                                depth, 
                                latent_dim, 
                                self.encoder.last_filter_size,
                                self.encoder.last_feature_size,
                                )
        
        self.reduction = reduction
        self.kld_weight = kld_weight
        self.rec_loss = rec_loss
        
    def vae_loss_function(self, x, x_rec, log_var, mean):
        if self.rec_loss == "L1":
            train_rec_loss = F.l1_loss(x_rec, x, reduction=self.reduction)   
        else:
            train_rec_loss = F.mse_loss(x_rec, x, reduction=self.reduction)     
        train_kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mean**2 - log_var.exp(), dim = 1), dim = 0)

        train_loss = train_rec_loss  + train_kld_loss*self.kld_weight
        
        return train_loss, train_rec_loss, train_kld_loss*self.kld_weight  
    
    def forward(self, x, y):
        z, mean, log_var = self.encoder(x, y)
        x_rec = self.decoder(z, y)
        return x_rec, mean, log_var
        

In [8]:
def vae_loss_function(x, x_rec, log_var, mean, rec_loss, reduction, kld_weight):
    if rec_loss == "L1":
        train_rec_loss = F.l1_loss(x_rec, x, reduction=reduction)   
    else:
        train_rec_loss = F.mse_loss(x_rec, x, reduction=reduction)     
    train_kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mean**2 - log_var.exp(), dim = 1), dim = 0)

    train_loss = train_rec_loss  + train_kld_loss*kld_weight
    
    return train_loss, train_rec_loss, train_kld_loss*kld_weight

## Training functions

---
### Epoch training loop

In [9]:
def train_epoch_single(model, model_config, optimizer):
    model.train()
    epoch_loss = 0
    epoch_rec_loss = 0
    epoch_kld_loss = 0
 
    for data in model_config.train_iterator:
        x= data[model_config.data_item]
        y= data["Y"]
        
        optimizer.zero_grad()
        x_rec, mean, log_var = model(x,y)
        
        # calculating losses
        train_loss, train_rec_loss, train_kld_loss = vae_loss_function (x, 
                                                                        x_rec, 
                                                                        log_var, 
                                                                        mean, 
                                                                        rec_loss= model.rec_loss, 
                                                                        reduction= model.reduction,
                                                                        kld_weight= model.kld_weight)

        # updating the history
        epoch_rec_loss += train_rec_loss.item()
        epoch_kld_loss += train_kld_loss.item()
        epoch_loss += train_loss.item() 
        
        train_loss.backward()
        optimizer.step()
 
    counter = len(model_config.train_iterator)   
    results = [epoch_loss/counter,
                epoch_rec_loss/counter, 
                epoch_kld_loss/counter]      
    
    return results

### Main training loop

In [10]:
def train_model(model, epochs, batch_size):  
    start_time = time.time()
    optimizer = optim.Adam(model.parameters(), lr = 0.0001)
    
    
    model_config.batch_size = batch_size
    model_config.process_dataset_dataloaders()
    train_losses = []
    train_rec_losses = []
    train_kld_losses = []
    eval_losses = []
    for epoch in range(epochs+1):
        # train
        train_loss, train_rec_loss, train_kld_loss = train_epoch_single(
                                                                        model, 
                                                                        model_config,
                                                                        optimizer,
                                                                        )
        log_plot= False
        report_interval = 50
        if epoch % report_interval ==0:
            log_plot= True
            path_to_save_plot = "runs/progress/tmp_fig_{}.png".format(epoch)  
            print("<<Image {} saved>>".format(epoch)) 
              
        eval_loss, eval_rec_losses, eval_kld_loss = eval_epoch(
                                                                model, 
                                                                model_config, 
                                                                model_config.valid_iterator, 
                                                                loss_function= vae_loss_function,
                                                                rec_loss= model.rec_loss, 
                                                                reduction=model.reduction, 
                                                                kld_weight= model.kld_weight,
                                                                save_plot= log_plot,
                                                                path_to_save_plot= path_to_save_plot,
                                                                is_vae= True
                                                                )
        
        train_losses.append(train_loss)
        train_rec_losses.append(train_rec_loss)
        train_kld_losses.append(train_kld_loss)
        eval_losses.append(eval_loss)
        
        if epoch % report_interval ==0:
            print("{}:\tTotal: {:.5f}\tEval loss: {:.5f}\t Rec loss: {:.5f}\t KLD loss: {:.5f}\t time: {:.1f}s".format(epoch, 
                                                                          train_loss,
                                                                          eval_loss,
                                                                          train_rec_loss,
                                                                          train_kld_loss,
                                                                          time.time()-start_time))
    quick_plot(train_losses, train_rec_losses, train_kld_losses, eval_losses)
    return train_losses ,train_rec_losses,train_kld_losses, eval_losses

In [11]:
def quick_plot(train_losses=None, rec= None, KLD= None, eval_losses= None, log_scale= True):
    items_to_plot= []
    if train_losses:
        train = go.Scatter(x= np.arange(len(train_losses)), y=train_losses, name="training", mode='lines')
        items_to_plot.append(train)
    if eval_losses:
        valid = go.Scatter(x= np.arange(len(eval_losses)),  y=eval_losses, name="validation", mode='lines')
        items_to_plot.append(valid)
    
    if rec:
        rec = go.Scatter(x= np.arange(len(rec)), y=rec, name="reconstruction", mode='lines')
        items_to_plot.append(rec)
    
    if KLD:
        KLD = go.Scatter(x= np.arange(len(KLD)), y=KLD, name="KLD", mode='lines')
        items_to_plot.append(KLD)

    fig = go.Figure(items_to_plot)
    
    if log_scale:
        fig.update_yaxes(type="log")
        
    fig.update_layout(template =  "plotly_dark")
    fig.show()

___

### Testingthe parameters and workflow manually

This model is based on the optimization results:

In [12]:
# 3 or 5 both work!
kernel_size = 5

# between 10 and 5
first_filter_size = 9

# depth should be 2, 3, 4
depth = 2
dropout = 0.1
epochs = 300

batch_size = 128
latent_dim = 8

rec_loss= "L1"
reduction= "sum" 
kld_weight = 1e-1

model = VAE_CNN(model_config.device, 
                first_filter_size, 
                kernel_size, 
                depth, 
                dropout,
                latent_dim,
                rec_loss= rec_loss,
                reduction= reduction,
                kld_weight = kld_weight)

In [13]:
train_losses ,train_rec_losses,train_kld_losses, eval_losses = train_model (model, epochs, batch_size) 

<<Image 0 saved>>
0:	Total: 5876.16468	Eval loss: 7837.97803	 Rec loss: 5869.79834	 KLD loss: 6.36631	 time: 5.7s
<<Image 50 saved>>
50:	Total: 482.77175	Eval loss: 438.45502	 Rec loss: 425.87397	 KLD loss: 56.89778	 time: 13.1s
<<Image 100 saved>>
100:	Total: 385.92646	Eval loss: 314.13008	 Rec loss: 329.41664	 KLD loss: 56.50983	 time: 20.4s
<<Image 150 saved>>
150:	Total: 344.19507	Eval loss: 285.50156	 Rec loss: 288.86668	 KLD loss: 55.32839	 time: 27.7s
<<Image 200 saved>>
200:	Total: 315.45881	Eval loss: 270.03328	 Rec loss: 261.01551	 KLD loss: 54.44330	 time: 35.0s
<<Image 250 saved>>
250:	Total: 315.26925	Eval loss: 269.17783	 Rec loss: 262.03925	 KLD loss: 53.22999	 time: 42.4s
<<Image 300 saved>>
300:	Total: 287.91034	Eval loss: 255.60143	 Rec loss: 234.91006	 KLD loss: 53.00029	 time: 49.7s


In [296]:
torch.save(model, "./models/model_for_all_initial_visualizations.pt")

In [14]:
quick_plot(train_losses=train_losses, eval_losses=eval_losses, KLD=None, rec=None)

In [15]:
quick_plot(train_losses=train_losses, eval_losses=eval_losses, KLD=train_kld_losses, rec= train_rec_losses)

In [278]:
def plot_layout_style(fig, plot_title):
    fig.update_layout(height=800,
                        margin=dict(l=5, r=5, t=50, b=50),
                        template =  "plotly_dark",
                        title_text= plot_title,
                        font=dict(family="Roboto, monospace",
                                size=12,
                                color="white"
                                ),
                        scene=dict(
                                    aspectratio = dict( x=1, y=1, z=1 ),
                                    camera=dict(up=dict(x=0, y=0, z=1),
                                                eye=dict(x=-1.5, y=1.5, z=1.5)),
                                ),
                        showlegend= False,
                        coloraxis_showscale=False,)

    fig.update_xaxes(showticklabels=False, 
                        showgrid=False, 
                        zeroline=False,
                        zerolinewidth=1, 
                        zerolinecolor='gray',
                        fixedrange= True,)

    fig.update_yaxes(showticklabels=False, 
                        showgrid=False, 
                        zeroline=False,
                        zerolinewidth=1, 
                        zerolinecolor='gray',
                        fixedrange= True,
                        )
    
    return fig

In [290]:
def compare_motion_data_plots(motion_data, index = 0):
    all_plots = []
    
    color = np.arange(20)
    color_scales = ['viridis', 'Agsunset']
    names = ["Original", "Generated"]
    widths= [4, 1]
    
    original_styel = dict(
                        color=color,
                        width= widths[0],
                        colorscale= color_scales[0],
                        )
                        
    generated_styel = dict(
                        color=color,
                        width=widths[1],
                        colorscale= color_scales[1],
                        )
    
    for i, stroke_date in enumerate(motion_data):
        sample = stroke_date[index,:,:] 
        line_style = generated_styel
        
        if i == 0:
            line_style = original_styel
        
        plot = go.Scatter3d(
                            x= sample[:,0].detach().cpu().numpy(), 
                            y= sample[: ,1].detach().cpu().numpy(), 
                            z= sample[:, 2].detach().cpu().numpy(),  
                            mode= 'lines',
                            # name= names[i],
                            line= line_style,
                            )
        all_plots.append(plot)


    scaler_val = 0.1
    plot = go.Scatter3d(
                            x= [-scaler_val, scaler_val], 
                            y= [-scaler_val, scaler_val], 
                            z= [0, scaler_val], 
                            mode="markers",
                            marker=dict(size= 0.1)
                        )
    all_plots.append(plot)
    fig = go.Figure(data=all_plots)
    plot_title = "Compare Motion Plots"
    fig = plot_layout_style(fig, plot_title)
    fig.show()
    return fig

In [291]:
def test_scaling_method():
    sample_data= None
    sample_size = np.random.randint(0, 128, size=1)

    for d in model_config.train_iterator:
        original_centered_data = d["X_centered"]
        
        sample_centered_scaled = d[model_config.data_item]
        centered_max_val = d["centered_max_val"][0]
        centered_min_val = d["centered_min_val"][0]
        
        rescaled_data = (sample_centered_scaled*(centered_max_val-centered_min_val)+centered_min_val)
        compare_motion_data_plots([original_centered_data, rescaled_data], 0)
    
        break
    
test_scaling_method()    

In [292]:
def generation_range_visualization(motion_data):
    all_plots = []
    motion_cout = motion_data.shape[0]
    
    color = np.arange(20)
    color_scales = ['plotly3', 'teal', 'viridis', 'Agsunset']
    names = ["Original", "Generated"]
    widths= [4, 2]
    
    original_styel = dict(
                        color=color,
                        width= widths[0],
                        colorscale= color_scales[0],
                        )
                        
    generated_styel = dict(
                        color=color,
                        width=widths[1],
                        colorscale= color_scales[1],
                        )
    
    for i, sample in enumerate(motion_data):
        line_style = generated_styel
        if i <= motion_cout//2:
            shifter = (i-(motion_cout//2)-1)*0.05
        else:
            shifter = (i-(motion_cout//2))*0.05
        
        if i == 0:
            line_style = original_styel
            shifter= 0
            
        plot = go.Scatter3d(
                            x= sample[: ,0].detach().cpu().numpy()+shifter, 
                            y= sample[: ,1].detach().cpu().numpy(), 
                            z= sample[: ,2].detach().cpu().numpy(),  
                            mode= 'lines',# "lines+markers"
                            line= line_style,
                            )
        all_plots.append(plot)


    scaler_val = 0.1
    plot = go.Scatter3d(
                            x= [-scaler_val, scaler_val], 
                            y= [-scaler_val, scaler_val], 
                            z= [0, scaler_val], 
                            mode="markers",
                            marker=dict(size= 0.1)
                        )
    all_plots.append(plot)
    fig = go.Figure(data=all_plots)
    plot_title = "Compare Motion Plots"
    fig = plot_layout_style(fig, plot_title)
    fig.show()
    return fig

In [293]:
def test_generation_method():
    for j in range(1):
        sample_size = np.random.randint(0, 128, size=1)
        d_seed= np.random.randint(0, 10)
        for i, d in enumerate(model_config.train_iterator):
            if i ==d_seed:
                # picking a random sample
                x_samples = d[model_config.data_item][sample_size, :,:]
                y_sampels = d['Y'][sample_size, :,:] 
                print(torch.sum(y_sampels))
                # getting the latent vector z
                z, __, __ = model.encoder(x_samples, y_sampels)

                # adding noise to latent vector
                generation_size = 8
                noise = torch.normal(mean= .1, std= .2, size = (generation_size, 256)).to(model_config.device)
                z = z + noise
                # generating new motions based on the new z signals
                x_generated = model.decoder(z[:generation_size//2] , y_sampels[:generation_size//2])
                x_opposite_hand_generated = model.decoder(z[generation_size//2:] , y_sampels[generation_size//2:]*0) 
                
                # scaling to correct scale for the robot
                x_generated = scale_back(x_generated, d)
                
                x_opposite_hand_generated = scale_back(x_opposite_hand_generated, d)
                x_samples = scale_back(x_samples, d)     
            
                samples= torch.cat((x_samples, x_generated, x_opposite_hand_generated))   
                generation_range_visualization(samples)     
                 
                break
     
    
def scale_back(motion, d):
    centered_max_val = d["centered_max_val"][0]
    centered_min_val = d["centered_min_val"][0]
    rescaled_data = (motion*(centered_max_val-centered_min_val)+centered_min_val)
    return rescaled_data

test_generation_method()    

tensor(20., device='cuda:0')


tensor(0., device='cuda:0')


tensor(20., device='cuda:0')


tensor(0., device='cuda:0')


tensor(0., device='cuda:0')


In [283]:
def show_generated_motions(x_generated, plot_title):
    all_plots = []
    color = np.arange(20)
    for index in range(x_generated.shape[0]):
        sample_generated = x_generated[index,:,:] 
        
        x_rec_plot = go.Scatter3d(
                                x= sample_generated[:,0].detach().cpu().numpy(), 
                                y= sample_generated[: ,1].detach().cpu().numpy(), 
                                z= sample_generated[:, 2].detach().cpu().numpy(),  
                                mode='lines',
                                name="Generated",
                                line=dict(color=color,
                                        width=2,
                                        colorscale= 'Agsunset',),
                                )
        all_plots.append(x_rec_plot)

    fig = go.Figure(data=all_plots)
    fig = plot_layout_style(fig, plot_title)
    return fig
    fig.show()

In [284]:
sample_data= None
sample_size = np.random.randint(0, 128, size= 16)

for d in model_config.train_iterator:
    x_samples = d[model_config.data_item][sample_size, :,:]
    y_sampels = d['Y'][sample_size, :,:] 
    
    z, __, __ = model.encoder(x_samples, y_sampels)

    noise = torch.normal(mean=.1, std=.21, size = z.shape).to(model_config.device)
    z = z + noise
 
    x_generated = model.decoder(z , y_sampels)
    x_rec,__, __ = model(x_samples, y_sampels)
    
    fig_1 = show_generated_motions(x_samples, "Original")
    fig_1.show()
    fig_2 = show_generated_motions(x_rec, "Reconstration")
    fig_2.show()
    fig_3 = show_generated_motions(x_generated, "Generated")
    fig_3.show()
    
    break

In [285]:
counter = 0
for d in model_config.train_iterator:
    if counter > 6:
        X = d["X"][:3, :, :]
        print(X.shape)
        min_val, max_val = get_min_max_from_dataset(X)
            
        X_scaled = torch.zeros_like(X)
        X_centered = torch.zeros_like(X)
        X_centered_scaled = torch.zeros_like(X)

        # # finding the touching point of each motion and centering the motion on that
        center_points = torch.zeros_like(X[:, 0:1, :])        
        center_points[:, 0, :3] = X[:, 9, :3]
        X_centered =  X - center_points

        # # scaling data between 0 and 1
        X_scaled =(X - min_val) / (max_val - min_val)
        
        print (torch.sum(X -(X_scaled*(max_val - min_val)+min_val)) < 0.001)
        
        # # scaling the centered data between 0 and 1
        centered_min_val, centered_max_val = get_min_max_from_dataset(X_centered)
        
        x_samples = d[model_config.data_item]
        y_samples = d["Y"]
        
        x_rec, __, __ = model(x_samples, y_samples)
        
        centered_min_val= d["centered_min_val"][0]
        centered_max_val= d["centered_max_val"][0]
        
        X_centered_scaled = (X_centered - centered_min_val) / (centered_max_val - centered_min_val)
        X_back_to_orig = (X_centered_scaled*(centered_max_val - centered_min_val)+centered_min_val)+center_points
            

        fig_01 = compare_motion_data_plots([X, X_back_to_orig], 0) 
        break
    counter +=1

torch.Size([3, 20, 9])
tensor(True, device='cuda:0')


In [286]:
def stroke_visualizer(dataset, samples_to_display= 64, scaled= False, centered = False):
    """A function to show the samples on motions in a 3d plot

    Args:
        dataset (torch dataset): dataset of motions, each motion should be of size 
        samples_to_display (int, optional): [description]. Defaults to 64.
    """
    
    test_scatters = []
    
    random_indices = np.random.choice(len(dataset), samples_to_display, replace=False)
    
    for i, d in enumerate(dataset):
        if i in random_indices:
            plot_title ="{} Motoins".format(samples_to_display)
            
            if scaled or centered:
                if scaled and centered: 
                    plot_title = "{}, Scaled, Centered".format(plot_title)
                    X = d[4]
                elif scaled:
                    plot_title = "{}, Scaled".format(plot_title)
                    X = d[2]
                else:
                    plot_title = "{}, Centered".format(plot_title)
                    X = d[3]
            else:
                X = d[0]
                
            color = np.arange(20)
            
            # more info here: https://plotly.com/python/3d-line-plots/
            tmp_plot = go.Scatter3d(x= X[:,0].detach().cpu().numpy(), 
                                    y= X[: ,1].detach().cpu().numpy(), 
                                    z= X[:, 2].detach().cpu().numpy(),  
                                    mode='lines',
                                    line=dict(
                                                color=color,
                                                width=2,
                                                colorscale= 'Agsunset', #'Agsunset', #'GnBu', 'Plasma', 'Sunset','Bluered_r'
                                        ),
                                    )
            test_scatters.append(tmp_plot)
            
    fig = go.Figure(data=test_scatters)

    fig.update_layout(height=600,
                        margin=dict(l=5, r=5, t=50, b=5),
                        template =  "plotly_dark",
                        title_text=plot_title,
                        font=dict(family="Roboto, monospace",
                                size=12,
                                color="white"
                                ),
                        scene=dict(
                                    aspectratio = dict( x=1, y=1, z=1 ),
                                    # camera=dict(up=dict(x=0, y=0, z=1),eye=dict(x=-1, y=-1, z=1)),
                                ),
                        showlegend= False,
                        coloraxis_showscale=False,)

    fig.update_xaxes(showticklabels=False, 
                        showgrid=False, 
                        zeroline=False,
                        zerolinewidth=1, 
                        zerolinecolor='gray',
                        fixedrange= True,)

    fig.update_yaxes(showticklabels=False, 
                        showgrid=False, 
                        zeroline=False,
                        zerolinewidth=1, 
                        zerolinecolor='gray',
                        fixedrange= True,)
    
    fig.show()