In [1]:
import torch
import hub
import random
import math
import PIL
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
from torchinfo import summary
from torchvision import transforms
from torchvision.io import read_image
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import pandas as pd
import h5py
import numpy as np

from typing import Callable
import csv
import copy
import time
import json
import pathlib
import os
from os import listdir
from os.path import isfile, join

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
image_size = 64
nb_symbols = 2199

from kanji_detection_model import kanji_detector

    
def getModel():
    return kanji_detector()


def testModel():
    modelRunnable = getModel().to(device=device)
    print(modelRunnable)
    
    summary1 = summary(
        modelRunnable,
        input_size=[
            (20, 1, image_size, image_size)
        ],
        dtypes=[torch.double, torch.double],
        depth=3
    )
    
    print(summary1)
    
    del modelRunnable
    torch.cuda.empty_cache()



In [3]:
timeSelectionDataset = 0
class KanjiRandomImageCustomLoader():
    def __init__(self, hdf5_file_name: str, batch_size: int, isTraining:bool, transform=None, target_transform=None):
        self.f = h5py.File(hdf5_file_name, 'r')
        self.batch_size = batch_size
        
        self.group = 'training_group' if isTraining else 'evaluation_group'
        self.dataset = self.f[self.group]['dataset']
        self.labels = self.f[self.group]['labels']
        
        self.index_list = [i for i in range(self.dataset.shape[0])]
        random.shuffle(self.index_list)
        
        self.transform = transform
        self.target_transform = target_transform
        self.selector_index = 0

    def getNextBatch(self):
        global timeSelectionDataset
        
        true_list_len = len(self.index_list)
        selected_index = -1
        list_len = true_list_len - self.selector_index
        
        selected_indices = []
        while len(selected_indices) < self.batch_size:
            list_len = true_list_len - self.selector_index
            n_to_find = self.batch_size - len(selected_indices)
            start = self.selector_index

            if list_len >= n_to_find:
                end = self.selector_index + n_to_find
                selected_indices.extend(self.index_list[start:end])
                self.selector_index += n_to_find
            else:
                selected_indices.extend(self.index_list[start:])
                random.shuffle(self.index_list)
                self.selector_index = 0
    
        #print(str(batch[0]) +" "+ str(batch[99]))
        #print(batch_size)
        #print(len(batch))
        
        labels = torch.LongTensor([self.labels[index,0] for index in selected_indices])-1
        
        start=time.time() #debug
        images = np.array([self.dataset[index] for index in selected_indices])
        end=time.time() #debug
        timeSelectionDataset+=end-start  #debug
        
        #selected_indices.sort()
        #np_indices = np.array(selected_indices)
        #images = self.dataset[np_indices]
        
        
        images = (torch.as_tensor(images,dtype=torch.float)/255).unsqueeze(1)
        
        
        if self.transform:
            images = self.transform(images)
        if self.target_transform:
            labels = self.target_transform(labels)
            
        return images, labels


In [4]:
def countCorrect(answer: torch.FloatTensor, correctAnswerIndices: torch.FloatTensor):
    
    _,indicesAnswer = torch.max(answer, dim=1)
    
    #print(indicesAnswer)
    #print(indicesCorrect)
    numCorrect = (indicesAnswer == correctAnswerIndices).long().sum()
    
    return numCorrect.item()

def countTop5Correct(answer: torch.FloatTensor, correctAnswerIndices: torch.FloatTensor):
    _,indicesAnswer = answer.topk(k=5, dim=1)
    
    #print(indicesAnswer.shape)
    #print(correctAnswerIndices.shape)
    numCorrect = (indicesAnswer == correctAnswerIndices.unsqueeze(-1)).long().sum()
    
    return numCorrect.item()

In [5]:

