In [1]:
import os
import sys
import torch

import plotly
import numpy as np
import plotly.graph_objs as go
from sklearn.decomposition import PCA

src_path = os.path.join('..', '..')

sys.path.append(src_path)
from src.utils import int_to_pitch 



In [2]:
def display_pca_scatterplot(embedding, user_input=None, labels=None, opacity=None, colors=None, fontsize=18, markersize=20, textposition='top left', cam=(0,0,0), dims=3, avg=False):
    emb_vectors = np.array([
        embedding[0](
            torch.Tensor([note]).long().cuda()).cpu().detach().numpy().squeeze()
                  for note in user_input
    ])

    pca = PCA(random_state=0).fit_transform(emb_vectors)[:,:dims]

    data = []
  
    if avg:
        xx = []
        yy = []

        for i in range(12):
            xxs = np.array([])
            yys = np.array([])
                                        
            for j in range(len(pca[:,0]) // 12):
                xxs = np.append(xxs, pca[:,0][12*j+i])
                yys = np.append(yys, pca[:,1][12*j+i]) 
            
            xx.append(xxs.mean())
            yy.append(yys.mean())
        
        xx = np.array(xx)
        xx = np.append(pca[:, 0], xx)

        yy = np.array(yy)
        yy = np.append(pca[:, 1], yy)
        
        if dims == 3:
            zz = []
            
            for i in range(12):
                zzs = np.array([])
                
                for j in range(len(pca[:,0]) // 12):
                    zzs = np.append(xxs, pca[:,2][12*j+i])
                
                zz.append(zzs.mean())
   
            zz = np.array(zz)
            zz = np.append(pca[:, 2], zz)
    else:
        xx = pca[:, 0]
        yy = pca[:, 1]
        
        if dims == 3:
            zz = pca[:, 2]
    
    marker = {
        'size': markersize,
        'opacity': opacity if dims == 2 else 0.8,
        'color': colors
    }
    
    if dims == 2:
        trace = go.Scatter(
            x = xx, 
            y = yy,  
            text = labels,
            textposition = textposition,
            textfont_size = fontsize,
            mode = 'markers+text',
            marker = marker
        )
    else:      
        trace = go.Scatter3d(
            x = xx, 
            y = yy, 
            z = zz,
            text = labels,
            textposition = textposition,
            textfont_size = fontsize,
            mode = 'markers+text',
            marker = marker
        )

    data.append(trace)

    # Configure the layout

    layout = go.Layout(
        margin = {'l': 0, 'r': 0, 'b': 0, 't': 0},
        showlegend=False,
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        legend=dict(
            x=1,
            y=0.5,
            font=dict(
                family="Courier New",
                size=20,
                color="black"
            )
        ),
        font = dict(
            family = "Courier New",
            size = 12
        ),
        autosize = False,
        width = 1000,
        height = 1000,
        scene_camera=dict(
            eye=dict(x=cam[0], y=cam[1], z=cam[2])
        )
    )

    plot_figure = go.Figure(data = data, layout = layout)
    plot_figure.show()

In [13]:
model_path = os.path.join(
    src_path,
    'mlruns',
#     '2/528338104be34c8b90cf934bc1cc75ea', 
#     '2/603b3504919e46b3980a9ddfcc110fcf',
#     '2/9d6f067d0c804a99af08c6e2a67bd688',
#     '5/8fcc0502375b4611861a8a532530f57c',
#     '2/9d6f067d0c804a99af08c6e2a67bd688', 
#     '2/e30ecc06a2854d58a731bed4821eff7b',
#     '5/0771fdebe48e4a27a7e227f90c5acea6',
#     '5/35a6d6d2fc364c7c89f121dca3d90a5a',
#     '5/8d5a71e46daa42788fbe1a701e882d14',
#     '5/6b32d9e292674cbb833655f161350381',
#     '5/72b4d266dbca40308624362723ce7f11',
#     '5/72b4d266dbca40308624362723ce7f11',
#     '6/59716f59ea91408d9c56b6b39004b160',
#      '5/10890460b0ea43fea7e57354d0835405',
     '10/f67d6b65ec6846b8971352b7367ab6d4',

    
    'artifacts',
#     '22_05_21_10_09_12_transpose_all_chord_extended_7_batchsize_64_seed_1234567890_early_stop.pt',
#     '22_05_22_00_52_58_transpose_all_chord_extended_7_batchsize_128_seed_1234567890.pt',
#     '22_05_20_13_49_37_transpose_all_chord_extended_7_batchsize_64_seed_1234567890.pt',
#     '22_05_28_20_53_27_transpose_all_chord_extended_7_batchsize_64_seed_1234567890.pt',
#     '22_05_20_13_49_37_transpose_all_chord_extended_7_batchsize_64_seed_1234567890.pt',
#     'transpose_all_chord_extended_7_batchsize_64_seed_1234567890_best_val.pt',
#     '22_05_30_14_10_32_transpose_all_chord_extended_7_batchsize_64_seed_1234567890_early_stop.pt',
#     '22_05_30_21_47_05_transpose_all_chord_extended_7_batchsize_64_seed_1234567890_early_stop.pt',
#     '22_05_30_22_52_17_transpose_all_chord_extended_7_batchsize_64_seed_1234567890.pt',
#     '22_05_31_13_54_18_transpose_all_chord_extended_7_batchsize_128_seed_1234567890.pt',
#     '22_06_01_23_14_12_transpose_all_chord_extended_7_batchsize_64_seed_1234567890_best_val.pt',
#     '22_06_01_23_14_12_transpose_all_chord_extended_7_batchsize_64_seed_1234567890_early_stop.pt',
#     '22_06_17_17_18_49_transpose_none_chord_extended_7_batchsize_128_seed_1234567890_best_val.pt'
#     '22_06_07_00_15_51_transpose_all_chord_extended_7_batchsize_64_seed_1234567890_best_val.pt',
    '22_06_20_12_03_43_transpose_all_chord_extended_7_batchsize_128_seed_1234567890_best_val.pt'
        
)

model = torch.load(open(model_path, 'rb'))

In [14]:
rng = range(48, 84)
octaves = len(rng) // 12

user_input = list(rng) + [128, 129]

opacity = list(np.repeat(1, 12*octaves)) + list(np.repeat(1, 2)) + list(np.repeat(0, 12))

# labels = ['C3' ,'Db3', 'D3', 'Eb3', 'E3', 'F3', 'F#3', 'G3', 'Ab3', 'A3', 'Bb3', 'B3'] +\
#          ['C4' ,'Db4', 'D4', 'Eb4', 'E4', 'F4', 'F#4', 'G4', 'Ab4', 'A4', 'Bb4', 'B4'] +\
#          ['C5' ,'Db5', 'D5', 'Eb5', 'E5', 'F5', 'F#5', 'G5', 'Ab5', 'A5', 'Bb5', 'B5'] +\
#          ['R', 'P'] + \
#          ['C' ,'Db', 'D', 'Eb', 'E', 'F', 'F#', 'G', 'Ab', 'A', 'Bb', 'B']

labels =  list(np.repeat("", 12*octaves)) + \
         [' R', 'P'] + \
         ['  C' ,'Db', '   D', 'Eb', 'E', ' F', '  F#', 'G', 'Ab', '  A', 'Bb', 'B']

colors = [(i % 12)*3 for i in range(12*octaves)] + ['grey', 'darkgrey'] + [(i % 12)*3 for i in range(12)]
# colors = user_input[:-2] + ['grey', 'darkgrey']


display_pca_scatterplot(model.pitch_encoder, user_input, labels, opacity, colors, 
                        fontsize=28, markersize=32, textposition='bottom right', cam=(0,0,0), dims=2, avg=True)

In [15]:
user_input = list(range(0, 48))

labels = list(np.repeat("", len(user_input))) + [f'{i%12}' for i in user_input]
opacity = list(np.repeat(1, 12*4))  + list(np.repeat(0, 12))
colors = [i%12 for i in user_input] + [i % 12 for i in range(12)]

display_pca_scatterplot(model.offset_encoder, user_input, labels, opacity, colors,
                        fontsize=32, markersize=32, textposition='bottom left', cam=(0.5,0,0), dims=2, avg=True)

In [16]:
durations = torch.Tensor(list(range(1, 75)))
idx = model.convert_durations_to_ids(durations).long().tolist()


labels = [f'{i}' if i <= 12 else "" for i in durations.long().tolist()]
offset = 0.8
colors = [i for i in idx] 

display_pca_scatterplot(model.duration_encoder, idx, labels, fontsize=32, textposition='top center', markersize=32, dims=2)

AttributeError: 'TimeStepFullModel' object has no attribute 'convert_durations_to_ids'

In [18]:
attacks = torch.Tensor(list(range(0, 4)))


labels = [f'{i}' for i in attacks]
colors = [i for i in attacks] 

display_pca_scatterplot(model.attack_encoder, attacks, labels, fontsize=32, textposition='top center', markersize=32, dims=2)