In [1]:
import torch
import hub
import random
import math
import PIL
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
from torchinfo import summary
from torchvision import transforms
from torchvision.io import read_image
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import pandas as pd
import h5py
import numpy as np

from typing import Callable
import csv
import copy
import time
import json
import pathlib
import os
from os import listdir
from os.path import isfile, join

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
image_size = 64
nb_symbols = 2199

from kanji_detection_model import kanji_detector

    
def getModel():
    return kanji_detector()


def testModel():
    modelRunnable = getModel().to(device=device)
    print(modelRunnable)
    
    summary1 = summary(
        modelRunnable,
        input_size=[
            (20, 1, image_size, image_size)
        ],
        dtypes=[torch.double, torch.double],
        depth=3
    )
    
    print(summary1)
    
    del modelRunnable
    torch.cuda.empty_cache()



In [3]:
print("Allocated : " + str(torch.cuda.memory_allocated()))
print("Reserved : " + str(torch.cuda.memory_reserved()))

#testModel()
#start = time.time()
print("Allocated : " + str(torch.cuda.memory_allocated()))
print("Reserved : " + str(torch.cuda.memory_reserved()))

Allocated : 0
Reserved : 0
Allocated : 0
Reserved : 0


In [4]:
length_train = 0
length_eval = 0

train_set = {}
eval_set = {}
train_labels = {}
eval_labels = {}

f = h5py.File('image_set.hdf5', 'r')

train_set = f['training_group']['dataset']
eval_set = f['evaluation_group']['dataset']
train_labels = f['training_group']['labels']
eval_labels = f['evaluation_group']['labels']
#print(train_set[0])

train_index_list = [i for i in range(train_set.shape[0])]
eval_index_list = [i for i in range(eval_set.shape[0])]

random.shuffle(train_index_list)
random.shuffle(eval_index_list)

#picturesNames = [f for f in listdir(trainingPath) if isfile(join(trainingPath, f))]
#g_dictNames = {name:{'name':name , 'number':int(name.split('_')[0]) , 'symbol':name.split('_')[1]} for name in picturesNames}
#shufflePicturesNames = picturesNames.copy()
#random.shuffle(shufflePicturesNames)


In [5]:
class KanjiImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file, delimiter=',', header=0)
        #print(self.img_labels)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        #image = read_image(img_path)
        image = Image.open(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
    
img_transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize(image_size),
    transforms.ToTensor()
])

training_data = KanjiImageDataset("loader_data_train.csv", ".\Training_set", img_transform, int)
evaluation_data = KanjiImageDataset("loader_data_eval.csv", ".\Training_set", img_transform, int)

In [16]:
selectorIndexTrain = [0]
selectorIndexEval = [0]
def selectBatches(batch_size: int, isTraining: bool) -> list:
    global selectorIndexTrain
    global selectorIndexEval
    global train_index_list
    global eval_index_list
    
    true_list_len = len(train_index_list) if isTraining else len(eval_index_list)
    index_list = train_index_list if isTraining else eval_index_list
    selector_index = selectorIndexTrain if isTraining else selectorIndexEval
    
    selected_indices = []
    while len(selected_indices) < batch_size:
        list_len = true_list_len - selector_index[0]
        n_to_find = batch_size - len(selected_indices)
        start = selector_index[0]
        
        if list_len >= n_to_find:
            end = selector_index[0]+n_to_find
            selected_indices.extend(index_list[start:end])
            selector_index[0] += n_to_find
        else:
            selected_indices.extend(index_list[start:])
            random.shuffle(index_list)
            selector_index[0] = 0
    
    #print(str(batch[0]) +" "+ str(batch[99]))
    #print(batch_size)
    #print(len(batch))
    
    dataset = train_set if isTraining else eval_set
    labels = train_labels if isTraining else eval_labels
    
    images = np.array([dataset[index] for index in selected_indices])
    selected_labels = torch.LongTensor([labels[index,0] for index in selected_indices])
    
    return (torch.FloatTensor(images)/255).unsqueeze(1), selected_labels-1


#def countCorrect(answer: torch.FloatTensor, correctAnswer: torch.FloatTensor):
def countCorrect(answer: torch.FloatTensor, correctAnswerIndices: torch.FloatTensor):
    
    _,indicesAnswer = torch.max(answer, dim=1)
    
    #print(indicesAnswer)
    #print(indicesCorrect)
    numCorrect = (indicesAnswer == correctAnswerIndices).long().sum()
    
    return numCorrect.item()

#def countTop5Correct(answer: torch.FloatTensor, correctAnswer: torch.FloatTensor):
def countTop5Correct(answer: torch.FloatTensor, correctAnswerIndices: torch.FloatTensor):
    _,indicesAnswer = answer.topk(k=5, dim=1)
    
    #print(indicesAnswer.shape)
    #print(correctAnswerIndices.shape)
    numCorrect = (indicesAnswer == correctAnswerIndices.unsqueeze(-1)).long().sum()
    
    return numCorrect.item()



In [12]:

