<a href="https://colab.research.google.com/github/TranGiaHuan0102/HME_Assignment_SS25_10423043/blob/main/Assignment_HME_recognition_ColabVersion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub
ntcuong2103_crohme2019_path = kagglehub.dataset_download('ntcuong2103/crohme2019')

print('Data source import complete.')

In [None]:
# If it is failed, run the above cell again to fix
assert ntcuong2103_crohme2019_path == "/kaggle/input/crohme2019"

## **CROHME Temporal Sequence Classification (CROHME-CTC) Baseline**

In this project, we will build and train a recurrent neural network (RNN) for translating handwritten mathematical expressions into sequence of symbols and relations based on Symbol Relation Tree and CROHME dataset.

The model we'll be using is a simple RNN with a single layer of Bidirectional LSTM cells. We will use Connectionist Temporal Classification (CTC) loss to train the model.

The model is trained on the CROHME dataset, which contains handwritten mathematical expressions in the form of Symbol Relation Trees. The dataset is preprocessed and converted into a sequence of symbols and relations, which is used as the input to the model.

Your job is to feature engineer the input data and train the model to achieve the best performance possible. Additionally, consider implementing data augmentation techniques to enhance the training dataset.

Some features from the project:

*   Public competition dataset: CROHME https://paperswithcode.com/dataset/crohme-2019
*   Advance data preparation: time series feature extraction, collate function
*   High level API of pytorch lightning
*   Experiment tracking with Weight and Biases



## Dependencies

In [None]:
!pip -q install pytorch-lightning torchmetrics wandb

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import numpy as np
import pandas as pd
import os

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
import wandb

import xml.etree.ElementTree as ET
from torch.nn.utils.rnn import pad_sequence

import matplotlib.pyplot as plt
from collections.abc import Sequence
from typing import Literal, Optional, Union

import json

> An example of inkml file, the points (x, y coords) are recorded in each trace (handwritten stroke).

```
<ink xmlns="http://www.w3.org/2003/InkML">
<traceFormat>
<channel name="X" type="decimal"/>
<channel name="Y" type="decimal"/>
</traceFormat>
...

<trace id="0">
241 123, 240 123, 239 123, 238 123, 237 123, 236 123, 236 122, 237 122, 238 122, 240 122, 242 122, 244 122, 247 122, 251 121, 254 121, 257 121, 260 120, 262 120, 265 120, 266 120, 267 120, 268 120, 269 120, 268 120
</trace>
<trace id="1">
301 123, 300 124, 299 124, 298 124, 297 124, 296 124, 295 124, 294 124, 295 124, 296 124, 297 124, 298 124, 299 123, 301 123, 302 123, 303 123, 304 122, 305 122, 306 122, 307 122, 308 122, 310 122, 311 122, 314 122, 316 121, 317 121, 319 121, 321 121, 323 121, 328 120, 330 120, 332 120, 333 120, 335 120, 336 120, 344 119, 346 119, 348 119, 350 119, 352 119, 357 119, 358 119, 360 119, 362 119, 364 119, 365 119, 370 119, 372 119, 373 119, 374 119, 375 119, 377 119, 377 118, 376 118, 374 118, 372 118
</trace>
<trace id="2">
318 89, 319 89, 320 89, 321 88, 323 87, 324 86, 325 84, 326 83, 327 82, 328 80, 329 79, 330 77, 331 75, 332 74, 333 73, 333 72, 333 73, 333 75, 333 76, 333 78, 333 80, 332 81, 332 83, 332 85, 332 86, 332 88, 332 90, 332 92, 332 93, 333 95, 333 97, 333 98, 333 100, 333 101, 333 102, 333 103, 333 104, 333 103, 333 101, 332 99
</trace>
<trace id="3">
302 148, 301 148, 302 147, 302 146, 303 145, 304 144, 305 143, 306 142, 307 140, 308 139, 308 138, 308 137, 308 136, 308 137, 308 138, 309 139, 309 140, 309 141, 309 142, 309 144, 309 146, 309 147, 309 149, 309 151, 309 153, 309 155, 309 157, 310 159, 310 162, 310 164, 310 169, 310 171, 310 172, 310 174, 311 175, 311 179, 311 180, 311 181, 311 182, 311 181
</trace>
...

</ink>

```

> An example of annotation file: left part is inkml file path, right part is the label (target) for that inkml file.

```
...
crohme2019/test/UN19_1041_em_597.inkml	- Right \sqrt Inside 2
crohme2019/test/UN19_1019_em_256.inkml	a Right n Right y
crohme2019/test/UN19_1033_em_474.inkml	V Sub n Right - Right 1 NoRel = Right \int Right d Sup n Right - Right 1 NoRel x Right \sqrt Inside h

...
```

> The label is a sequence of **symbols** and **spatial relations** between symbols (based on writing order).



## Task 1: Build Vocabulary

Builds and manages a vocabulary for converting characters (tokens) to indices (encoding) and vice-versa (decoding). This is essential for processing text data in machine learning tasks, particularly for sequence-to-sequence models like those used in handwriting recognition or mathematical expression translation.

In this project, the vocabulary is for encoding the target sequence (for example: "- Right \\sqrt Inside 2") into a sequence of indices ([5, 37, 74, 30, 10]). It is constructed from multiple annotation files (train, test, and validation sets) and includes a special blank character (this is for CTC loss). It supports encoding (tokens to indices) and decoding (indices to tokens) operations.


> _You will need to build the `class Vocab` using the two functions provided. Also the vocabulary should be exported as a JSON file so that we can reuse it later in the dataset._

***Input***: the path to the dataset files _(train, valid, and test)_.


```python
paths = [
    "dataset/crohme2019_train.txt",
    "dataset/crohme2019_test.txt",
    "dataset/crohme2019_valid.txt",
]
```

***Output***: the vocabulary.

```python
vocab = {
    "a": 0,
    "b": 1,
    "c": 2,
    ...
}
```

In [None]:
ROOT_DIR = "/kaggle/input/crohme2019" # Use this if using Google Colab
#ROOT_DIR = "/dataset"  # Use this if clone this repo

In [None]:
paths = [
    "/kaggle/input/crohme2019/crohme2019_train.txt",
    "/kaggle/input/crohme2019/crohme2019_test.txt",
    "/kaggle/input/crohme2019/crohme2019_valid.txt",
]

df = pd.read_csv(paths[1], sep="\t", header=None, names=["path", "label"]).dropna().astype(str)

df.head()

In [None]:
def get_unique_chars(path) -> set:
    df = pd.read_csv(path, sep="\t", header=None, names=["path", "label"]).dropna().astype(str)

    res = set()

    for i in df['label'].apply(lambda x: x.strip().split()):
        res.update(i)
    return res

In [None]:
unique_chars = get_unique_chars(paths[0])
unique_chars

The vocabulary should be sorted by ascii table. The first character (index 0) must be blank (`''`).

In [None]:
def chr2idx(unique_chars: set) -> dict:
    unique_chars.add('')
    return {char: idx for idx, char in enumerate(sorted(unique_chars))}

