# Sequence Learning - Word Training - English - Testing Session
In this session, we will look into the working status of our HMRNN-based AE and test:   
    1. whether the segmentation is okayish  
    2. by plotting the progression of each sub-phoneme segment and try to grab some info from it  

In [3]:
import torch
import torchaudio
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence, pack_sequence
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np
# import csv
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.metrics import homogeneity_completeness_v_measure
import pickle
from paths import *
from my_utils import *
from recorder import *
from padding import generate_mask_from_lengths_mat, mask_it, masked_loss
# import pytz
from datetime import datetime
# import random
# import gc

In [4]:
from model import PhonLearn_Net

### Dirs

In [5]:
model_save_dir = model_eng_save_dir
# random_data:phone_seg_random_path
# anno_data: phone_seg_anno_path

# random_log_path = phone_seg_random_log_path + "log.csv"
random_log_path = word_seg_anno_log_path
random_path = word_seg_anno_path
anno_log_path = phone_seg_anno_path

### Constants

In [6]:
# 规范用语；规定两种方式：全加载；按rec加载（舍弃了按chunk加载，处理起来更简单）
# RandomPhoneDataset; AnnoPhoneDataset; AnnoSeqDataset

In [7]:
class AnnoWordWholeDataset(Dataset):
    """
    A PyTorch dataset that loads cutted wave files from disk and returns input-output pairs for
    training autoencoder. 
    
    Version 3: wav -> mel
    """
    
    def __init__(self, load_dir, load_control_path, transform=None):
        """
        Initializes the class by reading a CSV file and merging the "rec" and "idx" columns.

        The function reads the CSV file from the provided control path, extracts the "rec" and "idx" columns,
        and concatenates the values from these columns using an underscore. It then appends the ".wav" extension
        to each of the merged strings and converts the merged pandas Series to a list, which is assigned to
        the 'dataset' attribute of the class.

        Args:
        load_dir (str): The directory containing the files to load.
        load_control_path (str): The path to the CSV file containing the "rec" and "idx" columns.

        Attributes:
        dataset (list): A list of merged strings from the "rec" and "idx" columns, with the ".wav" extension.
        """
        control_file = pd.read_csv(load_control_path)
        control_file = control_file[control_file['n_frames'] > 400]
        control_file = control_file[control_file['duration'] <= 2.0]
        
        # Extract the "rec" and "idx" columns
        rec_col = control_file['rec'].astype(str)
        idx_col = control_file['idx'].astype(str).str.zfill(8)
        
        # Merge the two columns by concatenating the strings with '_' and append extension name
        merged_col = rec_col + '_' + idx_col + ".wav"
        
        self.dataset = merged_col.tolist()
        self.load_dir = load_dir
        self.transform = transform
        
    
    def __len__(self):
        """
        Returns the length of the dataset.
        
        Returns:
            int: The number of input-output pairs in the dataset.
        """
        return len(self.dataset)
    
    def __getitem__(self, idx):
        """
        Returns a tuple (input_data, output_data) for the given index.

        The function first checks if the provided index is a tensor, and if so, converts it to a list.
        It then constructs the file path for the .wav file using the dataset attribute and the provided index.
        The .wav file is loaded using torchaudio, and its data is normalized. If a transform is provided,
        the data is transformed using the specified transform. Finally, the input_data and output_data are
        set to the same data (creating a tuple), and the tuple is returned.

        Args:
        idx (int or torch.Tensor): The index of the desired data.

        Returns:
        tuple: A tuple containing input_data and output_data, both of which are the audio data
               from the .wav file at the specified index.

        Note: 
        This function assumes that the class has the following attributes:
        - self.load_dir (str): The directory containing the .wav files.
        - self.dataset (list): A list of .wav file names.
        - self.transform (callable, optional): An optional transform to apply to the audio data.
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()
        wav_name = os.path.join(self.load_dir,
                                self.dataset[idx])
        
        data, sample_rate = torchaudio.load(wav_name, normalize=True)
        if self.transform:
            data = self.transform(data, sr=sample_rate)
        
        # # Prepare for possible in-out discrepencies in the future
        # input_data = data
        # output_data = data
        
        return data

def collate_fn(xx):
    # only working for one data at the moment
    batch_first = True
    x_lens = [len(x) for x in xx]
    xx_pad = pad_sequence(xx, batch_first=batch_first, padding_value=0)
    return xx_pad, x_lens


class MyTransform(nn.Module): 
    def __init__(self, sample_rate, n_fft): 
        super().__init__()
        # self.transform = torchaudio.transforms.MelSpectrogram(sample_rate, n_fft=n_fft, n_mels=64)
        # self.to_db = torchaudio.transforms.AmplitudeToDB()
        # self.transform = torchaudio.transforms.MFCC(n_mfcc=13)
    
    def forward(self, waveform, sr=16000): 
        # extract mfcc
        feature = torchaudio.compliance.kaldi.mfcc(waveform, sample_frequency=sr)

        # add deltas
        d1 = torchaudio.functional.compute_deltas(feature)
        d2 = torchaudio.functional.compute_deltas(d1)
        feature = torch.cat([feature, d1, d2], dim=-1)

        # Apply normalization (CMVN)
        eps = 1e-9
        mean = feature.mean(0, keepdim=True)
        std = feature.std(0, keepdim=True, unbiased=False)
        # print(feature.shape)
        # print(mean, std)
        feature = (feature - mean) / (std + eps)

        # mel_spec = self.transform(waveform)
        # # mel_spec = self.to_db(mel_spec)
        # mel_spec = mel_spec.squeeze()
        # mel_spec = mel_spec.permute(1, 0) # (F, L) -> (L, F)
        return feature

In [8]:
EPOCHS = 10
BATCH_SIZE = 128

# SEGMENTS_IN_CHUNK = 100  # set_size

# INPUT_DIM = 128
# OUTPUT_DIM = 128

INPUT_DIM = 39
OUTPUT_DIM = 13

INTER_DIM_0 = 32
INTER_DIM_1 = 16
INTER_DIM_2 = 8
INTER_DIM_3 = 3

SIZE_LIST = [INTER_DIM_1, INTER_DIM_2]

DROPOUT = 0.5

REC_SAMPLE_RATE = 16000
N_FFT = 400

LOADER_WORKER = 8

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
recon_loss = nn.MSELoss(reduction='none')
# model = TwoRNNAttn(1.0, SIZE_LIST, in_size=INPUT_DIM, 
#                       in2_size=INTER_DIM_0, hid_size=INTER_DIM_3, out_size=OUTPUT_DIM)
model = PhonLearn_Net(1.0, SIZE_LIST, in_size=INPUT_DIM, 
                      in2_size=INTER_DIM_0, hid_size=INTER_DIM_3, out_size=OUTPUT_DIM)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [10]:
model

PhonLearn_Net(
  (encoder): Encoder(
    (lin_1): LinearPack(
      (linear): Linear(in_features=39, out_features=32, bias=True)
      (relu): Tanh()
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (rnn): HM_LSTM(
      (cell_1): HM_LSTMCell()
      (cell_2): HM_LSTMCell()
    )
    (lin_2): LinearPack(
      (linear): Linear(in_features=8, out_features=3, bias=True)
      (relu): Tanh()
      (dropout): Dropout(p=0.5, inplace=False)
    )
  )
  (decoder): Decoder(
    (lin_1): LinearPack(
      (linear): Linear(in_features=13, out_features=8, bias=True)
      (relu): Tanh()
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (rnn): LSTM(8, 32, batch_first=True)
    (attention): ScaledDotProductAttention(
      (w_q): Linear(in_features=32, out_features=32, bias=True)
      (w_k): Linear(in_features=3, out_features=32, bias=True)
      (w_v): Linear(in_features=3, out_features=32, bias=True)
    )
    (lin_2): LinearPack(
      (linear): Linear(in_features=32, out_feature

In [11]:
device

device(type='cuda')

In [12]:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])

In [13]:
params

13066

In [14]:
# Just for keeping records of training hists. 
ts = str(get_timestamp())
# ts = "0620172459"
save_txt_name = "train_txt_{}.hst".format(ts)
save_trainhist_name = "train_hist_{}.hst".format(ts)
save_valhist_name = "val_hist_{}.hst".format(ts)

In [15]:
valid_losses = LossRecorder(model_save_dir + save_valhist_name)
train_losses = LossRecorder(model_save_dir + save_trainhist_name)
text_hist = HistRecorder(model_save_dir + save_txt_name)

In [16]:
# READ = False
READ = True

In [17]:
if READ: 
    # valid_losses.read()
    # train_losses.read()

    # model_name = last_model_name
    model_name = "PT_0620172459_25_full.pt"
    model_path = os.path.join(model_save_dir, model_name)
    state = torch.load(model_path)
    model = PhonLearn_Net(1.0, SIZE_LIST, in_size=INPUT_DIM, 
                      in2_size=INTER_DIM_0, hid_size=INTER_DIM_3, out_size=OUTPUT_DIM)
    model.load_state_dict(state)
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

In [18]:
mytrans = MyTransform(sample_rate=REC_SAMPLE_RATE, n_fft=N_FFT)

In [19]:
test_wave, sr = torchaudio.load("s2201b_00000935.wav")

In [20]:
testMFCC = mytrans(test_wave)

In [21]:
testMFCC = testMFCC.unsqueeze(0)

In [22]:
mask = torch.ones((testMFCC.size(0), testMFCC.size(1)))

In [23]:
testMFCC = testMFCC.to(device)
mask = mask.to(device)

In [24]:
testout, z1, z2, h_2 = model.encode(testMFCC, mask)

In [25]:
z1.sum()

tensor(28., device='cuda:0')

In [29]:
torch.sum(z1, dim=(1, 2)).shape

torch.Size([1])

In [24]:
h_2

tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000],
         [-0.1940, -0.0320, -0.0912, -0.1086,  0.1649, -0.1969, -0.0430,
           0.0038],
         [-0.1940, -0.0320, -0.0912, -0.1086,  0.1649, -0.1969, -0.0430,
           0.0038],
         [-0.2573, -0.0402, -0.1558, -0.1350,  0.1660, -0.2730, -0.1038,
           0.1004],
         [-0.2788, -0.0502, -0.1857, -0.1453,  0.1439, -0.3296, -0.1363,
           0.1360],
         [-0.2678, -0.0521, -0.1794, -0.1274,  0.1039, -0.3238, -0.1822,
           0.1629],
         [-0.2048, -0.0855, -0.0401, -0.1142,  0.0841, -0.2590, -0.1721,
           0.2236],
         [-0.2272, -0.1069,  0.0527, -0.1168,  0.0382, -0.2404, -0.2292,
           0.2997],
         [-0.2275, -0.1348,  0.1122, -0.1235, -0.0073, -0.2245, -0.2834,
           0.3573],
         [-0.2044, -0.1625,  0.1398, -0.1443, -0.0608, -0.2236, -0.334

In [25]:
testout

tensor([[[ 0.0000, -0.0806,  0.0000],
         [ 0.0000, -0.0806,  0.2843],
         [ 0.3472, -0.0000,  0.0000],
         [ 0.0000, -0.0017,  0.3534],
         [ 0.4092, -0.0583,  0.0000],
         [ 0.4520, -0.0590,  0.0000],
         [ 0.3712, -0.0531,  0.0000],
         [ 0.1336, -0.1509,  0.3174],
         [-0.0000, -0.2219,  0.2558],
         [-0.3430, -0.2595,  0.2376],
         [-0.0000, -0.2733,  0.2224],
         [-0.5613, -0.2793,  0.0000],
         [-0.5737, -0.0000,  0.1620],
         [-0.6126, -0.0000,  0.0000],
         [-0.6677, -0.2427,  0.1063],
         [-0.7462, -0.2522,  0.0000],
         [-0.0000, -0.1878, -0.0000],
         [-0.0000, -0.0000, -0.1147],
         [-0.0000, -0.0000, -0.0000],
         [-0.0000, -0.0000, -0.0000],
         [-0.0000, -0.0532, -0.3803],
         [-0.0000, -0.0184, -0.4731],
         [-0.5566,  0.0030, -0.5576],
         [-0.5470,  0.0132, -0.5724],
         [-0.5258,  0.0278, -0.5560],
         [-0.5440,  0.0000, -0.0000],
         [-0

In [81]:
import plotly.express as px
from sklearn.decomposition import PCA

In [108]:
def oneOut2ProgFrame(oneOut): 
    # oneOut is of tensor of shape (L, D)
    df = pd.DataFrame(oneOut, columns=["dim_0", "dim_1", "dim_2"])
    df["timestep"] = df.index
    df = df[["timestep", "dim_0", "dim_1", "dim_2"]]
    return df
def minmax(arr, a=-1, b=1): 
    min = arr.min()
    max = arr.max()
    return (b - a) * ((arr - min) / (max - min)) + a
def operate_on(arr): 
    return minmax(arr)
    # return arr
def framify(these_hids): 
    # these are token categories to be included
    # these hids are the corresponding hids
    # these numtags are the corresponding tags, named using indices in these
    # these_hids = st.zscore(these_hids, axis=0)
    df = pd.DataFrame(data=these_hids)
    # df = df.rename(columns={0: "dim_0", 1: "dim_1", 2: "dim_2"})
    df['dim_0_norm'] = operate_on(df['dim_0'])
    df['dim_1_norm'] = operate_on(df['dim_1'])
    df['dim_2_norm'] = operate_on(df['dim_2'])
    return df

In [109]:
def plot3d(X): 
    df = framify(X)
    config = {
    'toImageButtonOptions': {
        'format': 'png', # one of png, svg, jpeg, webp
        'filename': 'custom_image',
        'height': 1280,
        'width': 1280,
        'scale': 1 # Multiply title/legend/axis/canvas sizes by this factor
    }
    }
    fig = px.scatter_3d(df, x="dim_0_norm", y="dim_1_norm", z="dim_2_norm", animation_frame="timestep")
                # color='IPA')
    fig.update_traces(marker=dict(size=2),
                    selector=dict(mode='markers'))
    fig.update_layout(
        scene = dict(
            xaxis = dict(nticks=8, range=[-1,1],),
                        yaxis = dict(nticks=8, range=[-1,1],),
                        zaxis = dict(nticks=8, range=[-1,1],),),)
    fig.update_layout(legend= {'itemsizing': 'constant'})
    fig.update_layout(legend_title_text='Phone')
    fig.update_layout(
        legend=dict(
            x=0,
            y=1,
            title_font_family="Times New Roman",
            font=dict(
                family="Times New Roman",
                size=36,
                color="black"
            ),
            # bgcolor="LightSteelBlue",
            bordercolor="Black",
            borderwidth=1
        )
    )
    fig.update_layout(
        margin=dict(l=0, r=0, t=0, b=0),
    )
    camera = dict(
        eye=dict(x=0., y=0., z=2.5)
    )
    fig.update_layout(scene_camera=camera)
    # html_plot = fig.to_html(full_html=False, config=config)
    fig.show(config=config)
    # return html_plot
    return 

In [111]:
testoutcpu = h_2.cpu().detach()[:, :, :3]
progf = oneOut2ProgFrame(testoutcpu[0])
# framify(progf)

In [91]:
h2cpu = h_2.cpu().detach()[0]

In [92]:
pca = PCA(n_components=2)

In [94]:
X = pca.fit_transform(h2cpu)

In [95]:
df = pd.DataFrame(X, columns=["dim_0", "dim_1"])
df["timestep"] = df.index
df = df[["timestep", "dim_0", "dim_1"]]

In [None]:
df = pd.DataFrame(data=df)
df = df.rename(columns={0: "dim_0", 1: "dim_1"})
df['dim_0_norm'] = operate_on(df['dim_0'])
df['dim_1_norm'] = operate_on(df['dim_1'])
df['dim_2_norm'] = operate_on(df['dim_2'])

In [53]:
config = {
'toImageButtonOptions': {
    'format': 'png', # one of png, svg, jpeg, webp
    'filename': 'custom_image',
    'height': 1280,
    'width': 1280,
    'scale': 1 # Multiply title/legend/axis/canvas sizes by this factor
}
}
fig = px.scatter(framify(X), x="dim_0_norm", y="dim_1_norm", z="dim_2_norm", animation_frame="timestep")
fig.update_traces(marker=dict(size=2),
                selector=dict(mode='markers'))
fig.update_layout(
    scene = dict(
        xaxis = dict(nticks=8, range=[-1,1],),
                    yaxis = dict(nticks=8, range=[-1,1],),
                    zaxis = dict(nticks=8, range=[-1,1],),),)
fig.update_layout(legend= {'itemsizing': 'constant'})
fig.update_layout(legend_title_text='Phone')
fig.update_layout(
    legend=dict(
        x=0,
        y=1,
        title_font_family="Times New Roman",
        font=dict(
            family="Times New Roman",
            size=36,
            color="black"
        ),
        # bgcolor="LightSteelBlue",
        bordercolor="Black",
        borderwidth=1
    )
)
fig.update_layout(
    margin=dict(l=0, r=0, t=0, b=0),
)
camera = dict(
    eye=dict(x=0., y=0., z=2.5)
)
fig.update_layout(scene_camera=camera)
# html_plot = fig.to_html(full_html=False, config=config)
fig.show(config=config)