<a href="https://colab.research.google.com/github/Hramchenko/Handwritting/blob/master/prof_Att_HTR_tf_unif_att_v5_l.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
print("Device " + torch.cuda.get_device_name(0))
device = torch.device("cuda:0")
#device = torch.device("cpu")
print(device)

Device Tesla K80
cuda:0


In [0]:
batch_size = 100

In [3]:
import sys
sys.path.append("./Handwritting/")
from IAMWords import IAMWords
image_width = 1500
image_height = 200
train_set = IAMWords("train", "./IAM/", batch_size=batch_size, line_height=image_height, line_width=image_width, scale=1)
test_set = IAMWords("test", "./IAM/", batch_size=batch_size, line_height=image_height, line_width=image_width, scale=1)

Reading ./IAM/words.train.pkl...
Reading finished
Reading ./IAM/words.test.pkl...
Reading finished


In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from IPython.core.debugger import set_trace
%matplotlib inline
import matplotlib.pyplot as plt

In [0]:
class ConvLayer(nn.Module):
    def __init__(self, size, padding=1, pool_layer=nn.MaxPool2d(2, stride=2),
                 bn=False, dropout=False, activation_fn=nn.ReLU(), stride=1):
        super(ConvLayer, self).__init__()
        layers = []
        layers.append(nn.Conv2d(size[0], size[1], size[2], padding=padding, stride=stride))
        if pool_layer is not None:
            layers.append(pool_layer)
        if bn:
            layers.append(nn.BatchNorm2d(size[1]))
        if dropout:
            layers.append(nn.Dropout2d())
        layers.append(activation_fn)
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

In [0]:
class DeconvLayer(nn.Module):
    def __init__(self, size, padding=1, stride=1, 
                 bn=False, dropout=False, activation_fn=nn.ReLU(), output_padding=0):
        super(DeconvLayer, self).__init__()
        layers = []
        layers.append(nn.ConvTranspose2d(size[0], size[1], size[2], padding=padding, 
                                         stride=stride, output_padding=output_padding))
        if bn:
            layers.append(nn.BatchNorm2d(size[1]))
        if dropout:
            layers.append(nn.Dropout2d())
        layers.append(activation_fn)
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

In [0]:
class FullyConnected(nn.Module):
    def __init__(self, sizes, dropout=False, activation_fn=nn.Tanh):
        super(FullyConnected, self).__init__()
        layers = []
        
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i+1]))
            if dropout:
                layers.append(nn.Dropout())
            layers.append(activation_fn())
        else: # нам не нужен дропаут и фнкция активации в последнем слое
            layers.append(nn.Linear(sizes[-2], sizes[-1]))
        
        self.model = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.model(x)

In [0]:
class FullyConnectedX(nn.Module):
    def __init__(self, sizes, dropout=False, activation_fn=nn.Tanh(), flatten=False, last_fn=None):
        super(FullyConnectedX, self).__init__()
        layers = []
        self.flatten = flatten
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i+1]))
            if dropout:
                layers.append(nn.Dropout(dropout))
            layers.append(activation_fn) # нам не нужен дропаут и фнкция активации в последнем слое
        else: 
            layers.append(nn.Linear(sizes[-2], sizes[-1]))
        if last_fn is not None:
            layers.append(last_fn)
        self.model = nn.Sequential(*layers)
        
    def forward(self, x):
        if self.flatten:
            x = x.view(x.shape[0], -1)
        return self.model(x)

In [0]:
batch = train_set.make_batch(use_binarization=False)
data, target = batch
target = target.to(device)
data = data/255.0
data = data.view(batch_size, 1, image_width, image_height).to(device)

In [0]:
class HTREncoder(nn.Module):
    def __init__(self, batchnorm=True, dropout=False):
        super(HTREncoder, self).__init__()
        
        self.convolutions = nn.Sequential(
        ConvLayer([1, 4, 3], padding=0, stride=2, bn=batchnorm, pool_layer=None),
        ConvLayer([4, 16, 3], padding=0, stride=2, bn=batchnorm, pool_layer=None),
        ConvLayer([16, 32, 3], padding=0, stride=2, bn=batchnorm, pool_layer=None),
        ConvLayer([32, 64, 3], padding=0, stride=2, bn=batchnorm, pool_layer=None),
        ConvLayer([64, 64, 1], padding=0, stride=(1,11), bn=batchnorm, pool_layer=None))
        
        #self.fc = FullyConnectedX([64*15*49, 64*49*3, 64*49], activation_fn=nn.ReLU())
    
    def forward(self, x):
        h = self.convolutions(x)
        h = h.squeeze(-1)
        #h = h.flatten(start_dim=1)
        #h = self.fc(h)
        #h = F.max_pool2d(h, [1, h.size(1)], padding=[0, 0])
        #h = h.permute([2, 3, 0, 1])[0]
        #h = h.permute([2, 3, 0, 1])
        
        return h
    

