# Linear model try
V1.0  
This version is a first build-up of the whole system. 

### Import Libs

In [1]:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import os
import numpy as np
import matplotlib.pyplot as plt

from model_dataset import HandshapeDataset, HandshapeDict
from paths import *
from model_model import HandshapePredictor
from model_configs import *
from utils import *
from recorder import *
from graph_tools import GraphTool, Plotter, Smoother

### Init Model

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.CrossEntropyLoss()
model = HandshapePredictor(
    input_dim=in_dim, 
    enc_lat_dims=enc_lat_dims, 
    hid_dim=hid_dim, 
    dec_lat_dims=dec_lat_dims, 
    output_dim=out_dim
)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [3]:
model

HandshapePredictor(
  (encoder): Sequential(
    (0): LinPack(
      (lin): Linear(in_features=63, out_features=32, bias=True)
      (relu): ReLU()
    )
    (1): ResBlock(
      (lin1): Linear(in_features=32, out_features=32, bias=True)
      (lin2): Linear(in_features=32, out_features=32, bias=True)
      (relu): ReLU()
    )
    (2): LinPack(
      (lin): Linear(in_features=32, out_features=16, bias=True)
      (relu): ReLU()
    )
    (3): ResBlock(
      (lin1): Linear(in_features=16, out_features=16, bias=True)
      (lin2): Linear(in_features=16, out_features=16, bias=True)
      (relu): ReLU()
    )
    (4): Linear(in_features=16, out_features=5, bias=True)
  )
  (decoder): Sequential(
    (0): LinPack(
      (lin): Linear(in_features=5, out_features=16, bias=True)
      (relu): ReLU()
    )
    (1): LinPack(
      (lin): Linear(in_features=16, out_features=32, bias=True)
      (relu): ReLU()
    )
    (2): Linear(in_features=32, out_features=93, bias=True)
  )
)

In [4]:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params)

9026


In [5]:
# # Just for keeping records of training hists. 
# ts = str(get_timestamp())
# # ts = "0623152604"
# save_txt_name = "train_txt_{}.hst".format(ts)
# save_trainhist_name = "train_{}.hst".format(ts)
# save_valhist_name = "val_{}.hst".format(ts)
# save_valacc_name = "valacc{}.hst".format(ts)

# valid_losses = LossRecorder(model_save_dir + save_valhist_name)
# train_losses = LossRecorder(model_save_dir + save_trainhist_name)
# valid_accuracies = LossRecorder(model_save_dir + save_valacc_name)
# text_hist = HistRecorder(model_save_dir + save_txt_name)

In [6]:
# READ = False
READ = True

In [7]:
if READ: 
    # model_name = last_model_namec
    model_name = "PT_0816115032_1799_full.pt"
    # model_name = "PT_0816184446_1349_full.pt"
    model_path = os.path.join(model_save_dir, model_name)
    state = torch.load(model_path)
    model.load_state_dict(state)
    model.to(device)

In [8]:
css_in = """
<head>
    <link rel="stylesheet" href="{}">
</head>
""".format(outstyle_path)

control_script_pre = """
<script>
const integerInput = document.getElementById('integerInput');
const outputDiv = document.getElementById('outputDiv');

const stringsList ="""

control_script_post = """
integerInput.addEventListener('input', () => {
  const selectedIndex = parseInt(integerInput.value);
  
  if (selectedIndex >= 0 && selectedIndex < stringsList.length) {
    const selectedString = stringsList[selectedIndex];
    outputDiv.textContent = selectedString;
  } else {
    outputDiv.textContent = "無";
  }
});
</script>
"""


frame_pred_vis = """
<div class="container">
  <input type="number" id="integerInput" min="0">
  <div class="output" id="outputDiv"></div>
</div>
"""

In [9]:
file_prefix = "cynthia_data"
tag_path = os.path.join(data_dir, file_prefix + "_tag.npz")
hsdict = HandshapeDict(tag_path)

for vd in os.listdir(det_dir): 
    for clip in os.listdir(det_dir + vd + "/")[876:]: 
        print(clip)
        for whether_smooth in ["non", "ma"]: 
            smooth_dir = os.path.join(spec_dir, whether_smooth + "/")
            pic_smooth_dir = os.path.join(spec_pic_dir, whether_smooth + "/")
            mk(smooth_dir)
            mk(pic_smooth_dir)

            for side in ["Right", "Left"]:
                find_name = "{}_{}".format(side, clip)
                reverse_find_name = "{}_{}".format(clip, side)

                gt = GraphTool(graph_dir, find_name)
                gt.interpolate(window_size=2)

                if whether_smooth == "non": 
                    smoothed_features = gt.interpolated_features
                elif whether_smooth == "ma": 
                    smoothed_features = Smoother.moving_average(gt.interpolated_features)
                else: 
                    smoothed_features = gt.interpolated_features

                this_features = torch.from_numpy(smoothed_features.copy())
                batch_num, lm_num, dim_num = this_features.size()

                x = this_features
                x = x.to(device)
                x = x.to(torch.float32)

                hid_rep, pred = model.predict(x, hsdict)

                hid_rep = hid_rep.cpu().detach().numpy()

                html = """"""

                html += css_in

                html += """<h1>{}</h1><br>""".format(reverse_find_name)

                html += frame_pred_vis

                html += Plotter.plot_spectrogram(
                    hid_rep, 
                    title="ML Spectrogram" + side, 
                    save_path= os.path.join(pic_smooth_dir, reverse_find_name + "_mlspec")
                )

                html += Plotter.plot_line_graph(hid_rep, ["0", "1", "2", "3", "4"], 
                                                "ML Linegraph" + side, y_axis_label="Val", 
                                                save_path= os.path.join(pic_smooth_dir, reverse_find_name + "_mlline"))
                
                html += control_script_pre + str(pred) + control_script_post
                
                Plotter.write_to_html(html, "{}{}.html".format(smooth_dir, reverse_find_name))

HKSL_lesson_only465-WHEELCHAIR-13PA-943
HKSL_lesson_only466-BLIND-0TNI-944
HKSL_lesson_only467-DISABLED_PERSON-0QSO-945
HKSL_lesson_only468-FIRST_AID-0O15-946
HKSL_lesson_only469-RESCUE-0ONF-947
HKSL_lesson_only47-YESTERDAY-0PH8-509
HKSL_lesson_only470-HOSPITALIZE-1401-948
HKSL_lesson_only471-SURGERY-0OIB-949
HKSL_lesson_only472-CANCER-0TIC-950
HKSL_lesson_only473-FEVER-0TJS-951
HKSL_lesson_only474-RECOVER-0NLN-952
HKSL_lesson_only475-SHY-0MTJ-953
HKSL_lesson_only476-BLUSHING-15R2-954
HKSL_lesson_only477-PALE_FACED-15R2-955
HKSL_lesson_only478-SCARED-0MTJ-956
HKSL_lesson_only479-SCARY-0O2G-957
HKSL_lesson_only48-TOMORROW-0PGE-510
HKSL_lesson_only480-GREEDY-135A-958
HKSL_lesson_only480-^GREEDY_2-135A-959
HKSL_lesson_only481-STINGY-0MR4-960
HKSL_lesson_only482-PETTY-0N0F-961
HKSL_lesson_only483-GRUDGEFUL-12GO-962
HKSL_lesson_only484-LUSTFUL-0MBT-963
HKSL_lesson_only485-WIDE_EYED-0TPS-964
HKSL_lesson_only486-POINTLESS-0S91-965
HKSL_lesson_only487-JEALOUS_IN_LOVE-0L3N-968
HKSL_lesson_only4