# Sequence Learning - Direct - English - Testing Session - Phoneme Plots
20230921: Progression plot of single phonemes.  
Written after finding out the failure of model learning representation of phonemes but only remenbering occurrance 
orders. This is bad, since what we wanted to see was that the model could separate the different phonemes (at least phones).   
However, it remains a doubt whether the model has learned the conditional & nonconditional distribution of 
phonemes, despite not learning them separately. That is to say, we want to know whether the model has learned to put phonemes
that love to stay in very distinct phonological contexts into distinct places OR whether the model can only distinguish this 
when they are in the context. That is, ex. whether /ng/ will be projected to ONSET community if being fed alone or it will 
be projected to CODA community like when fed with words. 

In [61]:
import torch
import torchaudio
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import random

In [62]:
from padding import generate_mask_from_lengths_mat, mask_it
from paths import *
from my_utils import *
from loss import *
from model import SimplerPhxLearner
from dataset import SeqDatasetAnno, MelSpecTransform
from my_dataset import DS_Tools
from reshandler import AnnoEncoderResHandler
from misc_progress_bar import draw_progress_bar

## Preps

### Dirs

In [63]:
model_save_dir = model_eng_save_dir

log_path = phone_seg_anno_log_path
rec_path = phone_seg_anno_path

### Constants

In [64]:
EPOCHS = 10
BATCH_SIZE = 1

INPUT_DIM = 64
OUTPUT_DIM = 64

INTER_DIM_0 = 32
INTER_DIM_1 = 16
INTER_DIM_2 = 3

ENC_SIZE_LIST = [INPUT_DIM, INTER_DIM_0, INTER_DIM_1, INTER_DIM_2]
DEC_SIZE_LIST = [OUTPUT_DIM, INTER_DIM_0, INTER_DIM_1, INTER_DIM_2]

DROPOUT = 0.5

REC_SAMPLE_RATE = 16000
N_FFT = 400
N_MELS = 64

LOADER_WORKER = 16

## Model

### Model-related defs

In [65]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
recon_loss = nn.MSELoss(reduction='none')
masked_recon_loss = MaskedLoss(recon_loss)
model_loss = masked_recon_loss

model = SimplerPhxLearner(enc_size_list=ENC_SIZE_LIST, dec_size_list=DEC_SIZE_LIST, num_layers=2)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

### Load Model

In [66]:
load_ts = "0918192113"
stop_epoch = "299"

In [67]:
model_raw_name = "PT_{}_{}_full".format(load_ts, stop_epoch)
model_name = model_raw_name + ".pt"
model_path = os.path.join(model_save_dir, model_name)
state = torch.load(model_path)

model.load_state_dict(state)
model.to(device)

SimplerPhxLearner(
  (encoder): RLEncoder(
    (rnn): LSTM(64, 16, num_layers=2, batch_first=True)
    (lin_2): LinearPack(
      (linear): Linear(in_features=16, out_features=3, bias=True)
      (relu): Tanh()
    )
  )
  (decoder): RALDecoder(
    (rnn): LSTM(64, 3, num_layers=2, batch_first=True)
    (attention): ScaledDotProductAttention(
      (w_q): Linear(in_features=3, out_features=3, bias=True)
      (w_k): Linear(in_features=3, out_features=3, bias=True)
      (w_v): Linear(in_features=3, out_features=3, bias=True)
    )
    (lin_3): LinearPack(
      (linear): Linear(in_features=3, out_features=64, bias=True)
      (relu): Tanh()
    )
  )
)

### Dataset

- Note that due to the separate setting of word and phone datasets, we cannot really make it to select those that have not been trained on 
for this test. This is a point to further fix. Make reference to out first work. 

In [8]:
mytrans = MelSpecTransform(sample_rate=REC_SAMPLE_RATE, n_fft=N_FFT, n_mels=N_MELS)
ds = SeqDatasetAnno(rec_path, os.path.join(log_path, "log.csv"), transform=mytrans)

# valid_ds_indices = DS_Tools.read_indices(os.path.join(model_save_dir, "valid_ds_{}.pkl".format(load_ts)))

# valid_ds = torch.utils.data.Subset(ds, valid_ds_indices)

# this is to reduce the size of the dataset when the training power is not sufficient
small_len = int(0.05 * len(ds))
other_len = len(ds) - small_len