In [0]:
encoder = HTREncoder().to(device)

In [12]:
c = encoder(data)
c.shape

torch.Size([100, 64, 92])

In [13]:
64*15*49


47040

In [0]:
class HTRDecoder(nn.Module):
    def __init__(self, ntoken, encoded_width=92, encoded_height=64, batchnorm=True, dropout=False, rnn_type="LSTM"):
        super(HTRDecoder, self).__init__()
        self.ntoken = ntoken
        self.encoded_width = encoded_width
        self.encoded_height = encoded_height
        self.lstm_size = 256
        self.lstm_layers = 2
        self.rnn_type = rnn_type
        self.emb_size = 128
        features_size = self.encoded_height*encoded_width + self.emb_size
        from math import floor
        lstm_inp_size = floor(features_size*0.3)
        
        if rnn_type == "LSTM":
          self.rnn = nn.LSTM(lstm_inp_size, self.lstm_size, self.lstm_layers, dropout=0.3, bidirectional=False)
        else:
          self.rnn = nn.GRU(lstm_inp_size, self.lstm_size, self.lstm_layers, dropout=0.3, bidirectional=False)
        self.embedding = nn.Embedding(ntoken, self.emb_size)
        self.decoder = nn.Linear(1*self.lstm_size*1, ntoken)#*batch_size)
        self.drop = nn.Dropout(0.0)

        self.fc = FullyConnectedX([features_size, floor(features_size*0.7), floor(features_size*0.5), lstm_inp_size], activation_fn=nn.ReLU(), last_fn=nn.Tanh())
        
        self.attention = FullyConnectedX([self.lstm_size*2 + self.encoded_height*encoded_width, self.encoded_height*encoded_width*2,  self.encoded_width], activation_fn=nn.LeakyReLU(0.2), last_fn=nn.Tanh())
#        self.attention = FullyConnectedX([self.lstm_size*2 + self.encoded_height*encoded_width, self.encoded_height*encoded_width*2,  self.encoded_height*self.encoded_width], activation_fn=nn.LeakyReLU(0.2), last_fn=nn.Tanh())
        #print(self.attention)
        #self.concatenated = torch.FloatTensor(24, )
        self.attention_weights = None
    
    def forward(self, x, prev, hidden=None):
        #set_trace()
        x = self.drop(x).squeeze()
        if hidden is not None:
          
          hidden_m = hidden.permute(1, 0, 2)

          hidden_m = hidden_m.flatten(start_dim=1)
          #print(hidden.shape)
          
          attention_inp = torch.cat([x, hidden_m], dim=1).detach()
          self.attention_weights = self.attention(attention_inp)
          #print(x.shape)
          #print(hidden_m.shape)
          #print(attention_inp.shape)
          #print(self.attention_weights.shape)
          
          self.attention_weights = F.softmax(self.attention_weights, dim=1)
          #print(self.attention_weights.shape)
          
          self.attention_weights = self.attention_weights.repeat([1, self.encoded_height])
          #print(self.attention_weights.shape)
                  
          #print(x.shape)
          x = x * self.attention_weights
          #print(x.shape)
          #raise Exception()
          
          
          #print("********************")
          #print(x)
          #print(attention_w)
          #print(X)
          #print("---------------")
        emb = self.embedding(prev).squeeze().detach()
        
        
        x = torch.cat([x, emb], dim=1)
        
        x = self.fc(x)
        x = x.unsqueeze(0)
        x, hidden = self.rnn(x, hidden)
        x = x.squeeze(dim=0)
        x = self.drop(x)
        x = self.decoder(x)
        
        x = F.log_softmax(x, dim=1)
        
        return x, hidden  
      
    def makeHidden(self):
      if self.rnn_type == "LSTM":
        h1 = torch.zeros(self.lstm_layers, batch_size, self.lstm_size).to(device)
        h2 = torch.zeros(self.lstm_layers, batch_size, self.lstm_size).to(device)
        return (h1, h2)
      else:
        h1 = torch.zeros(self.lstm_layers, batch_size, self.lstm_size).to(device)
        return h1
      
