# Sequence Learning - Direct - English - Testing Session -Plots
In this session, we will look into the working status of our LSTM AE and plot:   
- the progression plot (how hid_rs are progressing along the timeline)

In [1]:
import torch
import torchaudio
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px

In [2]:
from padding import generate_mask_from_lengths_mat, mask_it
from paths import *
from my_utils import *
from loss import *
from model import SimplerPhxLearner
from seqinfo_dataset import SeqDatasetInfo, collate_fn, MyTransform
from my_dataset import DS_Tools
from reshandler import EncoderResHandler

### Dirs

In [3]:
model_save_dir = model_eng_save_dir

random_log_path = word_seg_anno_log_path
random_path = word_seg_anno_path
anno_log_path = phone_seg_anno_path

### Constants

In [4]:
EPOCHS = 10
BATCH_SIZE = 1

INPUT_DIM = 39
OUTPUT_DIM = 13

INTER_DIM_0 = 64
INTER_DIM_1 = 16
INTER_DIM_2 = 3
# INTER_DIM_3 = 3

ENC_SIZE_LIST = [INPUT_DIM, INTER_DIM_0, INTER_DIM_1, INTER_DIM_2]
DEC_SIZE_LIST = [OUTPUT_DIM, INTER_DIM_0, INTER_DIM_1, INTER_DIM_2]

DROPOUT = 0.5

REC_SAMPLE_RATE = 16000
N_FFT = 400

LOADER_WORKER = 0

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
recon_loss = nn.MSELoss(reduction='none')
masked_recon_loss = MaskedLoss(recon_loss)
model_loss = masked_recon_loss

model = SimplerPhxLearner(enc_size_list=ENC_SIZE_LIST, dec_size_list=DEC_SIZE_LIST, num_layers=2)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [6]:
load_ts = "0908015948"
stop_epoch = "248"

In [7]:
model_raw_name = "PT_{}_{}_full".format(load_ts, stop_epoch)
model_name = model_raw_name + ".pt"
model_path = os.path.join(model_save_dir, model_name)
state = torch.load(model_path)

model.load_state_dict(state)
model.to(device)

SimplerPhxLearner(
  (encoder): RLEncoder(
    (rnn): LSTM(39, 16, num_layers=2, batch_first=True)
    (lin_2): LinearPack(
      (linear): Linear(in_features=16, out_features=3, bias=True)
      (relu): Tanh()
    )
  )
  (decoder): RALDecoder(
    (rnn): LSTM(13, 3, num_layers=2, batch_first=True)
    (attention): ScaledDotProductAttention(
      (w_q): Linear(in_features=3, out_features=3, bias=True)
      (w_k): Linear(in_features=3, out_features=3, bias=True)
      (w_v): Linear(in_features=3, out_features=3, bias=True)
    )
    (lin_3): LinearPack(
      (linear): Linear(in_features=3, out_features=13, bias=True)
      (relu): Tanh()
    )
  )
)

In [8]:
model

SimplerPhxLearner(
  (encoder): RLEncoder(
    (rnn): LSTM(39, 16, num_layers=2, batch_first=True)
    (lin_2): LinearPack(
      (linear): Linear(in_features=16, out_features=3, bias=True)
      (relu): Tanh()
    )
  )
  (decoder): RALDecoder(
    (rnn): LSTM(13, 3, num_layers=2, batch_first=True)
    (attention): ScaledDotProductAttention(
      (w_q): Linear(in_features=3, out_features=3, bias=True)
      (w_k): Linear(in_features=3, out_features=3, bias=True)
      (w_v): Linear(in_features=3, out_features=3, bias=True)
    )
    (lin_3): LinearPack(
      (linear): Linear(in_features=3, out_features=13, bias=True)
      (relu): Tanh()
    )
  )
)

In [9]:
mytrans = MyTransform(sample_rate=REC_SAMPLE_RATE, n_fft=N_FFT)
ds = SeqDatasetInfo(random_path, os.path.join(random_log_path, "log.csv"), transform=mytrans)

valid_ds_indices = DS_Tools.read_indices(os.path.join(model_save_dir, "valid_ds_{}.pkl".format(load_ts)))

valid_ds = torch.utils.data.Subset(ds, valid_ds_indices)

# # this is to reduce the size of the dataset when the training power is not sufficient
# small_len = int(0.002 * len(valid_ds))
# other_len = len(valid_ds) - small_len

# # # Randomly split the dataset into train and validation sets
# valid_ds, other_ds = random_split(valid_ds, [small_len, other_len])

valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=LOADER_WORKER, collate_fn=collate_fn)
valid_num = len(valid_loader.dataset)

In [10]:
len(valid_loader)

113

