In [1]:
import os
import sys
import torch

import plotly
import numpy as np
import plotly.graph_objs as go
from sklearn.decomposition import PCA

sys.path.append('..')
from src.utils import int_to_pitch 

src_path = os.path.join('..')

In [2]:
model_path = os.path.join(
    src_path,
    'mlruns', '2',
    '528338104be34c8b90cf934bc1cc75ea', 'artifacts',
    '22_05_21_10_09_12_transpose_all_chord_extended_7_batchsize_64_seed_1234567890_early_stop.pt'
)

model = torch.load(open(model_path, 'rb'))

In [81]:
def display_pca_scatterplot(embedding, user_input=None, labels=None, opacity=None, colors=None, fontsize=18, markersize=20, cam=(0,0,0), dims=3, avg=False):
    emb_vectors = np.array([
        embedding(
            torch.Tensor([note]).long().cuda()).cpu().detach().numpy().squeeze()
                  for note in user_input
    ])

    pca = PCA(random_state=0).fit_transform(emb_vectors)[:,:dims]

    data = []
  
    if avg:
        xx = []
        yy = []

        for i in range(12):
            xx.append((pca[:,0][0+i] + pca[:,0][12+i] + pca[:,0][24+i]) / 3)
            yy.append((pca[:,1][0+i] + pca[:,1][12+i] + pca[:,1][24+i]) / 3)
        
        xx = np.array(xx)
        xx = np.append(pca[:, 0], xx)

        yy = np.array(yy)
        yy = np.append(pca[:, 1], yy)
        
        if dims == 3:
            zz = []
            
            for i in range(12):
                zz.append((pca[:,2][0+i] + pca[:,2][12+i] + pca[:,2][24+i]) / 3)

            zz = np.array(zz)
            zz = np.append(pca[:, 2], zz)
    else:
        xx = pca[:, 0]
        yy = pca[:, 1]
        
        if dims == 3:
            zz = pca[:, 2]
    
    marker = {
        'size': markersize,
        'opacity': opacity if dims == 2 else 0.8,
        'color': colors
    }
    
    if dims == 2:
        trace = go.Scatter(
            x = xx, 
            y = yy,  
            text = labels,
            textposition = "top center",
            textfont_size = fontsize,
            mode = 'markers+text',
            marker = marker
        )
    else:      
        trace = go.Scatter3d(
            x = xx, 
            y = yy, 
            z = zz,
            text = labels,
            textposition = "top center",
            textfont_size = fontsize,
            mode = 'markers+text',
            marker = marker
        )

    data.append(trace)

    # Configure the layout

    layout = go.Layout(
        margin = {'l': 0, 'r': 0, 'b': 0, 't': 0},
        showlegend=False,
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        legend=dict(
            x=1,
            y=0.5,
            font=dict(
                family="Courier New",
                size=20,
                color="black"
            )
        ),
        font = dict(
            family = "Courier New",
            size = 12
        ),
        autosize = False,
        width = 1000,
        height = 1000,
        scene_camera=dict(
            eye=dict(x=cam[0], y=cam[1], z=cam[2])
        )
    )

    plot_figure = go.Figure(data = data, layout = layout)
    plot_figure.show()

In [159]:
user_input = list(range(60, 84)) + [128, 129]

opacity = list(np.repeat(0.2, 12*2)) + list(np.repeat(1, 12))

labels = list(np.repeat("", 12*2)) + \
        ['R', 'P'] + \
        ['C' ,'Db', 'D', 'Eb', 'E', 'F', 'F#', 'G', 'Ab', 'A', 'Bb', 'B']

colors = [(i % 12)*3 for i in range(12*2)] + ['grey', 'darkgrey'] + [(i % 12)*3 for i in range(12)]

display_pca_scatterplot(model.pitch_encoder, user_input, labels, opacity, colors, 
                        fontsize=32, markersize=32, cam=(0,0,0), dims=3, avg=False)

In [150]:
user_input = list(range(48))

labels = [f'{i}' for i in user_input]

opacity = 0.8
colors = [i%12 for i in user_input] 

display_pca_scatterplot(model.offset_encoder, user_input, labels, opacity, colors, 
                        fontsize=12, markersize=14, cam=(0.5,0,0), dims=2, avg=False)

In [97]:
durations = torch.Tensor(list(range(1, 75)))
idx = model.convert_durations_to_ids(durations).long().tolist()


labels = [f'{i}' for i in durations.long().tolist()]
offset = 0.8
colors = [i for i in idx] 

display_pca_scatterplot(model.duration_encoder, idx, labels, markersize=12, dims=2)