<a href="https://colab.research.google.com/github/Hramchenko/Handwritting/blob/master/AT_prof_v2_Att_HTR_tf_unif_att_v5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
print("Device " + torch.cuda.get_device_name(0))
device = torch.device("cuda:0")
#device = torch.device("cpu")
print(device)

Device Tesla K80
cuda:0


In [0]:
batch_size = 100

In [3]:
import sys
sys.path.append("./Handwritting/")
from IAMWords import IAMWords
image_width = 1500
image_height = 200
train_set = IAMWords("train", "./IAM/", batch_size=batch_size, line_height=image_height, line_width=image_width, scale=1)
test_set = IAMWords("test", "./IAM/", batch_size=batch_size, line_height=image_height, line_width=image_width, scale=1)

Reading ./IAM/words.train.pkl...
Reading finished
Reading ./IAM/words.test.pkl...
Reading finished


In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from IPython.core.debugger import set_trace
%matplotlib inline
import matplotlib.pyplot as plt

In [0]:
class ConvLayer(nn.Module):
    def __init__(self, size, padding=1, pool_layer=nn.MaxPool2d(2, stride=2),
                 bn=False, dropout=False, activation_fn=nn.ReLU(), stride=1):
        super(ConvLayer, self).__init__()
        layers = []
        layers.append(nn.Conv2d(size[0], size[1], size[2], padding=padding, stride=stride))
        if pool_layer is not None:
            layers.append(pool_layer)
        if bn:
            layers.append(nn.BatchNorm2d(size[1]))
        if dropout:
            layers.append(nn.Dropout2d())
        layers.append(activation_fn)
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

In [0]:
class DeconvLayer(nn.Module):
    def __init__(self, size, padding=1, stride=1, 
                 bn=False, dropout=False, activation_fn=nn.ReLU(), output_padding=0):
        super(DeconvLayer, self).__init__()
        layers = []
        layers.append(nn.ConvTranspose2d(size[0], size[1], size[2], padding=padding, 
                                         stride=stride, output_padding=output_padding))
        if bn:
            layers.append(nn.BatchNorm2d(size[1]))
        if dropout:
            layers.append(nn.Dropout2d())
        layers.append(activation_fn)
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

In [0]:
none_lambda = lambda _: None

In [0]:
class FullyConnected(nn.Module):
    def __init__(self, sizes, dropout=False, activation_fn=nn.Tanh, init_fn=none_lambda):
        super(FullyConnected, self).__init__()
        layers = []
        
        for i in range(len(sizes) - 2):
            fc = nn.Linear(sizes[i], sizes[i+1])
            init_fn(fc.weight)
            layers.append(fc)
            if dropout:
                layers.append(nn.Dropout())
            layers.append(activation_fn())
        else: # нам не нужен дропаут и фнкция активации в последнем слое
            layers.append(nn.Linear(sizes[-2], sizes[-1]))
        
        self.model = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.model(x)

In [0]:
class FullyConnectedX(nn.Module):
    def __init__(self, sizes, dropout=False, batch_norm=False, activation_fn=nn.Tanh(), flatten=False, last_fn=None, init_fn=none_lambda):
        super(FullyConnectedX, self).__init__()
        layers = []
        self.flatten = flatten
        for i in range(len(sizes) - 2):
            fc = nn.Linear(sizes[i], sizes[i+1])
            
            init_fn(fc.weight)
            layers.append(fc)
            if dropout:
                layers.append(nn.Dropout(dropout))
            layers.append(activation_fn) # нам не нужен дропаут и фнкция активации в последнем слое
            if False:#batch_norm:
                layers.append(nn.BatchNorm1d(sizes[i+1]))
        else: 
            layers.append(nn.Linear(sizes[-2], sizes[-1]))
        if last_fn is not None:
            layers.append(last_fn)
        self.model = nn.Sequential(*layers)
        
    def forward(self, x):
        if self.flatten:
            x = x.view(x.shape[0], -1)
        return self.model(x)

In [0]:
batch = train_set.make_batch(use_binarization=False)
data, target = batch
target = target.to(device)
data = data/255.0
data = data.view(batch_size, 1, image_width, image_height).to(device)

In [0]:
class HTREncoder(nn.Module):
    def __init__(self, batchnorm=False, dropout=False):
        super(HTREncoder, self).__init__()
        
        self.convolutions = nn.Sequential(
        ConvLayer([1, 4, 3], padding=0, stride=2, bn=batchnorm, pool_layer=None),
        ConvLayer([4, 16, 3], padding=0, stride=2, bn=batchnorm, pool_layer=None),
        ConvLayer([16, 32, 3], padding=0, stride=2, bn=batchnorm, pool_layer=None),
        ConvLayer([32, 64, 3], padding=0, stride=2, bn=batchnorm, pool_layer=None),
        ConvLayer([64, 64, 1], padding=0, stride=(1,11), bn=batchnorm, pool_layer=None))
    
    def forward(self, x):
        h = self.convolutions(x)
        h = h.squeeze(-1)
        return h


