<a href="https://colab.research.google.com/github/Hramchenko/Handwritting/blob/master/HTR_tf_unif.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [61]:
import torch
print("Device " + torch.cuda.get_device_name(0))
device = torch.device("cuda:0")
#device = torch.device("cpu")
print(device)

Device Tesla K80
cuda:0


In [0]:
batch_size = 30

In [63]:
import sys
sys.path.append("./Handwritting/")
from IAMWords import IAMWords
train_set = IAMWords("train", "./IAM/", batch_size=batch_size)
test_set = IAMWords("test", "./IAM/", batch_size=batch_size)

Reading ./IAM/words.train.pkl...
Reading finished
Reading ./IAM/words.test.pkl...
Reading finished


In [0]:
def modify_dataset(dataset):
  l = len(dataset.codes)
  s = "<START>"
  dataset.codes[s] = l
  dataset.inv_codes[l] = s
  return dataset

train_set = modify_dataset(train_set)
test_set = modify_dataset(test_set)


In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [0]:
class ConvLayer(nn.Module):
    def __init__(self, size, padding=1, pool_layer=nn.MaxPool2d(2, stride=2),
                 bn=False, dropout=False, activation_fn=nn.ReLU(), stride=1):
        super(ConvLayer, self).__init__()
        layers = []
        layers.append(nn.Conv2d(size[0], size[1], size[2], padding=padding, stride=stride))
        if pool_layer is not None:
            layers.append(pool_layer)
        if bn:
            layers.append(nn.BatchNorm2d(size[1]))
        if dropout:
            layers.append(nn.Dropout2d())
        layers.append(activation_fn)
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

In [0]:
class DeconvLayer(nn.Module):
    def __init__(self, size, padding=1, stride=1, 
                 bn=False, dropout=False, activation_fn=nn.ReLU(), output_padding=0):
        super(DeconvLayer, self).__init__()
        layers = []
        layers.append(nn.ConvTranspose2d(size[0], size[1], size[2], padding=padding, 
                                         stride=stride, output_padding=output_padding))
        if bn:
            layers.append(nn.BatchNorm2d(size[1]))
        if dropout:
            layers.append(nn.Dropout2d())
        layers.append(activation_fn)
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

In [0]:
class FullyConnected(nn.Module):
    def __init__(self, sizes, dropout=False, activation_fn=nn.Tanh):
        super(FullyConnected, self).__init__()
        layers = []
        
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i+1]))
            if dropout:
                layers.append(nn.Dropout())
            layers.append(activation_fn())
        else: # нам не нужен дропаут и фнкция активации в последнем слое
            layers.append(nn.Linear(sizes[-2], sizes[-1]))
        
        self.model = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.model(x)

In [0]:
batch = train_set.make_batch()
data, target = batch
target = target.to(device)
data = data/255.0
data = data.view(batch_size, 1, 128, 400).to(device)

In [0]:
class HTREncoder(nn.Module):
    def __init__(self, batchnorm=True, dropout=False):
        super(HTREncoder, self).__init__()
        
        self.convolutions = nn.Sequential(
        ConvLayer([1, 16, 3], padding=0, bn=batchnorm),
        ConvLayer([16, 32, 3], padding=0, bn=batchnorm),
        ConvLayer([32, 50, 3], padding=0, bn=batchnorm),
        ConvLayer([50, 64, 3], padding=0, stride=2, bn=batchnorm, pool_layer=None))
    
    def forward(self, x):
        h = self.convolutions(x)
        h = F.max_pool2d(h, [h.size(2), 1], padding=[0, 0])
        h = h.permute([2, 3, 0, 1])[0]
        return h
    

In [0]:
encoder = HTREncoder().to(device)

