"""
         aim of this part is to modify the code for word level attention
    
        1. Check the dataloader and create a dummy dataloader similar to it.
    
        2. Pass the input from dummy dataloader to model
    
        3. Modify the model
    
        4.Train the modified model on the actual data
    
"""

In [1]:

import sys
sys.path.insert(0, '/home/aniketag/Documents/phd/TensorFlow-2.x-YOLOv3_simula/Handwriting-1-master/VerticalAttentionOCR/')
from torch.optim import Adam
from OCR.document_OCR.v_attention.trainer_pg_va import Manager
from OCR.document_OCR.v_attention.models_pg_va import VerticalAttention, LineDecoderCTC
from basic.models import FCN_Encoder
from basic.generic_dataset_manager import OCRDataset
from basic.generic_training_manager import GenericTrainingManager
from torch.nn import Flatten, LSTM, Embedding
import torch
import torch.multiprocessing as mp
import torch.nn as nn


Apex not installed


In [7]:

dataset_name = "IAM"  # ["RIMES", "IAM", "READ_2016"]

params = {
    "dataset_params": {
        "datasets": {
            dataset_name: "/media/aniketag/c4eb0693-4a65-4f0c-8d65-a6dad4b97ff9/IAM/formatted/{}_paragraph".format(dataset_name),
        },
        "train": {
            "name": "{}-train".format(dataset_name),
            "datasets": [dataset_name, ],
        },
        "valid": {
            "{}-valid".format(dataset_name): [dataset_name, ],
        },
        "dataset_class": OCRDataset,
        "config": {
            "width_divisor": 8,  # Image width will be divided by 8
            "height_divisor": 32,  # Image height will be divided by 32
            "padding_value": 0,  # Image padding value
            "padding_token": None,  # Label padding value (None: default value is chosen)
            "charset_mode": "CTC",  # add blank label
            "constraints": ["padding", "CTC_va"],  # Padding for models constraints and CTC requirements
            "padding": {
                "min_height": 480,  # to handle model requirements (AdaptivePooling)
                "min_width": 800,  # to handle model requirements (AdaptivePooling)
            },
            "preprocessings": [
                {
                    "type": "dpi",  # modify image resolution
                    "source": 300,  # from 300 dpi
                    "target": 150,  # to 150 dpi
                },
                {
                    "type": "to_RGB",
                    # if grayscale image, produce RGB one (3 channels with same value) otherwise do nothing
                },
            ],
            # Augmentation techniques to use at training time
            "augmentation": {
                "dpi": {
                    "proba": 0.2,
                    "min_factor": 0.75,
                    "max_factor": 1,
                },
                "perspective": {
                    "proba": 0.2,
                    "min_factor": 0,
                    "max_factor": 0.3,
                },
                "elastic_distortion": {
                    "proba": 0.2,
                    "max_magnitude": 20,
                    "max_kernel": 3,
                },
                "random_transform": {
                    "proba": 0.2,
                    "max_val": 125,
                },
                "dilation_erosion": {
                    "proba": 0.2,
                    "min_kernel": 1,
                    "max_kernel": 3,
                    "iterations": 1,
                },
                "brightness": {
                    "proba": 0.2,
                    "min_factor": 0.01,
                    "max_factor": 1,
                },
                "contrast": {
                    "proba": 0.2,
                    "min_factor": 0.01,
                    "max_factor": 1,
                },
                "sign_flipping": {
                    "proba": 0.2,
                },
            },
        }
    },

    "model_params": {
        # Model classes to use for each module
        "models": {
            "encoder": FCN_Encoder,
            "attention": VerticalAttention,
            "decoder": LineDecoderCTC,
        },
        "transfer_learning": None,
        # "transfer_learning": {
        #     # model_name: [state_dict_name, checkpoint_path, learnable, strict]
        #     "encoder": ["encoder", "../../line_OCR/ctc/outputs/iam/checkpoints/best_XX.pt", True, True],
        #     "decoder": ["decoder", "../../line_OCR/ctc/outputs/iam/checkpoints/best_XX.pt", True, True],
        #
        # },
        "input_channels": 3,  # 3 for RGB images, 1 for grayscale images

        # dropout probability for standard dropout (half dropout probability is taken for spatial dropout)
        "dropout": 0.5,  # dropout for encoder module
        "dec_dropout": 0.5,  # dropout for decoder module
        "att_dropout": 0,  # dropout for attention module

        "features_size": 256,  # encoder output features maps
        "att_fc_size": 256,  # number of channels for attention sum computation

        "use_location": True,  # use previous attention weights in attention module
        "use_coverage_vector": True,  # use coverage vector in attention module
        "coverage_mode": "clamp",  # mode to use for the coverage vector

        "emb_max_features_width": 250,  # maximum feature width (for use_abs_position)
        "emb_max_features_height": 100,  # maximum feature height (for use_abs_position)

        "use_hidden": True,  # use decoder hidden state in attention (and thus LSTM in decoder)
        "hidden_size": 256,  # number of cells for LSTM decoder hidden state
        "nb_layers_decoder": 1,  # number of layers for LSTM decoder

        "min_height_feat": 15,  # min height for attention module (AdaptivePooling)
        "min_width_feat": 100,  # min width for attention module (AdaptivePooling)
    },

    "training_params": {
        "output_folder": "van_iam_paragraph_learned_stop",  # folder names for logs and weigths
        "max_nb_epochs": 5000,  # max number of epochs for the training
        "max_training_time": 3600 * (24 + 23),  # max training time limit (in seconds)
        "load_epoch": "best",  # ["best", "last"], to load weights from best epoch or last trained epoch
        "interval_save_weights": None,  # None: keep best and last only
        "batch_size": 1,  # mini-batch size per GPU
        "use_ddp": False,  # Use DistributedDataParallel
        "ddp_port": "10000",  # Port for Distributed Data Parallel communications
        "use_apex": True,  # Enable mix-precision with apex package
        "nb_gpu": torch.cuda.device_count(),
        "optimizer": {
            "class": Adam,
            "args": {
                "lr": 0.0001,
                "amsgrad": False,
            }
        },
        "eval_on_valid": True,  # Whether to eval and logs metrics on validation set during training or not
        "eval_on_valid_interval": 2,  # Interval (in epochs) to evaluate during training
        "focus_metric": "cer",  # Metrics to focus on to determine best epoch
        "expected_metric_value": "low",  # ["high", "low"] What is best for the focus metric value
        "set_name_focus_metric": "{}-valid".format(dataset_name),
        "train_metrics": ["loss_ctc", "cer", "wer"],  # Metrics name for training
        "eval_metrics": ["cer", "wer", "diff_len"],  # Metrics name for evaluation on validation set during training
        "force_cpu": False,  # True for debug purposes to run on cpu only
        "max_pred_lines": 30,  # Maximum number of line predictions at evaluation time
        "stop_mode": "learned",  # ["fixed", "early", "learned"]

    },

}


