In [0]:
from __future__ import print_function, division
import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from skimage import io
import torch.nn as nn
import torch.nn.functional as F

from skimage.color import rgb2gray

import PIL
from PIL import Image

from google.colab import output

In [2]:
%cd /content/
if(os.path.isdir('/content/Lithuanian_OCR') == False):
  !git clone https://github.com/PauliusMilmantas/Lithuanian_OCR

/content
Cloning into 'Lithuanian_OCR'...
remote: Enumerating objects: 1777, done.[K
remote: Counting objects: 100% (1777/1777), done.[K
remote: Compressing objects: 100% (1217/1217), done.[K
remote: Total 1777 (delta 563), reused 1703 (delta 490), pack-reused 0[K
Receiving objects: 100% (1777/1777), 18.65 MiB | 23.01 MiB/s, done.
Resolving deltas: 100% (563/563), done.


In [0]:
class ORCDataset(Dataset):
  def __init__(self, root):
    self.root = root

  def __len__(self):
    lt = 0
    classes = os.listdir(self.root)
    for cl in classes:
      lt += len(os.listdir(self.root + '/' + cl))

    return lt

  def __getitem__(self, idx):
    if torch.is_tensor(idx):
     idx = idx.tolist()

    if(idx <= len(self)):
      found_file = ""
      found_type = ""

      fldrs = os.listdir(self.root)
      for fld in fldrs:
        fls = os.listdir(self.root + '/' + fld + '/')
        for fl in fls:
          if(fl == str(idx) + ".jpg"):
            found_file = self.root + '/' + fld + '/' + fl         
            found_type = fld

      try:
        img = io.imread(found_file)
        img = rgb2gray(img)

        return {'image': img, 'class_name': found_type}
      except:
        if(found_file != ""):
          print("Bad file: " + found_file)
        else:
          print("File not found, idx = " + str(idx))
    else:
      print()
      raise Exception("Dataset index out of boundaries")

train_dataset = ORCDataset('/content/Lithuanian_OCR/Data/training')
val_dataset = ORCDataset('/content/Lithuanian_OCR/Data/val')
test_dataset = ORCDataset('/content/Lithuanian_OCR/Data/test')

In [0]:
class Net(nn.Module):
    def __init__(self, input_size, hidden1_size, hidden2_size, num_classes):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden1_size)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden1_size, hidden2_size)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden2_size, num_classes)  
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu1(out)
        out = self.fc2(out)
        out = self.relu2(out)
        out = self.fc3(out)
        return out

In [5]:
network = Net(4096, 130, 28, 3)
optimizer = torch.optim.SGD(network.parameters(), lr=0.0000001, momentum=0.6)
criterion = nn.MSELoss()

print(network)

Net(
  (fc1): Linear(in_features=4096, out_features=130, bias=True)
  (relu1): ReLU()
  (fc2): Linear(in_features=130, out_features=28, bias=True)
  (relu2): ReLU()
  (fc3): Linear(in_features=28, out_features=3, bias=True)
)


In [0]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = len(train_dataset),shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = len(val_dataset), shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = len(test_dataset), shuffle=True)

In [7]:
dataiter = iter(train_loader)
itr = dataiter.next()

label = itr['class_name']
img = itr['image']
print("Class name: {}".format(label[0]))

fig = plt.figure(figsize = (5,5)) 
ax = fig.add_subplot(111)
ax.imshow(img[0])

File not found, idx = 282
File not found, idx = 263
File not found, idx = 285
File not found, idx = 280
File not found, idx = 236
File not found, idx = 268
File not found, idx = 284
File not found, idx = 277
File not found, idx = 253
File not found, idx = 265
File not found, idx = 237
File not found, idx = 300
File not found, idx = 257
File not found, idx = 255
File not found, idx = 269
File not found, idx = 264
File not found, idx = 267
File not found, idx = 242
File not found, idx = 286
File not found, idx = 294
File not found, idx = 299
File not found, idx = 298
File not found, idx = 290
File not found, idx = 259
File not found, idx = 281
File not found, idx = 249
File not found, idx = 296
File not found, idx = 262
File not found, idx = 256
File not found, idx = 254
File not found, idx = 297
File not found, idx = 295
File not found, idx = 270
File not found, idx = 289
File not found, idx = 288
File not found, idx = 241
File not found, idx = 271
File not found, idx = 244
File not fou

