In [3]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import csv
import gzip
import numpy as np
import torch as t
import torch.nn as nn
from torch.autograd import Variable as V
import torch.utils.data as Data
from torchvision import models, datasets
from torchvision import transforms as T

t.manual_seed(777)   # random number seed

<torch._C.Generator at 0x7faba007fe10>

In [5]:
class NameDataset(Data.Dataset):
    """Diabetes dataset."""
    
    # Initialize your data, download, etc.
    def __init__(self, is_train_set=False):
        filename = './data/names_train.csv.gz' if is_train_set else './data/names_test.csv.gz'
        
        with gzip.open(filename, "rt") as f:
            reader = csv.reader(f)
            rows = list(reader)
        
        self.names = [row[0] for row in rows]
        self.countries = [row[1] for row in rows]
        self.len = len(self.countries)
        
        self.country_list = list(sorted(set(self.countries)))
        
    
    def __getitem__(self, index):
        return self.names[index], self.countries[index]
    
    
    def __len__(self):
        return self.len
    
    
    def get_countries(self):
        return self.country_list
    
    
    def get_country(self, id):
        return self.country_list[id]
    
    
    def get_country_id(self, country):
        return self.country_list.index(country)
    

In [7]:
dataset = NameDataset(False)
print("countries: ", dataset.get_countries())
print("id 3 map to country: ", dataset.get_country(3))
print("'Korean map to id: ", dataset.get_country_id('Korean'))

train_loader = Data.DataLoader(dataset=dataset, 
                               batch_size=10,
                               shuffle=True)

print(len(train_loader.dataset))
for epoch in range(2):
    for i, (names, countries) in enumerate(train_loader):
        # Run training process
        print(epoch, i, "names", names, "countries", countries)

countries:  ['Arabic', 'Chinese', 'Czech', 'Dutch', 'English', 'French', 'German', 'Greek', 'Irish', 'Italian', 'Japanese', 'Korean', 'Polish', 'Portuguese', 'Russian', 'Scottish', 'Spanish', 'Vietnamese']
id 3 map to country:  Dutch
'Korean map to id:  11
6700
0 0 names ('Pavlichenko', 'De laurentis', 'Naslednikov', 'Kajahara', 'Rutshtein', 'Byon', 'Schneider', 'Jongolovich', 'Mahoney', 'Mikhailyants') countries ('Russian', 'Italian', 'Russian', 'Japanese', 'Russian', 'Korean', 'German', 'Russian', 'English', 'Russian')
0 1 names ('Dizhbak', 'De la fuente', 'Srour', 'Deniskin', 'Bagdatiev', 'Biryukov', 'Bahmutsky', 'Yakimets', 'Toman', 'Zimovets') countries ('Russian', 'Spanish', 'Arabic', 'Russian', 'Russian', 'Russian', 'Russian', 'Russian', 'Russian', 'Russian')
0 2 names ('Senmatsu', 'Quraishi', 'Essa', 'Cumming', 'Finnegan', 'Haletsky', 'Hankeev', 'Gordusenko', 'Durylin', 'Vinding') countries ('Japanese', 'Arabic', 'Arabic', 'English', 'Irish', 'Russian', 'Russian', 'Russian', 'R