decoder = HTRDecoder(len(train_set.codes), rnn_type="GRU").to(device)

In [0]:
class HTRRecognition(nn.Module):
  def __init__(self):
    super(HTRRecognition, self).__init__()
    self.encoder = HTREncoder()
    self.decoder = HTRDecoder(len(train_set.codes), rnn_type="GRU")
    self.encoder_optimizer = optim.Adam(self.encoder.parameters(), lr=1e-4, weight_decay=0.00005)
    self.decoder_optimizer = optim.Adam(self.decoder.parameters(), lr=1e-4, weight_decay=0.00005)
    self.START = train_set.start_code
    self.STOP = train_set.stop_code
    self.recognition_result = torch.LongTensor(batch_size, 30+1).to(device)
    self.old_symbol = torch.LongTensor(batch_size, 1).to(device)
    self.loss = 0
    self.stop_symbol = torch.LongTensor(batch_size, 1).to(device)
    self.stop_symbol.fill_(self.STOP)
    self.criterion = nn.NLLLoss()
    
  def zero_grad(self):
    self.encoder.zero_grad()
    self.decoder.zero_grad()
    
  def forward(self, data, target, use_teacher_forcing):
    orig_data = data

    hidden = self.decoder.makeHidden()   

    self.loss = 0
    enc = self.encoder(data)
    s = enc.permute(1, 0, 2)
    s = s.flatten(start_dim=1).view(1, batch_size, -1)
    
    self.old_symbol[:, 0] = self.START

    for i in range(0, target.shape[1]):

      dec, hidden = self.decoder(s, self.old_symbol, hidden)
      self.hidden = hidden
      
      self.recognition_result[:, i] = dec.topk(1, dim=1)[1].flatten().detach()
      if use_teacher_forcing:
        self.old_symbol[:, 0] = target[:, i]
      else:
        self.old_symbol[:, 0] = self.recognition_result[:, i]
      self.loss += self.criterion(dec, target[:, i])
    self.length = target.shape[1]
    return self.recognition_result[:, 0: target.shape[1]]
  
  def hidden_state(self):
    return self.hidden
  
  def normed_loss(self):
    return self.loss/target.shape[1]
   
  def backprop(self):
    self.loss.backward()
    
  def step(self):
    grad_clip = 0.1
    torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), grad_clip)
    torch.nn.utils.clip_grad_norm_(self.decoder.parameters(), grad_clip)
    self.encoder_optimizer.step()
    self.decoder_optimizer.step()


In [0]:
# From https://github.com/aryopg/Professor_Forcing_Pytorch/blob/master/models/losses.py
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, input_length):
        super(Discriminator, self).__init__()
        self.hidden_size = hidden_size
        self.input_length = input_length
        self.rnn_layers = 2

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, self.rnn_layers)
        from math import floor
        gru_out = input_length*hidden_size
        self.fc = FullyConnectedX([gru_out, floor(gru_out*0.7), floor(gru_out*0.3), 1], activation_fn=nn.ReLU())
        self.optimizer = optim.Adam(self.parameters(), lr=1e-4, weight_decay=0.00005)
        
    def zero_grad(self):
      self.optimizer.zero_grad()

    def forward(self, x):
        outputs = torch.zeros(self.input_length, batch_size, self.hidden_size, device=device)

        hidden = self.initHidden()
        for ei in range(x.shape[1]):
            embedded = self.embedding(x[:, ei])
            embedded = embedded.view(1, batch_size, -1)
            output = embedded
            output, hidden = self.gru(output, hidden)
            outputs[ei] = output[0, 0]

        outputs = outputs.permute(1,0,2)
        feat = outputs.contiguous().view(x.shape[0], -1)
        out = self.fc(feat)
        
        self.features = feat

        return out

    def initHidden(self):
        return torch.zeros(self.rnn_layers, batch_size, self.hidden_size, device=device)