def train(model, n_epoch, batch_size, lr: Callable[[int], float]):
    
    global timeSelectionDataset
    
    n_batches = 100
    optimizer = torch.optim.Adam(model.parameters(), lr=lr(0))
    
    custom_loader_train = KanjiRandomImageCustomLoader("image_set.hdf5", batch_size=batch_size, isTraining=True)
    custom_loader_eval = KanjiRandomImageCustomLoader("image_set.hdf5", batch_size=batch_size, isTraining=False)
    
    frequency_detailed_results = 5
    
    model.train()
    loss_f = torch.nn.CrossEntropyLoss()
    best_percent = 0
    for epoch in range(n_epoch):
        
        n_correct_1_t = 0
        n_correct_5_t = 0
        n_correct_1_e = 0
        n_correct_5_e = 0
        
        n_total = n_batches*batch_size
        t_loss = 0
        
        for g in optimizer.param_groups:
            g['lr'] = lr(epoch)
        #model.zero_grad()
            
        time_select = 0 #debug
        time_model = 0 #debug
        timeSelectionDataset = 0 # debug
        
        start_epoch = time.time()
            
        print("Epoch " + str(epoch+1))
        
        for i in range(n_batches):
            model.train()
            model.zero_grad()
            #optimizer.zero_grad() not needed ?
            
            start = time.time() #debug
            images, labels = custom_loader_train.getNextBatch()
            end = time.time() #debug
            time_select += end-start # debug
            
            start = time.time() #debug
            answer = model(images.to(device=device))
            end = time.time() #debug
            time_model += end-start # debug
            
            loss = loss_f(answer,labels.to(device=device)).cpu()
            t_loss += loss.item()
            
            loss.backward()
            optimizer.step() #Trying at the end of the epoch ?
            
            
            if (epoch+1) % frequency_detailed_results == 0:
                model.eval()
                images_eval, labels_eval = custom_loader_eval.getNextBatch()
                answer_eval = model(images_eval.to(device=device))
                
                n_correct_1_t += countCorrect(answer, labels.to(device=device))
                n_correct_5_t += countTop5Correct(answer, labels.to(device=device))
                n_correct_1_e += countCorrect(answer_eval, labels_eval.to(device=device))
                n_correct_5_e += countTop5Correct(answer_eval, labels_eval.to(device=device))
        
        adjust = 100
        percent_1_t = math.floor(adjust*100*n_correct_1_t/n_total)/adjust
        percent_5_t = math.floor(adjust*100*n_correct_5_t/n_total)/adjust
        percent_1_e = math.floor(adjust*100*n_correct_1_e/n_total)/adjust
        percent_5_e = math.floor(adjust*100*n_correct_5_e/n_total)/adjust
        
        display_loss = math.floor(adjust*t_loss)/adjust
        
        best_percent = percent_5_e if percent_5_e > best_percent else best_percent
        
        end_epoch = time.time()
        time_epoch = end_epoch-start_epoch
        
        
        timeSelection=0
        print("Time epoch : " + str(math.floor(time_epoch*adjust)/adjust) + "s") #debug
        
        #print("\tTime select : " + str(math.floor(time_select*adjust)/adjust)) #debug
        #print("\t - Time select (inside) : " + str(math.floor(timeSelectionDataset*adjust)/adjust)) #debug
        #print("\tTime model : " + str(math.floor(time_model*adjust)/adjust)) #debug
        
        
        timer1 = 0
        timer2 = 0
        timer3 = 0
        
        print("\tLoss : " + str(display_loss))
        
        if (epoch+1) % frequency_detailed_results == 0:
            print("\tTop-1 training accuracy : " + str(percent_1_t) + "%")
            print("\tTop-5 training accuracy : " + str(percent_5_t) + "%")
            print("\tTop-1 evaluation accuracy : " + str(percent_1_e) + "%")
            print("\tTop-5 evaluation accuracy : " + str(percent_5_e) + "%")
        
        if percent_5_e > 99.5:
            break
        
        print("")
        
    return best_percent


def weights_init(m):
    if isinstance(m, torch.nn.Conv2d):
        m.weight.data.normal_(0, 0.02)
        m.bias.data.normal_(0, 0.001)
    
    if isinstance(m, torch.nn.Linear):
        m.weight.data.normal_(0, 0.02)
        m.bias.data.normal_(0, 0.001)
        
        

In [6]:
batch_sizes = [100] #[25, 50, 100, 150, 200]
learning_rates = [0.00001] #[0.001, 0.005, 0.01, 0.05, 0.1, 0.5]

n_epochs = 50

In [7]:
trainModel = getModel().to(device=device)
weights_init(trainModel)
n_epochs = 1000

#lr: Callable[[int], float] = lambda epoch: 0.0003
#lr: Callable[[int], float] = lambda epoch: 0.0005/(1.002**epoch)
lr: Callable[[int], float] = lambda epoch: 0.0003/(1.002**epoch)

print("Running on " + device + "\n")
#train(trainModel, n_epochs, batch_sizes[0], learning_rates[0])
train(trainModel, n_epochs, 100, lr)
#Best in 1.56s/epoch

Running on cuda

Epoch 1
Time epoch : 4.3s
	Loss : 769.72

Epoch 2
Time epoch : 1.05s
	Loss : 769.61

Epoch 3
Time epoch : 1.04s
	Loss : 769.6

Epoch 4
Time epoch : 1.48s
	Loss : 769.67

Epoch 5
Time epoch : 2.52s
	Loss : 769.66
	Top-1 training accuracy : 0.07%
	Top-5 training accuracy : 0.23%
	Top-1 evaluation accuracy : 0.03%
	Top-5 evaluation accuracy : 0.15%

Epoch 6
Time epoch : 1.47s
	Loss : 769.65

Epoch 7
Time epoch : 1.5s
	Loss : 769.7

Epoch 8
Time epoch : 1.47s
	Loss : 769.46

Epoch 9
Time epoch : 1.48s
	Loss : 769.48

Epoch 10
Time epoch : 2.57s
	Loss : 766.09
	Top-1 training accuracy : 0.12%
	Top-5 training accuracy : 0.45%
	Top-1 evaluation accuracy : 0.06%
	Top-5 evaluation accuracy : 0.38%

Epoch 11
Time epoch : 1.48s
	Loss : 751.52

Epoch 12
Time epoch : 1.55s
	Loss : 732.14

Epoch 13
Time epoch : 1.51s
	Loss : 710.88

Epoch 14
Time epoch : 1.56s
	Loss : 692.22

Epoch 15
Time epoch : 2.57s
	Loss : 674.65
	Top-1 training accuracy : 0.82%
	Top-5 training accuracy : 4.09%

96.45

In [8]:
trainModel.eval()

torch.save(trainModel.cpu(),"./Models/kanji_model_v7_top5_96_eval.pt")

torch.save(trainModel.cpu().state_dict(), "./Models/kanji_model_v7_top5_96_eval.pth")

#temp = torch.jit.script(trainModel.cpu())
#torch.jit.save(temp, "./Models/kanji_model_96_1.pt")