# # Randomly split the dataset into train and validation sets
valid_ds, other_ds = random_split(ds, [small_len, other_len])

valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=LOADER_WORKER, collate_fn=SeqDatasetAnno.collate_fn)
valid_num = len(valid_loader.dataset)

In [9]:
len(valid_loader)

41640

## Inference

In [10]:
def infer(): 
    model.eval()
    reshandler = AnnoEncoderResHandler(whole_res_dir=phone_plot_res_path, file_prefix=model_raw_name)
    all_res = np.empty((0, 3))
    all_token = []
    all_name = []

    log_token = []  # for making csv to map sound to filename
    log_name = []

    total = len(valid_loader)

    for idx, (x, x_lens, token, name) in enumerate(valid_loader): 
        token = token[0]
        name = name[0]

        x_mask = generate_mask_from_lengths_mat(x_lens, device=device)
        
        x = x.to(device)

        hid_r = model.encode(x, x_lens, x_mask)

        hid_r = hid_r.cpu().detach().numpy().squeeze()

        length = hid_r.shape[0]

        all_res = np.concatenate((all_res, hid_r), axis=0)
        all_token += [token] * length
        all_name += [name] * length
        log_token += [token]
        log_name += [name]

        if idx % 100 == 0: 
            draw_progress_bar(idx, total)
    

    reshandler.res = all_res
    reshandler.tok = all_token
    reshandler.name = all_name
    reshandler.save()

    # save guideline
    df = pd.DataFrame({'Name': log_name, 'Token': log_token})
    df.to_csv(os.path.join(phone_plot_res_path, model_raw_name + '_guide.csv'), index=False) 

In [11]:
if __name__ == "__main__": 
    infer()



## Plotter

In [68]:
def manyOuts2progFrame(manyOuts, groups): 
    df = pd.DataFrame(manyOuts, columns=["dim_0", "dim_1", "dim_2"])
    df["name"] = groups

    df = df.sort_values(by=["name"])
    # Group the DataFrame by the grouping column and assign timesteps within each group
    df['timestep'] = df.groupby("name").cumcount() + 1
    return df

In [69]:
def oneOut2ProgFrame(oneOut): 
    # oneOut is of tensor of shape (L, D)
    df = pd.DataFrame(oneOut, columns=["dim_0", "dim_1", "dim_2"])
    df["timestep"] = df.index
    df = df[["timestep", "dim_0", "dim_1", "dim_2"]]
    return df
def minmax(arr, a=-1, b=1): 
    min = arr.min()
    max = arr.max()
    return (b - a) * ((arr - min) / (max - min)) + a
def operate_on(arr): 
    # return minmax(arr)
    return arr
def framify(these_hids): 
    # these are token categories to be included
    # these hids are the corresponding hids
    # these numtags are the corresponding tags, named using indices in these
    # these_hids = st.zscore(these_hids, axis=0)
    df = pd.DataFrame(data=these_hids)
    # df = df.rename(columns={0: "dim_0", 1: "dim_1", 2: "dim_2"})
    df['dim_0_norm'] = operate_on(df['dim_0'])
    df['dim_1_norm'] = operate_on(df['dim_1'])
    df['dim_2_norm'] = operate_on(df['dim_2'])
    return df

In [70]:
def framify_group(these_hids, these_tags): 
    # these are token categories to be included
    # these hids are the corresponding hids
    # these numtags are the corresponding tags, named using indices in these
    # these_hids = st.zscore(these_hids, axis=0)
    df = pd.DataFrame(data=these_hids)
    df = df.rename(columns={0: "dim_0", 1: "dim_1", 2: "dim_2"})
    df['dim_0_norm'] = operate_on(df['dim_0'])
    df['dim_1_norm'] = operate_on(df['dim_1'])
    df['dim_2_norm'] = operate_on(df['dim_2'])

    df["Tag"] = these_tags
    return df
