# Kanji Recognition with CNN
**Author:** Thomas K/BIDI  
**Description:** This notebook aims to train a model for Kanji classification on the 500 most frequently used Kanji using CNN and TripletLoss.

In [None]:
import pandas as pd
import numpy as np
import os
import torch
import torchvision
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision.io import decode_image, read_file
from torchvision import transforms
import timm
import json
from PIL import Image

## Kanji Data Loading

We're going to use the dataset from [davidluzgouveia/kanji-data](https://github.com/davidluzgouveia/kanji-data).  
From it, we will extract the 500 most frequent kanjis, in order, and their furiganas and meanings, these kanji will be used to train our model.

In [None]:
with open("kanji.json", "r", encoding="utf-8") as f:
    kanji_json = json.load(f)

kanji500 = [i for i in range(500)]

In [None]:
for (key, value) in kanji_json.items():
    if (value["freq"] != None and value["freq"] <= 500):
        kanji500[value["freq"]-1] = key

In [None]:
rows = []
for kanji in kanji500:
    rows.append({
        "Kanji": kanji,
        "Furigana": " ".join(kanji_json[kanji]["readings_kun"]),
        "Meaning": " | ".join(kanji_json[kanji]["meanings"])
    })
df = pd.DataFrame(rows)
df

## Triplet Creation

Let's make a small introduction to **Triplet Loss**.  
The **Triplet Loss** is used in *metric learning* and mostly for *one-shot learning*, our objective is to create *embeddings* of handwritten kanji images where, in the *latent* space, we have :
 - Similar images which are close to each other
 - Different images which are far from each other

This is where the **Triplet Loss** is useful, we will consider this 3 elements :
 - **Anchor (A)** : a reference element
 - **Positive (P)** : an example of the same class as the Anchor
 - **Negative (N)** : an example of another class

We use the network to obtain our embeddings.  
For an input $x$ we will call $f(x) \in \mathbb{R}^d$ its embedding.

We aim to establish the following inequality :  
$||f(A)-f(P)||^2 < ||f(A)-f(N)||^2$  

The objective is to make embedding of the same class closer than those from different classes.

We introduce $\alpha$, a margin which will represent the minimum distance between two classes.  
This lead us to the following formulation of the **Triplet Loss**:  
$\mathcal{L}(A,P,N)=max(0, ||f(A)-f(P)||^2 - ||f(A)-f(N)||^2 + \alpha)$  

Which will be minimized when $||f(A)-f(P)||^2 + \alpha \le ||f(A)-f(N)||^2$  

We will start by generating a certain number of triplet (A, P, N) for each of our 500 kanji, since $Anchor=Positive$ we will only work with the couple (AnchorPositive, Negative) in the future even if its mention as triplet.

In [None]:
def generate_triplet(df, kanji_column, nb_couple_by_char=5):
    """
    @brief Generate the couples (AnchorPositive, Negative) for the Triplet Loss.
    @details This function return a list of nb_couple_by_char couples by char as (AnchorPositive, Negative).
    @param df The dataframe.
    @param kanji_column The name of the column with the chars.
    @param nb_couple_by_char The number of "triplet" which will be create for each char.
    """
    triplets = []
    len_chars = len(df.loc[:, kanji_column])
    for i in range(len_chars):
        for j in range(nb_couple_by_char):
            rand_idx = i

            # Avoid having triplet of same values
            while df.loc[rand_idx, kanji_column] == df.loc[i, kanji_column]: 
                rand_idx = np.random.randint(0, len_chars)

            triplets.append((df.loc[i, kanji_column], df.loc[rand_idx, kanji_column]))
    return triplets

triplet = generate_triplet(df, "Kanji")

## Loading Handwritten Kanji images

In this work, we will use the ETL Character Database provided by the Electrotechnical Laboratory (ETL) from Japan. (ETL1->ETL9b)  
I cannot share this data with you directly, but you can go on their [website](etlcdb.db.aist.go.jp) and download it yourself.

