# Predict Data. 

### Import Libs

In [9]:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from model_dataset import HandshapeDataset, HandshapeDict, HandshapeIndexor, DS_Tools, FixedHandshapeDict
from paths import *
from model_model import LinearHandshapePredictor
from model_configs import *
from utils import *
from recorder import *
from graph_tools import GraphTool, Plotter, Smoother

### Init Model

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.CrossEntropyLoss()
model = LinearHandshapePredictor(
    input_dim=in_dim, 
    enc_lat_dims=enc_lat_dims, 
    hid_dim=hid_dim, 
    dec_lat_dims=dec_lat_dims, 
    output_dim=out_dim
)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [3]:
model

LinearHandshapePredictor(
  (encoder): Sequential(
    (0): LinPack(
      (lin): Linear(in_features=63, out_features=128, bias=True)
      (relu): ReLU()
      (batch_norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ResBlock(
      (lin1): Linear(in_features=128, out_features=128, bias=True)
      (lin2): Linear(in_features=128, out_features=128, bias=True)
      (relu): ReLU()
    )
    (2): LinPack(
      (lin): Linear(in_features=128, out_features=32, bias=True)
      (relu): ReLU()
      (batch_norm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): ResBlock(
      (lin1): Linear(in_features=32, out_features=32, bias=True)
      (lin2): Linear(in_features=32, out_features=32, bias=True)
      (relu): ReLU()
    )
    (4): LinPack(
      (lin): Linear(in_features=32, out_features=5, bias=True)
      (relu): ReLU()
      (batch_norm): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=Tru

In [4]:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params)

53516


In [5]:
ts = "1111200151"
stop_epoch = "621"
save_subdir = os.path.join(model_save_dir, "{}/".format(ts))
model_raw_name = f"{stop_epoch}"
model_name = model_raw_name + ".pt"
model_path = os.path.join(save_subdir, model_name)
state = torch.load(model_path)
model.load_state_dict(state)
model.to(device)

LinearHandshapePredictor(
  (encoder): Sequential(
    (0): LinPack(
      (lin): Linear(in_features=63, out_features=128, bias=True)
      (relu): ReLU()
      (batch_norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ResBlock(
      (lin1): Linear(in_features=128, out_features=128, bias=True)
      (lin2): Linear(in_features=128, out_features=128, bias=True)
      (relu): ReLU()
    )
    (2): LinPack(
      (lin): Linear(in_features=128, out_features=32, bias=True)
      (relu): ReLU()
      (batch_norm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): ResBlock(
      (lin1): Linear(in_features=32, out_features=32, bias=True)
      (lin2): Linear(in_features=32, out_features=32, bias=True)
      (relu): ReLU()
    )
    (4): LinPack(
      (lin): Linear(in_features=32, out_features=5, bias=True)
      (relu): ReLU()
      (batch_norm): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=Tru

In [6]:
class HiddenViewer: 
    css_in = f"""
<head>
    <link rel="stylesheet" href="{hidview_style_path}">
</head>
"""
    frame_pred_vis = """
<div class="container">
<input type="number" id="integerInput" min="0">
<div class="output" id="outputDiv"></div>
</div>
"""
    control_script_pre = """
<script>
const integerInput = document.getElementById('integerInput');
const outputDiv = document.getElementById('outputDiv');

const stringsList ="""

    control_script_post = """
integerInput.addEventListener('input', () => {
const selectedIndex = parseInt(integerInput.value);

if (selectedIndex >= 0 && selectedIndex < stringsList.length) {
    const selectedString = stringsList[selectedIndex];
    outputDiv.textContent = selectedString;
} else {
    outputDiv.textContent = "無";
}
});
</script>
"""
    

class PredictionViewer: 
    def __init__(self, file_name, sign_name, frame_count, predictions):
        # file_name: should be sign name + left/right
        # sign_name: name of sign as in guide file
        # frame_count: total number of frames of this sign
        # predictions: list of predicted signs
        
        self.html = f"""
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>{file_name}</title>
    <link rel="stylesheet" href="{predsview_style_path}">
    <script>
        var sign = "{sign_name}"; 
        var filenum = {frame_count}; 

        var preds = {predictions}; 
    </script>
</head>
<body>
<!-- Container to center the image -->
<div id="image-container">
    <img id="image-viewer" src="" alt="Image Viewer">
</div>
<div id="text-container">
    <img id="text-viewer" src="" alt="Text Viewer">
    <p id="text-alt"></p>
</div>
<script src="{predsview_js_path}"></script>
</body>
</html>
"""

In [7]:
hsdict = FixedHandshapeDict()
det_sub_dir = os.path.join(det_dir, "Cynthia_train/")

linegraph_ = os.path.join(preds_dir, "line/")
spectrogram_ = os.path.join(preds_dir, "spec/")
hidden_ = os.path.join(preds_dir, "hidden_viewer/")
preds_ = os.path.join(preds_dir, "preds_viewer/")
mk(linegraph_), mk(spectrogram_), mk(hidden_), mk(preds_)
for data_sub_dir in [linegraph_, spectrogram_, hidden_, preds_]: 
    no_smooth_sub = os.path.join(data_sub_dir, "non/")
    ma_smooth_sub = os.path.join(data_sub_dir, "ma/")
    mk(no_smooth_sub), mk(ma_smooth_sub)

In [10]:
guideline = pd.read_csv(guide_path)

In [12]:
def find_target_hs(sign_name, side, guideline): 
    result = guideline[guideline['NewFileName'] == sign_name]
    if result.empty:
        # if cannot find the sign in guide file (should not be this case), return "not a sign" = NAS 
        return "NAS"
    

Unnamed: 0,ASLLEXcode,NewFileName,EntryID,Arthur's Notes (from shoot),File number,Side,DH_1,OH_1,DH_2,OH_2,DH_3,OH_3,DH_4,OH_4,note,ASLinTextbook,SAME_1,ONLY_1
0,B_01_011,B_01_011-START-15CB-2,start,,1.0,D,x,x,NONE,NONE,NONE,NONE,NONE,NONE,,,True,1


In [8]:
total = len(os.listdir(det_sub_dir))
for idx, clip in enumerate(os.listdir(det_sub_dir)): 
    draw_progress_bar(idx, total)
    for whether_smooth in ["non", "ma"]: # non = no smoothing, ma = moving average smoothed. 
        for side in ["Right", "Left"]:
            find_name = "{}_{}".format(side, clip)
            reverse_find_name = "{}_{}".format(clip, side)

            gt = GraphTool(graph_dir, find_name)
            gt.interpolate(window_size=2)

            if whether_smooth == "non": 
                smoothed_features = gt.interpolated_features
            elif whether_smooth == "ma": 
                smoothed_features = Smoother.moving_average(gt.interpolated_features)
            else: 
                smoothed_features = gt.interpolated_features

            this_features = torch.from_numpy(smoothed_features.copy())
            batch_num, lm_num, dim_num = this_features.size()

            x = this_features
            x = x.to(device)
            x = x.to(torch.float32)

            hid_rep, pred = model.predict(x, hsdict)

            hid_rep = hid_rep.cpu().detach().numpy()

            # hidden viewer
            html = """"""
            html += HiddenViewer.css_in

            html += f"""<h1>{reverse_find_name}</h1><br>"""

            html += HiddenViewer.frame_pred_vis

            html += Plotter.plot_spectrogram(
                hid_rep, 
                title="ML Spectrogram" + side, 
                save_path= os.path.join(spectrogram_,whether_smooth, reverse_find_name)
            )

            html += Plotter.plot_line_graph(hid_rep, ["0", "1", "2", "3", "4"], 
                                            "ML Linegraph" + side, y_axis_label="Val", 
                                            save_path= os.path.join(linegraph_, whether_smooth, reverse_find_name))
            
            html += HiddenViewer.control_script_pre + str(pred) + HiddenViewer.control_script_post
            
            Plotter.write_to_html(html, os.path.join(hidden_, whether_smooth, f"{reverse_find_name}.html"))

            # preds viewer
            predsviewer = PredictionViewer(
                file_name=reverse_find_name, 
                sign_name=clip, 
                frame_count=len(os.listdir(os.path.join(det_sub_dir, clip))), 
                predictions=pred
            )
            Plotter.write_to_html(predsviewer.html, os.path.join(preds_, whether_smooth, f"{reverse_find_name}.html"))

 [                                                  ] 0%