# Sequence Learning - Direct - English - Testing Session -Plots
In this session, we will look into the working status of our LSTM AE and plot:   
- the progression plot (how hid_rs are progressing along the timeline)

In [14]:
import torch
import torchaudio
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px

In [11]:
from padding import generate_mask_from_lengths_mat, mask_it
from paths import *
from my_utils import *
from loss import *
from model import SimplerPhxLearner
from dataset import SeqDatasetInfo, MelSpecTransform
from my_dataset import DS_Tools
from reshandler import EncoderResHandler
from misc_progress_bar import draw_progress_bar

### Dirs

In [3]:
model_save_dir = model_eng_save_dir

random_log_path = word_seg_anno_log_path
random_path = word_seg_anno_path
anno_log_path = phone_seg_anno_path

### Constants

In [4]:
EPOCHS = 10
BATCH_SIZE = 1

INPUT_DIM = 64
OUTPUT_DIM = 64

INTER_DIM_0 = 32
INTER_DIM_1 = 16
INTER_DIM_2 = 8

ENC_SIZE_LIST = [INPUT_DIM, INTER_DIM_0, INTER_DIM_1, INTER_DIM_2]
DEC_SIZE_LIST = [OUTPUT_DIM, INTER_DIM_0, INTER_DIM_1, INTER_DIM_2]

DROPOUT = 0.5

REC_SAMPLE_RATE = 16000
N_FFT = 400
N_MELS = 64

LOADER_WORKER = 16

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
recon_loss = nn.MSELoss(reduction='none')
masked_recon_loss = MaskedLoss(recon_loss)
model_loss = masked_recon_loss

model = SimplerPhxLearner(enc_size_list=ENC_SIZE_LIST, dec_size_list=DEC_SIZE_LIST, num_layers=2)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [6]:
load_ts = "0926172946"
stop_epoch = "194"

In [7]:
model_raw_name = "PT_{}_{}_full".format(load_ts, stop_epoch)
model_name = model_raw_name + ".pt"
model_path = os.path.join(model_save_dir, model_name)
state = torch.load(model_path)

model.load_state_dict(state)
model.to(device)

SimplerPhxLearner(
  (encoder): RLEncoder(
    (rnn): LSTM(64, 16, num_layers=2, batch_first=True)
    (lin_2): LinearPack(
      (linear): Linear(in_features=16, out_features=8, bias=True)
      (relu): Tanh()
    )
  )
  (decoder): RALDecoder(
    (rnn): LSTM(64, 8, num_layers=2, batch_first=True)
    (attention): ScaledDotProductAttention(
      (w_q): Linear(in_features=8, out_features=8, bias=True)
      (w_k): Linear(in_features=8, out_features=8, bias=True)
      (w_v): Linear(in_features=8, out_features=8, bias=True)
    )
    (lin_3): LinearPack(
      (linear): Linear(in_features=8, out_features=64, bias=True)
      (relu): Tanh()
    )
  )
)

In [8]:
mytrans = MelSpecTransform(sample_rate=REC_SAMPLE_RATE, n_fft=N_FFT, n_mels=N_MELS)
ds = SeqDatasetInfo(random_path, os.path.join(random_log_path, "log.csv"), transform=mytrans)

valid_ds_indices = DS_Tools.read_indices(os.path.join(model_save_dir, "valid_ds_{}.pkl".format(load_ts)))

valid_ds = torch.utils.data.Subset(ds, valid_ds_indices)

# this is to reduce the size of the dataset when the training power is not sufficient
small_len = int(0.05 * len(valid_ds))
other_len = len(valid_ds) - small_len

# # Randomly split the dataset into train and validation sets
valid_ds, other_ds = random_split(valid_ds, [small_len, other_len])

valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=LOADER_WORKER, collate_fn=SeqDatasetInfo.collate_fn)
valid_num = len(valid_loader.dataset)

In [9]:
len(valid_loader)

2840

In [12]:
def infer(model_num=""): 
    model.eval()
    reshandler = EncoderResHandler(data_dir=word_plot_res_path, info_dir=word_plot_info_path)
    total = len(valid_loader)

    for idx, (x, x_lens, info, info_rec, info_idx, info_token) in enumerate(valid_loader):
        info = info[0]
        info_rec = info_rec[0]
        info_idx = info_idx[0]
        info_token = info_token[0]
        reshandler.file_prefix = "{}_{}_{}_{}".format(model_raw_name, info_rec, info_idx, info_token).zfill(8)

        x_mask = generate_mask_from_lengths_mat(x_lens, device=device)
        
        x = x.to(device)

        hid_r = model.encode(x, x_lens, x_mask)

        hid_r = hid_r.cpu().detach().numpy()
        # feed in the data and info to be saved
        reshandler.data = hid_r[0]
        reshandler.info = (info_token, info)

        reshandler.save()
        if idx % 10 == 0: 
            draw_progress_bar(idx, total)
        

In [13]:
if __name__ == "__main__": 
    infer()



# Thinkings
We want to test:   
1. the degree of returning （回头点） along the frames
2. cluster from the points: question is whether to provide number of phonemes as number of clusters or not 
3. measure the transitional distances (between transition points) and in-cluster distances (between points in one cluster)