## Main Imports

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from tensorboardX import SummaryWriter

## Tools: CNN Dims 

In [67]:
from numpy import floor

In [78]:
def conv2d(height, width, kernel, stride=(1,1)):
    '''
    returns the spatial dims of conv or max pooling
    '''
    return  int(floor((height - kernel)/stride[0] + 1))  , \
            int(floor((width - kernel )/stride[1]) + 1)

In [110]:
height, width = (64, 256)
kernels = [3,5,14,16]
strides = [(1,1),(2,2),(1,3),(1,3)]

for n in range(len(kernels)):
    
    k = kernels[n]
    stride = strides[n]
    
    height, width =conv2d(height,width, k, stride)
    print(height, width)


62 254
29 125
16 38
1 8


## Network Definition

In [1]:
#Helper to count params
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
class CNN_Model(torch.nn.Module):
    '''
    img format required by pytorch 
    (N_batch,channels,Height,Width)
    '''
    
    def __init__(self):
        super(OCR_CNN, self).__init__()
               
        
        kernels = [3,5,14,16]
        strides = [(1,1),(2,2),(1,3),(1,3)]
        filters = [4,4,6,8,10] # filters should grow exp?
        
        # conv blocks
        self.conv = []
        
        for n in range(0,len(kernels)):    
            self.conv.append(
                torch.nn.Conv2d(filters[n], filters[n+1], kernels[n], stride=strides[n])
            )
        
        # activation
        self.activation = torch.nn.ReLU()
        self.softmax = torch.nn.Softmax(dim=1)# after squezze channels are the second, third is position
    
        
    def forward(self, x):
        # incoming format: N_batch,Height,Width,channels
        # permute to:      N_batch,channels,Height,Width
        x = x.permute((0,3,1,2))
        
        
        for conv in self.conv[0:-1]:
            x = conv(x)
            x = self.activation(x)
        
        x = self.conv[-1](x) # last conv 
        
        # height dim (1st) is 1, so contract
        x = torch.squeeze(x, 2) # [batch, n_filters, 1 vertical, 8 horizontal pos]       
        
        x = self.softmax(x) # [softmax over n_filters i.e. char probs]
#         pdb.set_trace()
        
        return x

## Monitoring tools

### TensorboardX

In [None]:
from tensorboardX import SummaryWriter

#Set up Tensorboard writer for current test
writer = SummaryWriter(log_dir="./summary")

# writer.add_scalar("total_loss", ave_total_loss.average(), n_iter) 
# writer.add_scalar("CER", CER_total.average(), n_iter)
# writer.add_scalar("lr",lr,n_iter)

### Memory managment: 

In [None]:
# pip install memory-profiler
# !mprof run train.py

TODO: find a proper way to log decorated functions in jupyter

## Generators

### Generator tools

### Generator


In [None]:
n_chars_dict = 10
n_chars = 8
BATCH_SIZE = 2

def generator(char_img, epoch_size = 100, batch_size = 2):
    '''
    generator template
    '''
    
    for N in range(epoch_size):
        y_gt = np.zeros((batch_size, n_chars_dict, n_chars))
        imgs = [] 
        for N in range(batch_size):
            pass
    
        # yield batch
        yield imgs, y_gt
        

### Definitions and tests


## Learning loop