In [None]:
def get_subfolders(path):
    """
    @brief Return a dict with every path's subfolders as keys and for each of them a list of its subfolders.
    @param path The path for the root directory in which your subfolders of ETLX are located.
    """
    folders = [name for name in os.listdir(path) if os.path.isdir(os.path.join(path, name))]
    subfolders = {}
    for folder in folders:
        path_to_folder = path + folder + "/"
        subfolders[folder] = [path_to_folder+name for name in os.listdir(path_to_folder) if os.path.isdir(os.path.join(path_to_folder, name))]
    return subfolders

subfolders = get_subfolders("datasets/")

In [None]:
def get_dictionary(subfolders):
    """
    @brief Load the dictionary of each subfolder.
    @details The return value is a dictionary with the subfolders' names as keys and a dataframe with the data of the subfolder as value.
    @param subfolders The path towars the subfolders.
    """
    dictionary = {}
    for _, subpacks in subfolders.items():
        for subfolder in subpacks:
            dictionary[subfolder] = pd.read_csv(subfolder+"/meta.csv")
    return dictionary
dictionary = get_dictionary(subfolders)

In [None]:
def get_kanji_paths(df, kanji_column, dictionary):
    """
    @brief Return a dictionary linking a kanji to the paths of all its handwritten images.
    @param df The dataframe with the Kanjis to search for.
    @param kanji_column The column with the kanjis.
    @param dictionary The dictionary with the information of each subfolders.
    """
    kanji_paths = {}
    for kanji in df[kanji_column]:
        for path, index_df in dictionary.items():
            # Check if the kanji exist in the dataframe
            kanji_exist = index_df["char"].isin([kanji]).any() 
            if kanji_exist:
                kanji_paths[kanji] = []
                max_size = len(str(index_df[-1::].index[0])) # Get the nb chars of the max len filename
                column = index_df.loc[(index_df["char"] == kanji)]
                for idx in column.index:
                    # We add the path towards this file in the kanji_paths conversion list
                    kanji_paths[kanji].append(f"{path}/{str(idx).zfill(max_size)}.png")
    
    return kanji_paths
kanji_paths = get_kanji_paths(df, "Kanji", dictionary)

## Triplet's custom dataset

We are creating a function to split the links to the images between training and test datasets, and we create a custom `TripletDataset` class to handle this specific usage of triplets.

In [None]:
def split_images_paths(images_paths, split_ratio=0.8, seed=None):
    """
    @brief Return the train and test images_path according to the given split_ratio and seed.
    @param images_paths The dictionary linking the char (kanji) to the paths of all its handwritten images.
    @param split_ratio The desired split_ratio for training and testing.
    @param seed Optional seed for shuffling and splitting. None means random.
    """
    rng = np.random.default_rng(seed)
    train_paths, test_paths = {}, {}

    for cls, paths in images_paths.items():
        n = len(paths)
        split_idx = int(n * split_ratio)
        paths = paths.copy()
        rng.shuffle(paths)

        train_paths[cls] = paths[:split_idx]
        test_paths[cls]  = paths[split_idx:]

    return train_paths, test_paths

class TripletDataset(Dataset):
    def __init__(self, triplet, images_paths, transform=None):
        # Triplets are represented as couple (AnchorPositive, Negative)
        self.triplet = triplet
        self.transform = transform
        self.images_paths = images_paths

    def __getitem__(self, index):
        anchor_char = positive_char = triplet[index][0]
        negative_char = triplet[index][1]

        # Load the images in GRAYSCALE
        anchor = decode_image(read_file(np.random.choice(self.images_paths[anchor_char])), torchvision.io.ImageReadMode.GRAY)
        positive = decode_image(read_file(np.random.choice(self.images_paths[positive_char])), torchvision.io.ImageReadMode.GRAY)
        negative = decode_image(read_file(np.random.choice(self.images_paths[negative_char])), torchvision.io.ImageReadMode.GRAY)
        
        # Apply the transforms
        if self.transform:
            anchor = self.transform(anchor)
            positive = self.transform(positive)
            negative = self.transform(negative)

        return anchor, positive, negative
    
    def __len__(self):
        return len(self.triplet)

## Function to train the model

In this part, we define a series of functions that will be used to train and test the model.