TypeError: ignored

In [0]:
train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(10 + 1)]

In [0]:
!mkdir results

In [0]:
def name_to_int(data):

  switcher = {
      'a': 1,
      'a2': 2,
      'b': 3,
      'c': 4,
      'c2': 5,
      'd': 6,
      'e': 7,
      'e2': 8,
      'e3': 9,
      'f': 10,
      'g': 11,
      'h': 12,
      'i': 13,
      'i2': 14,
      'j': 15,
      'k': 16,
      'l': 17,
      'm': 18,
      'n': 19,
      'y': 20,
      'A': 1,
      'B': 2,
      'C': 3
  }

  new_data = []

  for dt in data:
    new_data.append(switcher.get(dt))

  return new_data

In [0]:
def get_max_from_tensor(data):
  maxVal = data[0]
  maxId = 0
  for i in range(len(data)):
    if(data[i] > maxVal):
      maxVal = data[i]
      maxId = i

  return maxId + 1

In [0]:
def train(train_loader, val_loader, epoch_amount, save_checkpoint = 10):
  network.eval()
  train_loss_hist = []
  val_loss_hist = []
  checkpoint = save_checkpoint
  for epoch in range(epoch_amount):
    num_images_train = 0
    num_images_val = 0

    # TRAINING DATASET
    correct = 0
    for data in train_loader:
      images = data['image']
      labels = torch.from_numpy(np.array(name_to_int(data['class_name'])))

      num_images_train = len(images)

      lossSum = 0
      for idx in range(len(images)):
        optimizer.zero_grad()

        outputs = network(images[idx].flatten().float())

        maxIdx = get_max_from_tensor(outputs)
        real_value = np.zeros(3)
        real_value[labels[idx] - 1] = 1

        loss = criterion(outputs, torch.Tensor(real_value)) 
        
        loss.backward()

        lossSum += loss.item()
        train_loss_hist.append(lossSum)

        optimizer.step()

    # VALIDATION
    for data in val_loader:
      images = data['image']
      labels = torch.from_numpy(np.array(name_to_int(data['class_name'])))

      num_images_val = len(images)

      lossSum = 0
      for idx in range(len(images)):
        optimizer.zero_grad()

        outputs = network(images[idx].flatten().float())

        if(labels[idx] == get_max_from_tensor(outputs)):
          correct += 1

        maxIdx = get_max_from_tensor(outputs)
        real_value = np.zeros(3)
        real_value[labels[idx] - 1] = 1
        loss = criterion(outputs, torch.Tensor(real_value)) 

        lossSum += loss.item()
        val_loss_hist.append(lossSum)

    print("Epoch: {} Training loss: {} Eval loss: {} Correct: {}%".format(epoch,train_loss_hist[len(train_loss_hist) - 1],val_loss_hist[len(val_loss_hist) - 1],correct*100/num_images_train))

    if(checkpoint == 0):
      torch.save(network.state_dict(), '/content/results/model.pth')
      torch.save(optimizer.state_dict(), '/content/results/optimizer.pth')

      checkpoint = save_checkpoint
    else:
      checkpoint -= 1

train(train_loader, val_loader, 200)

In [0]:
for data in test_loader:
  images = data['image']
  labels = torch.from_numpy(np.array(name_to_int(data['class_name'])))

  for i in range(len(images)):
    output = network(images[i].flatten().float())
    
    if(get_max_from_tensor(output) == labels[i]):
      print("OK: {} == {}".format(get_max_from_tensor(output), labels[i]))
    else:
      print("False: {} == {}".format(get_max_from_tensor(output), labels[i]))