In [0]:
encoder = HTREncoder().to(device)

In [13]:
c = encoder(data)
c.shape

torch.Size([100, 64, 92])

In [0]:
def init_gru(a):
  from torch.nn import init
  for layer_p in a._all_weights:
      for p in layer_p:
          if 'weight' in p:
              # print(p, a.__getattr__(p))
              init.normal(a.__getattr__(p), 0.0, 0.02)
              # print(p, a.__getattr__(p))

In [15]:
class HTRDecoder(nn.Module):
    def __init__(self, ntoken, encoded_width=92, encoded_height=64, batchnorm=True, dropout=False, rnn_type="LSTM"):
        super(HTRDecoder, self).__init__()
        self.ntoken = ntoken
        self.encoded_width = encoded_width
        self.encoded_height = encoded_height
        self.lstm_size = 256
        self.lstm_layers = 2
        self.rnn_type = rnn_type
        self.emb_size = 128
        features_size = self.encoded_height*encoded_width + self.emb_size
        from math import floor
        lstm_inp_size = floor(features_size*0.3)
        
        if rnn_type == "LSTM":
          self.rnn = nn.LSTM(lstm_inp_size, self.lstm_size, self.lstm_layers, dropout=0.3, bidirectional=False)
        else:
          self.rnn = nn.GRU(lstm_inp_size, self.lstm_size, self.lstm_layers, dropout=0.3, bidirectional=False)
         
        init_gru(self.rnn)
        #self.rnn.weight.data.uniform_(-0.1, 0.1)  
          
        self.embedding = nn.Embedding(ntoken, self.emb_size)
        self.decoder = nn.Linear(1*self.lstm_size*1, ntoken)#*batch_size)
        self.drop = nn.Dropout(0.0)
        self.fc = FullyConnectedX([features_size, floor(features_size*0.7), floor(features_size*0.5), lstm_inp_size], activation_fn=nn.ReLU(), last_fn=nn.ReLU(), init_fn=nn.init.kaiming_uniform_, batch_norm=True)
        self.attention = FullyConnectedX([self.lstm_size*2 + self.encoded_height*encoded_width, self.encoded_height*encoded_width*2,  self.encoded_width], activation_fn=nn.LeakyReLU(0.2), last_fn=nn.Tanh(), init_fn=nn.init.kaiming_uniform_, batch_norm=True)
        self.attention_weights = None
    
    def forward(self, x, prev, hidden=None):
        x = self.drop(x).squeeze()
        if hidden is not None:
          
          hidden_m = hidden.permute(1, 0, 2)

          hidden_m = hidden_m.flatten(start_dim=1)
          
          attention_inp = torch.cat([x, hidden_m], dim=1).detach()
          self.attention_weights = self.attention(attention_inp)
          self.attention_weights = F.softmax(self.attention_weights, dim=1)
          self.attention_weights = self.attention_weights.repeat([1, self.encoded_height])
          x = x * self.attention_weights
        emb = self.embedding(prev).squeeze().detach()
        
        x = torch.cat([x, emb], dim=1)
        x = self.fc(x)
        x = x.unsqueeze(0)
        self.rnn_input = x
        x, hidden = self.rnn(x, hidden)
        x = x.squeeze(dim=0)
        x = self.drop(x)
        x = self.decoder(x)
        x = F.log_softmax(x, dim=1)
        return x, hidden  
    
    def rnnInput(self):
      return self.rnn_input
    
    def makeHidden(self):
      if self.rnn_type == "LSTM":
        h1 = torch.zeros(self.lstm_layers, batch_size, self.lstm_size).to(device)
        h2 = torch.zeros(self.lstm_layers, batch_size, self.lstm_size).to(device)
        return (h1, h2)
      else:
        h1 = torch.zeros(self.lstm_layers, batch_size, self.lstm_size).to(device)
        return h1
      
decoder = HTRDecoder(len(train_set.codes), rnn_type="GRU").to(device)

  import sys


In [0]:
class HTRRecognitionState:
  
  def __init__(self):
    None
  