In [None]:
def get_data(batch_size, triplet, images_path, seed=None):
    """
    @brief Create DataLoaders for training, validation, and test using triplets and a given images_path split.
    @details Return the training, validation and testing loaders.
    @param batch_size The batch size for the DataLoaders.
    @param triplet The list of triplets to use (actually couples (AnchorPositive, Negative)).
    @param images_path The dictionary linking each class (kanji) to the paths of all its handwritten images.
    @param seed Optional seed for shuffling and splitting. None means random.
    """
    # Transform for the black background images
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ConvertImageDtype(torch.float32),
        transforms.RandomRotation(10),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    # 80% Training, 20% test data
    train_images, test_images = split_images_paths(images_path, 0.8, seed)
    
    # Instanciation of the TripleDataset
    training_data = TripletDataset(triplet, train_images, transform)
    test_data = TripletDataset(triplet, test_images, transform)

    num_train = len(training_data)
    # 20% for validation
    val_size = int(num_train * 0.2)
    train_size = num_train - val_size

    used_generator = torch.Generator().manual_seed(seed) if seed is not None else None

    # Get the training and validation data
    training_data, validation_data = torch.utils.data.random_split(training_data, [train_size, val_size], generator=used_generator)

    # Initialization of the dataloader
    train_loader = torch.utils.data.DataLoader(training_data, batch_size, shuffle=True, generator=used_generator)
    validation_loader = torch.utils.data.DataLoader(validation_data, batch_size, shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size, shuffle=False)

    return train_loader, validation_loader, test_loader

In [None]:
def get_optimizer(net, learning_rate, weight_decay):
    """
    @brief Create an Adam optimizer for the given network with specified learning rate and weight decay.
    @param net The neural network whose parameters will be optimized.
    @param learning_rate The learning rate for the optimizer.
    @param weight_decay The weight decay (L2 regularization) factor.
    """
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
    return optimizer

In [None]:
def get_cost_function(margin=0.2):
    """
    @brief Return the TripletMarginLoss cost function with the given margin.
    @details We use the euclidean distance (p=2).
    @param margin The margin parameter.
    """
    cost_function = torch.nn.TripletMarginLoss(margin=margin, p=2)
    return cost_function

In [None]:
def train(net, data_loader, optimizer, cost_function, device='cuda:0'):
    """
    @brief Train the network for one epoch using the provided DataLoader, optimizer, and loss function.
    @param net The neural network to train.
    @param data_loader DataLoader providing batches of triplets.
    @param optimizer The optimizer used to update the network parameters.
    @param cost_function The loss function used to compute the triplet loss.
    @param device The device on which to run the computations.
    """
    samples = 0.
    cumulative_loss = 0.
    correct = 0
    
    net.train() # Set the network to training mode

    for anchor, positive, negative in data_loader:
        anchor = anchor.to(device)
        positive = positive.to(device)
        negative = negative.to(device)

        # Get the embeddings of each images of the triplet
        emb_a = net(anchor)
        emb_p = net(positive)
        emb_n = net(negative)

        # Use them to compute the loss
        loss = cost_function(emb_a, emb_p, emb_n)

        # Backpropagation phase
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Triplet accuracy
            # We compute the euclidean distance between both Anchor-Positive and Anchor-Negative
            # If the Anchor is closer to the Positive, we consider that the model did a good classification
            # We'll use that score to compute accuracy
        dist_pos = (emb_a - emb_p).pow(2).sum(dim=1)
        dist_neg = (emb_a - emb_n).pow(2).sum(dim=1)
        correct += (dist_pos < dist_neg).sum().item()

        # Compute the weighted average loss per batch and add it to the cumulative loss
        batch_size = anchor.size(0)
        samples += batch_size
        cumulative_loss += loss.item() * batch_size

    # Return the mean loss and the accuracy
    mean_loss = cumulative_loss / samples
    accuracy = correct / samples
    return mean_loss, accuracy

