In [2]:
from argparse import Namespace
from collections import Counter
import json
import os
import string

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm_notebook

## Dataset class

In [9]:
class SurnameDataset(Dataset):
    def __init__(self,surname_df,surname_voectorizer):
        self.df = surname_df
        self.vectorizer = surname_voectorizer
        
        self.train_df = self.df[self.df.split == 'train']
        self.train_size = len(self.tarin_df)
        
        self.valid_df = self.df[self.df.split == 'valid']
        self.valid_size = len(self.valid_df)
        
        self.test_df = self.df[self.df.split == 'test']
        self.test_size = len(self.test_df)
        
        self.lookup_dict = {'train': (self.train_df, self.train_size),
                             'valid': (self.valid_df, self.valid_size),
                             'test': (self.test_df, self.test_size)}
        
    @classmethod    
    def load_data_build_vectorizer(cls,surname_path,vectorizer_path):
        surname_df = pd.read_csv(surname_path)
        vectorizer = cls.load_vectorizer_only(vectorizer_path)
        return cls(surname_df, vectorizer)
    
    @staticmethod
    def load_vectorizer_only(vectorizer_path):
        with open(vectorizer_path) as fp:
            return SurnameVectorizer.from_serializable(json.load(fp))
        
    def save_vectorizer(self, vectorizer_path):
   
        with open(vectorizer_path, "w") as fp:
            json.dump(self._vectorizer.to_serializable(), fp)
    
    def get_vectorizer(self):
            
        return self.vectorizer
    
    def __len__(self):
        return self._target_size
    
    def set_split(self, split="train"):
        self._target_split = split
        self._target_df, self._target_size = self.lookup_dict[split]

    def __len__(self):
        return self._target_size
    
    def __getitem__(self, index):
        
        row = self._target_df.iloc[index]

        surname_vector = self.vectorizer.vectorize(row.names)
        nationality_index = self.vectorizer.nationality_vocab.lookup_token(row.namescoutry)
        return {'x_surname': surname_vector,
                'y_nationality': nationality_index}
    
    def get_num_batches(self, batch_size):

        return len(self) // batch_size


In [10]:
def generate_batches(dataset, batch_size, shuffle=True,
                     drop_last=True, device="cpu"): 

    dataloader = DataLoader(dataset=dataset, batch_size=batch_size,
                            shuffle=shuffle, drop_last=drop_last)

    for data_dict in dataloader:
        out_data_dict = {}
        for name, tensor in data_dict.items():
            out_data_dict[name] = data_dict[name].to(device)
        yield out_data_dict

In [None]:
class Vocabulary(object): # vocabulary should include all data we have both testing and training
    def __init__(self,token2idx = None, add_UNK = True, UNK = '<UNK>'):
        
        if token2idx == None:
            token2idx = {}
        
        self.token2idx = token2idx
        self.idx2token = {v,k for k,v in self.token2idx.items()} #initial two dicts
        
        self.add_UNK = add_UNK
        self.UNK = UNK
        
        self.UNK_index = -1 # the 0 based in python
        
        if add_UNK:
            self.UNK_index = self.add_token(UNK)
        
    def serialized(self):
        return {'token2idx' : self.token2idx,
               'add_UNK': self.add_UNK,
               'UNK':self.UNK}
    
    @classmethod
    def from_serialied(cls,content):
        return cls(**content)
    
    def add_token(self,token):
        
        try:
            idx = self.token2idx[token]
        except KeyError:
            idx = len(self.idx2token)
            self.token2idx[token] = idx
            self.idx2token[idx] = token
            
    def add_manytoken(self,tokens):
        
        return [self.add_token(t) for t in tokens]
    
    def look_up_token(self,token):
        
        if self.UNK_index >= 0:
            return self.token2idx.get(token,self.UNK)
        else: # We dont have unknow word, so that we directly get it from dict.
            return self.token2idx[token]
            
            
    
        
        
    