class HTRRecognition(nn.Module):
  def __init__(self):
    super(HTRRecognition, self).__init__()
    self.encoder = HTREncoder()
    self.decoder = HTRDecoder(len(train_set.codes), rnn_type="GRU")
    self.encoder_optimizer = optim.Adam(self.encoder.parameters(), lr=1e-4, weight_decay=0.00005)
    self.decoder_optimizer = optim.Adam(self.decoder.parameters(), lr=1e-4, weight_decay=0.00005)
    self.START = train_set.start_code
    self.STOP = train_set.stop_code
    self.recognition_result = torch.LongTensor(batch_size, 10+1).to(device)
    self.old_symbol = torch.LongTensor(batch_size, 1).to(device)
    self.loss = 0
    self.stop_symbol = torch.LongTensor(batch_size, 1).to(device)
    self.stop_symbol.fill_(self.STOP)
    self.criterion = nn.NLLLoss()
    
  def zero_grad(self):
    self.encoder.zero_grad()
    self.decoder.zero_grad()
    
  def forward(self, data, target, use_teacher_forcing):
    orig_data = data

    hidden = self.decoder.makeHidden()   

    self.loss = 0
    enc = self.encoder(data)
    s = enc.permute(1, 0, 2)
    s = s.flatten(start_dim=1).view(1, batch_size, -1)
    
    self.old_symbol[:, 0] = self.START
    
    self.hidden_states_ = []
    self.rnn_inputs_ = []

    for i in range(0, target.shape[1]):

      dec, hidden = self.decoder(s, self.old_symbol, hidden)
      self.hidden_states_.append(hidden.unsqueeze(0))
      self.rnn_inputs_.append(self.decoder.rnnInput())
      self.recognition_result[:, i] = dec.topk(1, dim=1)[1].flatten().detach()
      if use_teacher_forcing:
        self.old_symbol[:, 0] = target[:, i]
      else:
        self.old_symbol[:, 0] = self.recognition_result[:, i]
      self.loss += self.criterion(dec, target[:, i])
    self.length = target.shape[1]
    self.result = self.recognition_result[:, 0: target.shape[1]]
    return self.result
  
  def state(self):
    r = HTRRecognitionState()
    r.hidden = self.hidden_states()
    r.rnn_inputs = self.rnn_inputs()
    r.result = self.result
    return r
  
  def normed_loss(self):
    return self.loss/target.shape[1]
   
  def backprop(self):
    self.loss.backward()
    
  def hidden_states(self):
    r = torch.cat(self.hidden_states_, dim = 0)
    r = r.permute(0, 2, 1, 3)
    r = r.flatten(start_dim=2)
    return r
  
  def rnn_inputs(self):
    r = torch.cat(self.rnn_inputs_, dim=0)
    #print(r.shape)
    return r
    
  def step(self):
    #grad_clip = 0.1
    #torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), grad_clip)
    #torch.nn.utils.clip_grad_norm_(self.decoder.parameters(), grad_clip)
    self.encoder_optimizer.step()
    self.decoder_optimizer.step()


In [0]:
# From https://github.com/aryopg/Professor_Forcing_Pytorch/blob/master/models/losses.py
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, input_length, symbs_cnt):
        super(Discriminator, self).__init__()
        
        self.embedding = nn.Embedding(symbs_cnt, 128)
        
        from math import floor
        self.hidden_cells = 256
        self.hidden_layers = 2
        
        self.hidden_size = hidden_size
        self.input_length = input_length
        self.rnn_layers = 2
        
        input_size = 512 + 1804 + 128
        
       
        
        gru_input_size = 256*2
        
        self.enc = FullyConnectedX([input_size, floor(input_size*0.7), gru_input_size], activation_fn=nn.ReLU())

        self.gru = nn.GRU(gru_input_size, hidden_size, self.rnn_layers)
        
        init_gru(self.gru) 
        
        gru_out = input_length*hidden_size
        self.fc = FullyConnectedX([gru_out, floor(gru_out*0.7), floor(gru_out*0.3), 1], activation_fn=nn.ReLU(), init_fn=nn.init.kaiming_uniform_, batch_norm=True)
        self.optimizer = optim.Adam(self.parameters(), lr=1e-4, weight_decay=0.00005)
        
    def zero_grad(self):
      self.optimizer.zero_grad()

    def forward(self, hidden_states, dec_inputs, dec_outputs):
        emb_outputs = self.embedding(dec_outputs).permute(1, 0, 2)
#         print("lllllllllll")
#         print(hidden_states.shape)
#         print(dec_inputs.shape)
#         print(emb_outputs.shape)
        
        full_input = torch.cat([hidden_states, dec_inputs, emb_outputs], dim=2)
        #full_input = full_input.permute([1, 0, 2])
#         print(full_input.shape)
        