In [11]:
def oneOut2ProgFrame(oneOut): 
    # oneOut is of tensor of shape (L, D)
    df = pd.DataFrame(oneOut, columns=["dim_0", "dim_1", "dim_2"])
    df["timestep"] = df.index
    df = df[["timestep", "dim_0", "dim_1", "dim_2"]]
    return df
def minmax(arr, a=-1, b=1): 
    min = arr.min()
    max = arr.max()
    return (b - a) * ((arr - min) / (max - min)) + a
def operate_on(arr): 
    # return minmax(arr)
    return arr
def framify(these_hids): 
    # these are token categories to be included
    # these hids are the corresponding hids
    # these numtags are the corresponding tags, named using indices in these
    # these_hids = st.zscore(these_hids, axis=0)
    df = pd.DataFrame(data=these_hids)
    # df = df.rename(columns={0: "dim_0", 1: "dim_1", 2: "dim_2"})
    df['dim_0_norm'] = operate_on(df['dim_0'])
    df['dim_1_norm'] = operate_on(df['dim_1'])
    df['dim_2_norm'] = operate_on(df['dim_2'])
    return df

In [12]:
def plot3dtrajectory(X): 
    config = {
    'toImageButtonOptions': {
        'format': 'png', # one of png, svg, jpeg, webp
        'filename': 'custom_image',
        'height': 1280,
        'width': 1280,
        'scale': 1 # Multiply title/legend/axis/canvas sizes by this factor
    }
    }

    fig = px.line_3d(framify(X), x="dim_0_norm", y="dim_1_norm", z="dim_2_norm", 
                     hover_data=["timestep"], markers=True)
    fig.update_traces(marker=dict(size=2, color="red"))
    fig.update_layout(
        scene = dict(
            xaxis = dict(nticks=8, range=[-1,1],),
                        yaxis = dict(nticks=8, range=[-1,1],),
                        zaxis = dict(nticks=8, range=[-1,1],),),)
    # fig.update_layout(legend= {'itemsizing': 'constant'})
    # fig.update_layout(legend_title_text='Phone')
    fig.update_layout(
        legend=dict(
            x=0,
            y=1,
            title_font_family="Times New Roman",
            font=dict(
                family="Times New Roman",
                size=36,
                color="black"
            ),
            # bgcolor="LightSteelBlue",
            bordercolor="Black",
            borderwidth=1
        )
    )
    fig.update_layout(
        margin=dict(l=0, r=0, t=0, b=0),
    )
    camera = dict(
        eye=dict(x=0., y=0., z=2.5)
    )
    fig.update_layout(scene_camera=camera)
    html_plot = fig.to_html(full_html=False, config=config)
    # fig.show(config=config)
    return html_plot

In [13]:
def save_html(htmlplot, info_rec, info_idx, info_token, info_produce_segs, model_serialnum=""): 
    save_html_path = os.path.join(word_plot_path, "{}_{}_{}_{}.html".format(model_serialnum, info_rec, info_idx, info_token).zfill(8))
    with open(save_html_path, "w") as f: 
        f.write('<meta charset="UTF-8">')
        f.write("<h3>Rec: {}</h3>".format("{}_{}".format(info_rec, info_idx).zfill(8)))
        f.write("<h3>Token: {}</h3>".format(info_token))
        f.write("<h3>Produced Segments: {}</h3>".format(info_produce_segs))
        f.write("<hr>")
        f.write(htmlplot)

In [14]:
def infer(model_num=""): 
    model.eval()
    reshandler = EncoderResHandler(data_dir=word_plot_res_path, info_dir=word_plot_info_path)

    for idx, (x, x_lens, info, info_rec, info_idx, info_token) in enumerate(valid_loader):
        info = info[0]
        info_rec = info_rec[0]
        info_idx = info_idx[0]
        info_token = info_token[0]
        reshandler.file_prefix = "{}_{}_{}_{}".format(model_raw_name, info_rec, info_idx, info_token).zfill(8)

        x_mask = generate_mask_from_lengths_mat(x_lens, device=device)
        
        x = x.to(device)

        hid_r = model.encode(x, x_lens, x_mask)

        hid_r = hid_r.cpu().detach().numpy()
        # feed in the data and info to be saved
        reshandler.data = hid_r[0]
        reshandler.info = (info_token, info)

        res_df = oneOut2ProgFrame(hid_r[0]) # one in batch
        res_df = framify(res_df)
        htmlplot = plot3dtrajectory(res_df)
        save_html(htmlplot, info_rec, info_idx, info_token, info, model_serialnum=model_raw_name)
        reshandler.save()
        print(idx)
        

In [15]:
if __name__ == "__main__": 
    infer()

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112


# Thinkings
We want to test:   
1. the degree of returning （回头点） along the frames
2. cluster from the points: question is whether to provide number of phonemes as number of clusters or not 
3. measure the transitional distances (between transition points) and in-cluster distances (between points in one cluster)