In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

Core: Nehalem


In [45]:
# A simple encoder decoder design to do machine translation
class Encoder(nn.Module):
    
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
        
    def forward(self, x):
        x = self.embedding(x)
        o, h = self.gru(x)
        return o, h


class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
        
    def forward(self, x, h):
        x = self.embedding(x)
        o, h = self.gru(x, h)
        o = self.fc(o)
        return o, h     

In [159]:
from faker import Faker
import random
from tqdm import tqdm
from babel.dates import format_date

In [160]:
fake = Faker()
Faker.seed(12345)
random.seed(12345)

In [161]:
FORMATS = ['short',
           'medium',
           'long',
           'full',
           'full',
           'full',
           'full',
           'full',
           'full',
           'full',
           'full',
           'full',
           'full',
           'd MMM YYY', 
           'd MMMM YYY',
           'dd MMM YYY',
           'd MMM, YYY',
           'd MMMM, YYY',
           'dd, MMM YYY',
           'd MM YY',
           'd MMMM YYY',
           'MMMM d YYY',
           'MMMM d, YYY',
           'dd.MM.YY']

# change this if you want it to work with another language
LOCALES = ['en_US']

In [179]:
def load_date():
    """
        Loads some fake dates 
        :returns: tuple containing human readable string, machine readable string, and date object
    """
    dt = fake.date_object()

    try:
        human_readable = format_date(dt, format=random.choice(FORMATS),  locale='en_US') # locale=random.choice(LOCALES))
        human_readable = human_readable.lower()
        human_readable = human_readable.replace(',','')
        machine_readable = dt.isoformat()
        
    except AttributeError as e:
        return None, None, None

    return human_readable, machine_readable, dt

def load_dataset(m):
    """
        Loads a dataset with m examples and vocabularies
        :m: the number of examples to generate
    """
    
    human_vocab = set()
    machine_vocab = set()
    dataset = []
    

    for i in tqdm(range(m)):
        h, m, _ = load_date()
        if h is not None:
            dataset.append((h, m))
            human_vocab.update(tuple(h))
            machine_vocab.update(tuple(m))
    
    human = dict(zip(sorted(human_vocab) + ['<unk>', '<pad>'], 
                     list(range(len(human_vocab) + 2))))
    inv_machine = dict(enumerate(sorted(machine_vocab)))
    machine = {v:k for k,v in inv_machine.items()}
    return dataset, human, machine, inv_machine

In [180]:
m = 10000
dataset, human_vocab, machine_vocab, inv_machine_vocab = load_dataset(m)

100%|██████████| 10000/10000 [00:00<00:00, 48397.30it/s]


In [181]:
dataset[:1]

[('sunday june 28 1998', '1998-06-28')]

In [183]:
human_dates, machine_dates = zip(*dataset)

X, Y = [], []

Tx = 30
Ty = 10

for string in human_dates:
    rep = list(map(lambda x: human_vocab.get(x, '<unk>'), string))
    if len(string) < Tx:
        rep += [human_vocab['<pad>']] * (Tx - len(string))
    X.append(rep)
    

for date in machine_dates:
    rep = list(map(lambda x: machine_vocab.get(x), date))
    Y.append(rep)

In [185]:
X = torch.tensor(X, dtype=torch.long)
Y = torch.tensor(Y, dtype=torch.long)

In [186]:
Y.shape

torch.Size([10000, 10])

In [187]:
len(human_vocab)

37

In [188]:
from torch.utils.data import Dataset, DataLoader

In [189]:
class CustomDataset(Dataset):
    
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, index):
        return self.X[index], self.Y[index]
    

In [190]:
batch_size = 32
dataset = CustomDataset(X, Y)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [90]:
human_vocab_size = len(human_vocab)
machine_vocab_size = len(machine_vocab)

embedding_dim = 64
hidden_dim = 64

encoder = Encoder(human_vocab_size, embedding_dim, hidden_dim)
decoder = Decoder(machine_vocab_size, embedding_dim, hidden_dim)

criterion = nn.CrossEntropyLoss()
encoder_optimizer = optim.Adam(encoder.parameters())
decoder_optimizer = optim.Adam(decoder.parameters())

In [252]:
n_epochs = 10
for epoch in range(n_epochs):
    total_loss = 0
    for X, Y in dataloader:
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        _, encoder_hidden = encoder(X)
        #print(encoder_hidden.shape)
        
        
        decoder_input = Y[:, 0:1]
        decoder_hidden = encoder_hidden
        loss = 0
        for t in range(Ty):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            loss += criterion(decoder_output.squeeze(), Y[: ,t])
            decoder_input = Y[:, t:t+1]
        
        loss.backward()
        encoder_optimizer.step()
        decoder_optimizer.step()
        
        total_loss += loss.item()
        
    print(f'Epoch [{epoch+1}/{n_epochs}], Loss: {total_loss/len(dataloader)}')   
    
    

Epoch [1/10], Loss: 0.3404660337982467
Epoch [2/10], Loss: 0.04244079551519677
Epoch [3/10], Loss: 0.0910513338754876
Epoch [4/10], Loss: 0.5122210739519649
Epoch [5/10], Loss: 0.17097430388196208
Epoch [6/10], Loss: 0.03939784933833745
Epoch [7/10], Loss: 0.08137495556887918
Epoch [8/10], Loss: 0.6496298866816603
Epoch [9/10], Loss: 0.09856251039253637
Epoch [10/10], Loss: 0.04138681307113685


In [262]:
def decode_date(human_date):
    rep = list(map(lambda x: human_vocab.get(x, '<unk>'), human_date))
    if len(human_date) < Tx:
        rep += [human_vocab['<pad>']] * (Tx - len(human_date))
    
    X = torch.tensor(rep, dtype=torch.long)
    X = X.view(1, X.shape[0])
    
    with torch.no_grad():
        _, encoder_hidden = encoder(X)
        
        date_start_idx = machine_vocab.get("1")
        decoder_input = torch.tensor([date_start_idx]).view(1, 1)
        decoder_hidden = encoder_hidden
        output = [date_start_idx]
        for ty in range(Ty):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            #print(decoder_output.shape)
            top_token_idx = torch.argmax(decoder_output, dim=2)
            output.append(top_token_idx.item())
            #print(top_token_idx.shape)
            decoder_input = top_token_idx  
    
    output = output[1:]
    ds = list(map(lambda x: inv_machine_vocab.get(x), output))
    ds = "".join(ds)
    return ds

In [263]:
decode_date("sunday october 22 1996")

'1996-10-22'

In [264]:
torch.save(encoder, "encoder_basic_nmt.pth")
torch.save(decoder, "decoder_basic_nmt.pth")