In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import altair as alt
from ipywidgets import widgets

import warnings
warnings.filterwarnings('ignore')

In [2]:
anot = pd.read_pickle("./data/task_2_annotations.pkl")
anot = anot.fillna(anot.median())
feature_cols = list(anot.columns)[3:][:-1]
anot.head()

Unnamed: 0,pianist_id,segment_id,annotator_id,arousal,valence,gems_wonder,gems_transcendence,gems_tenderness,gems_nostalgia,gems_peacefulness,gems_power,gems_joyful_activation,gems_tension,gems_sadness,gemmes_flow,gemmes_movement,gemmes_force,gemmes_interior,gemmes_wandering
0,1,0,91,1,-1,2,1,2,4,2.0,1,1,1,2,3,2,1,1.0,2
31,1,0,19,2,-1,3,3,3,4,4.0,1,2,3,3,3,2,2,3.0,3
62,1,0,189,2,0,2,1,2,1,4.0,2,2,1,1,3,2,1,1.0,4
93,1,0,126,2,2,4,5,2,3,5.0,2,4,1,3,5,1,2,2.0,5
124,1,0,26,4,2,3,5,2,3,3.0,1,3,4,1,4,1,2,3.0,1


In [3]:
def normalize(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    anota = df.annotator_id.unique()
    cols_to_norm = df.columns.difference(["pianist_id", "segment_id", "annotator_id"])
    
    for a_id in anota:
        sub_df = df[df.annotator_id == a_id]
        sub_df[cols_to_norm] = (sub_df[cols_to_norm] - sub_df[cols_to_norm].mean()) / sub_df[cols_to_norm].std()
        df[df.annotator_id == a_id] = sub_df
        
    return df

anot_n = normalize(anot)
anot_n.head()

Unnamed: 0,pianist_id,segment_id,annotator_id,arousal,valence,gems_wonder,gems_transcendence,gems_tenderness,gems_nostalgia,gems_peacefulness,gems_power,gems_joyful_activation,gems_tension,gems_sadness,gemmes_flow,gemmes_movement,gemmes_force,gemmes_interior,gemmes_wandering
0,1,0,91,-1.933003,-0.77961,0.12238,-0.566947,-0.047883,1.598264,0.047883,-0.562236,-0.746937,-1.209073,0.279576,0.48107,-0.408248,-0.663212,-0.763581,0.755929
31,1,0,19,-1.179377,-0.901388,-0.53161,-0.880271,0.139087,0.457905,1.502731,-1.353404,-0.769534,-0.203906,0.538682,-0.457905,-0.969628,-0.806226,0.150493,1.483596
62,1,0,189,-0.503236,-0.354235,-0.346144,-0.884282,-0.094491,-0.834523,1.271936,-0.203906,-0.312641,-0.887114,-0.346144,-0.206623,-1.021857,-0.963624,-0.931133,0.901388
93,1,0,126,-1.140175,1.168566,1.351324,2.234648,-0.086189,0.781602,2.12132,0.0,1.285422,-0.551807,1.770291,1.316561,-0.609449,-0.240074,-0.043577,1.437735
124,1,0,26,0.563918,1.461206,0.203906,1.316561,0.0,0.849837,-0.240074,-1.203942,1.012757,1.031327,-0.566947,0.938912,-0.999829,-0.524531,0.524531,-0.69215


In [4]:
def aggregate_all_feature(df: pd.DataFrame) -> pd.DataFrame:
    """
    Aggregates all anot-features from each sample with the same pianist-segment pairs 
    and returns this compressed dataframe.
    """
    df = df.copy()
    df = df.dropna()
    df = df.groupby(['pianist_id', 'segment_id']).agg(['median', np.var])
    df = df.drop(['annotator_id'], axis=1)
    df = df.reset_index()
    df.columns = ["_".join(a) if a[1] != "" else a[0] for a in df.columns.to_flat_index()]
    return df

anot_na = aggregate_all_feature(anot_n)
feature_cols_na = anot_na.columns[2:]
anot_na.head()

Unnamed: 0,pianist_id,segment_id,arousal_median,arousal_var,valence_median,valence_var,gems_wonder_median,gems_wonder_var,gems_transcendence_median,gems_transcendence_var,...,gemmes_flow_median,gemmes_flow_var,gemmes_movement_median,gemmes_movement_var,gemmes_force_median,gemmes_force_var,gemmes_interior_median,gemmes_interior_var,gemmes_wandering_median,gemmes_wandering_var
0,1,0,-0.240074,1.280437,-0.267261,0.74298,0.12238,0.636281,-0.077904,1.182969,...,0.86536,0.401977,-0.609449,0.491158,-0.524531,0.688379,0.215473,1.102626,1.233794,0.477967
1,1,1,0.961763,0.455852,0.114991,0.760549,0.629257,0.632377,0.705082,0.434863,...,0.485513,0.385192,1.076967,0.22294,0.179319,0.910188,-0.497658,0.3178,-0.244964,0.731559
2,1,2,0.0,0.582619,-0.071626,0.746793,-0.293766,0.562037,-0.310362,0.515974,...,0.062594,0.681252,1.292837,0.898373,0.658281,0.324061,-0.763581,0.336886,-0.566947,1.069631
3,1,3,0.331777,1.184887,-0.087434,0.505226,-0.68702,1.024582,-0.460591,0.668752,...,-0.873065,0.741477,0.619255,1.289282,0.472411,0.669784,-1.197013,0.770359,-0.69562,0.597558
4,1,4,0.0,0.737823,-0.467426,1.368457,-0.086189,1.442232,-0.18545,0.817383,...,0.393398,0.909097,-0.609449,0.457542,-0.240074,0.869531,0.424212,0.946506,0.0,0.917043


In [5]:
def max_label(df: pd.DataFrame) -> pd.DataFrame: 
    df = df.copy()
    
    f_df = df[feature_cols_na[4:22:2]]
    df = df.assign(label_gems = f_df.idxmax(axis=1), var_gems = f_df.idxmax(axis=1))
    df.var_gems = [df[col[:-6]+"var"][i] for i, col in enumerate(df.var_gems)]
    f_df = df[feature_cols_na[22::2]]
    df = df.assign(label_gemmes = f_df.idxmax(axis=1), var_gemmes = f_df.idxmax(axis=1))
    df.var_gemmes = [df[col[:-6]+"var"][i] for i, col in enumerate(df.var_gemmes)]
    
    return df

def show_plot(df: pd.DataFrame, category: str):
    
    selection = alt.selection_multi(fields=[f'{category}'])
    color = alt.condition(selection,alt.Color(f'{category}:N', legend=None),alt.value('lightgray'))
    opacity = alt.condition(selection, alt.value(1), alt.value(0))
    
    scatter = alt.Chart(df).mark_point(size=60).encode(
        x='valence_median:Q',
        y='arousal_median:Q',
        color=color,
        opacity=opacity,
        size='var_gems:Q'
        # tooltip=[]
    ).interactive()
    
    legend = alt.Chart(df).mark_point().encode(
    y=alt.Y(f'{category}:N', axis=alt.Axis(orient='right')),
    color=color
    ).add_selection(selection)
    
    line_y = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule().encode(y='y')
    line_x = alt.Chart(pd.DataFrame({'x': [0]})).mark_rule().encode(x='x')
    
    return line_y + line_x + scatter | legend

show_plot(df= max_label(anot_na), category= "label_gems")

In [6]:
@widgets.interact(x = feature_cols,
                  y = feature_cols
                 )
def show_plot(x, y):
    df = anot
    plot = alt.Chart(df).mark_circle().encode(
        x= alt.X(x+':O'),
        y= alt.Y(y+':O'),
        size='count():Q',
        color='stdev():Q'
        
    ).configure_axis(grid=True)
    return plot

interactive(children=(Dropdown(description='x', options=('arousal', 'valence', 'gems_wonder', 'gems_transcende…