In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torch import nn
import torchvision

from tqdm import tqdm
from torch.utils.data import Dataset
from albumentations import Compose
from albumentations.pytorch import ToTensorV2
import albumentations as albu

import PIL

import cv2 as cv

INPUT_PATH = '/kaggle/input/bengaliai-cv19'

In [2]:
# ======================
# Params
BATCH_SIZE = 64
N_WORKERS = 4

# My weights dataset for this compeititon; feel free to vote the dataste ;)
# https://www.kaggle.com/pestipeti/bengali-ai-model-weights
WEIGHTS_FILE = '/kaggle/input/bengali-ai-model-weights/baseline_weights.pth'

In [3]:
# setup image hight and width
HEIGHT = 137
WIDTH = 236

SIZE = 32

def threshold_image(img):
    '''
    Helper function for thresholding the images
    '''
    gray = PIL.Image.fromarray(np.uint8(img), 'L')
    ret,th = cv.threshold(np.array(gray),0,255,cv.THRESH_BINARY+cv.THRESH_OTSU)
    return th

train_transforms = albu.Compose([
        # compose the random cropping and random rotation
        albu.CenterCrop(height = 128, width = 128),
        #albu.Rotate(limit=5, p=p),
        albu.Resize(height = SIZE, width = SIZE)
    ], p=1.0)

valid_transforms =  albu.Compose([
        # compose the random cropping and random rotation
        albu.CenterCrop(height = 128, width = 128),
        albu.Resize(height = SIZE, width = SIZE)
    ], p=1.0)

def get_image(idx, df, labels):
    '''
    Helper function to get the image and label from the training set
    '''
    # get the image id by idx
    image_id = df.iloc[idx].image_id
    # get the image by id
    img = df[df.image_id == image_id].values[:, 1:].reshape(HEIGHT, WIDTH).astype(float)
    # get the labels
    row = labels[labels.image_id == image_id]
    
    # return labels as tuple
    labels = row['grapheme_root'].values[0], \
    row['vowel_diacritic'].values[0], \
    row['consonant_diacritic'].values[0]
    
    return img, labels

def get_validation(idx, df):
    '''
    Helper function to get the validation image and image_id from the test set
    '''
    # get the image id by idx
    image_id = df.iloc[idx].image_id
    # get the image by id
    img = df[df.image_id == image_id].values[:, 1:].reshape(HEIGHT, WIDTH).astype(float)
    return img, image_id

In [4]:
class BengaliParquetDataset(Dataset):

    def __init__(self, parquet_file, transform=None):

        self.data = pd.read_parquet(parquet_file)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data.iloc[idx, 1:].values.reshape(HEIGHT, WIDTH)
        image_id = self.data.iloc[idx, 0]
        
        img = threshold_image(img)

        aug = valid_transforms(image = img)
        img = TF.to_tensor(aug['image'])

        return {
            'image_id': image_id,
            'image': img
        }

In [5]:
class BengaliModel(nn.Module):
    def __init__(self):
        super(BengaliModel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1)
        
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        
        self.pool1 = nn.MaxPool2d(2)
        
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)
        
        self.conv4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        
        self.pool2 = nn.MaxPool2d(2)
        
        self.conv5 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=5, padding=2)
        
        self.fc1 = nn.Linear(in_features=16384, out_features=1024)
        self.fc2 = nn.Linear(in_features=1024, out_features=512) 
        
        self.fc3 = nn.Linear(in_features=512, out_features=168) # grapheme_root
        self.fc4 = nn.Linear(in_features=512, out_features=11) # vowel_diacritic
        self.fc5 = nn.Linear(in_features=512, out_features=7) # consonant_diacritic
        
    def forward(self, x):
        y = F.relu(self.conv1(x))
        
        y = F.relu(self.conv2(y))
        
        y = self.pool1(y)
        
        y = F.relu(self.conv3(y))
        
        y = F.relu(self.conv4(y))
        
        y = self.pool2(y)
        
        y = F.relu(self.conv5(y))
        
        # flatten
        y = y.reshape(y.size(0), -1)
        
        y = F.relu(self.fc1(y))
        y = F.relu(self.fc2(y))
        
        # multi-output
        grapheme_root = self.fc3(y)
        vowel_diacritic = self.fc4(y)
        consonant_diacritic = self.fc5(y)
        
        return grapheme_root, vowel_diacritic, consonant_diacritic

In [6]:
model = BengaliModel()

In [7]:
state = torch.load('../input/bengaliaiutils/cnn-1_4.pth', map_location=lambda storage, loc: storage)
model.load_state_dict(state["state_dict"])

<All keys matched successfully>

In [8]:
test_df = pd.read_csv(INPUT_PATH + '/test.csv')
submission_df = pd.read_csv(INPUT_PATH + '/sample_submission.csv')

device = torch.device("cuda:0")
model.to(device)

results = []

In [9]:
for i in range(4):
    parq = INPUT_PATH + '/test_image_data_{}.parquet'.format(i)
    test_dataset = BengaliParquetDataset(
        parquet_file=parq,
        transform=None
    )
    data_loader_test = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        num_workers=N_WORKERS,
        shuffle=False
    )

    print('Parquet {}'.format(i))

    model.eval()

    tk0 = tqdm(data_loader_test, desc="Iteration")

    for step, batch in enumerate(tk0):
        inputs = batch["image"]
        image_ids = batch["image_id"]
        inputs = inputs.to(device, dtype=torch.float)

        out_graph, out_vowel, out_conso = model(inputs)
        out_graph = F.softmax(out_graph, dim=1).data.cpu().numpy().argmax(axis=1)
        out_vowel = F.softmax(out_vowel, dim=1).data.cpu().numpy().argmax(axis=1)
        out_conso = F.softmax(out_conso, dim=1).data.cpu().numpy().argmax(axis=1)

        for idx, image_id in enumerate(image_ids):
            results.append(out_conso[idx])
            results.append(out_graph[idx])
            results.append(out_vowel[idx])

Iteration:   0%|          | 0/1 [00:00<?, ?it/s]

Parquet 0


Iteration: 100%|██████████| 1/1 [00:01<00:00,  1.18s/it]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]

Parquet 1


Iteration: 100%|██████████| 1/1 [00:00<00:00,  3.05it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]

Parquet 2


Iteration: 100%|██████████| 1/1 [00:00<00:00,  3.02it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]

Parquet 3


Iteration: 100%|██████████| 1/1 [00:00<00:00,  2.89it/s]


In [10]:
submission_df['target'] = results
submission_df.to_csv('./submission.csv', index=False)

In [11]:
submission_df

Unnamed: 0,row_id,target
0,Test_0_consonant_diacritic,0
1,Test_0_grapheme_root,3
2,Test_0_vowel_diacritic,0
3,Test_1_consonant_diacritic,0
4,Test_1_grapheme_root,118
5,Test_1_vowel_diacritic,2
6,Test_2_consonant_diacritic,0
7,Test_2_grapheme_root,19
8,Test_2_vowel_diacritic,0
9,Test_3_consonant_diacritic,0