In [None]:
def test(net, data_loader, cost_function, device='cuda:0'):
    """
    @brief Evaluate the network for one epoch using the provided DataLoader and loss function.
    @param net The neural network to evaluate.
    @param data_loader DataLoader providing batches of triplets.
    @param cost_function The loss function used to compute the triplet loss.
    @param device The device on which to run the computations.
    """
    samples = 0.
    cumulative_loss = 0.
    correct = 0

    net.eval()  # Set the network to evaluation mode

    # Disable gradient computation
    with torch.no_grad():
        for anchor, positive, negative in data_loader:
            anchor = anchor.to(device)
            positive = positive.to(device)
            negative = negative.to(device)

            # Get embeddings
            emb_a = net(anchor) # shape [batch_size, embedding_dim]
            emb_p = net(positive) # shape [batch_size, embedding_dim]
            emb_n = net(negative) # shape [batch_size, embedding_dim]

            # Compute the loss
            loss = cost_function(emb_a, emb_p, emb_n)

            # Triplet accuracy
                # We compute the euclidean distance between both Anchor-Positive and Anchor-Negative
                # If the Anchor is closer to the Positive, we consider that the model did a good classification
                # We'll use that score to compute accuracy
            dist_pos = (emb_a - emb_p).pow(2).sum(dim=1)
            dist_neg = (emb_a - emb_n).pow(2).sum(dim=1)
            correct += (dist_pos < dist_neg).sum().item()

            # Compute the weighted average loss per batch and add it to the cumulative loss
            batch_size = anchor.size(0)
            samples += batch_size
            cumulative_loss += loss.item() * batch_size

    # Return the mean loss and the accuracy
    mean_loss = cumulative_loss / samples
    accuracy = correct / samples
    return mean_loss, accuracy

The **accuracy** is just *informative*, it is not a good metric here.  
Since we consider that when the inequality $||f(A)-f(P)||^2 < ||f(A)-f(N)||^2$ holds, the model has made a correct classification and this is not strictly true.  
However, it gives us an idea about the validity of this inequality across the epochs.

### EfficientNet as the backbone