In [None]:
vocab = chr2idx(get_unique_chars(paths[0]))

assert vocab[""] == 0
assert vocab["|"] == 108
assert vocab["\\pi"] == 68
assert vocab["\\exists"] == 51
assert len(vocab) == 109

In [None]:
class Vocab:
    """

    Attributes:
        paths (list): A list of file paths to the annotation files.
        char2idx (dict): A dictionary mapping characters to their corresponding indices.  This is the primary vocabulary.
        idx2char (dict): A reverse mapping from indices to characters, facilitating decoding.

    Methods:
        build_vocab(): Builds the vocabulary from the specified annotation files, sorting characters lexicographically and adding a blank character.
        get_vocab(): Returns the character-to-index vocabulary.
        save_vocab(path): Saves the vocabulary to a JSON file.
        encode(tokens): Converts a list of tokens (characters) into a list of their corresponding indices.
        decode(ids): Converts a list of indices back into a list of tokens (characters).
    """
    def __init__(self, vocab_file=None):
        self.char2idx = {}
        if vocab_file is not None:
            self.char2idx = json.load(open(vocab_file))
        self.idx2char = {v: k for k, v in self.char2idx.items()}

    def build_vocab(self, annotations) -> None:
        def get_unique_chars(path) -> set:
            df = pd.read_csv(path, sep="\t", header=None, names=["path", "label"]).dropna().astype(str)

            res = set()

            for i in df['label'].apply(lambda x: x.strip().split()):
                res.update(i)
            return res

        def chr2idx(unique_chars: set) -> dict:
            unique_chars.add('')
            return {char: idx for idx, char in enumerate(sorted(unique_chars))}
        # TODO: build vocab from all annotation files
        vocab = {}

        # vet the annotations before generating vocab
        try:
            for a in annotations:
                unique_chars = get_unique_chars(a)
        except IndexError as e:
            print("Bro forgot to check the annotations 💀💀💀")
            return
        else:
            vocab = chr2idx(unique_chars)

        # assign vocab to the class
        self.char2idx = vocab

    def get_vocab(self) -> dict:  # getter
        return self.char2idx

    def save_vocab(self, path: str) -> None:  # save vocab to json file
        with open(path, 'w') as f:
            json.dump(self.char2idx, f)

    def encode(self, tokens):
        return [self.char2idx[token] for token in tokens]

    def decode(self, ids):
        return [self.idx2char[id] for id in ids]


In [None]:
# CHECKPOINT: build vocab from annotations
annotations = [
    "/kaggle/input/crohme2019/crohme2019_train.txt",
    "/kaggle/input/crohme2019/crohme2019_test.txt",
    "/kaggle/input/crohme2019/crohme2019_valid.txt",
]


vocab = Vocab()
vocab.build_vocab(annotations)
vocab.save_vocab('vocab.json')

## Task 2: Build Dataset

In this task, we will build a dataset class that will be used to load the dataset files _(train, valid, and test)_.

The dataset class will also be used to preprocess the data by converting the characters in the data to integers using the vocabulary.

> _Before going to this task, it is **highly recommended** to read this [Tutorial on creating Custom Dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files) with PyTorch._
>
> _You should have basic understanding of required methods in the Dataset class such as `__init__`, `__len__`, and `__getitem__`._


***Input***: Dataset class should take the following arguments:.

- `annotation`: the path to the dataset annotation *(`.txt`)*.
- `root_dir`: the root directory of the dataset file.
- `vocab`: the vocabulary.

***Output***: the dataset class.

> *Note: You can adjust the Dataset class base on your references.*


### Dataflow
Below is a typical dataflow to train a model.
```mermaid
flowchart LR
    A[Raw Data] --> |parse| B[Cleaned Data]
    B --> |feature engineering| C[Dataset]
    C --> |package| D[DataLoader]
    D --> |load| E[Model]
    E --> |predict| F[Output]
    F --> |calculate| G[Loss]
    G --> |backpropagate| E
```

### Handling Inkml files

We have an inkml file: `crohme2019/test/UN19_1041_em_597.inkml`

Its label is `- Right \sqrt Inside 2`

Target would be: `[5, 37, 74, 30, 10]`, where the index here is the index of tokens (from a vocabulary)

We provided you with class `Inkml` to handle the parsing and processing of the inkml files. The main method you will use is `getTraces()` which returns the traces of the inkml file.

> _Inspect this class and understand how you should be working with this type of data._