def train(model, n_epoch, batch_size, lr: Callable[[int], float]):
    global timer1
    global timer2
    global timer3
    
    n_batches = 100
    optimizer = torch.optim.Adam(model.parameters(), lr=lr(0))
    
    frequency_detailed_results = 5
    
    model.train()
    loss_f = torch.nn.CrossEntropyLoss()
    best_percent = 0
    for epoch in range(n_epoch):
        
        n_correct_1_t = 0
        n_correct_5_t = 0
        n_correct_1_e = 0
        n_correct_5_e = 0
        
        n_total = n_batches*batch_size
        t_loss = 0
        
        for g in optimizer.param_groups:
            g['lr'] = lr(epoch)
        #model.zero_grad()
            
        time_select = 0 #debug
        time_model = 0 #debug
        
        start_epoch = time.time()
            
        print("Epoch " + str(epoch+1))
        
        for i in range(n_batches):

            model.zero_grad()
            #optimizer.zero_grad() not needed ?
            
            start = time.time() #debug
            images, labels = selectBatches(batch_size, isTraining=True)
            #print(images.shape)
            #images, correct_answer_indices = next(iter(dataloader_train))
            end = time.time() #debug
            time_select += end-start # debug
            
            start = time.time() #debug
            answer = model(images.to(device=device))
            end = time.time() #debug
            time_model += end-start # debug
            
            loss = loss_f(answer,labels.to(device=device)).cpu()
            t_loss += loss.item()
            
            loss.backward()
            optimizer.step() #Trying at the end of the epoch ?
            
            if (epoch+1) % frequency_detailed_results == 0:
                #images_eval, correct_answer_indices_eval = next(iter(dataloader_eval))
                #answer_eval = model(images_eval.to(device=device))
                images_eval, labels_eval = selectBatches(batch_size, isTraining=True)
                answer_eval = model(images_eval.to(device=device))
                
                n_correct_1_t += countCorrect(answer, labels.to(device=device))
                n_correct_5_t += countTop5Correct(answer, labels.to(device=device))
                n_correct_1_e += countCorrect(answer_eval, labels_eval.to(device=device))
                n_correct_5_e += countTop5Correct(answer_eval, labels_eval.to(device=device))
            
            
            #print(loss.item())
            #print(torch.softmax(answer,dim=1))
            #print(correct_answer)
            #print(n_correct)
        
        #optimizer.step()
        adjust = 100
        percent_1_t = math.floor(adjust*100*n_correct_1_t/n_total)/adjust
        percent_5_t = math.floor(adjust*100*n_correct_5_t/n_total)/adjust
        percent_1_e = math.floor(adjust*100*n_correct_1_e/n_total)/adjust
        percent_5_e = math.floor(adjust*100*n_correct_5_e/n_total)/adjust
        
        display_loss = math.floor(adjust*t_loss)/adjust
        
        best_percent = percent_5_e if percent_5_e > best_percent else best_percent
        
        end_epoch = time.time()
        time_epoch = end_epoch-start_epoch
        
        #print("Time epoch : " + str(math.floor(time_epoch*adjust)/adjust)) #debug
        #print("\tTime select : " + str(math.floor(time_select*adjust)/adjust)) #debug
        #print("\tTime model : " + str(math.floor(time_model*adjust)/adjust)) #debug
        
        
        timer1 = 0
        timer2 = 0
        timer3 = 0
        
        print("\tLoss : " + str(display_loss))
        
        if (epoch+1) % frequency_detailed_results == 0:
            print("\tTop-1 training accuracy : " + str(percent_1_t) + "%")
            print("\tTop-5 training accuracy : " + str(percent_5_t) + "%")
            print("\tTop-1 evaluation accuracy : " + str(percent_1_e) + "%")
            print("\tTop-5 evaluation accuracy : " + str(percent_5_e) + "%")
        
        if percent_5_e > 98.0:
            break
        
        print("")
        
    return best_percent


def weights_init(m):
    if isinstance(m, torch.nn.Conv2d):
        m.weight.data.normal_(0, 0.02)
        m.bias.data.normal_(0, 0.001)
    
    if isinstance(m, torch.nn.Linear):
        m.weight.data.normal_(0, 0.02)
        m.bias.data.normal_(0, 0.001)
        
        

In [13]:
batch_sizes = [100] #[25, 50, 100, 150, 200]
learning_rates = [0.00001] #[0.001, 0.005, 0.01, 0.05, 0.1, 0.5]

n_epochs = 50

In [14]:
"""
for bs in batch_sizes:
    for lr in learning_rates:
        trainModel = getModel().to(device=device)
        weights_init(trainModel)
        percent = train(trainModel, n_epochs, bs, lr)
        print("bs=" + str(bs) + " lr=" + str(lr) + " : " + str(percent) + "%")
"""

'\nfor bs in batch_sizes:\n    for lr in learning_rates:\n        trainModel = getModel().to(device=device)\n        weights_init(trainModel)\n        percent = train(trainModel, n_epochs, bs, lr)\n        print("bs=" + str(bs) + " lr=" + str(lr) + " : " + str(percent) + "%")\n'

In [17]:
trainModel = getModel().to(device=device)
weights_init(trainModel)
n_epochs = 3

#lr: Callable[[int], float] = lambda epoch: 0.0003
lr: Callable[[int], float] = lambda epoch: 0.0005/(1.002**epoch)

print("Running on " + device + "\n")
#train(trainModel, n_epochs, batch_sizes[0], learning_rates[0])
train(trainModel, n_epochs, 100, lr)

Running on cuda

Epoch 1
Time epoch : 1.76
	Time select : 0.43
	Time model : 0.55
	Loss : 769.73

Epoch 2
Time epoch : 1.46
	Time select : 0.44
	Time model : 0.43
	Loss : 769.64

Epoch 3
Time epoch : 1.53
	Time select : 0.53
	Time model : 0.42
	Loss : 769.73



0

In [78]:
trainModel.eval()

torch.save(trainModel.cpu(),"./Models/kanji_model_v5_top5_88_eval.pt")

torch.save(trainModel.cpu().state_dict(), "./Models/kanji_model_v5_top5_88_eval.pth")

#temp = torch.jit.script(trainModel.cpu())
#torch.jit.save(temp, "./Models/kanji_model_96_1.pt")