def plot3dGroup(X, y): 
    df = framify_group(X, y)
    config = {
    'toImageButtonOptions': {
        'format': 'png', # one of png, svg, jpeg, webp
        'filename': 'custom_image',
        'height': 1280,
        'width': 1280,
        'scale': 1 # Multiply title/legend/axis/canvas sizes by this factor
    }
    }
    fig = px.scatter_3d(df, x="dim_0_norm", y="dim_1_norm", z="dim_2_norm",
                color='Tag')
    fig.update_traces(marker=dict(size=2),
                    selector=dict(mode='markers'))
    fig.update_layout(
        scene = dict(
            xaxis = dict(nticks=8, range=[-1,1],),
                        yaxis = dict(nticks=8, range=[-1,1],),
                        zaxis = dict(nticks=8, range=[-1,1],),),)
    fig.update_layout(legend= {'itemsizing': 'constant'})
    fig.update_layout(legend_title_text='Class')
    fig.update_layout(
        legend=dict(
            x=0,
            y=1,
            title_font_family="Times New Roman",
            font=dict(
                family="Times New Roman",
                size=36,
                color="black"
            ),
            # bgcolor="LightSteelBlue",
            bordercolor="Black",
            borderwidth=1
        )
    )
    fig.update_layout(
        margin=dict(l=0, r=0, t=0, b=0),
    )
    camera = dict(
        eye=dict(x=0., y=0., z=2.5)
    )
    fig.update_layout(scene_camera=camera)
    html_plot = fig.to_html(full_html=False, config=config)
    # fig.show(config=config)
    return html_plot

In [71]:
def plot3dtrajectory(X): 
    config = {
    'toImageButtonOptions': {
        'format': 'png', # one of png, svg, jpeg, webp
        'filename': 'custom_image',
        'height': 1280,
        'width': 1280,
        'scale': 1 # Multiply title/legend/axis/canvas sizes by this factor
    }
    }

    fig = px.line_3d(framify(X), x="dim_0_norm", y="dim_1_norm", z="dim_2_norm", 
                     hover_data=["timestep"], markers=True)
    fig.update_traces(marker=dict(size=2, color="red"))
    fig.update_layout(
        scene = dict(
            xaxis = dict(nticks=8, range=[-1,1],),
                        yaxis = dict(nticks=8, range=[-1,1],),
                        zaxis = dict(nticks=8, range=[-1,1],),),)
    # fig.update_layout(legend= {'itemsizing': 'constant'})
    # fig.update_layout(legend_title_text='Phone')
    fig.update_layout(
        legend=dict(
            x=0,
            y=1,
            title_font_family="Times New Roman",
            font=dict(
                family="Times New Roman",
                size=36,
                color="black"
            ),
            # bgcolor="LightSteelBlue",
            bordercolor="Black",
            borderwidth=1
        )
    )
    fig.update_layout(
        margin=dict(l=0, r=0, t=0, b=0),
    )
    camera = dict(
        eye=dict(x=0., y=0., z=2.5)
    )
    fig.update_layout(scene_camera=camera)
    html_plot = fig.to_html(full_html=False, config=config)
    # fig.show(config=config)
    return html_plot

In [96]:
def plot3dtrajectoryAndGroup(X): 
    config = {
    'toImageButtonOptions': {
        'format': 'png', # one of png, svg, jpeg, webp
        'filename': 'custom_image',
        'height': 1280,
        'width': 1280,
        'scale': 1 # Multiply title/legend/axis/canvas sizes by this factor
    }
    }

    fig = px.line_3d(framify(X), x="dim_0_norm", y="dim_1_norm", z="dim_2_norm", 
                     hover_data=["timestep"], color="name", markers=True)
    fig.update_traces(marker=dict(size=2, color="red"))
    fig.update_layout(
        scene = dict(
            xaxis = dict(nticks=8, range=[-1,1],),
                        yaxis = dict(nticks=8, range=[-1,1],),
                        zaxis = dict(nticks=8, range=[-1,1],),),)
    # fig.update_layout(legend= {'itemsizing': 'constant'})
    # fig.update_layout(legend_title_text='Phone')
    fig.update_layout(
        legend=dict(
            x=0,
            y=1,
            title_font_family="Times New Roman",
            font=dict(
                family="Times New Roman",
                size=36,
                color="black"
            ),
            # bgcolor="LightSteelBlue",
            bordercolor="Black",
            borderwidth=1
        )
    )
    fig.update_layout(
        margin=dict(l=0, r=0, t=0, b=0),
    )
    camera = dict(
        eye=dict(x=0., y=0., z=2.5)
    )
    fig.update_layout(scene_camera=camera)
    html_plot = fig.to_html(full_html=False, config=config)
    # fig.show(config=config)
    return html_plot