In [0]:
# From https://github.com/aryopg/Professor_Forcing_Pytorch/blob/master/models/losses.py
class LatentDiscriminator(nn.Module):
    def __init__(self):
        super(LatentDiscriminator, self).__init__()
        self.rnn_size = 256

        self.rnn_layers = 2
        latent_layer_size = 2*256


        from math import floor
        self.fc = FullyConnectedX([latent_layer_size, floor(latent_layer_size*1.5), floor(latent_layer_size*0.7), floor(latent_layer_size*0.3), 1], activation_fn=nn.ReLU())
        self.optimizer = optim.Adam(self.parameters(), lr=1e-4, weight_decay=0.00005)
        
    def zero_grad(self):
      self.optimizer.zero_grad()

    def forward(self, x):
        x = x.permute(1, 0, 2)
        x = x.flatten(start_dim=1)
        out = self.fc(x)
        return out



In [0]:
batch_zeros = torch.zeros((batch_size, 1)).to(device)
batch_ones = torch.ones((batch_size, 1)).to(device)

In [0]:
generator = HTRRecognition().to(device)
#discriminator = Discriminator(len(train_set.codes), 256, 32).to(device)
from math import floor
discriminator = LatentDiscriminator().to(device)

In [24]:
def process_batch(data, target, batch_idx):
  generator.zero_grad()
  for i in range(0, 3):
    discriminator.zero_grad()
    free_run = generator(data, target, False)
    free_run_hid = generator.hidden_state().detach()
    teacher_forcing = generator(data, target, True)
    teacher_forcing_hid = generator.hidden_state().detach()
    d_free_run = discriminator(free_run_hid)
    d_teacher_forcing = discriminator(teacher_forcing_hid)
    true_loss = F.binary_cross_entropy_with_logits(d_free_run, batch_zeros)
    fake_loss = F.binary_cross_entropy_with_logits(d_teacher_forcing, batch_ones)
    D_loss = 0.5*(fake_loss + true_loss)
    d_val = D_loss.item()
    D_loss.backward()
    discriminator.optimizer.step()
    
  generator.zero_grad()
  discriminator.zero_grad()
  teacher_forcing_ = generator(data, target, True)
  teacher_forcing_hid_ = generator.hidden_state()
  fake_pred = discriminator(teacher_forcing_hid_)  
  tf_val = generator.normed_loss().item()
  G_loss = generator.normed_loss() + F.binary_cross_entropy_with_logits(fake_pred, batch_zeros)
  g_val = G_loss.item()
  G_loss.backward()
  generator.step()
  if batch_idx %10 == 0:
    print("Batch: %d Descr %.4f TF %.4f Full %.4f" % (batch_idx, d_val, tf_val, g_val)) 
  
def train(epoch):
  train_set.to_start()
  batch_idx = 0
  while True:
    batch = train_set.make_batch(use_binarization=False)
    if batch is None:
      break
    data, target = batch
    target = target.to(device)
    data = data/255.0
    data = data.view(batch_size, 1, image_width, image_height).to(device)
    process_batch(data, target, batch_idx)
    batch_idx += 1
    
for i in range(0, 100):
  print("Epoch %d" % i)
  print("***************************************")
  train(i)

Epoch 0
***************************************
Batch: 0 Descr 0.6936 TF 4.4127 Full 5.1358
Batch: 10 Descr 0.6933 TF 3.6388 Full 4.3337
Batch: 20 Descr 0.6932 TF 2.3977 Full 3.0910
Batch: 30 Descr 0.6839 TF 8.8002 Full 9.4991
Batch: 40 Descr 0.6710 TF 7.7865 Full 8.5653
Batch: 50 Descr 0.6612 TF 7.2025 Full 8.1167
Batch: 60 Descr 0.6700 TF 6.8196 Full 7.7149
Batch: 70 Descr 0.6865 TF 6.8619 Full 7.6532
Batch: 80 Descr 0.6880 TF 6.3007 Full 7.0354
Batch: 90 Descr 0.6868 TF 5.9838 Full 6.7120
Batch: 100 Descr 0.6954 TF 5.7079 Full 6.4347
Batch: 110 Descr 0.6938 TF 5.5551 Full 6.2669
Batch: 120 Descr 0.6923 TF 5.0200 Full 5.7162
Batch: 130 Descr 0.6934 TF 5.1465 Full 5.8545
Batch: 140 Descr 0.6883 TF 9.7425 Full 10.4125
Batch: 150 Descr 0.1900 TF 9.1778 Full 11.8826
Batch: 160 Descr 0.7253 TF 8.8114 Full 9.4711


KeyboardInterrupt: ignored