```python

In [None]:
import xml.etree.ElementTree as ET
import numpy as np
import matplotlib.pyplot as plt
from collections.abc import Sequence
from typing import Literal, Optional, Union


In [None]:
class Segment(object):
    """Class to reprsent a Segment compound of strokes (id) with an id and label."""

    __slots__ = ("id", "label", "strId")

    def __init__(self, *args):
        if len(args) == 3:
            self.id = args[0]
            self.label = args[1]
            self.strId = args[2]
        else:
            self.id = "none"
            self.label = ""
            self.strId = set([])


class Inkml(object):
    """Class to represent an INKML file with strokes, segmentation and labels"""

    __slots__ = ("fileName", "strokes", "strkOrder", "segments", "truth", "UI")

    NS = {
        "ns": "http://www.w3.org/2003/InkML",
        "xml": "http://www.w3.org/XML/1998/namespace",
    }

    def __init__(self, *args):
        self.fileName = None
        self.strokes = {}
        self.strkOrder = []
        self.segments = {}
        self.truth = ""
        self.UI = ""
        if len(args) == 1:
            self.fileName = args[0]
            self.loadFromFile()

    def fixNS(self, ns, att):
        """Build the right tag or element name with namespace"""
        return "{" + Inkml.NS[ns] + "}" + att

    def loadFromFile(self):
        """load the ink from an inkml file (strokes, segments, labels)"""
        tree = ET.parse(self.fileName)
        # # ET.register_namespace();
        root = tree.getroot()

        # look at all annotation tags
        for info in root.findall("ns:annotation", namespaces=Inkml.NS):
            # check if current annotation tag has the attribute type
            if "type" in info.attrib:
                if info.attrib["type"] == "truth":                          # if yes, check if the type attribute has value truth, if yes, assign the content of tag to truth
                    self.truth = info.text.strip()                          # this is the overall label of the entire image
                if info.attrib["type"] == "UI":                             # then check if the type attribute has value UI, if yes, assign the content of tag to UI
                    self.UI = info.text.strip()
        # look at all trace tags
        for strk in root.findall("ns:trace", namespaces=Inkml.NS):
            self.strokes[strk.attrib["id"]] = strk.text.strip()             # self.strokes's ith entry is assgined the content of ith trace tag
            self.strkOrder.append(strk.attrib["id"])                        # self.strkOrder ith entry is i itself  (i is the ID number)

        # look for first instance of a trace group tag that represents stroke groups (e.g., strokes that form a single character), assign it to segments
        segments = root.find("ns:traceGroup", namespaces=Inkml.NS)
        if segments is None or len(segments) == 0:
            return                                                          # return if segments has nothing (no trace group tags)

        # loop through and find every child trace group tag within the outer trace group in segments (most trace groups have children nested within the,)
        for seg in segments.iterfind("ns:traceGroup", namespaces=Inkml.NS):
            id = seg.attrib[self.fixNS("xml", "id")]                        # grab the content of the xml:id attribute of every child trace group tag
            label = seg.find("ns:annotation", namespaces=Inkml.NS).text     # grab the content of the annotation tag below this child trace group tag (this is the label of every stroke group)

            strkList = set([])

            for t in seg.findall("ns:traceView", namespaces=Inkml.NS):      # traceDataRef is a number that points to the stroke in self.strokes that corresponds to this stroke group
                strkList.add(t.attrib["traceDataRef"])                      # in the .inkml file, the first stroke group refers to id=0, which is the stroke responsible for the minus
            self.segments[id] = Segment(id, label, strkList)

    def getTraces(self, height=256):
        # Extract every x,y coordinate pair of every stroke in self strokes, and put them in traces array
        '''
        for example, traces_array would look like:
        [
            [[460, 226], [468, 221], [470, 221], [473, 221], [475, 220], [478, 220], [481, 219], ...] <-- first trace,
            [...]   <-- second trace,
            [...]   <-- third trace
        ]
        '''

        traces_array = [
            np.array(
                [p.strip().split()[:2] for p in self.strokes[id].split(",")], dtype="float"
            )
            for id in self.strkOrder
        ]

        # Flatten traces_array into a singular 1d-array, across the x axis
        '''
        post concatenate, traces_array would look like this (1d-array):
        [
            [460, 226], [468, 221], [470, 221], [473, 221], [475, 220], [478, 220], [481, 219], ...
        ]

        .max(0) and min(0) to find the highest x and y coordinates ([481, 226] and [460, 219] in this case)
        their difference is [21, 7], the second value (7) is the difference between max-y and min-y
        '''

        ratio = height / (
            (
                np.concatenate(traces_array, 0).max(0)
                - np.concatenate(traces_array, 0).min(0)
            )[1]
            + 1e-6
        )

        # Scale y-axis to 256 pixels by multiplying each trace in traces_array by ratio (highest y value will be 256)
        return [(trace * ratio).astype(int).tolist() for trace in traces_array]

    def view(self):
        plt.figure(figsize = (16, 4))
        plt.axis("off")
        for trace in self.getTraces():
            trace_arr = np.array(trace)
            plt.plot(trace_arr[:, 0], -trace_arr[:, 1])  # invert y coordinate

In [None]:
inkml_path = '/kaggle/input/crohme2019/crohme2019/crohme2019/test/UN19_1041_em_597.inkml'
ink = Inkml(inkml_path)
ink.view()

In [None]:
# Display shapes of first 5 traces in one file
traces = ink.getTraces() # list of trace, each trace is a list of points
# traces[0]
[np.array(trace).shape for trace in ink.getTraces()][:5]

In [None]:
len(traces), traces[0][:60]

### Build Dataset

For calculating the Connectionist Temporal Classification (CTC) loss, read the [expected output](https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html) of this loss function from PyTorch.

> _Hint: focus on the shapes and the expected variables. (see code example for calculating loss)_
> ```python
> loss = ctc_loss(input_tensor, target_tensor, input_lengths, target_lengths)
> ```

**Base on the observation, how can we build a dataset class from the dataset files?**

> Hints:
>   - *Understand the big picture. (Dataflow)*
>   - *What informations needed to calculate our Loss?*
>   - *What should be returned in the `__getitem__` method?*
>   - *The outputs should be of Tensor type*
   
***References from pytorch [tutorials](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html)***

`torch.utils.data.Dataset` is an abstract class representing a dataset. Your custom dataset should inherit `Dataset` and override the following methods:

- `__len__` so that `len(dataset)` returns the size of the dataset.

- `__getitem__` to support the indexing such that `dataset[i]` can be used to get ith sample.


### Feature representation

In this task, we will convert the strokes into feature representations that can be used by the model.

Initially, the `.getTraces()` method returns list of strokes, where each stroke is a list of (x, y) coordinates. Intuitively, we can use those coordinates as features. But if directly using the coordinates as features, the model will have a hard time learning the patterns.

Instead, we can calculate the difference *(∆d)* between consecutive coordinates as features. This way, the model can learn the patterns more easily.

$\Delta_x = x_{i+1} - x_i$ and $\Delta_y = y_{i+1} - y_i$

The feature would then be normalized as

($\frac{\Delta_x}{d}$ , $\frac{\Delta_y} {d}$), where $d = \sqrt{\Delta_x^2 + \Delta_y^2}$

**Pen-up and Pen-down**

In the dataset, each stroke is separated by a pen-up event. We can use this information to separate the strokes.

1. The pen is lifted from the paper (connecting the end of a stroke to the start of a stroke): $pen\_up = 1$
2. The pen is on the paper: $pen\_up = 0$

Then, our feature representation would be:
($\frac{\Delta_x}{d}$ , $\frac{\Delta_y} {d}$, $d$, $pen\_up$)


```
point 1: (x1, y1)
point 2: (x2, y2)
point 3: (x3, y3) <--- end of stroke #1
point 4: (x4, y4) <--- start of stroke #2
...
point n: (xn, yn)
```

The feature representation will be:

```
f1 = ((x2 - x1) / d, (y2 - y1)/d, d, 0)
f2 = ((x3 - x2) / d, (y3 - y2)/d, d, 0)
f3 = ((x4 - x3) / d, (y4 - y3)/d, d, 1) <-- pen up
...
fn-1 = ((xn - xn-1) / d, (yn - yn-1)/d, d, 0)
```

#### **CHECK LIST**

For each data sample, we will do the following steps
- Combine all the strokes into a single stroke (N, 2)
- Compute first order differences of x and y coordinates
- Compute Euclidean distances between consecutive points
- Remove any zero length consecutive points
- Normalize the x and y coordinates by Euclidean distance
- Add feature pen-up/pen-down

> *Features of transformed data: (delta traces, distance, pen_up_down)*

Label
- Define label (list of indices of the words)

Finally, we will convert the data and label to PyTorch tensors.
- Convert data and label to PyTorch tensors

In [None]:
class InkmlDataset(Dataset):
    def __init__(self, annotation, root_dir, vocab : Vocab):
        """
        Arguments:
            annotation (string): annotation file txt.
            root_dir (string): directory holds the dataset.
            vocab (set): of vocab.
        """
        self.annotation = annotation
        self.root_dir = root_dir
        self.vocab = vocab

        # load annotations and create self.inks and self.labels from the annotation file
        df = pd.read_csv(annotation, sep="\t", header=None, names=["path", "label"]).dropna().astype(str)

        self.inks = [Inkml(f"{root_dir}/crohme2019/{stroke_id}") for stroke_id in df["path"]]
        self.labels = [label_id for label_id in df["label"].apply(lambda x: x.strip().split())]

    def __len__(self):
        """This code should return the number of samples in the dataset"""
        return len(self.labels)

    def __getitem__(self, idx):
        def feature_extraction(traces):
            trace_lengths = [len(trace) for trace in traces]   # For the first example, there are 3 traces. Altogether, they contribute n_points (306) coordinate pairs (strokes)
            n_points = sum(trace_lengths)                      # total number of points
            feature_length = n_points - 1                      # feature length = n_points - 1

            # 1) pen-up feature
            cumsum_trace_lengths = np.cumsum(trace_lengths)    # find the cummulative sum across 3 traces (start-end of 2 distinct traces is where liftup occurs)
            pen_up_index = cumsum_trace_lengths - 1            # -1 because we start counting from 0

            pen_up_feature = np.zeros(feature_length)          # initialize pen-up feature by a vector of all zeros (all strokes are continuous)
            pen_up_feature[pen_up_index[:-1]] = 1              # except for the ones that mark the end and start of two distinct traces

            # 2) delta_x, delta_y
            # combine all the strokes into a single stroke (N, 2)
            combined_trace = np.concatenate(traces, axis=0)             # combine every stroke into one big array of shape (306, 2), 306 coordinate pairs
            deltas = combined_trace[1:] - combined_trace[:-1]           # (delta x, delta y) of shape (305, 2), one less since we're taking the difference

            # 3) compute Euclidean distances between consecutive points
            deltas_squared = deltas ** 2    # ( delta_x ^ 2, delta_y ^ 2)
            euclidean_distances = np.sqrt(np.sum(deltas_squared, axis=1, keepdims=True)) # e_d = sqrt( delta_x ^ 2 + delta_y ^ 2),  e_d.shape = (305, 1)

            # shape: deltas(305, 2), euclidean_distances(305,1), pen_up_feature(305,)
            # shape: feature(305, 4) T = 305 (each timestep has 4 features, unrelated to C), N = 1 (collate_fn will allow for greater n)

            # 4) merge all the features (delta_x, delta_y, d, pen_up)
            feature = np.concatenate([deltas, euclidean_distances, pen_up_feature[:, np.newaxis]], axis=1)      # np.newaxis turns (305,) into (305, 1), needed for concat

            # 5) remove any zero length consecutive points (eucledian_distance = 0)
            feature_without_zeros = feature[np.where(feature[:, 2]  != 0)]                                      # spoiler: no useless feature for first example

            # 6) normalize the delta_x and delta_y by Euclidean distance d
            # (delta_x, delta_y, d, pen_up) ->  (delta_x/d, delta_y/d, d, pen_up)
            # for some reason when taking a single column, shape is (305, ), use .newaxis to force it to (305, 1)
            feature_without_zeros[:, :2] /= feature_without_zeros[:, 2][:, np.newaxis] # No ZeroDivision, crisis averted!
            feature = feature_without_zeros
            return feature

        """This code should return the idx-th sample in the dataset"""

        ink = self.inks[idx]
        label = self.labels[idx]

        # Extract tensor_data
        traces = ink.getTraces()
        input_tensor = torch.tensor(feature_extraction(traces), dtype=torch.float32)
        input_length = torch.tensor(len(input_tensor), dtype=torch.long)

        # Extract label data
        target_tensor = torch.tensor(self.vocab.encode(label), dtype=torch.long)
        target_length = torch.tensor(len(target_tensor), dtype=torch.long)

        return input_tensor, target_tensor, input_length, target_length




In [None]:
# Testing dataset
dataset = InkmlDataset(annotation=f"{ROOT_DIR}/crohme2019_test.txt", root_dir=ROOT_DIR, vocab=Vocab("vocab.json"))
input_tensor, target_tensor, input_length, target_length = dataset[0]

target_length

## Task 3: Build Lightning Data Module via Dataloader

Revise [this tutorial](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#preparing-your-data-for-training-with-dataloaders) to see how you can prepare your data for training with DataLoader.

Then refer to this [documentation](https://pytorch.org/docs/stable/data.html#loading-batched-and-non-batched-data) to understand how DataLoader in PyTorch loads Batched or Non-Batched data.

> **TASK:** Write the dataloader with custom collate function to pad the input and target sequence

For a better understanding of the importance of handling variable-length sequences in deep learning models, it is crucial to format and pad the data appropriately to ensure consistency during training. Learn more about this topic [here](https://plainenglish.io/blog/understanding-collate-fn-in-pytorch-f9d1742647d3).

> _**Hint 1**: use `torch.nn.utils.rnn.pad_sequence` to pad the input and target sequences. Read more about this function [here](https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pad_sequence.html)._

> _**Hint 2**: read Example 3 from this [tutorial](https://www.programiz.com/python-programming/methods/built-in/zip) to unpack data using `zip()`._

In [None]:
def collate_fn(batch):
    # Unpack the batch into separate lists  (for the first 2 samples)
    input_tensors = [sample[0] for sample in batch]
    label_tensors = [sample[1] for sample in batch]        # the 2nd sample has greater length (567)
    input_lens = [sample[2] for sample in batch]
    label_lens = [sample[3] for sample in batch]

    # Padding with zero
    padded_features = pad_sequence(input_tensors, batch_first=True, padding_value=0)   # of shape (B=2, T=567, *= 4) (batch_size=2, longest_tensor_length=567, num_dim_of_tensor=4)
    padded_labels = pad_sequence(label_tensors, batch_first=True, padding_value=0)     # of shape (B=2, T=35, *=1)

    # Stack input_lens
    input_lens = torch.stack(input_lens)
    label_lens = torch.stack(label_lens)


    # features, labels, input_lens, label_lens should be torch.tensor
    return padded_features, padded_labels, input_lens, label_lens

In [None]:
dataset = InkmlDataset(annotation=f"{ROOT_DIR}/crohme2019_test.txt", root_dir=ROOT_DIR, vocab=Vocab("vocab.json"))
features, labels, input_lens, label_lens = collate_fn([dataset[0], dataset[1]])

import numpy.testing as npt

assert type(input_lens) == torch.Tensor
assert type(label_lens) == torch.Tensor

> **TASK:** implement InkmlDataset_PL Lightning Datamodule

In [None]:
class InkmlDataset_PL(pl.LightningDataModule):
    """
    PyTorch Lightning data module for handling the INKML dataset.
    """

    def __init__(
        self,
        batch_size: int = 10,
        workers: int = 5,
        train_data: str = "",
        val_data: str = "",
        test_data: str = "",
        root_dir: str = ROOT_DIR,
        vocab_file='vocab.json',
    ):
        super().__init__()
        self.batch_size = batch_size
        self.workers = workers
        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        self.root_dir = root_dir
        self.vocab = Vocab(vocab_file)

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_dataset = InkmlDataset(self.train_data, root_dir=self.root_dir, vocab=self.vocab)
            self.val_dataset = InkmlDataset(self.val_data, root_dir=self.root_dir, vocab=self.vocab)
        if stage == "val" or stage is None:
            self.val_dataset = InkmlDataset(self.val_data, root_dir=self.root_dir, vocab=self.vocab)
        if stage == "test" or stage is None:
            self.test_dataset = InkmlDataset(self.test_data, root_dir=self.root_dir, vocab=self.vocab)

    @staticmethod
    def custom_collate_fn(batch):
        # Unpack the batch into separate lists  (example: for the first 2 samples)
        input_tensors = [sample[0] for sample in batch]
        label_tensors = [sample[1] for sample in batch]        # the 2nd sample has greater length (567)
        input_lens = [sample[2] for sample in batch]
        label_lens = [sample[3] for sample in batch]

        # Padding with zero
        padded_features = pad_sequence(input_tensors, batch_first=True, padding_value=0)   # of shape (B=2, T=567, *= 4) (batch_size=10, longest_tensor_length=567, num_dim_of_tensor=4)
        padded_labels = pad_sequence(label_tensors, batch_first=True, padding_value=0)     # of shape (B=2, T=35, *=1)

        # Stack input_lens
        input_lens = torch.stack(input_lens)
        label_lens = torch.stack(label_lens)

        # tuple of four torch tensors
        return padded_features, padded_labels, input_lens, label_lens

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.workers,
            collate_fn=self.custom_collate_fn,
            shuffle = True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.workers,
            collate_fn=self.custom_collate_fn,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.workers,
            collate_fn=self.custom_collate_fn,
        )

## Task 4: Build Model

In this task, we will build a model that will be used to train the dataset.
The model will be a simple RNN with a single layer of Bidirectional LSTM cells. The model will take the input from the dataset and output the predicted sequence of symbols and relations and will use CTC loss to train the model.

The model will be built using PyTorch and will use the following architecture


> _We highly recommend you to read through this [tutorial on Creating a model using Pytorch](https://pytorch.org/tutorials/beginner/introyt/modelsyt_tutorial.html) and understand the basic building blocks of a model in PyTorch._

> _**ATTENTION** Below is the architecture for the model. You must **STRICTLY** follow the architecture and the parameters mentioned below._
> | Layer Type         | Configuration                              |
> |--------------------|--------------------------------------------|
> | Input              | Sequence of vectors with `input_size=4`    |
> | LSTM Layer         | `hidden_size=256`, `num_layers=2`          |
> |                    | `batch_first=True`                         |
> |                    | `bidirectional=True`                       |
> | LSTM Output        | Output shape: `(batch_size, seq_len, hidden_size*2)`|
> | Fully Connected    | `Linear(hidden_size*2, num_classes)`|
> | Activation         | `LogSoftmax(dim=...)`|
> | Output             | `(batch_size, seq_len, num_classes)`|


In [None]:
class LSTM_TemporalClassification(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(LSTM_TemporalClassification, self).__init__()

        self.lstm = nn.LSTM(input_size=input_size,
                            hidden_size=hidden_size,
                            num_layers=num_layers,
                            batch_first=True,
                            bidirectional=True)

        self.fc = nn.Linear(hidden_size*2, num_classes)

        '''LSTM records how the sequence changes over time and builds a context-aware hidden state
           Linear translates LSTM's outputs into class scores
        '''

        self.activation = torch.nn.LogSoftmax(dim=2)       # convert class scores into log probabilities (on third dim since that's where the class scores are)

    def forward(self, x):
        lstm_out, _ = self.lstm(x)  # LSTM output: tuple of 2 elements, first is tensor of shape (B, Tmax, hidden_size*2) batch_size, longest_seq, 256*2=512
        x = self.fc(lstm_out)       # Linear output: tensor of shape (B, Tmax, C) batch_size, longest_seq_ num_classes
        x = self.activation(x)      # Softmax dim=2 to target num_classes
        return x

In [None]:
# Test your implementation
model = LSTM_TemporalClassification(4, 256, 2, 109)
assert model.forward(torch.rand((10, 100, 4))).shape == (10, 100, 109)

## Task 5: Understand CTC Loss

In this task, we will understand how to use CTC loss to train the model. The CTC loss is used to train the model to predict the sequence of symbols and relations from the input sequence of features.

> _For deeper understanding of CTC loss, read this [blog](https://distill.pub/2017/ctc/)_

Check out this [Pytorch documentation](https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html) to learn more about its implementation.

> _Hint: read the expected input and output of the CTC loss function. The input should be of shape `(T, N, C)` where T is the length of the input sequence, N is the batch size, and C is the number of classes. The target should be of shape `(N, S)` where S is the length of the target sequence._

After understanding how CTC Loss works, we can proceed to implement it in our model.

In [None]:
class GreedyCTCDecoder(torch.nn.Module):
    def __init__(self, vocab):
        super().__init__()
        self.vocab = [char for char in vocab.char2idx.keys()]     # Grab list of vocabs (the keys of the dict)
        self.blank = int(vocab.char2idx[""])                      # Blank is the first elem of dict

    def forward(self, emission: torch.Tensor) -> list:
        """Given a sequence emission over labels, get the best path
        Args:
          emission (Tensor): Logit tensors. Shape `[seq_len, num_label]`.

        Returns:
          list: The resulting transcript
        """
        # TODO: implement the decoder (code from documentation)
        indices = torch.argmax(emission, dim=-1)
        indices = torch.unique_consecutive(indices, dim=-1)
        indices = [i for i in indices if i != self.blank]
        joined = " ".join([self.vocab[i] for i in indices])
        output_seq_list = joined.replace("|", " ").strip().split()
        return output_seq_list

In [None]:
def edit_distance(pred_seq: list, label_seq: list):
    # TODO: implement Token Edit distance
    m = len(pred_seq)
    n = len(label_seq)

    # Create a dp table to store results of subproblems
    dp = [[0] * (n + 1) for _ in range(m + 1)]

    # Base cases
    for i in range(m + 1):
        dp[i][0] = i
    for j in range(n + 1):
        dp[0][j] = j

    # Fill the rest of dp[][]
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if pred_seq[i - 1] == label_seq[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
            else:
                dp[i][j] = 1 + min(dp[i][j - 1], dp[i - 1][j], dp[i - 1][j - 1])
    return dp[m][n]


## Task 6: Build Lightning Module

In this task, we will build a Lightning module that will be used to train the model. The Lightning module will mostly be used to define the training and validation steps, as well as the optimizer and learning rate scheduler.

> _More on building Lightning module can be found [here](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html). You should read about core methods to know what to implement in your module._

In [None]:
class MathOnlineModel(pl.LightningModule):
    def __init__(
        self,
        lr=0.001,
        input_size=4,       # number of features (delta_x, delta_y, ed, penup)
        output_size=109,    # number of classes (words in Vocab)
        hidden_size=256,
        num_layers=2,
        decoder=None,
        vocab:Vocab=None
    ):
        super().__init__()
        self.model = LSTM_TemporalClassification(
            input_size, hidden_size, num_layers, output_size
        )
        self.criterion = nn.CTCLoss()
        self.lr = lr
        self.decoder = decoder
        self.vocab = vocab
        self.RELATION_TOKENS=['Right', 'Above', 'Below', 'Inside', 'Sub', 'Sup', 'NoRel']
        self.hyperparamater = 0.1

    def forward(self, x):
        return self.model(x)

    def get_relation_indices(self):
        relation_indices = self.vocab.encode(self.RELATION_TOKENS)
        return relation_indices

    def constraint_loss(self, output, features, feature_lens):
        # Constraint loss to penalize relation outputs during pen-up timesteps
        B, Tmax, C = output.shape
        total_loss = 0.0
        valid_samples = 0

        # Define relation token indices
        relation_indices = self.get_relation_indices()

        for i in range(B):
            real_timestep_length = feature_lens[i].item()
            if real_timestep_length <= 0 or real_timestep_length > Tmax:       # Invalid input length
                continue

            # Get pen states (pen_down=0, pen_up=1)
            pen_states = features[i, :real_timestep_length, 3]                  # shape(T, )    # T: original (unpadded) timestep
            pen_down_mask = (pen_states == 0).float()                           # pen down = 0

            # Get model output for this sample
            output_sample = output[i, :real_timestep_length, :]                 # shape (seq_len, vocab_size)

            # Convert from log probs to normal probs
            probs = torch.softmax(output_sample, dim=-1)                        # dim=-1 a.k.a horizontally a.k.a run it through each timestep

            # Calculate p_rel: sum of probabilities for relation tokens at each timestep
            p_rel = torch.zeros(real_timestep_length, device=output.device)
            for rel_idx in relation_indices:
                p_rel += probs[:, rel_idx]                                      # Add probabilities for this relation token across all timesteps

            # Calculate constraint loss for this sample
            # L_constraint = -log(1 - sum(p_rel * pen_down_mask))
            weighted_p_rel = p_rel * pen_down_mask  # Zero out pen-up timesteps
            sum_weighted_p_rel = torch.sum(weighted_p_rel)

            # Add small epsilon to avoid log(0)
            epsilon = 1e-8
            sample_loss = -torch.log(torch.clamp(1 - sum_weighted_p_rel, min=epsilon))

            total_loss += sample_loss
            valid_samples += 1
        # Return average loss across valid samples
        return total_loss / valid_samples if valid_samples > 0 else torch.tensor(0.0, device=output.device)

    def training_step(self, batch, batch_idx):                                  # <-- This 'batch' is the RETURN VALUE from collate_fn
        features, labels, feature_lens, label_lens = batch                      # Every batch contains 4 tensors

        output = self.forward(features)                                         # Train on all padded tensors of x
        loss_constraint = self.constraint_loss(output, features, feature_lens)  # Constraint loss (do this before permuting dims)

        output = torch.permute(output, (1, 0, 2))                               # Permute dimensions for CTC Loss Calculation (B, Tmax, C) --> (Tmax, B, C)
        loss = self.criterion(output, labels, feature_lens, label_lens)         # pass in x_lens since the model returns T as the value of one of its dimentsions

        # Combine constraint loss with CTCLoss for total loss
        total_loss = loss + self.hyperparamater * loss_constraint

        # Logging
        self.log("train_loss_ctc", loss, on_step=True, on_epoch=True)
        self.log("train_loss_constraint", loss_constraint, on_step=True, on_epoch=True)
        self.log("train_loss_total", total_loss, on_step=True, on_epoch=True)

        return total_loss

    def validation_step(self, batch, batch_idx):
        features, labels, feature_lens, label_lens = batch

        # Calculate constraint loss before permuting dims
        output = self.forward(features)
        loss_constraint = self.constraint_loss(output, features, feature_lens)

        # Calculate CTCLoss
        output = torch.permute(output, (1, 0, 2))
        loss = self.criterion(output, labels, feature_lens, label_lens)

        # Combine constraint loss with CTCLoss for total loss
        total_loss = loss + self.hyperparamater * loss_constraint

        # Handling WER
        if (self.decoder is not None and self.vocab is not None):
            total_wer = 0
            total_sym_wer = 0
            total_rel_wer = 0

            # Count samples that actually have symbols/relations
            valid_samples = 0
            valid_sym_samples = 0
            valid_rel_samples = 0

            batch_size = output.size(1)

            for i in range(batch_size):
                # Get prediction for this sample
                real_timestep_length = feature_lens[i].item()
                sample_output = output[:real_timestep_length, i, :]         # Since dims have been swapped, indexing must also follow suit
                decoded_pred = self.decoder(sample_output)

                # Decode labels
                real_label_length = label_lens[i].item()
                sample_label = labels[i][:real_label_length].tolist()
                decoded_label = self.vocab.decode(sample_label)

                # Calculate overall WER (avoid division by zero)
                if len(decoded_label) > 0:
                    sample_wer = edit_distance(decoded_pred, decoded_label) / len(decoded_label)
                    total_wer += sample_wer
                    valid_samples += 1

                # Separate symbols and relations
                symbol_pred = [pred for pred in decoded_pred if pred not in self.RELATION_TOKENS]
                relation_pred = [pred for pred in decoded_pred if pred in self.RELATION_TOKENS]

                symbol_label = [label for label in decoded_label if label not in self.RELATION_TOKENS]
                relation_label = [label for label in decoded_label if label in self.RELATION_TOKENS]

                # Calculate symbol WER
                if len(symbol_label) > 0:
                    sym_sample_wer = edit_distance(symbol_pred, symbol_label) / len(symbol_label)
                    total_sym_wer += sym_sample_wer
                    valid_sym_samples += 1

                # Calculate relation WER
                if len(relation_label) > 0:
                    rel_sample_wer = edit_distance(relation_pred, relation_label) / len(relation_label)
                    total_rel_wer += rel_sample_wer
                    valid_rel_samples += 1

            # Calculate averages only over valid samples
            avg_wer = total_wer / valid_samples if valid_samples > 0 else 0.0
            avg_sym_wer = total_sym_wer / valid_sym_samples if valid_sym_samples > 0 else 0.0
            avg_rel_wer = total_rel_wer / valid_rel_samples if valid_rel_samples > 0 else 0.0

            # Log metrics
            self.log("val_wer", avg_wer, prog_bar=True, on_step=False, on_epoch=True)
            self.log("sym_wer", avg_sym_wer, prog_bar=True, on_step=False, on_epoch=True)
            self.log("rel_wer", avg_rel_wer, prog_bar=True, on_step=False, on_epoch=True)

        self.log("val_loss", total_loss, prog_bar=True, on_step=False, on_epoch=True)
        return total_loss

    def test_step(self, batch, batch_idx):
        features, labels, feature_lens, label_lens = batch

        # Calculate constraint loss before permuting dims
        output = self.forward(features)
        loss_constraint = self.constraint_loss(output, features, feature_lens)

        # Calculate CTCLoss
        output = torch.permute(output, (1, 0, 2))
        loss = self.criterion(output, labels, feature_lens, label_lens)

        # Combine constraint loss with CTCLoss for total loss
        total_loss = loss + self.hyperparamater * loss_constraint

        self.log("test_loss", total_loss, prog_bar=True, on_step=False, on_epoch=True)
        return total_loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)

## Task 7: Train the Model with PyTorch Lightning

Read more about Training with PyTorch Lightning [here](https://lightning.ai/docs/pytorch/stable/common/trainer.html) and understand how to use the Trainer class to train the model.

From the documentation:

> The Lightning Trainer does much more than just “training”. Under the hood, it handles all loop details for you, some examples include:
> - Automatically enabling/disabling grads
> - Running the training, validation and test dataloaders
> - Calling the Callbacks at the appropriate times
> - Putting batches and computations on the correct devices

> **IMPORTANT**: You must config your WandB logger to log the training and validation metrics. Without this, your work will not be graded.

In [None]:
student_id = "10423043"  # TODO: replace with your student ID
api_key = os.environ.get("WANDB_API_KEY", "53874cefa1af47124bf4e9a6f98ada8e56fd204e")  # configure your wandb key here

if api_key == "":
    raise ValueError("Please set your wandb key in the code or in the environment variable WANDB_API_KEY")
else:
    print("WandB API key is set. Proceeding with login...")

wandb.login(key=api_key)

WandB API key is set. Proceeding with login...


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33m10423043[0m ([33m10423043-vietnamese-german-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
wandb_logger = WandbLogger(
    entity="cuong-nt-vgu-ai-2025",  # DO NOT CHANGE THIS
    project="math_online_2025", # DO NOT CHANGE THIS
    name=f"{student_id}_run_1",
    config={
        "student_id": student_id,  # DO NOT CHANGE THIS
        "model": "LSTM_TemporalClassification",
        "lr": 0.001,
        "batch_size": 32,
        "hidden_size" : 256,
        "num_layers" : 2
    },
    log_model=True,
    save_dir="wandb_logs",
)

trainer = Trainer(
    callbacks = [
        LearningRateMonitor(logging_interval='step'),
        ModelCheckpoint(filename='{epoch}-{val_wer:.4f}', save_top_k=5, monitor='val_wer', mode='min'),
    ],
    logger = wandb_logger,
    check_val_every_n_epoch=1,
    limit_train_batches=2,  # Only run 2 batches for testing
    limit_val_batches=1,    # Only run 1 validation batch
    fast_dev_run=True,      # enable for testing model
    default_root_dir='checkpoint',
    deterministic=False,
    max_epochs=50,
    log_every_n_steps=20,
    devices = "auto",
)

# model = MathOnlineModel(decoder=GreedyCTCDecoder(vocab), vocab=vocab)
model = MathOnlineModel.load_from_checkpoint(r"wandb_logs\math_online_2025\ll1zayhv\checkpoints\epoch=26-val_wer=0.1401.ckpt")

# Set decoder and vocab after loading
model.decoder = GreedyCTCDecoder(vocab)
model.vocab = Vocab('vocab.json')



dm = InkmlDataset_PL(train_data=f"{ROOT_DIR}/crohme2019_train.txt",
                     val_data=f"{ROOT_DIR}/crohme2019_valid.txt",
                     test_data=f"{ROOT_DIR}/crohme2019_test.txt",
                     vocab_file='vocab.json',
                     batch_size=32,
                     workers=0
                     )

trainer.fit(model, dm)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.


FileNotFoundError: [Errno 2] No such file or directory: '/content/wandb_logs\\math_online_2025\\ll1zayhv\\checkpoints\\epoch=26-val_wer=0.1401.ckpt'

**Requirements:**
- Validation loss (val_loss < 0.7)

**Tips on training:**
- Debug with **fast_dev_run**:
Before run the training process, it is better to make a quick check of training and validation loop. Set fast_dev_run=True, then run the trainer and check if there is any bug exist.
- Training initially with **small batch size**:
In practice, training with CTC loss converges slowly for large batch size. To accelerate convergence, training with small batch size first, save model, then, train with large batch size.
- Train more epochs by setting:

```python
trainer = Trainer(
    ...
    max_epochs=20,
    ...
)
```

## Task 8: Test your model

Run the test set and check the accuracy of your model. The test set is used to evaluate the performance of the model on unseen data.

In [None]:
trainer = Trainer(
    devices=1,
)

# Load the model from a checkpoint
model = MathOnlineModel.load_from_checkpoint(
    r"wandb_logs\math_online_2025\ll1zayhv\checkpoints\epoch=26-val_wer=0.1401.ckpt",
)

# Set decoder and vocab after loading
model.decoder = GreedyCTCDecoder(vocab)
model.vocab = Vocab("vocab.json")

# Initialize the data module
dm = InkmlDataset_PL(train_data=f"{ROOT_DIR}/crohme2019_train.txt",
                     val_data=f"{ROOT_DIR}/crohme2019_valid.txt",
                     test_data=f"{ROOT_DIR}/crohme2019_test.txt",
                     vocab_file='vocab.json',
                     batch_size=32,
                     workers=0
                     )

# Test the model
trainer.test(model, datamodule=dm)

## Task 9: Inference

The output of networks need to process by a decoding step.

- Greedy decode: Your task is to implement greedy decoding method. Which converts the output into a string of symbols and relations (same form with labels). Greedy decoder produce the best path by removing consecutive repeated symbols/relations and then remove \<blank\>

> **TASK**: Implement the greedy decoder for the model output.

Based on GreedyCTCDecoder class from this [link](https://pytorch.org/audio/main/tutorials/asr_inference_with_ctc_decoder_tutorial.html#greedy-decoder), write decoding for an output. Your implementation should handle the output of the model and convert it into a string of symbols and relations, ensuring to remove consecutive repeated symbols/relations and the blank token.

In [None]:
class GreedyCTCDecoder(torch.nn.Module):
    def __init__(self, vocab):
        super().__init__()
        self.vocab = [char for char in vocab.char2idx.keys()]     # Grab list of vocabs (the keys of the dict)
        self.blank = int(vocab.char2idx[""])                      # Blank is the first elem of dict

    def forward(self, emission: torch.Tensor) -> list:
        """Given a sequence emission over labels, get the best path
        Args:
          emission (Tensor): Logit tensors. Shape `[seq_len, num_label]`.

        Returns:
          list: The resulting transcript
        """
        # TODO: implement the decoder (code from documentation)
        indices = torch.argmax(emission, dim=-1)
        indices = torch.unique_consecutive(indices, dim=-1)
        indices = [i for i in indices if i != self.blank]
        joined = " ".join([self.vocab[i] for i in indices])
        output_seq_list = joined.replace("|", " ").strip().split()
        return output_seq_list

Then you can test the output of your model here

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = MathOnlineModel.load_from_checkpoint(r"wandb_logs\math_online_2025\ll1zayhv\checkpoints\epoch=26-val_wer=0.1401.ckpt")
model.to(device)
model.eval()

# Load vocabulary and decoder
vocab = Vocab('vocab.json')
greedy_decoder = GreedyCTCDecoder(vocab)

dataset = InkmlDataset(f"{ROOT_DIR}/crohme2019_test.txt", ROOT_DIR, vocab)
feature, label, input_len, label_len = dataset.__getitem__(0)

# Extract some important data
feature = feature.to(device)
decoded_labels = vocab.decode(label.tolist())

with torch.no_grad():
    emission = model(feature.unsqueeze(0))

# Decode the model output
decoded_output = greedy_decoder.forward(emission[0])  # Pass the logits
print(decoded_output)

# possible output if your training work well
# decoded -> ['\\phi', 'Right', '(', 'Right', '0', 'Right', '(', 'Right', 'n', 'Right', ')', 'Right', ')']


## Task 10: Implement calculation metric for training

In [None]:
def edit_distance(pred_seq: list, label_seq: list):
    # TODO: implement Token Edit distance
    m = len(pred_seq)
    n = len(label_seq)

    # Create a dp table to store results of subproblems
    dp = [[0] * (n + 1) for _ in range(m + 1)]

    # Base cases
    for i in range(m + 1):
        dp[i][0] = i
    for j in range(n + 1):
        dp[0][j] = j

    # Fill the rest of dp[][]
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if pred_seq[i - 1] == label_seq[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
            else:
                dp[i][j] = 1 + min(dp[i][j - 1], dp[i - 1][j], dp[i - 1][j - 1])
    return dp[m][n]


wer = edit_distance(decoded_output, decoded_labels) / label_len.item()

print(wer)
# Test your implementation
assert (
    edit_distance(
        [
            "\\phi",
            "Right",
            "(",
            "Right",
            "0",
            "Right",
            "(",
            "Right",
            "n",
            "Right",
            ")",
            "Right",
            ")",
        ],
        [
            "\\phi",
            "Right",
            "(",
            "Right",
            "\\phi",
            "Right",
            "(",
            "Right",
            "n",
            "Right",
            ")",
            "Right",
            ")",
        ],
    )
    == 1
)

> **TASK**: Implement word error rate metric (wer) for training,  validation and testing in your model.
>
> $wer = \frac{total\ edit\ distance (predict\ sequence,\ target\ sequence)}{total\ target\ sequence\ length}$


**Steps to process:**

- Decode the predicted sequences to obtain the text output.
- Calculate the total edit distance between the predicted and target sequences.
- Compute the word error rate using the formula provided.
- Log the WER metric during training and validation.

**Continue to train the model**

- Load the latest trained model:

```
model = MathOnlineModel.load_from_checkpoint('path/to/your/checkpoint.ckpt')
```
- Change config of ModelCheckPoint to monitor the new metric (val_wer) instead.

```
        ModelCheckpoint(filename='{epoch}-{val_loss:.4f}', save_top_k=5, monitor='val_loss', mode='min')
        
      --> ModelCheckpoint(filename='{epoch}-{val_wer:.4f}', save_top_k=5, monitor='val_wer', mode='min'),

