In [29]:
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
import torchvision

from tqdm import tqdm
from torch.utils.data import Dataset
from albumentations import Compose
from albumentations.pytorch import ToTensorV2

import PIL

import cv2 as cv

INPUT_PATH = '/kaggle/input/bengaliai-cv19'

In [30]:
# ======================
# Params
BATCH_SIZE = 32
N_WORKERS = 4
N_EPOCHS = 5

HEIGHT = 137
WIDTH = 236
TARGET_SIZE = 256

# My weights dataset for this compeititon; feel free to vote the dataste ;)
# https://www.kaggle.com/pestipeti/bengali-ai-model-weights
WEIGHTS_FILE = '/kaggle/input/bengali-ai-model-weights/baseline_weights.pth'

In [31]:
# setup image hight and width
HEIGHT = 137
WIDTH = 236
SIZE = 128

def bbox(img):
    rows = np.any(img, axis=1)
    cols = np.any(img, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]
    return rmin, rmax, cmin, cmax

def crop_resize(img0, size=SIZE, pad=16):
    #crop a box around pixels large than the threshold 
    #some images contain line at the sides
    ymin,ymax,xmin,xmax = bbox(img0[5:-5,5:-5] > 80)
    #cropping may cut too much, so we need to add it back
    xmin = xmin - 13 if (xmin > 13) else 0
    ymin = ymin - 10 if (ymin > 10) else 0
    xmax = xmax + 13 if (xmax < WIDTH - 13) else WIDTH
    ymax = ymax + 10 if (ymax < HEIGHT - 10) else HEIGHT
    img = img0[ymin:ymax,xmin:xmax]
    #remove low intensity pixels as noise
    #img[img < 28] = 0
    lx, ly = xmax-xmin,ymax-ymin
    l = max(lx,ly) + pad
    #make sure that the aspect ratio is kept in rescaling
    img = np.pad(img, [((l-ly)//2,), ((l-lx)//2,)], mode='constant')
    return cv.resize(img,(size,size))

def threshold_image(img):
    '''
    Helper function for thresholding the images
    '''
    gray = PIL.Image.fromarray(np.uint8(img), 'L')
    ret,th = cv.threshold(np.array(gray),0,255,cv.THRESH_BINARY+cv.THRESH_OTSU)
    return th

class BengaliParquetDataset(Dataset):

    def __init__(self, parquet_file, transform=None):

        self.data = pd.read_parquet(parquet_file)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data.iloc[idx, 1:].values.reshape(HEIGHT, WIDTH)
        
        img = 255.0 - img
        img = (img*(255.0/img.max())).astype(np.uint8)
        img = crop_resize(img, size=SIZE, pad=16)
        img = threshold_image(img)
        img = 255 - img

        image_id = self.data.iloc[idx, 0]

        if self.transform:
            transformed = self.transform(image=img)
            img = transformed['image']
            
        img = TF.to_tensor(img)

        return {
            'image_id': image_id,
            'image': img
        }

In [32]:
# implement mish activation function
def f_mish(input):
    '''
    Applies the mish function element-wise:
    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
    '''
    return input * torch.tanh(F.softplus(input))

# implement class wrapper for mish activation function
class mish(nn.Module):
    '''
    Applies the mish function element-wise:
    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))

    Shape:
        - Input: (N, *) where * means, any number of additional
          dimensions
        - Output: (N, *), same shape as the input

    Examples:
        >>> m = mish()
        >>> input = torch.randn(2)
        >>> output = m(input)

    '''
    def __init__(self):
        '''
        Init method.
        '''
        super().__init__()

    def forward(self, input):
        '''
        Forward pass of the function.
        '''
        return f_mish(input)
    
# implement swish activation function
def f_swish(input):
    '''
    Applies the swish function element-wise:
    swish(x) = x * sigmoid(x)
    '''
    return input * torch.sigmoid(input)

# implement class wrapper for swish activation function
class swish(nn.Module):
    '''
    Applies the swish function element-wise:
    swish(x) = x * sigmoid(x)

    Shape:
        - Input: (N, *) where * means, any number of additional
          dimensions
        - Output: (N, *), same shape as the input

    Examples:
        >>> m = swish()
        >>> input = torch.randn(2)
        >>> output = m(input)

    '''
    def __init__(self):
        '''
        Init method.
        '''
        super().__init__()

    def forward(self, input):
        '''
        Forward pass of the function.
        '''
        return f_swish(input)

In [33]:
class LinearBottleNeck(nn.Module):

    def __init__(self, in_channels, out_channels, stride, t=6, class_num=10, activation = 'relu'):
        super().__init__()
        
        if activation == 'relu':
            f_activation = nn.ReLU6(inplace=True)
            
        if activation == 'swish':
            f_activation = swish()
            
        if activation == 'mish':
            f_activation = mish()

        self.residual = nn.Sequential(
            nn.Conv2d(in_channels, in_channels * t, 1),
            nn.BatchNorm2d(in_channels * t),
            f_activation,

            nn.Conv2d(in_channels * t, in_channels * t, 3, stride=stride, padding=1, groups=in_channels * t),
            nn.BatchNorm2d(in_channels * t),
            f_activation,

            nn.Conv2d(in_channels * t, out_channels, 1),
            nn.BatchNorm2d(out_channels)
        )

        self.stride = stride
        self.in_channels = in_channels
        self.out_channels = out_channels
    
    def forward(self, x):

        residual = self.residual(x)

        if self.stride == 1 and self.in_channels == self.out_channels:
            residual += x
        
        return residual
    
class MobileNetV2(nn.Module):

    def __init__(self, class_num=10, activation = 'relu'):
        super().__init__()
        
        if activation == 'relu':
            f_activation = nn.ReLU6(inplace=True)
            
        if activation == 'swish':
            f_activation = swish()
            
        if activation == 'mish':
            f_activation = mish()

        self.pre = nn.Sequential(
            nn.Conv2d(1, 32, 1, padding=1),
            nn.BatchNorm2d(32),
            f_activation
        )

        self.stage1 = LinearBottleNeck(32, 16, 1, 1, activation = activation)
        self.stage2 = self._make_stage(2, 16, 24, 2, 6, activation = activation)
        self.stage3 = self._make_stage(3, 24, 32, 2, 6, activation = activation)
        self.stage4 = self._make_stage(4, 32, 64, 2, 6, activation = activation)
        self.stage5 = self._make_stage(3, 64, 96, 1, 6, activation = activation)
        self.stage6 = self._make_stage(3, 96, 160, 1, 6, activation = activation)
        self.stage7 = LinearBottleNeck(160, 320, 1, 6, activation = activation)

        self.conv1 = nn.Sequential(
            nn.Conv2d(320, 1280, 1),
            nn.BatchNorm2d(1280),
            f_activation
        )

        self.conv2 = nn.Conv2d(1280, class_num, 1)
        
    def forward(self, x):
        x = self.pre(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)
        x = self.stage6(x)
        x = self.stage7(x)
        x = self.conv1(x)
        x = F.adaptive_avg_pool2d(x, 1)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)

        return x
    
    def _make_stage(self, repeat, in_channels, out_channels, stride, t, activation = 'relu'):

        layers = []
        layers.append(LinearBottleNeck(in_channels, out_channels, stride, t, activation = activation))
        
        while repeat - 1:
            layers.append(LinearBottleNeck(out_channels, out_channels, 1, t, activation = activation))
            repeat -= 1
        
        return nn.Sequential(*layers)

def mobilenetv2(activation = 'relu'):
    return MobileNetV2(class_num = 1000, activation = activation)

In [34]:
backbone_model = mobilenetv2()

In [35]:
class BengaliModel(nn.Module):
    def __init__(self, backbone_model):
        super(BengaliModel, self).__init__()
        #self.conv = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)
        self.backbone_model = backbone_model
        self.fc1 = nn.Linear(in_features=1000, out_features=168) # grapheme_root
        self.fc2 = nn.Linear(in_features=1000, out_features=11) # vowel_diacritic
        self.fc3 = nn.Linear(in_features=1000, out_features=7) # consonant_diacritic
        
    def forward(self, x):
        # pass through the backbone model
        #y = self.conv(x)
        y = self.backbone_model(x)
        
        # multi-output
        grapheme_root = self.fc1(y)
        vowel_diacritic = self.fc2(y)
        consonant_diacritic = self.fc3(y)
        
        return grapheme_root, vowel_diacritic, consonant_diacritic

In [36]:
model = BengaliModel(backbone_model)

In [37]:
state = torch.load('../input/bengaliaiutils/mobilenet_v2_15.pth', map_location=lambda storage, loc: storage)
model.load_state_dict(state["state_dict"])

<All keys matched successfully>

In [38]:
test_df = pd.read_csv(INPUT_PATH + '/test.csv')
submission_df = pd.read_csv(INPUT_PATH + '/sample_submission.csv')

device = torch.device("cuda:0")
model.to(device)

results = []

In [None]:
for i in range(4):
    parq = INPUT_PATH + '/test_image_data_{}.parquet'.format(i)
    test_dataset = BengaliParquetDataset(
        parquet_file=parq,
        transform=None
    )
    data_loader_test = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        num_workers=N_WORKERS,
        shuffle=False
    )

    print('Parquet {}'.format(i))

    model.eval()

    tk0 = tqdm(data_loader_test, desc="Iteration")

    for step, batch in enumerate(tk0):
        inputs = batch["image"]
        image_ids = batch["image_id"]
        inputs = inputs.to(device, dtype=torch.float)

        out_graph, out_vowel, out_conso = model(inputs)
        out_graph = F.softmax(out_graph, dim=1).data.cpu().numpy().argmax(axis=1)
        out_vowel = F.softmax(out_vowel, dim=1).data.cpu().numpy().argmax(axis=1)
        out_conso = F.softmax(out_conso, dim=1).data.cpu().numpy().argmax(axis=1)

        for idx, image_id in enumerate(image_ids):
            results.append(out_conso[idx])
            results.append(out_graph[idx])
            results.append(out_vowel[idx])

In [None]:
submission_df['target'] = results
submission_df.to_csv('./submission.csv', index=False)

In [None]:
submission_df