##### Guideline: https://www.digitalocean.com/community/tutorials/how-to-build-a-neural-network-to-translate-sign-language-into-english

## Load Mnist dataset

In [None]:
!wget https://assets.digitalocean.com/articles/signlanguage_data/sign-language-mnist.tar.gz

In [None]:
!tar -xzf sign-language-mnist.tar.gz

## Imports

In [None]:
from torch.utils.data import Dataset
from torch.autograd import Variable
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import torch

from typing import List

import csv

import onnx
import onnxruntime as ort

In [None]:
class SignLanguageMNIST(Dataset):
    """Sign Language classification dataset.
    Utility for loading Sign Language dataset into PyTorch. Dataset posted on
    Kaggle in 2017, by an unnamed author with username `tecperson`:
    https://www.kaggle.com/datamunge/sign-language-mnist
    Each sample is 1 x 1 x 28 x 28, and each label is a scalar.
    """

    @staticmethod
    def get_label_mapping():
        """
        We map all labels to [0, 23]. This mapping from dataset labels [0, 23]
        to letter indices [0, 25] is returned below.
        """
        mapping = list(range(25))
        mapping.pop(9)
        return mapping

    @staticmethod
    def read_label_samples_from_csv(path: str):
        """
        Assumes first column in CSV is the label and subsequent 28^2 values
        are image pixel values 0-255.
        """
        mapping = SignLanguageMNIST.get_label_mapping()
        labels, samples = [], []
        with open(path) as f:
            _ = next(f)  # skip header
            for line in csv.reader(f):
                label = int(line[0])
                labels.append(mapping.index(label))
                samples.append(list(map(int, line[1:])))
        return labels, samples

    def __init__(self,
            path: str="data/sign_mnist_train.csv",
            mean: List[float]=[0.485],
            std: List[float]=[0.229]):
        """
        Args:
            path: Path to `.csv` file containing `label`, `pixel0`, `pixel1`...
        """
        labels, samples = SignLanguageMNIST.read_label_samples_from_csv(path)
        self._samples = np.array(samples, dtype=np.uint8).reshape((-1, 28, 28, 1))
        self._labels = np.array(labels, dtype=np.uint8).reshape((-1, 1))

        self._mean = mean
        self._std = std

    def __len__(self):
        return len(self._labels)

    def __getitem__(self, idx):
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomResizedCrop(28, scale=(0.8, 1.2)),
            transforms.ToTensor(),
            transforms.Normalize(mean=self._mean, std=self._std)])

        return {
            'image': transform(self._samples[idx]).float(),
            'label': torch.from_numpy(self._labels[idx]).float()
        }
    
class SignLanguageCollected(Dataset):

    @staticmethod
    def get_label_mapping():
        mapping = list(range(26))
        return mapping

    @staticmethod
    def read_label_samples_from_csv(path: str):
        """
        Assumes first column in CSV is the label and subsequent 28^2 values
        are image pixel values 0-255.
        """
        mapping = SignLanguageCollected.get_label_mapping()
        labels, samples = [], []
        with open(path) as f:
            _ = next(f)  # skip header
            for line in csv.reader(f):
                label = int(line[0])
                labels.append(mapping.index(label))
                samples.append(list(map(int, line[1:])))
        return labels, samples

    def __init__(self,
            path: str="../Data/collected_train.csv",
            mean: List[float]=[0.485],
            std: List[float]=[0.229]):
        """
        Args:
            path: Path to `.csv` file containing `label`, `pixel0`, `pixel1`...
        """
        labels, samples = SignLanguageCollected.read_label_samples_from_csv(path)
        self._samples = np.array(samples, dtype=np.uint8).reshape((-1, 28, 28, 1))
        self._labels = np.array(labels, dtype=np.uint8).reshape((-1, 1))

        self._mean = mean
        self._std = std

    def __len__(self):
        return len(self._labels)

    def __getitem__(self, idx):
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomResizedCrop(28, scale=(0.8, 1.2)),
            transforms.ToTensor(),
            transforms.Normalize(mean=self._mean, std=self._std)])

        return {
            'image': transform(self._samples[idx]).float(),
            'label': torch.from_numpy(self._labels[idx]).float()
        }

In [None]:
def get_train_test_loaders(use_mnist_dataset,batch_size=32):
    
    if(use_mnist_dataset):
        print("Using kaggle get_train_test_loader")
        trainset = SignLanguageMNIST('data/sign_mnist_train.csv')
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
        
        testset = SignLanguageMNIST('data/sign_mnist_test.csv')
        testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
    else:
        print("Using collected get_train_test_loader")
        trainset = SignLanguageCollected('../Data/collected_train.csv')
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
        
        testset = SignLanguageCollected('../Data/collected_test.csv')
        testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
        
    return trainloader, testloader