```


In [None]:
'''BONUS 3: Visualize model prediction by timesteps and probability of softmax outputs'''
# Prioritize GPU over CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model from checkpoint
model = MathOnlineModel.load_from_checkpoint(
    r"wandb_logs\math_online_2025\ll1zayhv\checkpoints\epoch=26-val_wer=0.1401.ckpt",
)
model.to(device)
model.eval()

# Prepare vocab and dataset
vocab = Vocab("vocab.json")
dataset = InkmlDataset(f"{ROOT_DIR}/crohme2019_test.txt", ROOT_DIR, vocab)

def plot_probs_over_time(sample_idx: int, tokens: list[str]):
    """
    Plot the softmax probabilities for each token in `tokens` over the time dimension
    for the example at index `sample_idx`.
    """
    # Extract features and move to device
    feature, _, _, _ = dataset[sample_idx]
    x = feature.unsqueeze(0).to(device)  # Add new batch dimension with unsqueeze, now of shape (1, T, 4)

    # Forward pass: get log-probs then convert to probabilities
    with torch.no_grad():
        probs = torch.softmax(model(x).squeeze(0), dim=-1).cpu().numpy()  # Remove batch dimension with squeeze (T, C). dim=-1 to target C (vocabulary)

    # Plotting
    timesteps = np.arange(probs.shape[0])          # Numpy array from 0 to (T-1)
    plt.figure(figsize=(12, 6))

    # Encode tokens into idx
    encoded_tokens = vocab.encode(tokens)

    # Plot
    for idx in range(len(tokens)):
        plt.plot(timesteps, probs[:, encoded_tokens[idx]], label=tokens[idx])

    plt.xlabel("Time Step")
    plt.ylabel("Probability")
    plt.title(f"Softmax Probabilities over Time (Sample {sample_idx})")
    plt.legend(loc="upper right")
    plt.tight_layout()
    plt.show()

# Example: visualize for tokens '0', '1', '\sqrt', 'Sup' on sample #10
plot_probs_over_time(0, ['2', '\\sqrt', 'Right'])

## **BONUS TASKS**

The following tasks are optional and can be done for extra credit to the final exam.

### **Bonus 1**: Add metric to evaluate accuracy of symbols and relations seperately *(+0.5pt)*

**Add the metrics: wer for symbols and wer for relations**
```
self.log('wer_sym',...)
self.log('wer_rel',...)
```

### **Bonus 2**: Modify loss function to constraint the output of relations at the time step of pen-up *(+1pt)*
Modify loss function to constrain the output of relations at the timestep of pen-up.

The idea: Provide a masked sequence such that the position of pen-up is masked, the additional loss would penaltize all the relations output to the timestep that has been masked.


**Loss function to constraint relation output**

```
pen_down: a masked sequence, where len(pen_down) = len(input_sequence)
        and pen_down[t] == 1 (pen down), pen_down[t] == 0 (pen up)