In [3]:
class miniTrain(GenericTrainingManager):
    def __init__(self,params) -> None:
        
        self.params = params
        
        print("\n\t inside minitrain!!")
        
        self.load_dataset()
        self.load_model()
    
    def testdata():
        pass
    
    

In [4]:



miniTrain(params)


	 inside minitrain!!

	 stamp 1

	 datasets[key] = /media/aniketag/c4eb0693-4a65-4f0c-8d65-a6dad4b97ff9/IAM/formatted/IAM_paragraph/

	 joinPath = /media/aniketag/c4eb0693-4a65-4f0c-8d65-a6dad4b97ff9/IAM/formatted/IAM_paragraph/labels.pkl

	 is file: True

	 stamp 1.1

	 stamp 1.3

	 paths_and_sets: [{'path': '/media/aniketag/c4eb0693-4a65-4f0c-8d65-a6dad4b97ff9/IAM/formatted/IAM_paragraph', 'set_name': 'train'}]

	 from_segmentation: False  	 paths_and_sets: [{'path': '/media/aniketag/c4eb0693-4a65-4f0c-8d65-a6dad4b97ff9/IAM/formatted/IAM_paragraph', 'set_name': 'train'}]

	 paths_and_sets: [{'path': '/media/aniketag/c4eb0693-4a65-4f0c-8d65-a6dad4b97ff9/IAM/formatted/IAM_paragraph', 'set_name': 'train'}]

	 keys: dict_keys(['name', 'label', 'img', 'unchanged_label', 'raw_line_seg_label'])

	 self.samples[i][name]: IAM_paragraph/train/train_0.png

	 self.samples[i][name]: IAM_paragraph/train/train_1.png

	 self.samples[i][name]: IAM_paragraph/train/train_2.png

	 self.samples[i][name]