In [0]:
class HTRDecoder(nn.Module):
    def __init__(self, ntoken, encoded_width=23, encoded_height=64, batchnorm=True, dropout=False, rnn_type="LSTM"):
        super(HTRDecoder, self).__init__()
        self.ntoken = ntoken
        self.encoded_height = encoded_height
        self.lstm_size = 256
        lstm_layers = 1
        self.rnn_type = rnn_type
        
        if rnn_type == "LSTM":
          self.rnn = nn.LSTM(self.encoded_height*encoded_width + ntoken, self.lstm_size, lstm_layers, dropout=0.3, bidirectional=False)
        else:
          self.rnn = nn.GRU(self.encoded_height*encoded_width + ntoken, self.lstm_size, lstm_layers, dropout=0.3, bidirectional=False)
        self.embedding = nn.Embedding(ntoken, ntoken)
        self.decoder = nn.Linear(1*self.lstm_size*1, ntoken)#*batch_size)
        self.drop = nn.Dropout(0.3)
        self.concatenated = torch.FloatTensor(24, )
    
    def forward(self, x, prev, hidden=None):
        x = self.drop(x)
        emb = self.embedding(prev)
        emb = emb.permute([1, 0, 2])
        x = torch.cat([x, emb], dim=2)
        x, hidden = self.rnn(x, hidden)
        x = x.permute(1, 0, 2)
        x = x.flatten(start_dim=1)
        x = self.drop(x)
        x = self.decoder(x)
        return x, hidden  
      
    def makeHidden(self):
      if self.rnn_type == "LSTM":
        h1 = torch.zeros(1, batch_size, self.lstm_size).to(device)
        h2 = torch.zeros(1, batch_size, self.lstm_size).to(device)
        return (h1, h2)
      else:
        h1 = torch.zeros(1, batch_size, self.lstm_size).to(device)
        return h1

In [73]:
decoder = HTRDecoder(len(train_set.codes)).to(device)

  "num_layers={}".format(dropout, num_layers))


In [0]:
START = train_set.codes['<START>']
current_symbol = torch.LongTensor(batch_size, 1).to(device)
current_symbol[:, :] = START

In [75]:
encoder_optimizer = optim.Adam(encoder.parameters(), lr=1e-4, weight_decay=0.00005)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=1e-4, weight_decay=0.00005)

criterion = nn.CrossEntropyLoss()

teacher_forcing_ratio = 0.5

from random import random


def train(epoch):
  print("Training epoch " + str(epoch) + "...")
  train_set.to_start()
  batch_idx = 0
  c_loss = 0
  START = train_set.codes['<START>']
  current_symbol = torch.LongTensor(batch_size, 30+1).to(device)
  while True:
    batch = train_set.make_batch()
    if batch is None:
      break
    encoder.zero_grad()
    decoder.zero_grad()
    
    data, target = batch
    data = data.view(batch_size, 1, 128, 400)/255.0
    data = data.to(device)
    target = target.to(device)
    hidden = decoder.makeHidden()    

    loss = 0
    enc = encoder(data)
    #s = enc.contiguous().view(1, batch_size, -1)
   
    s = enc.permute(1, 0, 2)
    s = s.flatten(start_dim=1).view(1, 30, 1472)
    
    current_symbol[:, 0] = START
    use_teacher_forcing = True if random() < teacher_forcing_ratio else False
    for i in range(0, target.shape[1]):
      symb = current_symbol[:, i].view(batch_size, 1).contiguous()
      dec, hidden = decoder(s, symb, hidden)
      if use_teacher_forcing:
        current_symbol[:, i + 1] = target[:, i]
      else:
        sampled = torch.multinomial(dec.exp(), 1)
        current_symbol[:, i+1] = sampled.squeeze()
      o = dec.view(30, 1, 81).flatten(start_dim=0,end_dim=1)
      t = target[:, i].flatten()
      loss += criterion(o, t)
    c_loss += loss.item()
    freq = 30
    if (batch_idx % freq == 0) and (batch_idx != 0):
      print("TF: " + str(use_teacher_forcing))
      if not use_teacher_forcing:
        for k in range(0, 5):
           print("  " + train_set.decode_word(target[k,:]) + " -> " + train_set.decode_word(current_symbol[k,:]))
      c_loss /= freq 
      print("  Batch: " + str(batch_idx) + " Loss: " + str(c_loss))
      c_loss = 0
      

      
    loss.backward()
    grad_clip = 0.1
    torch.nn.utils.clip_grad_norm_(decoder.parameters(), grad_clip)
    encoder_optimizer.step()
    decoder_optimizer.step()
    batch_idx += 1

for i in range(0, 100):
  train(i)


Training epoch 0...
TF: True
  Batch: 30 Loss: 87.11978683471679
TF: False
  of                             -> <START>k                            
  .                              -> <START>rhJ                          
  .                              -> <START>qr                        Dn 
  had                            -> <START>Z  s+         Z     h        
  come                           -> <START>xp   ? osq                   
  Batch: 60 Loss: 27.65483678181966
TF: True
  Batch: 90 Loss: 19.13895454406738
TF: True
  Batch: 120 Loss: 17.177790355682372
TF: True
  Batch: 150 Loss: 16.284815947214764
TF: False
  change                         -> <START>+Sdlornans                   
  '                              -> <START>lz                           
  ebriety                        -> <START>'/6omn                       
  to                             -> <START>h                            
  direction                      -> <START>pgno                         
  Batch: 1

KeyboardInterrupt: ignored