In [1]:
import torch
print('Version', torch.__version__)
print('CUDA enabled:', torch.cuda.is_available())

Version 1.10.0+cu102
CUDA enabled: False


In [2]:
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

import sys
import os
import pickle
import re
import csv

import pt_util

In [3]:
DATA_DIR = 'dakshina_dataset_v1.0'
LANG = 'ta'
LANG_DIR = 'lexicons'

In [4]:
def load_file(dirpath, filename):
    # Load a single file
    path = os.path.join(dirpath, filename)
    data = []
    with open(path, "r") as f:
        tsv = csv.reader(f, delimiter="\t")
        for l in tsv:
            data.append(l)
    return data

def load_directory(dirpath, filename=None):
    # Create a list of all data from all files in directory
    if filename != None:
        return load_file(dirpath, filename)

    all_files = os.listdir(dirpath)
    all_data = []
    for f in all_files:
        all_data.append(load_file(dirpath, f))
    
    return all_data

In [5]:
d = load_directory("{}/{}/{}".format(DATA_DIR, LANG, LANG_DIR))
d

[[['ஃபார்ம்', 'faarm', '1'],
  ['ஃபார்ம்', 'farm', '2'],
  ['ஃபார்ம்', 'form', '1'],
  ['ஃபார்ம்', 'hpaarm', '1'],
  ['ஃபேஸ்', 'face', '3'],
  ['ஃபேஸ்', 'hpaes', '1'],
  ['ஃபேஸ்', 'pace', '2'],
  ['ஃபேஸ்', 'paes', '1'],
  ['ஃபேஸ்', 'phase', '1'],
  ['அஇஅதிமுக', 'aeathimuka', '1'],
  ['அஇஅதிமுக', 'aiathimuka', '1'],
  ['அஇஅதிமுக', 'ayiathimuka', '1'],
  ['அகத்தி', 'agaththi', '3'],
  ['அகத்தி', 'akaththi', '2'],
  ['அகத்திக்கீரை', 'agaththikkeerai', '3'],
  ['அகத்திக்கீரை', 'akaththikkeerai', '2'],
  ['அகமதாபாத்', 'agamadhabaath', '1'],
  ['அகமதாபாத்', 'agamathaabaath', '1'],
  ['அகமதாபாத்', 'ahamadabad', '1'],
  ['அகமதாபாத்', 'ahemadaabad', '1'],
  ['அகமதாபாத்', 'ahmadabad', '1'],
  ['அகமதாபாத்', 'ahmedabad', '3'],
  ['அகமதாபாத்', 'akamatapat', '1'],
  ['அகமதாபாத்தில்', 'agamadhabaatthil', '1'],
  ['அகமதாபாத்தில்', 'agamathaabaaththil', '1'],
  ['அகமதாபாத்தில்', 'ahmadabadil', '1'],
  ['அகமதாபாத்தில்', 'ahmedabadil', '1'],
  ['அகமதாபாத்தில்', 'ahmedabadthil', '1'],
  ['அகழாய்வில்', 'ag

In [None]:
# Any preprocessing we may need.

In [6]:
class TransliterateNet(nn.Module):
    def __init__(self, in_alph_size, out_alph_size, feature_size):
        super(TransliterateNet, self).__init__()
        # Encoder and Decoder RNN
        self.encoder = nn.Embedding(in_alph_size, self.feature_size)
        self.rnn = nn.RNN(self.feature_size, self.feature_size, 2)
        # Decoder embedding
        self.dec = nn.Linear(self.feature_size, out_alph_size)
        
    def forward(self, x, hidden_state=None):
        x = self.enc(x)
        x, hs = self.rnn(x, hidden_state)
        x = self.dec(x)
        return x, hs

    # This defines the function that gives a probability distribution and implements the temperature computation.
    def inference(self, x, hidden_state=None, temperature=1):
        x = x.view(-1, 1)
        x, hidden_state = self.forward(x, hidden_state)
        x = x.view(1, -1)
        x = x / max(temperature, 1e-20)
        x = F.softmax(x, dim=1)
        return x, hidden_state

    # Predefined loss function
    def loss(self, prediction, label, reduction='mean'):
        loss_val = F.cross_entropy(prediction, label)
        return loss_val

    # Saves the current model
    def save_model(self, file_path, num_to_keep=1):
        pt_util.save(self, file_path, num_to_keep)

    # Saves the best model so far
    def save_best_model(self, accuracy, file_path, num_to_keep=1):
        if accuracy > self.best_accuracy:
            self.save_model(file_path, num_to_keep)
            self.best_accuracy = accuracy

    def load_model(self, file_path):
        pt_util.restore(self, file_path)

    def load_last_model(self, dir_path):
        return pt_util.restore_latest(self, dir_path)

In [7]:
# Ripped from HW 1
import time
def train(model, device, train_loader, optimizer, epoch, log_interval):
    model.train()
    losses = []
    for batch_idx, (data, label) in enumerate(train_loader):
        data, label = data.to(device), label.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = model.loss(output, label)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('{} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                time.ctime(time.time()),
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    return np.mean(losses)

def test(model, device, test_loader, log_interval=None):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for batch_idx, (data, label) in enumerate(test_loader):
            data, label = data.to(device), label.to(device)
            output = model(data)
            test_loss_on = model.loss(output, label, reduction='sum').item()
            test_loss += test_loss_on
            pred = output.max(1)[1]
            correct_mask = pred.eq(label.view_as(pred))
            num_correct = correct_mask.sum().item()
            correct += num_correct
            if log_interval is not None and batch_idx % log_interval == 0:
                print('{} Test: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    time.ctime(time.time()),
                    batch_idx * len(data), len(test_loader.dataset),
                    100. * batch_idx / len(test_loader), test_loss_on))

    test_loss /= len(test_loader.dataset)
    test_accuracy = 100. * correct / len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset), test_accuracy))
    return test_loss, test_accuracy

In [None]:
def generate_transliteration(model, input_chars):
    transliterations = []
    hidden = None
    
    for c in input_chars:
        x, hidden = model.inference(c, hidden)
        transliterations.append(torch.argmax(x))
        
    return transliterations

In [8]:
def main():
    pass