IndexError: list index out of range

In [None]:
#params

In [3]:
import torch

# Generate a random context vector for demonstration
context_vector = torch.randn(1, 256, 121)

batch_size = context_vector.size(0)
max_line_len = context_vector.size(2)
input_size = context_vector.size(1)
hidden_size = 64

# Reshape the context vector
context_vector = context_vector.transpose(1, 2).reshape(batch_size, max_line_len, input_size)

# Apply a linear layer to obtain a new tensor with shape (batch_size, max_line_len, hidden_size)
linear_layer = torch.nn.Linear(input_size, hidden_size)
context_vector = linear_layer(context_vector)

# Apply softmax along the second dimension
weights = torch.nn.functional.softmax(context_vector, dim=1)

# Check the shape of the weights tensor
print(weights.size())



torch.Size([1, 121, 64])


In [None]:
samples = load_samples()



In [3]:
import sys
import random
import os
import time
import torch
import torch.nn as nn
from torch import tanh, log_softmax, softmax, relu
from torch.nn import Conv1d, Conv2d, Dropout,  Linear, AdaptiveMaxPool2d, InstanceNorm1d, AdaptiveMaxPool1d
from torch.nn import Flatten, LSTM, Embedding

import sys
sys.path.insert(0, '/home/aniketag/Documents/phd/TensorFlow-2.x-YOLOv3_simula/Handwriting-1-master/VerticalAttentionOCR/')
from torch.optim import Adam
from OCR.document_OCR.v_attention.trainer_pg_va import Manager
#from OCR.document_OCR.v_attention.models_pg_va import VerticalAttention, LineDecoderCTC
from basic.models import FCN_Encoder
from basic.generic_dataset_manager import OCRDataset
import torch
import torch.multiprocessing as mp
#from parameters import params