p_rel: total probability of relation outputs for every time step t, len(p_rel) = len(input_sequence), p_rel[t] = sum(p[t][rel] for rel in ['Sub', 'Sup', 'Above', ...])

The loss is defined as:

    L_{constraint} = -log(1 - sum(p_rel * pen_down))
    
Explanation:

    Minimize L_{constraint} (value range of (-inf, 0)) would make maximize of (1 - sum(p_rel * pen_down)), or making sum(p_rel * mask) -> 0, then it will penalize the relation output on pen_down timesteps.
```

**Apply constraint loss function with ctc loss**

```
    L = loss_{ctc} + \lambda * L_{constraint}
    
where \lambda is a weighted parameter to control balance between the two losses.
   
```


You can also find it in a published paper here, section 3.4:
https://arxiv.org/pdf/2105.10156

### **Bonus 3**: Visualize model prediction by timesteps and probability of softmax outputs *(+0.5pt)*

![alt text](https://github.com/fuisl/crohme-ctc/blob/6722b97ec2000afcd16068220f0b1b83b3134ff8/assets/graph_with_prob.png?raw=true)

### **Bonus 4**: Use CUDA CTC Decoder to optimize decoding process in model *(+0.5pt)*

Follow the instructions via this [link](https://pytorch.org/audio/main/tutorials/asr_inference_with_ctc_decoder_tutorial.html#cuda-ctc-decoder) to implement CUDA CTC decoder in your model.