In [30]:
import numpy as np
import pandas as pd

from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data

import torchvision.models as models
from torchvision import transforms

In [31]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [32]:
HIDDEN_DIM = 512
OUTPUT_DIM = 14
MAX_LEN = 10
BATCH_SIZE = 32

In [33]:
IMG_PATH = 'cell_images/training_set/'
LABELS = 'cell_images/training_set_values.txt'
OUT_PATH = 'processed_images/training_set/'

In [34]:
def one_hot_embedding(labels, num_classes):
    """Embedding labels to one-hot form.

    Args:
      labels: (LongTensor) class labels, sized [N,].
      num_classes: (int) number of classes.

    Returns:
      (tensor) encoded labels, sized [N, #classes].
    """
    y = torch.eye(num_classes, dtype=torch.long) 
    return y[labels] 

In [35]:
def one_hot_digits(digits, max_length=MAX_LEN):
    
    cleaned_digits = []
    for digit in digits:
        digit = digit.replace('.','10')
        digit = digit.replace(',', '11')
        digit = digit.replace('-', '12')
        cleaned_digits.append(int(digit))

    cleaned_digits.append(13)
    cleaned_digits += [0] * (max_length - len(cleaned_digits))
    
    return one_hot_embedding(cleaned_digits, 14)

In [36]:
class OCRDataset(torch.utils.data.dataset.Dataset):
    def __init__(self, df, transforms = None):
        self.df = df
        self.images = self.df.values[:, 0]
        self.labels = self.df.values[:, 1]
        self.length = len(self.df.index)
        self.transforms = transforms
    
    def __getitem__(self, index):
        image_path = self.images[index]
        image = Image.open(f'{IMG_PATH}{image_path}')

        if self.transforms:
            image = self.transforms(image)
            
        label = self.labels[index]
        label = one_hot_digits(label)

        return (image, label)

    def __len__(self):
        return self.length

In [37]:
class CRNN(nn.Module):
    def __init__(self, backbone):
        super(CRNN, self).__init__()
        self.backbone = backbone
        self.linear2 = nn.Linear(HIDDEN_DIM, MAX_LEN)
        self.lstm = nn.LSTM(OUTPUT_DIM, HIDDEN_DIM, batch_first=True)
        self.out = nn.Linear(HIDDEN_DIM, OUTPUT_DIM)
        
    def forward(self, x, target):
        target = target.float()
        latent = self.backbone(x)
        length = self.linear2(latent)
        inputs = torch.zeros(BATCH_SIZE, 1, OUTPUT_DIM)
        hidden = (latent.unsqueeze(0), torch.zeros(1, BATCH_SIZE, HIDDEN_DIM))
        number = []
        
        for i in range(MAX_LEN):
            output, hidden = self.lstm(inputs, hidden)
            digit = self.out(output[:, -1, :])
            number.append(digit.unsqueeze(0))
            inputs = target[:, i, :].unsqueeze(1)
            
        return length, torch.cat(number, 0).transpose(0, 1)
    
    def evaluate(self, x):
        pass

In [38]:
resnet50 = models.resnet50(pretrained=True)
resnet50.fc = nn.Linear(2048, HIDDEN_DIM)

model = CRNN(resnet50)
model = model.to(device)

In [39]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [21]:
numEpochs = 50
batchSize = BATCH_SIZE

In [22]:
df = pd.read_csv(LABELS, sep=';')
mytransforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])
dataset = OCRDataset(df, transforms=mytransforms)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchSize, shuffle=True)

In [29]:
for epoch in range(numEpochs):
    running_loss = 0
    accuracies = []

    for features, target in tqdm(dataloader):
        features, target = features.to(device), target.to(device)

        optimizer.zero_grad()

        length, number = model(features, target)       
        labels = torch.max(target, -1)[1]
        
        loss = 0
        for i in range(number.size(1)):
            loss += criterion(number[:, i, :], labels[:, i])
            
        loss.backward()
        running_loss += loss.item()

        optimizer.step()
        
        y_hat = torch.max(number, -1)[1].cpu().numpy()
        labels = labels.cpu().numpy()
        
        acc = []
        for j in range(y_hat.shape[0]):
            acc.append((y_hat[i, :] == labels[i, :]).all())
            
        accuracies.append(np.sum(acc)/BATCH_SIZE)
    
    print('[Epoch {}] Loss: {:.5f} Accuracy: {:.5f}'.format(epoch, running_loss/len(dataloader), np.mean(accuracies)))





  0%|          | 0/235 [00:00<?, ?it/s][A[A[A[A

torch.Size([32, 10, 14])
torch.Size([32, 10])






  0%|          | 1/235 [00:19<1:17:02, 19.75s/it][A[A[A[A

[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]
torch.Size([32, 10, 14])
torch.Size([32, 10])






  1%|          | 2/235 [00:37<1:13:59, 19.05s/it][A[A[A[A

[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]
torch.Size([32, 10, 14])
torch.Size([32, 10])






  1%|▏         | 3/235 [00:54<1:11:59, 18.62s/it][A[A[A[A

[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]
torch.Size([32, 10, 14])
torch.Size([32, 10])


KeyboardInterrupt: 