In [72]:
def save_html_traj(htmlplot, save_name, token, model_serialnum=""): 
    save_html_path = os.path.join(phone_plot_path, "traj_{}-{}_{}.html".format(token, model_serialnum, save_name))
    with open(save_html_path, "w") as f: 
        f.write('<meta charset="UTF-8">')
        f.write("<h3>Rec: {}</h3>".format(save_name))
        f.write("<h3>Token: {}</h3>".format(token))
        f.write("<hr>")
        f.write(htmlplot)

In [73]:
def save_html_phone(htmlplot, save_name, token, model_serialnum=""): 
    save_html_path = os.path.join(phone_plot_path, "phone_{}_{}.html".format(save_name, model_serialnum))
    with open(save_html_path, "w") as f: 
        f.write('<meta charset="UTF-8">')
        f.write("<h3>Rec: {}</h3>".format(save_name))
        f.write("<h3>Token: {}</h3>".format(token))
        f.write("<hr>")
        f.write(htmlplot)

# Selective Plot

### Load Res

In [74]:
reshandler = AnnoEncoderResHandler(whole_res_dir=phone_plot_res_path, file_prefix=model_raw_name)
reshandler.read()

In [76]:
def select(data, guide, selector):
    # Ensure that the lengths of annotations and data_array match
    if len(guide) != data.shape[0]:
        raise ValueError("The length of guide must match the number of items in the data.")

    # Create a boolean mask for selected annotations
    mask = np.isin(guide, selector)

    # Use the mask to select the corresponding items from the data array
    selected_items = data[mask]

    return selected_items

### Plot by Phoneme

In [50]:
# get usable cluster groups
cluster_groups = ["aa", "uw"]

hidr_cs = select(data=reshandler.res, 
                 guide=reshandler.tok, 
                 selector=cluster_groups)
tags_cs = select(data=np.array(reshandler.tok), 
                 guide=reshandler.tok, 
                 selector=cluster_groups)

htmlplot = plot3dGroup(hidr_cs, tags_cs)
save_html_phone(htmlplot=htmlplot, save_name="{}".format("-".join(cluster_groups)), token=" ".join(cluster_groups), model_serialnum=model_raw_name)

### Plot by rec

In [36]:
guide_df = pd.read_csv(os.path.join(phone_plot_res_path, "{}_guide.csv".format(model_raw_name)))
total_len = len(guide_df["Name"].tolist())

all_name = guide_df["Name"]
all_token = guide_df["Token"]

In [60]:
randidx = random.randint(0, total_len)
save_name = all_name[randidx]
token = all_token[randidx]
hidr_cs = select(data=reshandler.res, 
                 guide=reshandler.name, 
                 selector=[save_name])
res_df = oneOut2ProgFrame(hidr_cs)
res_df = framify(res_df)
htmlplot = plot3dtrajectory(res_df)
save_html_traj(htmlplot, save_name, token=token, model_serialnum=model_raw_name)

### Plot by phoneme and rec

In [89]:
guide_df = pd.read_csv(os.path.join(phone_plot_res_path, "{}_guide.csv".format(model_raw_name)))
total_len = len(guide_df["Name"].tolist())

all_name = guide_df["Name"].to_numpy()
all_token = guide_df["Token"]

In [106]:
target_token = "aan"

names_by_token = select(data=all_name, 
                        guide=all_token, 
                        selector=[target_token])
names_by_token.shape

(8,)

In [107]:
subset_size = 8
random_subset_names_by_token = np.random.choice(names_by_token, size=subset_size, replace=False)

In [108]:
hidr_cs = select(data=reshandler.res, 
                 guide=reshandler.name, 
                 selector=random_subset_names_by_token)
tags_cs = select(data=np.array(reshandler.name), 
                 guide=reshandler.name, 
                 selector=random_subset_names_by_token)

res_df = manyOuts2progFrame(hidr_cs, tags_cs)

In [109]:
res_df = framify(res_df)
htmlplot = plot3dtrajectoryAndGroup(res_df)
save_html_traj(htmlplot, "rand-{}".format(subset_size), token=target_token, model_serialnum=model_raw_name)