if __name__ == '__main__':
    loader, _ = get_train_test_loaders(True,2)
    print(next(iter(loader)))

## Training

In [None]:
class Net(nn.Module):
    def __init__(self,use_mnist_dataset):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 6, 3)
        self.conv3 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 48)
        if(use_mnist_dataset):
            self.fc3 = nn.Linear(48, 24)
        else:
            self.fc3 = nn.Linear(48, 26)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def main(use_mnist_dataset,checkpoint_name):
    if(use_mnist_dataset):
        net = Net(use_mnist_dataset).float()
    else:
        net = Net(use_mnist_dataset).float()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    trainloader, testloader = get_train_test_loaders(use_mnist_dataset=use_mnist_dataset)
    for epoch in range(12):  # loop over the dataset multiple times
        train(net, criterion, optimizer, trainloader, epoch)
        scheduler.step()
    torch.save(net.state_dict(), checkpoint_name)


def train(net, criterion, optimizer, trainloader, epoch):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs = Variable(data['image'].float())
        labels = Variable(data['label'].long())
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels[:, 0])
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 100 == 0:
            print('[%d, %5d] loss: %.6f' % (epoch, i, running_loss / (i + 1)))




## Evaluation

In [None]:


def evaluate(outputs: Variable, labels: Variable) -> float:
    """Evaluate neural network outputs against non-one-hotted labels."""
    Y = labels.numpy()
    Yhat = np.argmax(outputs, axis=1)
    return float(np.sum(Yhat == Y))


def batch_evaluate(
        net: Net,
        dataloader: torch.utils.data.DataLoader) -> float:
    """Evaluate neural network in batches, if dataset is too large."""
    score = n = 0.0
    for batch in dataloader:
        n += len(batch['image'])
        outputs = net(batch['image'])
        if isinstance(outputs, torch.Tensor):
            outputs = outputs.detach().numpy()
        score += evaluate(outputs, batch['label'][:, 0])
    return score / n


def validate(checkpoint_name,filename,use_mnist_dataset):
    trainloader, testloader = get_train_test_loaders(use_mnist_dataset=use_mnist_dataset)
    if(use_mnist_dataset):
        net = Net(use_mnist_dataset).float().eval()
    else:
        net = Net(use_mnist_dataset).float().eval()

    pretrained_model = torch.load(checkpoint_name)
    net.load_state_dict(pretrained_model)

    print('=' * 10, 'PyTorch', '=' * 10)
    train_acc = batch_evaluate(net, trainloader) * 100.
    print('Training accuracy: %.1f' % train_acc)
    test_acc = batch_evaluate(net, testloader) * 100.
    print('Validation accuracy: %.1f' % test_acc)

    trainloader, testloader = get_train_test_loaders(use_mnist_dataset=use_mnist_dataset,batch_size=1)

    # export to onnx
    fname = filename
    dummy = torch.randn(1, 1, 28, 28)
    torch.onnx.export(net, dummy, fname, input_names=['input'])

    # check exported model
    model = onnx.load(fname)
    onnx.checker.check_model(model)  # check model is well-formed

    # create runnable session with exported model
    ort_session = ort.InferenceSession(fname)
    net = lambda inp: ort_session.run(None, {'input': inp.data.numpy()})[0]

    print('=' * 10, 'ONNX', '=' * 10)
    train_acc = batch_evaluate(net, trainloader) * 100.
    print('Training accuracy: %.1f' % train_acc)
    test_acc = batch_evaluate(net, testloader) * 100.
    print('Validation accuracy: %.1f' % test_acc)

In [None]:
if __name__ == '__main__':
    USE_MNIST_DATASET = True
    
    if (USE_MNIST_DATASET):
        main(use_mnist_dataset=USE_MNIST_DATASET,checkpoint_name="checkpoint_kaggle.pth")
        validate(checkpoint_name = "checkpoint_kaggle.pth",filename="signlanguage_kaggle.onnx",use_mnist_dataset=USE_MNIST_DATASET)
    else:
        main(use_mnist_dataset=USE_MNIST_DATASET,checkpoint_name="checkpoint_collected.pth")
        validate(checkpoint_name = "checkpoint_collected.pth",filename="signlanguage_collected.onnx",use_mnist_dataset=USE_MNIST_DATASET)