# Sequence Learning - Word Training - English - Testing Session -Plots
In this session, we will look into the working status of our HMRNN-based AE and plot:   
- the progression plot (how hid_rs are progressing along the timeline)

In [38]:
import torch
import torchaudio
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from paths import *
from my_utils import *
from padding import generate_mask_from_lengths_mat, mask_it, masked_loss
import plotly.express as px

In [39]:
# from model import PhonLearn_Net
from model import PhonLearn_Net
# DirectPassModel, TwoRNNModel, TwoRNNAttn

### Dirs

In [40]:
model_save_dir = model_eng_save_dir
# random_data:phone_seg_random_path
# anno_data: phone_seg_anno_path

# random_log_path = phone_seg_random_log_path + "log.csv"
random_log_path = word_seg_anno_log_path
random_path = word_seg_anno_path
anno_log_path = phone_seg_anno_path

### Constants

In [41]:
class AnnoWordWholeDatasetPlot(Dataset):
    """
    A PyTorch dataset that loads cutted wave files from disk and returns input-output pairs for
    training autoencoder. 
    
    Version 3: wav -> mel
    """
    
    def __init__(self, load_dir, load_control_path, transform=None):
        """
        Initializes the class by reading a CSV file and merging the "rec" and "idx" columns.

        The function reads the CSV file from the provided control path, extracts the "rec" and "idx" columns,
        and concatenates the values from these columns using an underscore. It then appends the ".wav" extension
        to each of the merged strings and converts the merged pandas Series to a list, which is assigned to
        the 'dataset' attribute of the class.

        Args:
        load_dir (str): The directory containing the files to load.
        load_control_path (str): The path to the CSV file containing the "rec" and "idx" columns.

        Attributes:
        dataset (list): A list of merged strings from the "rec" and "idx" columns, with the ".wav" extension.
        """
        control_file = pd.read_csv(load_control_path)
        control_file = control_file[control_file['n_frames'] > 400]
        control_file = control_file[control_file['duration'] <= 2.0]
        
        # Extract the "rec" and "idx" columns
        rec_col = control_file['rec'].astype(str)
        idx_col = control_file['idx'].astype(str).str.zfill(8)

        # Extract the "token" and "produced_segments" columns
        token_col = control_file['token'].astype(str)
        produced_segments_col = control_file['produced_segments'].astype(str)
        
        # Merge the two columns by concatenating the strings with '_' and append extension name
        merged_col = rec_col + '_' + idx_col + ".wav"
        
        self.dataset = merged_col.tolist()
        self.infoset = produced_segments_col.tolist()
        self.load_dir = load_dir
        self.transform = transform
        self.info_rec_set = rec_col.tolist()
        self.info_idx_set = idx_col.tolist()
        self.info_token_set = token_col.tolist()
        
    
    def __len__(self):
        """
        Returns the length of the dataset.
        
        Returns:
            int: The number of input-output pairs in the dataset.
        """
        return len(self.dataset)
    
    def __getitem__(self, idx):
        """
        Returns a tuple (input_data, output_data) for the given index.

        The function first checks if the provided index is a tensor, and if so, converts it to a list.
        It then constructs the file path for the .wav file using the dataset attribute and the provided index.
        The .wav file is loaded using torchaudio, and its data is normalized. If a transform is provided,
        the data is transformed using the specified transform. Finally, the input_data and output_data are
        set to the same data (creating a tuple), and the tuple is returned.

        Args:
        idx (int or torch.Tensor): The index of the desired data.

        Returns:
        tuple: A tuple containing input_data and output_data, both of which are the audio data
               from the .wav file at the specified index.

        Note: 
        This function assumes that the class has the following attributes:
        - self.load_dir (str): The directory containing the .wav files.
        - self.dataset (list): A list of .wav file names.
        - self.transform (callable, optional): An optional transform to apply to the audio data.
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()
        wav_name = os.path.join(self.load_dir,
                                self.dataset[idx])
        
        data, sample_rate = torchaudio.load(wav_name, normalize=True)
        if self.transform:
            data = self.transform(data, sr=sample_rate)
        
        info = self.infoset[idx]
        # extra info for completing a csv
        info_rec = self.info_rec_set[idx]
        info_idx = self.info_idx_set[idx]
        info_token = self.info_token_set[idx]
        
        
        # # Prepare for possible in-out discrepencies in the future
        # input_data = data
        # output_data = data
        
        return data, info, info_rec, info_idx, info_token

def collate_fn(data):
    xx, yy, aa, bb, cc = zip(*data)
    # only working for one data at the moment
    batch_first = True
    x_lens = [len(x) for x in xx]
    xx_pad = pad_sequence(xx, batch_first=batch_first, padding_value=0)
    return xx_pad, x_lens, yy, aa, bb, cc


class MyTransform(nn.Module): 
    def __init__(self, sample_rate, n_fft): 
        super().__init__()
        # self.transform = torchaudio.transforms.MelSpectrogram(sample_rate, n_fft=n_fft, n_mels=64)
        # self.to_db = torchaudio.transforms.AmplitudeToDB()
        # self.transform = torchaudio.transforms.MFCC(n_mfcc=13)
    
    def forward(self, waveform, sr=16000): 
        # extract mfcc
        feature = torchaudio.compliance.kaldi.mfcc(waveform, sample_frequency=sr)

        # add deltas
        d1 = torchaudio.functional.compute_deltas(feature)
        d2 = torchaudio.functional.compute_deltas(d1)
        feature = torch.cat([feature, d1, d2], dim=-1)

        # Apply normalization (CMVN)
        eps = 1e-9
        mean = feature.mean(0, keepdim=True)
        std = feature.std(0, keepdim=True, unbiased=False)
        # print(feature.shape)
        # print(mean, std)
        feature = (feature - mean) / (std + eps)

        # mel_spec = self.transform(waveform)
        # # mel_spec = self.to_db(mel_spec)
        # mel_spec = mel_spec.squeeze()
        # mel_spec = mel_spec.permute(1, 0) # (F, L) -> (L, F)
        return feature

In [42]:
EPOCHS = 10
BATCH_SIZE = 1

# SEGMENTS_IN_CHUNK = 100  # set_size

# INPUT_DIM = 128
# OUTPUT_DIM = 128

INPUT_DIM = 39
OUTPUT_DIM = 13

INTER_DIM_0 = 16
INTER_DIM_1 = 8
INTER_DIM_2 = 3
INTER_DIM_3 = 3

SIZE_LIST = [INTER_DIM_1, INTER_DIM_2]

DROPOUT = 0.5

REC_SAMPLE_RATE = 16000
N_FFT = 400

LOADER_WORKER = 0

In [43]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
recon_loss = nn.MSELoss(reduction='none')
# model = TwoRNNAttn(1.0, SIZE_LIST, in_size=INPUT_DIM, 
#                       in2_size=INTER_DIM_0, hid_size=INTER_DIM_3, out_size=OUTPUT_DIM)
model = PhonLearn_Net(1.0, SIZE_LIST, in_size=INPUT_DIM, 
                      in2_size=INTER_DIM_0, hid_size=INTER_DIM_3, out_size=OUTPUT_DIM)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [44]:
model

PhonLearn_Net(
  (encoder): Encoder(
    (lin_1): LinearPack(
      (linear): Linear(in_features=39, out_features=16, bias=True)
      (relu): Tanh()
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (rnn): HM_LSTM(
      (cell_1): HM_LSTMCell()
      (cell_2): HM_LSTMCell()
    )
  )
  (decoder): Decoder(
    (lin_1): LinearPack(
      (linear): Linear(in_features=13, out_features=3, bias=True)
      (relu): Tanh()
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (rnn): LSTM(3, 16, batch_first=True)
    (attention): ScaledDotProductAttention(
      (w_q): Linear(in_features=16, out_features=16, bias=True)
      (w_k): Linear(in_features=3, out_features=16, bias=True)
      (w_v): Linear(in_features=3, out_features=16, bias=True)
    )
    (lin_2): LinearPack(
      (linear): Linear(in_features=16, out_features=13, bias=True)
      (relu): Tanh()
      (dropout): Dropout(p=0.5, inplace=False)
    )
  )
)

In [45]:
# READ = False
READ = True

In [46]:
if READ: 
    # valid_losses.read()
    # train_losses.read()

    # model_name = last_model_name
    model_raw_name = "PT_0623152604_35_full"
    model_name = model_raw_name + ".pt"
    model_path = os.path.join(model_save_dir, model_name)
    state = torch.load(model_path)
    model = PhonLearn_Net(1.0, SIZE_LIST, in_size=INPUT_DIM, 
                      in2_size=INTER_DIM_0, hid_size=INTER_DIM_3, out_size=OUTPUT_DIM)
    model.load_state_dict(state)
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

In [47]:
mytrans = MyTransform(sample_rate=REC_SAMPLE_RATE, n_fft=N_FFT)
ds = AnnoWordWholeDatasetPlot(random_path, os.path.join(random_log_path, "log.csv"), transform=mytrans)
# small_len = int(0.1 * len(ds))
# other_len = len(ds) - small_len

# # Randomly split the dataset into train and validation sets
# ds, other_ds = random_split(ds, [small_len, other_len])

train_len = int(0.9995 * len(ds))
valid_len = len(ds) - train_len

# Randomly split the dataset into train and validation sets
train_ds, valid_ds = random_split(ds, [train_len, valid_len])

# train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=LOADER_WORKER, collate_fn=collate_fn)
# train_num = len(train_loader.dataset)

valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=LOADER_WORKER, collate_fn=collate_fn)
valid_num = len(valid_loader.dataset)

In [48]:
len(valid_loader)

143

In [49]:
def oneOut2ProgFrame(oneOut): 
    # oneOut is of tensor of shape (L, D)
    df = pd.DataFrame(oneOut, columns=["dim_0", "dim_1", "dim_2"])
    df["timestep"] = df.index
    df = df[["timestep", "dim_0", "dim_1", "dim_2"]]
    return df
def minmax(arr, a=-1, b=1): 
    min = arr.min()
    max = arr.max()
    return (b - a) * ((arr - min) / (max - min)) + a
def operate_on(arr): 
    # return minmax(arr)
    return arr
def framify(these_hids): 
    # these are token categories to be included
    # these hids are the corresponding hids
    # these numtags are the corresponding tags, named using indices in these
    # these_hids = st.zscore(these_hids, axis=0)
    df = pd.DataFrame(data=these_hids)
    # df = df.rename(columns={0: "dim_0", 1: "dim_1", 2: "dim_2"})
    df['dim_0_norm'] = operate_on(df['dim_0'])
    df['dim_1_norm'] = operate_on(df['dim_1'])
    df['dim_2_norm'] = operate_on(df['dim_2'])
    return df

In [50]:
# info_rec, info_idx, info_token, info_produce_segs

In [51]:
# def plot3d(X): 
#     config = {
#     'toImageButtonOptions': {
#         'format': 'png', # one of png, svg, jpeg, webp
#         'filename': 'custom_image',
#         'height': 1280,
#         'width': 1280,
#         'scale': 1 # Multiply title/legend/axis/canvas sizes by this factor
#     }
#     }
#     fig = px.scatter_3d(framify(X), x="dim_0_norm", y="dim_1_norm", z="dim_2_norm", animation_frame="timestep")
#     fig.update_traces(marker=dict(size=2),
#                     selector=dict(mode='markers'))
#     fig.update_layout(
#         scene = dict(
#             xaxis = dict(nticks=8, range=[-1,1],),
#                         yaxis = dict(nticks=8, range=[-1,1],),
#                         zaxis = dict(nticks=8, range=[-1,1],),),)
#     fig.update_layout(legend= {'itemsizing': 'constant'})
#     # fig.update_layout(legend_title_text='Phone')
#     fig.update_layout(
#         legend=dict(
#             x=0,
#             y=1,
#             title_font_family="Times New Roman",
#             font=dict(
#                 family="Times New Roman",
#                 size=36,
#                 color="black"
#             ),
#             # bgcolor="LightSteelBlue",
#             bordercolor="Black",
#             borderwidth=1
#         )
#     )
#     fig.update_layout(
#         margin=dict(l=0, r=0, t=0, b=0),
#     )
#     camera = dict(
#         eye=dict(x=0., y=0., z=2.5)
#     )
#     fig.update_layout(scene_camera=camera)
#     html_plot = fig.to_html(full_html=False, config=config)
#     # fig.show(config=config)
#     return html_plot

In [53]:
def plot3dtrajectory(X): 
    config = {
    'toImageButtonOptions': {
        'format': 'png', # one of png, svg, jpeg, webp
        'filename': 'custom_image',
        'height': 1280,
        'width': 1280,
        'scale': 1 # Multiply title/legend/axis/canvas sizes by this factor
    }
    }

    fig = px.line_3d(framify(X), x="dim_0_norm", y="dim_1_norm", z="dim_2_norm")
    # fig.update_traces(marker=dict(size=5),
    #                 selector=dict(mode='markers'))
    fig.update_layout(
        scene = dict(
            xaxis = dict(nticks=8, range=[-1,1],),
                        yaxis = dict(nticks=8, range=[-1,1],),
                        zaxis = dict(nticks=8, range=[-1,1],),),)
    # fig.update_layout(legend= {'itemsizing': 'constant'})
    # fig.update_layout(legend_title_text='Phone')
    fig.update_layout(
        legend=dict(
            x=0,
            y=1,
            title_font_family="Times New Roman",
            font=dict(
                family="Times New Roman",
                size=36,
                color="black"
            ),
            # bgcolor="LightSteelBlue",
            bordercolor="Black",
            borderwidth=1
        )
    )
    fig.update_layout(
        margin=dict(l=0, r=0, t=0, b=0),
    )
    camera = dict(
        eye=dict(x=0., y=0., z=2.5)
    )
    fig.update_layout(scene_camera=camera)
    html_plot = fig.to_html(full_html=False, config=config)
    # fig.show(config=config)
    return html_plot

In [54]:
def save_html(htmlplot, info_rec, info_idx, info_token, info_produce_segs, model_serialnum=""): 
    save_html_path = os.path.join(word_plot_path, "{}_{}_{}.html".format(model_serialnum, info_rec, info_idx).zfill(8))
    with open(save_html_path, "w") as f: 
        f.write('<meta charset="UTF-8">')
        f.write("<h3>Rec: {}</h3>".format("{}_{}".format(info_rec, info_idx).zfill(8)))
        f.write("<h3>Token: {}</h3>".format(info_token))
        f.write("<h3>Produced Segments: {}</h3>".format(info_produce_segs))
        f.write("<hr>")
        f.write(htmlplot)

In [55]:
def infer(model_num=""): 
    model.eval()
    for idx, (x, x_lens, info, info_rec, info_idx, info_token) in enumerate(valid_loader):
        info = info[0]
        info_rec = info_rec[0]
        info_idx = info_idx[0]
        info_token = info_token[0]

        x_mask = generate_mask_from_lengths_mat(x_lens, device=device)
        
        x = x.to(device)

        hid_r, z_1, z_2 = model.encode(x, x_mask)

        hid_r = hid_r.cpu().detach().numpy()
        res_df = oneOut2ProgFrame(hid_r[0]) # one in batch
        res_df = framify(res_df)
        htmlplot = plot3dtrajectory(res_df)
        save_html(htmlplot, info_rec, info_idx, info_token, info, model_serialnum=model_raw_name)
        print(idx)
        

In [56]:
if __name__ == "__main__": 
    infer()

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101


KeyboardInterrupt: 