class BahdanauAttention(torch.nn.Module):
    def __init__(self, hidden_size):
        super(BahdanauAttention, self).__init__()
        self.hidden_size = hidden_size
        self.W1 = torch.nn.Linear(hidden_size, hidden_size, bias=False)
        self.W2 = torch.nn.Linear(hidden_size, hidden_size, bias=False)
        self.V = torch.nn.Linear(hidden_size, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        # hidden shape: (batch_size, hidden_size)
        # encoder_outputs shape: (batch_size, max_line_len, hidden_size)

        # Compute attention scores
        hidden = hidden.unsqueeze(1)  # (batch_size, 1, hidden_size)
        W1_hidden = self.W1(hidden)  # (batch_size, 1, hidden_size)
        W2_encoder = self.W2(encoder_outputs)  # (batch_size, max_line_len, hidden_size)
        scores = self.V(torch.tanh(W1_hidden + W2_encoder))  # (batch_size, max_line_len, 1)

        # Compute attention weights
        attention_weights = torch.softmax(scores, dim=1)  # (batch_size, max_line_len, 1)

        # Compute context vector
        context_vector = (attention_weights * encoder_outputs).sum(dim=1)  # (batch_size, hidden_size)

        return context_vector, attention_weights


In [4]:
LineDecoderCTC(params)

KeyError: 'use_hidden'

In [29]:

import sys
import random
import os
import time
import torch
import torch.nn as nn
from torch import tanh, log_softmax, softmax, relu
from torch.nn import Conv1d, Conv2d, Dropout,  Linear, AdaptiveMaxPool2d, InstanceNorm1d, AdaptiveMaxPool1d
from torch.nn import Flatten, LSTM, Embedding
from basic.models import DepthSepConv2D

class LineDecoderCTC(torch.nn.Module):
    def __init__(self, params):
        super(LineDecoderCTC, self).__init__()

        self.params = params
        self.hidden_size = 256 #params["hidden_size"]
        self.batch_size = 2
        self.max_line_len = 10 # assuming there will be max 10 words in a line

        self.hidden = torch.zeros(self.batch_size, self.hidden_size)

        self.use_hidden = True #params["use_hidden"]
        self.input_size = 256 #params["features_size"]
        self.vocab_size = 79 #params["vocab_size"]
        self.attention = BahdanauAttention1(self.hidden_size)
        self.word_context_vectors = []
        
        self.context_vector2 = None
        self.context_vector = None 
        self.attention_weights = None
        
        self.wrdCon2Dcd = nn.Linear(256,self.input_size) # this will take input word context vector and
                                                             # converts it to the shape that LSTM can process 
        
        
        self.decoder = RNNDecoder(self.hidden_size,1)

        
        if self.use_hidden:
            
            self.lstm = LSTM(self.input_size, self.hidden_size, num_layers=1)
            self.end_conv = Conv2d(in_channels=self.hidden_size, out_channels=self.vocab_size + 1, kernel_size=1)
            
            """
            self.lstm1 = LSTM(self.input_size, self.hidden_size, num_layers=1)
            self.end_conv1 = Conv2d(in_channels=self.hidden_size, out_channels=self.vocab_size + 1, kernel_size=1)
            """
        else:
            self.end_conv = Conv2d(in_channels=self.input_size, out_channels=self.vocab_size + 1, kernel_size=1)
        

        self.linear1 = torch.nn.Linear(256, 256)
            

    def forward(self, x, h=None):
                        
                
        """
        x (B, C, W)
        """
        
        x1 = x.clone() # torch.Size([1, 256, 116])
        x1 = x1.permute(0,2,1)
        #print("\n\t x1.shape:",x1.shape)
        
        hidden_rep = self.linear1(x1)
        
        print("\n\t 1.x.shape:",x1.shape,"\t x.shape:",x.shape,"\t hidden_rep.shape:",hidden_rep.shape)        
        #           1.x1.shape: torch.Size([2, 116, 256]) 	 x.shape: torch.Size([2, 256, 116]) 	 hidden_rep.shape: torch.Size([2, 116, 256])

        print("\n\t self.hidden:",self.hidden.shape) # torch.Size([2, 256])
        for i in range(self.max_line_len):
            print("\n\t i:",i)

            self.context_vector, self.context_vector2, self.attention_weights = self.attention(self.hidden, hidden_rep)

            #########################################################################################################
            print("\n\t word context_vector:",self.context_vector.permute(2, 0, 1).shape," \t attention_weights.shape:",self.attention_weights.shape)
            print("\n\t word context_vector2:",self.context_vector2.shape) #  torch.Size([batch_size, 256])
            
            
            # 	 word context_vector: torch.Size([2, 256])  	 attention_weights.shape: torch.Size([2, 116, 1]) (old)
            #  word context_vector: torch.Size([2, 256, 116])  	 attention_weights.shape: torch.Size([2, 116])     (new)
            
                        
            xOut, hOut = self.lstm(self.context_vector.permute(2, 0, 1), h) 
            xOut = xOut.permute(1, 2, 0)

            print("\n\t xOut:",xOut.shape,"\t hOut.shape =",hOut[0].shape)
	        #  	 word context_vector: torch.Size([2, 256, 116])  	 attention_weights.shape: torch.Size([2, 116])    
            self.word_context_vectors.append(self.context_vector)

            temp3 = xOut
            out2 = self.end_conv(temp3.unsqueeze(3)).squeeze(3)
            print("\n\t out2.shape:",out2.shape)
            
            #########################################################################################################

            #########################################################################################################
            
            print("\n\t self.hidden.shape before:",self.hidden.shape)
            self.hidden = self.decoder(self.hidden,self.context_vector2) 
            print("\n\t self.hidden.shape after:",self.hidden.shape)

            
            #########################################################################################################

            
        if 1:#self.use_hidden:
            
            print("\n\t 1.1.0.x.shape:",x.permute(2, 0, 1).shape) #  1.1.0.x.shape: torch.Size([116, 2, 256])         

            x, h = self.lstm(x.permute(2, 0, 1), h) 
            print("\n\t 1.1.1. x.shape:",x.shape," \t h.shape:",h[0].shape)    
            #  1.1.1. x.shape: torch.Size([116, 2, 256]), h.shape: torch.Size([1, 2, 256]) middle dim 2 is batch size in both vectors    

            x = x.permute(1, 2, 0)

        temp2 = x.unsqueeze(3)
        print("\n\t 2.x.shape:",x.shape," temp2.shape:",temp2.shape)        
        # 2.x.shape: torch.Size([2, 256, 116])  temp2.shape: torch.Size([2, 256, 116, 1])
        
        out = self.end_conv(x.unsqueeze(3)).squeeze(3)
        print("\n\t out1.shape:",out.shape,"\t out2.shape:",out2.shape)
        
        out = torch.squeeze(out, dim=2)
        out = log_softmax(out, dim=1)
        return out, h


class RNNDecoder(nn.Module):
    def __init__(self, embed_size, num_layers, drop=0.3):
        super().__init__()

        self.num_layers = num_layers
        self.rnn = nn.GRU(embed_size, embed_size, num_layers)
        if self.num_layers > 1: self.rnn.dropout = drop

    def forward(self, hidden, context):
        _, h = self.rnn(context.unsqueeze(0), hidden.expand(self.num_layers, -1, -1).contiguous())

        return h[-1]

class DeepOutputLayer(nn.Module):
    def __init__(self, embed_size, vocab_size, drop=0.3):
        super().__init__()
        
        self.l1 = nn.Linear(embed_size*3, embed_size)
        self.l2 = nn.Linear(embed_size, vocab_size)
        self.drop = nn.Dropout(drop)
        
    def forward(self, prev, hidden, context):
        # this is called once for each timestep
        #(30,256)
        out = self.l1(torch.cat([prev,hidden,context], -1))
        out = self.l2(self.drop(F.leaky_relu(out)))
        return out

class Embedding(nn.Module):
    def __init__(self, vocab, d_model, drop=0.2):
        super(Embedding, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        emb = self.lut(x) * math.sqrt(self.d_model)
        return self.drop(emb)



In [30]:

contextVect = torch.rand([2,256,116])


decoder = LineDecoderCTC(params)

out,h = decoder(contextVect)

print("\n\t out.shape:",out.shape) 
print("\t h.shape",len(h))

print("1.",h[0].shape)

print("1.",h[1].shape)



	 1.x.shape: torch.Size([2, 116, 256]) 	 x.shape: torch.Size([2, 256, 116]) 	 hidden_rep.shape: torch.Size([2, 116, 256])

	 self.hidden: torch.Size([2, 256])

	 i: 0

	 attention_weights = torch.Size([2, 116, 1])

	 word context_vector: torch.Size([116, 2, 256])  	 attention_weights.shape: torch.Size([2, 116])

	 word context_vector2: torch.Size([2, 256])

	 xOut: torch.Size([2, 256, 116]) 	 hOut.shape = torch.Size([1, 2, 256])

	 out2.shape: torch.Size([2, 80, 116])

	 self.hidden.shape before: torch.Size([2, 256])

	 self.hidden.shape after: torch.Size([2, 256])

	 i: 1

	 attention_weights = torch.Size([2, 116, 1])

	 word context_vector: torch.Size([116, 2, 256])  	 attention_weights.shape: torch.Size([2, 116])

	 word context_vector2: torch.Size([2, 256])

	 xOut: torch.Size([2, 256, 116]) 	 hOut.shape = torch.Size([1, 2, 256])

	 out2.shape: torch.Size([2, 80, 116])

	 self.hidden.shape before: torch.Size([2, 256])

	 self.hidden.shape after: torch.Size([2, 256])

	 i: 2

	 att

In [15]:
"""
modified attention to handle context vector 
"""


class BahdanauAttention1(torch.nn.Module):
    def __init__(self, hidden_size):
        super(BahdanauAttention1, self).__init__()
        self.hidden_size = hidden_size
        self.W1 = torch.nn.Linear(hidden_size, hidden_size, bias=False)
        self.W2 = torch.nn.Linear(hidden_size, hidden_size, bias=False)
        self.V = torch.nn.Linear(hidden_size, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        # hidden shape: (batch_size, hidden_size)
        # encoder_outputs shape: (batch_size, max_line_len, hidden_size)

        # Compute attention scores
        hidden = hidden.unsqueeze(1)  # (batch_size, 1, hidden_size)
        W1_hidden = self.W1(hidden)  # (batch_size, 1, hidden_size)
        W2_encoder = self.W2(encoder_outputs)  # (batch_size, max_line_len, hidden_size)
        scores = self.V(torch.tanh(W1_hidden + W2_encoder))  # (batch_size, max_line_len, 1)

        # Compute attention weights
        attention_weights = torch.softmax(scores, dim=1)  # (batch_size, max_line_len, 1)
        print("\n\t attention_weights =",attention_weights.shape)
        
        #attention_weights = attention_weights.squeeze(2)
        
        # Compute context vector
        context_vector = (attention_weights * encoder_outputs).sum(dim=1)  # (batch_size, hidden_size)
        
        context_vector2 = context_vector.clone()
        
        context_vector = context_vector.unsqueeze(0).transpose(1, 2).repeat(1, encoder_outputs.shape[1], 1, 1)
        context_vector = context_vector.squeeze(0)
        context_vector = context_vector.permute(2,1,0) # torch.Size([3, 256, 116])
        
        
        # context_vector shape: (1, max_line_len, batch_size, hidden_size)

        return context_vector,context_vector2, attention_weights.squeeze(2)


In [12]:
import torch

# define batch size, max line length, and hidden size
batch_size = 12
max_line_len = 116
hidden_size = 256

# create random encoder_outputs tensor
encoder_outputs = torch.randn(batch_size, max_line_len, hidden_size)

# create random hidden tensor
hidden = torch.randn(batch_size, hidden_size)

# create BahdanauAttention object
attention1 = BahdanauAttention1(hidden_size)

# call forward method to get context_vector and attention_weights
context_vector,context_vector2, attention_weights = attention1(hidden, encoder_outputs)

print("context_vector shape:", context_vector.shape," \t cv2.shape:",context_vector2.shape)  # should output (batch_size, hidden_size)
print("attention_weights shape:", attention_weights.shape)  # should output (batch_size, max_line_len, 1)



	 attention_weights = torch.Size([12, 116, 1])
context_vector shape: torch.Size([12, 256, 116])  	 cv2.shape: torch.Size([12, 256])
attention_weights shape: torch.Size([12, 116])


In [1]:

%load_ext autoreload
%autoreload 2

import sys
sys.path.insert(0, '/home/aniketag/Documents/phd/TensorFlow-2.x-YOLOv3_simula/Handwriting-1-master/VerticalAttentionOCR/')
from torch.optim import Adam
from OCR.document_OCR.v_attention.trainer_pg_va import Manager
from OCR.document_OCR.v_attention.models_pg_va import VerticalAttention, LineDecoderCTC
from basic.models import FCN_Encoder
from basic.generic_dataset_manager import OCRDataset
from basic.generic_training_manager import GenericTrainingManager
from torch.nn import Flatten, LSTM, Embedding
import torch
import torch.multiprocessing as mp
import torch.nn as nn

import sys
import random
import os
import time
import torch
import torch.nn as nn
from torch import tanh, log_softmax, softmax, relu
import torch.nn.functional as F

from torch.nn import Conv1d, Conv2d, Dropout,  Linear, AdaptiveMaxPool2d, InstanceNorm1d, AdaptiveMaxPool1d
from torch.nn import Flatten, LSTM, Embedding



from basic.models import DepthSepConv2D

from torch.nn import CrossEntropyLoss, CTCLoss
from basic.generic_dataset_manager import DatasetManager


#dataset = DatasetManager(params["dataset_params"])


loss_ctc_func = CTCLoss(79, reduction="sum")

Apex not installed


In [56]:
class RNNDecoder(nn.Module):
    def __init__(self, embed_size, num_layers, drop=0.3):
        super().__init__()

        self.num_layers = num_layers
        self.rnn = nn.GRU(embed_size, embed_size, num_layers)
        if self.num_layers > 1: self.rnn.dropout = drop

    def forward(self, hidden, context):
        _, h = self.rnn(context.unsqueeze(0), hidden.expand(self.num_layers, -1, -1).contiguous())

        return h[-1]

class DeepOutputLayer(nn.Module):
    def __init__(self, embed_size, vocab_size, drop=0.3):
        super().__init__()
        
        self.l1 = nn.Linear(embed_size*2, embed_size)
        self.l2 = nn.Linear(embed_size, vocab_size+1)
        self.drop = nn.Dropout(drop)
        
    def forward(self, hidden, context):
        # this is called once for each timestep
        #(30,256)
        out = self.l1(torch.cat([hidden,context], -1))
        out = self.l2(self.drop(F.leaky_relu(out)))
        return out

class Embedding(nn.Module):
    def __init__(self, vocab, d_model, drop=0.2):
        super(Embedding, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        emb = self.lut(x) * math.sqrt(self.d_model)
        return self.drop(emb)
    



class BahdanauAttention1(torch.nn.Module):
    def __init__(self, hidden_size):
        super(BahdanauAttention1, self).__init__()
        self.hidden_size = hidden_size
        self.W1 = Linear(hidden_size, hidden_size, bias=False)
        self.W2 = Linear(hidden_size, hidden_size, bias=False)
        self.V = Linear(hidden_size, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        # hidden shape: (batch_size, hidden_size)
        # encoder_outputs shape: (batch_size, max_line_len, hidden_size)

        # Compute attention scores
        hidden = hidden.unsqueeze(1)  # (batch_size, 1, hidden_size)
        W1_hidden = self.W1(hidden)  # (batch_size, 1, hidden_size)
        W2_encoder = self.W2(encoder_outputs)  # (batch_size, max_line_len, hidden_size)
        
        #print("\n\t W1_hidden.shape:",W1_hidden.shape,"\t W2_encoder.shape:",W2_encoder.shape)#,"\t self.batch_size:",self.batch_size)

        
        scores = self.V(torch.tanh(W1_hidden + W2_encoder))  # (batch_size, max_line_len, 1)

        # Compute attention weights
        attention_weights = torch.softmax(scores, dim=1)  # (batch_size, max_line_len, 1)
        #print("\n\t attention_weights =",attention_weights.shape)
        
        #attention_weights = attention_weights.squeeze(2)
        
        # Compute context vector
        context_vector = (attention_weights * encoder_outputs).sum(dim=1)  # (batch_size, hidden_size)
        
        context_vector2 = context_vector.clone()
        
        context_vector = context_vector.unsqueeze(0).transpose(1, 2).repeat(1, encoder_outputs.shape[1], 1, 1)
        context_vector = context_vector.squeeze(0)
        context_vector = context_vector.permute(2,1,0) # torch.Size([3, 256, 116])
        
        
        # context_vector shape: (1, max_line_len, batch_size, hidden_size)

        return context_vector,context_vector2, attention_weights.squeeze(2)


class LineDecoderCTC1(torch.nn.Module):
    def __init__(self, params):
        super(LineDecoderCTC1, self).__init__()

        self.params = params
        self.hidden_size = 256 #params["hidden_size"]
        self.batch_size =  params["training_params"]["batch_size"] # 2
        self.max_line_len = 150 # assuming there will be max 10 words in a line

        self.hidden = torch.zeros(self.batch_size, self.hidden_size).to("cuda")

        self.use_hidden = True #params["use_hidden"]
        self.input_size = 256 #params["features_size"]
        self.vocab_size = 79 #params["vocab_size"]
        self.attention = BahdanauAttention1(self.hidden_size).to("cuda:0")
        self.word_context_vectors = []
        self.dec_inp = []
        
        self.context_vector2 = None
        self.context_vector = None 
        self.attention_weights = None
        
        self.wrdCon2Dcd = Linear(256,self.input_size) # this will take input word context vector and
                                                             # converts it to the shape that LSTM can process 

        self.decoder = RNNDecoder(self.hidden_size,1).to("cuda:0")
        self.output  = DeepOutputLayer(self.hidden_size, self.vocab_size).to("cuda:0")
                
        if self.use_hidden:
            
            self.lstm = LSTM(self.input_size, self.hidden_size, num_layers=1)
            self.end_conv = Conv2d(in_channels=self.hidden_size, out_channels=self.vocab_size + 1, kernel_size=1)
            
            """
            self.lstm1 = LSTM(self.input_size, self.hidden_size, num_layers=1)
            self.end_conv1 = Conv2d(in_channels=self.hidden_size, out_channels=self.vocab_size + 1, kernel_size=1)
            """
        else:
            self.end_conv = Conv2d(in_channels=self.input_size, out_channels=self.vocab_size + 1, kernel_size=1)
        

        self.linear1 = Linear(256, 256).to("cuda:0")
            

    def forward(self, x, h=None):
                        
                
        """
        x (B, C, W)
        """
        
        
        res,attns = [],[]
        
        maxChars = x.shape[2]
        
        x1 = x.clone() # torch.Size([1, 256, 116])

        
        x1 = x1.to("cuda:0")                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
        x1 = x1.permute(0,2,1) # [1, 116, 256]
        
        print("\n\t x1.shape:",x1.shape)
        #print("\n\t 1:",x1.device," \t x.device:",x.device)
        
        hidden_rep = self.linear1(x1) # [1, 116, 256]

        print("\n\t hidden_rep.shape:",hidden_rep.shape," \t maxChars:",maxChars)


        #print("\n\t 2")
        
        #print("\n\t 1.x.shape:",x1.shape,"\t x.shape:",x.shape,"\t hidden_rep.shape:",hidden_rep.shape)        
        #           1.x1.shape: torch.Size([2, 116, 256]) 	 x.shape: torch.Size([2, 256, 116]) 	 hidden_rep.shape: torch.Size([2, 116, 256])

        #print("\n\t self.hidden:",self.hidden.shape) # torch.Size([2, 256])
        
        randLen = maxChars
        
        """
        res = torch.zeros([randLen,1,80]).to("cuda:0")
        res = res.permute(0,2,1)
        """
        res = []
        
        #print("\n\t x1.shape:",x1.shape," \t res =",res.shape)
        
        #attns = torch.rand([randLen,1,randLen]).to("cuda:0")
        #res = torch.zeros([1, randLen, 80], dtype=torch.float32, device="cuda:0")
        
        for i in range(maxChars):
            #print("\n\t char i:",i)
                                                                                              #([1, 256],[1, 116, 256])
            self.context_vector, self.context_vector2, self.attention_weights = self.attention(self.hidden, hidden_rep)
            # [1, 256, 116]    , [1, 256]            , [1, 116]
            encoder_outputs = hidden_rep.clone()
            
            #print("\n\t ii:",self.context_vector.shape, self.context_vector2.shape, self.attention_weights.shape,encoder_outputs.shape)
            # 	        ii: torch.Size([1, 256, 116]), torch.Size([1, 256]), torch.Size([1, 116]) torch.Size([1, 116, 256])

            #print("\n\t word context_vector:",self.context_vector.permute(2, 0, 1).shape," \t attention_weights.shape:",self.attention_weights.shape)
            #print("\n\t word context_vector2:",self.context_vector2.shape) #  torch.Size([batch_size, 256])

            #print("\n\t self.hidden.shape before:",self.hidden.shape) # torch.Size([1, 256])
            
            #                          ([1, 256], [1, 256])                             
            self.hidden = self.decoder(self.hidden,self.context_vector2)  
            #print("\n\t self.hidden.shape after:",self.hidden.shape) # torch.Size([1, 256])
            
            #                          ([1, 256], [1, 256])
            charOut = self.output(self.hidden, self.context_vector2)
            #print("\n\t charOut.shape =",charOut.shape) # charOut.shape = torch.Size([1, 80])
            
            charOut = log_softmax(charOut, dim=1)
            
            #print("\n\t charOut.shape:",charOut.unsqueeze(dim=1).shape)
            
            #res = torch.cat((res[:, :i, :], charOut.unsqueeze(dim=1), res[:, i+1:, :]), dim=1)

            res.append(charOut)
            attns.append(self.attention_weights)
            #dec_inp = charOut.data.max(1)[1]

            self.dec_inp.append(charOut.data.max(1)[1])
            
            #############################################################################################################

            
        res = torch.stack(res)
        attns = torch.stack(attns)
        return res, attns, self.dec_inp

ldc1 = LineDecoderCTC1(params).to("cuda:0")

In [58]:
x = torch.rand([1, 256, 116])
x.shape

res, attns, dec_inp = ldc1(x)

print(res.shape, attns.shape)



	 x1.shape: torch.Size([1, 116, 256])

	 hidden_rep.shape: torch.Size([1, 116, 256])  	 maxChars: 116
torch.Size([116, 1, 80]) torch.Size([116, 1, 116])