We will use [EfficientNet](https://en.wikipedia.org/wiki/EfficientNet) as the backbone of our model by adding a new final layer to produce 128-dimensional embeddings, and by modifying the input to accept GRAYSCALE images.  
You can choose to set the *pretrained* parameter to *True* or *False*, depending on whether you want to use the EfficientNet pretrained weights.

In [None]:
class EfficientNetEmbedding(torch.nn.Module):
    def __init__(self, model_name='tf_efficientnetv2_s', embedding_dim=128):
        super().__init__()

        # Load the model without the last layer
        self.backbone = timm.create_model(model_name, pretrained=True, num_classes=0)

        # Modification of the first layer to accept GRAYSCALE images as input
        first_conv = self.backbone.conv_stem
        self.backbone.conv_stem = torch.nn.Conv2d(
            in_channels=1,
            out_channels=first_conv.out_channels,
            kernel_size=first_conv.kernel_size,
            stride=first_conv.stride,
            padding=first_conv.padding,
            bias=first_conv.bias is not None
        )
        # Apply a mean of the 3 canals weights for the new 1 canal layer we got
        self.backbone.conv_stem.weight = torch.nn.Parameter(first_conv.weight.mean(dim=1, keepdim=True))

        # Addition of an embedding layer as the last layer
        self.embedding = torch.nn.Linear(self.backbone.num_features, embedding_dim)

    def forward(self, x):
        # We use the backbone
        x = self.backbone(x)
        # L2 regularization of the embedding
        x = F.normalize(self.embedding(x), p=2, dim=1)
        return x

### Training

Here, we have the actual training of the model.

In [None]:
def main(batch_size=128,
         device=('cuda:0' if torch.cuda.is_available() else 'cpu'),
         learning_rate=0.001,
         weight_decay=1e-6,
         epochs=50,
         save_interval=50):
    """
    @brief Train and evaluate EfficientNetEmbedding using triplet loss.
    @param batch_size Batch size for DataLoaders.
    @param device Device to use ('cuda' or 'cpu').
    @param learning_rate Learning rate for the optimizer.
    @param weight_decay Weight decay for the optimizer.
    @param epochs Number of training epochs.
    @param save_interval Save the model every N epochs. The name will be net_{current_epoch}.pth
    """
    # Load datasets
    train_loader, val_loader, test_loader = get_data(batch_size, triplet, kanji_paths)

    # Initialize network, optimizer, and loss function
    net = EfficientNetEmbedding().to(device)
    optimizer = get_optimizer(net, learning_rate, weight_decay)
    cost_function = get_cost_function(margin=0.2)

    # Evaluation before training
    train_loss, train_acc = test(net, train_loader, cost_function, device)
    val_loss, val_acc = test(net, val_loader, cost_function, device)
    test_loss, test_acc = test(net, test_loader, cost_function, device)
    
    # Print the model's evaluation metrics, before training
    print("Before training :")
    print(f"\t Training loss : {train_loss:.5f} | Training triplet accuracy : {train_acc:.3f}")
    print(f"\t Validation loss {val_loss:.5f} | Validation triplet accuracy : {val_acc:.3f}")
    print(f"\t Test loss {test_loss:.5f} | Test triplet accuracy : {test_acc:.3f}")
    print("-"*50)

    # Training loop
    for e in range(epochs):
        # Save every save_interval epochs
        if ((e + 1) % save_interval == 0):
            torch.save(net.state_dict(), f"net_{e+1}.pth")

        # Train one epoch
        train_loss, train_acc = train(net, train_loader, optimizer, cost_function, device)

        # Evaluate on validation set with triplet accuracy
        val_loss, val_acc = test(net, val_loader, cost_function, device)

        # Print epoch summary
        print(f"Epoch {e+1}/{epochs} | "
            f"Train loss: {train_loss:.5f} | "
            f"Train triplet accuracy: {train_acc:.3f} | "
            f"Validation loss: {val_loss:.5f} | "
            f"Validation triplet accuracy: {val_acc:.3f}")
        print("-"*50)


    # Final evaluation after training
    train_loss, train_acc = test(net, train_loader, cost_function, device)
    val_loss, val_acc = test(net, val_loader, cost_function, device)
    test_loss, test_acc = test(net, test_loader, cost_function, device)

    # Print the model's evaluation metrics, after training
    print("After training :")
    print(f"\t Training loss : {train_loss:.5f} | Training triplet accuracy : {train_acc:.3f}")
    print(f"\t Validation loss {val_loss:.5f} | Validation triplet accuracy : {val_acc:.3f}")
    print(f"\t Test loss {test_loss:.5f} | Test triplet accuracy : {test_acc:.3f}")
    print("-"*50)

    return net

In [None]:
net = main()

With my RTX 2070, I run 50 epochs in **~390min (~6h30)**.  
For a *small number of kanji*, you could reduce the size of the input images in the transform, a lower resolution will *decrease training time*.  
I have already trained a model on 75 kanjis using a resolution of *64x64*, and I used it to predict the 10 most probable kanji from a drawing.  
However, for a large amount of complex kanji, this resolution is not enough.

### Utilization functions

These are functions to utilize the model in real-world use cases, you can transform the common white background images into a compatible format with the model, obtain their embeddings and compute the cosine similarity with a set of reference vectors (which consists of embeddings of references images for each kanjis).

In [None]:
# Transform for the white background images
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: 1.0 - x),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

In [None]:
def load_image(path, transform, device):
    """
    @brief Load an image, apply the transform on it and return its tensor.
    @param path The image's path, supposed valid.
    @param transform The transform to apply to the images.
    @param device The device on which the tensor will be loaded.
    """
    image = Image.open(path).convert('L')
    # shape [1, 1, H, W]
    tensor = transform(image).unsqueeze(0).to(device)
    return tensor

In [None]:
def get_embedding(net, image_tensor):
    """
    @brief Return the embedding of a tensor according the given net.
    @param net The net use for embedding.
    @param image_tensor An image's tensor.
    """
    net.eval()
    with torch.no_grad():
        embedding = net(image_tensor)
    return embedding

In [None]:
def cosine_sim(embedding, reference_vectors):
    """
    @brief Compute the cosine similarity between the given embedding and all reference vectors.
    @param embedding The image embedding (tensor of shape [1, embedding_dim]).
    @param reference_vectors Tensor containing reference embeddings (shape [num_refs, embedding_dim]).
    """
    embedding = F.normalize(embedding, p=2, dim=1)
    reference_vector = F.normalize(reference_vectors, p=2, dim=1)

    sims = torch.mm(embedding, reference_vector.T)
    return sims.squeeze(0)