#         print("ooooooooooooooooooolll")
        
        
        
      
      
        outputs = torch.zeros(self.input_length, batch_size, self.hidden_size, device=device)

        hidden = self.initHidden()
        for ei in range(hidden_states.shape[0]):
            #embedded = self.embedding(x[:, ei])
            #embedded = embedded.view(1, batch_size, -1)
            #output = embedded
            #output = hidden_states[ei]
            output = self.enc(full_input[ei, :, :])
#             print("****")
#             print(output.shape)
            output = output.unsqueeze(0)
            #print(output.shape)
            output, hidden = self.gru(output, hidden)
            
            outputs[ei] = output.squeeze(0)

        outputs = outputs.permute(1,0,2)
        #print(outputs.shape)
        #feat = outputs.contiguous().view(x.shape[0], -1)
        feat = outputs.flatten(start_dim=1)
        #print(feat.shape)
        out = self.fc(feat)
        
        self.features = feat

        return out

    def initHidden(self):
        return torch.zeros(self.rnn_layers, batch_size, self.hidden_size, device=device)


In [0]:
batch_zeros = torch.zeros((batch_size, 1)).to(device)
batch_ones = torch.ones((batch_size, 1)).to(device)

In [19]:
generator = HTRRecognition().to(device)
discriminator = Discriminator(256*2, 512, 10, len(train_set.codes)).to(device)

  import sys


In [0]:
def process_batch(data, target, batch_idx):
  
  period = 20
  activate_discriminator = target.shape[1] > 1
#  activate_discriminator = True
  if activate_discriminator:
    for i in range(0, 1):
      discriminator.zero_grad()
      generator.zero_grad()
      with torch.no_grad():
        free_run_result = generator(data, target, False)

        if (batch_idx % period == 0) and (i == 0):
          from random import random
          for k in range(0, min(2, target.shape[0])):
                decoded =free_run_result[k,0:target.shape[1]]
                #plt.imshow(data[k].cpu(), cmap="gray")
                #plt.show()
                print("  " + train_set.decode_word(target[k,:]) + " -> " + train_set.decode_word(decoded))

        #free_run_hidden = generator.hidden_states().detach()
        free_run_state = generator.state()
        generator(data, target, True)
        teacher_forcing_state = generator.state()
#         print("hhhhhhhhhhhhhh")
#         print(free_run_hidden.shape)
#         print("fffffffffff")
#         print(free_run_result.shape)
        #free_run_rnn_inputs = generator.rnn_inputs().detach()
        #teacher_forcing_hidden = generator.hidden_states().detach()
      d_free_run = discriminator(free_run_state.hidden.detach(), free_run_state.rnn_inputs.detach(), free_run_state.result.detach())
      d_teacher_forcing = discriminator(teacher_forcing_state.hidden.detach(), teacher_forcing_state.rnn_inputs.detach(), teacher_forcing_state.result.detach())
      true_loss = F.binary_cross_entropy_with_logits(d_free_run, batch_zeros)
      fake_loss = F.binary_cross_entropy_with_logits(d_teacher_forcing, batch_ones)
      D_loss = 0.5*(fake_loss + true_loss)
      d_val = D_loss.item()
      D_loss.backward()
      discriminator.optimizer.step()
  else:
    d_val = 0
  generator.zero_grad()
  discriminator.zero_grad()
  generator(data, target, True)
  if activate_discriminator and (d_val < 0.25):
    #print ("Applying discriminator...")
    state_ = generator.state()
    fake_pred = discriminator(state_.hidden, state_.rnn_inputs, state_.result)
    G_loss = generator.normed_loss() + F.binary_cross_entropy_with_logits(fake_pred, batch_zeros)
#     tf_val = generator.normed_loss().item()
#     g_val = G_loss.item()
#     print("  Batch: %d Descr %.4f TF %.4f Full %.4f" % (batch_idx, d_val, tf_val, g_val)) 
    
  else:
    G_loss = generator.normed_loss()
  tf_val = generator.normed_loss().item()
  g_val = G_loss.item()
  G_loss.backward()
  generator.step()
  if batch_idx % period == 0:
    print("Batch: %d Descr %.4f TF %.4f Full %.4f" % (batch_idx, d_val, tf_val, g_val)) 
  
def train(epoch, max_size):
  train_set.to_start(max_size)
  batch_idx = 0
  while True:
    batch = train_set.make_batch()
    if batch is None:
      break
    data, target = batch
    target = target.to(device)
    data = data/255.0
    data = data.view(batch_size, 1, image_width, image_height).to(device)
    process_batch(data, target, batch_idx)
    batch_idx += 1

    


In [0]:
for i in range(0, 100):
  print("Epoch %d" % i)
  print("***************************************")
  train(i, 5)

Epoch 0
***************************************
  Flor -> then
  Next -> ande
Batch: 0 Descr 0.3883 TF 3.4760 Full 3.4760
