In [272]:
import math
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
from tqdm.notebook import tqdm

In [283]:
labels = {}
with open('dates_simple.tsv', 'r') as f:
    for n, line in enumerate(f.readlines()):
        items = line.split('\t')
        normal = items[0].strip()
        variant = items[1].strip()
        labels[normal] = variant
        if n == -1:
            break
            
print(len(labels))

8453


In [284]:
max_len = 0
characters = set()

for normal, variant in labels.items():
    characters.update(normal)
    characters.update(variant)
    if len(normal) > max_len:
        max_len = len(normal)
    if len(variant) > max_len:
        max_len = len(variant)
        
print('num of characters', len(characters))

characters.update(['[unk]', '[cls]', '[sep]', '[pad]'])
max_len *= 2 + 2

print(max_len)

num_embeddings = len(characters)

num of characters 11
40


In [285]:
character_to_idx = {}

for idx, character in enumerate(characters):
    character_to_idx[character] = idx    

character_to_idx

{'7': 0,
 '1': 1,
 '2': 2,
 '4': 3,
 '5': 4,
 '3': 5,
 '0': 6,
 '9': 7,
 '[cls]': 8,
 '[sep]': 9,
 '[pad]': 10,
 '[unk]': 11,
 '/': 12,
 '6': 13,
 '8': 14}

In [286]:
train_x = []
train_y = []

def to_idxs(string):
    res = []
    for character in string:
        if character in character_to_idx:
            res.append(character_to_idx[character])
        else:
            res.append(character_to_idx['[unk]'])
    return res;

def get_x(normal, variant):
    x = [character_to_idx['[cls]']]
    x += to_idxs(normal)
    x += [character_to_idx['[sep]']]
    x += to_idxs(variant)
    return x

def add_padding(x):
    x += [character_to_idx['[pad]']] * (max_len - len(x))
    return x

# positive
for normal, variant in labels.items():
    x = get_x(normal, variant)
    x = add_padding(x)
    #print(normal, variant, x)
    train_x.append(x)
    train_y.append(1)

# negative
all_strings = list(labels.keys()) + list(labels.values())
for normal, _ in labels.items():
    # this could be improved
    variant = np.random.choice(all_strings)
    x = get_x(normal, variant)
    x = add_padding(x)
    #print(normal, variant, x)
    train_x.append(x)
    train_y.append(0)

In [287]:
len(train_x), len(train_y)

(16906, 16906)

In [288]:
all_train_x = torch.tensor(train_x, dtype=torch.long)
all_train_y = torch.tensor(train_y, dtype=torch.long)

In [289]:
all_train_x.shape, all_train_y.shape

(torch.Size([16906, 40]), torch.Size([16906]))

In [290]:
train_set = data_utils.TensorDataset(all_train_x, all_train_y)
train_loader = data_utils.DataLoader(train_set, batch_size = 50, shuffle=True)

In [291]:
class PositionalEncoding:
    
    def __init__(self, embedding_dim, max_len=100):
        self.pe = torch.zeros(max_len, embedding_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
        self.pe[:, 0::2] = torch.sin(position * div_term)
        self.pe[:, 1::2] = torch.cos(position * div_term)
        
    def encode(self, x):
        return x + self.pe

In [296]:
class Model(nn.Module):
    
    def __init__(self, 
                 num_embeddings, 
                 embedding_dim, 
                 nhead=8, num_layers=2, dim_feedforward=1024, max_len=100):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.embed = nn.Embedding(num_embeddings, embedding_dim)
        self.pos_encoder = PositionalEncoding(embedding_dim, max_len)
        self.transformer_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(embedding_dim, nhead, dim_feedforward), num_layers)
        self.fc1 = nn.Linear(embedding_dim*max_len, (int) (embedding_dim/2))
        self.fc2 = nn.Linear((int) (embedding_dim/2), 2)
        self.criterion = nn.CrossEntropyLoss()
    
    def forward(self, x, y=None):
        y_pred = self.embed(x)
        y_pred = self.pos_encoder.encode(y_pred) * math.sqrt(self.embedding_dim)  
        y_pred = self.transformer_encoder(y_pred)
        y_pred = y_pred.reshape((y_pred.shape[0], -1))
        y_pred = self.fc1(y_pred)
        y_pred = self.fc2(y_pred)
        if y is None:
            y_pred = nn.functional.softmax(y_pred, dim=0)[:, 0]
            return y_pred
        else:
            return self.criterion(y_pred, y)
        

In [297]:
model = Model(num_embeddings, 64, max_len=max_len)

In [298]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
num_running_loss = 5

for epoch in range(1000):
    
    model.train()
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):

        inputs, labels = data
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        loss = model(inputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % num_running_loss == num_running_loss - 1:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 5))
            running_loss = 0.0
        
    model.eval()
    with torch.no_grad():
        mse = 0
        n = 0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            output = model(inputs)
            mse += torch.sum((output - labels)**2)
            n += len(labels)
        mse /= n
        print('train mse %.3f' % mse)

[1,     5] loss: 0.720
[1,    10] loss: 0.674
[1,    15] loss: 0.632
[1,    20] loss: 0.611
[1,    25] loss: 0.625
[1,    30] loss: 0.611
[1,    35] loss: 0.596
[1,    40] loss: 0.563
[1,    45] loss: 0.568
[1,    50] loss: 0.558
[1,    55] loss: 0.553
[1,    60] loss: 0.557
[1,    65] loss: 0.513
[1,    70] loss: 0.474
[1,    75] loss: 0.522
[1,    80] loss: 0.487
[1,    85] loss: 0.540
[1,    90] loss: 0.529
[1,    95] loss: 0.527
[1,   100] loss: 0.515
[1,   105] loss: 0.479
[1,   110] loss: 0.510
[1,   115] loss: 0.508
[1,   120] loss: 0.467
[1,   125] loss: 0.527
[1,   130] loss: 0.498
[1,   135] loss: 0.450
[1,   140] loss: 0.445
[1,   145] loss: 0.504
[1,   150] loss: 0.454
[1,   155] loss: 0.497
[1,   160] loss: 0.470
[1,   165] loss: 0.502
[1,   170] loss: 0.465
[1,   175] loss: 0.547
[1,   180] loss: 0.516
[1,   185] loss: 0.526
[1,   190] loss: 0.442
[1,   195] loss: 0.491
[1,   200] loss: 0.468
[1,   205] loss: 0.465
[1,   210] loss: 0.490
[1,   215] loss: 0.447
[1,   220] 

In [None]:
for parameter in model.parameters():
    print(parameter)