In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import random
import math
import tqdm
import numpy as np
import pandas as pd

torch.manual_seed(192837)
rand = random.Random(192838)

In [2]:
CITY_COLUMN = 'city_ascii'
COUNTRY_COLUMN = 'iso2'

In [3]:
df_all = pd.read_csv('../emb/data/worldcities.csv')

'''
df_top_countries = df_all.groupby(COUNTRY_COLUMN).size().sort_values(ascending=False).reset_index()
classes = list(df_top_countries.head(N_CLASSES)[COUNTRY_COLUMN])
'''
classes = ['IN']
N_CLASSES = len(classes)

print(classes)
class_to_id = {classes[i]: i for i in range(N_CLASSES)}

df_all = df_all[df_all[COUNTRY_COLUMN].isin(classes)]
df_all = df_all[df_all[CITY_COLUMN].notna()]
df_all.head(10)

['IN']


Unnamed: 0,city,city_ascii,lat,lng,country,iso2,iso3,admin_name,capital,population,id
2,Delhi,Delhi,28.61,77.23,India,IN,IND,Delhi,admin,32226000.0,1356872604
4,Mumbai,Mumbai,19.0761,72.8775,India,IN,IND,Mahārāshtra,admin,24973000.0,1356226629
10,Kolkāta,Kolkata,22.5675,88.37,India,IN,IND,West Bengal,admin,21747000.0,1356060520
21,Bangalore,Bangalore,12.9789,77.5917,India,IN,IND,Karnātaka,admin,15386000.0,1356410365
29,Chennai,Chennai,13.0825,80.275,India,IN,IND,Tamil Nādu,admin,12395000.0,1356374944
39,Hyderābād,Hyderabad,17.3617,78.4747,India,IN,IND,Telangāna,admin,10494000.0,1356871768
55,Pune,Pune,18.5203,73.8567,India,IN,IND,Mahārāshtra,,8231000.0,1356081074
58,Ahmedabad,Ahmedabad,23.0225,72.5714,India,IN,IND,Gujarāt,minor,8009000.0,1356304381
79,Sūrat,Surat,21.205,72.84,India,IN,IND,Gujarāt,,6538000.0,1356758738
90,Prayagraj,Prayagraj,25.4358,81.8464,India,IN,IND,Uttar Pradesh,,5954391.0,1356718332


In [4]:
character_set = set()
for city in df_all[CITY_COLUMN]:
    for ch in city:
        character_set.add(ch)

PAD_TOKEN = 0
tokens = ['<PAD>'] + sorted(list(character_set))
token_to_id = {tokens[i]: i for i in range(len(tokens))}
N_VOCAB = len(tokens)

''.join(tokens)

'<PAD> -ABCDEFGHIJKLMNOPQRSTUVWYZabcdefghijklmnopqrstuvwxyz'

In [5]:
def encode(city: str, max_length: int):
    encoded = [token_to_id[city[i]] for i in range(min(max_length, len(city)))]
    return encoded + [PAD_TOKEN] * (max_length - len(encoded))

def decode(city_enc: list[int], decode_special: bool=False):
    return ''.join([tokens[tk] for tk in city_enc if (tk != PAD_TOKEN or decode_special)])

In [6]:
city_length_90_perc = np.quantile(np.array(list(map(len, df_all[CITY_COLUMN]))), 0.8)

SEQ_LENGTH = int(city_length_90_perc)
SEQ_LENGTH

11

In [7]:
data_list_x = [encode(city, SEQ_LENGTH) for city in df_all[CITY_COLUMN]]
data_list_y = [class_to_id[iso2] for iso2 in df_all[COUNTRY_COLUMN]]
data_list = list(zip(data_list_x, data_list_y))
rand.shuffle(data_list)

data_x = torch.tensor(np.array([x for x, _ in data_list], dtype=np.long))
data_y = torch.tensor(np.array([y for _, y in data_list], dtype=np.long))

for i in range(5):
    print(decode(data_x[i]), classes[data_y[i]])

TRAIN_TEST_SPLIT = 0.9
n_train = round(data_x.shape[0] * TRAIN_TEST_SPLIT)
train_x, train_y = data_x[:n_train, :], data_y[:n_train]
test_x,  test_y  = data_x[n_train:, :], data_y[n_train:]

print(data_x.shape[0], n_train, data_x.shape[0] - n_train)

Barah IN
Chittaurgar IN
Bhogpur IN
Jami IN
Bhadreswar IN
7108 6397 711


In [8]:
class RNNGenerator(nn.Module):
    def __init__(self,
                 emb_dim: int,
                 hidden_dim: int,
                 hidden_layers: int=1):
        super().__init__()

        self.emb = nn.Linear(N_VOCAB, emb_dim, bias=False)

        self.rnn = nn.GRU(emb_dim, hidden_dim, num_layers=hidden_layers, batch_first=True)
        self.head_1 = nn.Linear(hidden_dim, hidden_dim)
        self.head_2 = nn.Linear(hidden_dim, hidden_dim)
        self.head_3 = nn.Linear(hidden_dim, N_VOCAB)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor):
        if len(x.shape) == 2:  # unbatched input
            x = x.unsqueeze(0)
        B, *_ = x.shape

        x = self.emb(x)
        rnn_o, h_n = self.rnn(x)
        o = self.relu(self.head_1(rnn_o))
        o = self.relu(self.head_2(o))
        o = self.head_3(o)
        return o

In [9]:
def train_epoch(model: nn.Module,
                optimizer: optim.Optimizer,
                dataset_x: torch.Tensor,
                batch_size: int):
    model.train()
    
    loss_sum = 0
    accu_sum = 0
    n_batches = math.ceil(dataset_x.shape[0] / batch_size)

    for i in tqdm.tqdm(range(n_batches), 'train'):
        x = dataset_x[i * batch_size: (i+1) * batch_size, :]
        x_oh = F.one_hot(x, N_VOCAB).to(dtype=torch.float32)
        x_hat = model.forward(x_oh)

        # DEBUG
        # print(x.shape, x_oh.shape, x_hat.shape)

        loss = F.cross_entropy(x_hat[:, :-1, :].reshape((-1, N_VOCAB)), x[:, 1:].reshape((-1,)), ignore_index=PAD_TOKEN)
        loss_sum += loss.detach().clone()

        with torch.no_grad():
            accu_sum += torch.sum(x_hat[:, :-1, :].argmax(dim=-1) == x[:, 1:]) / x.shape[0] / (x.shape[1] - 1)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    return loss_sum / n_batches, accu_sum / n_batches

def test_epoch(model: nn.Module, dataset_x: torch.Tensor, batch_size):
    model.eval()
    
    with torch.no_grad():
        loss_sum = 0
        accu_sum = 0
        n_batches = math.ceil(dataset_x.shape[0] / batch_size)

        for i in tqdm.tqdm(range(n_batches), ' test'):
            x = dataset_x[i * batch_size: (i+1) * batch_size, :]
            x_oh = F.one_hot(x, N_VOCAB).to(dtype=torch.float32)
            x_hat = model.forward(x_oh)

            loss = F.cross_entropy(x_hat[:, :-1, :].reshape((-1, N_VOCAB)), x[:, 1:].reshape((-1,)), ignore_index=PAD_TOKEN)
            loss_sum += loss
            accu_sum += torch.sum(x_hat[:, :-1, :].argmax(dim=-1) == x[:, 1:]) / x.shape[0] / (x.shape[1] - 1)

        return loss_sum / n_batches, accu_sum / n_batches

In [10]:
model = RNNGenerator(emb_dim=16,
                     hidden_dim=32,
                     hidden_layers=2)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-5)
N_EPOCHS = 50
BATCH_SIZE = 32

In [11]:
for epoch_i in range(N_EPOCHS):
    print(f'=== epoch {epoch_i} ===')

    train_loss, train_accu = train_epoch(model, optimizer, train_x, BATCH_SIZE)
    test_loss, test_accu = test_epoch(model, test_x, BATCH_SIZE)

    print('train loss:', train_loss.item())
    print(' test loss:', test_loss.item())
    print('train accu:', train_accu.item())
    print(' test accu:', test_accu.item())

=== epoch 0 ===


train: 100%|██████████| 200/200 [00:02<00:00, 75.90it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 373.38it/s]


train loss: 3.0323071479797363
 test loss: 2.7356069087982178
train accu: 0.1586713343858719
 test accu: 0.18144410848617554
=== epoch 1 ===


train: 100%|██████████| 200/200 [00:02<00:00, 72.12it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 372.14it/s]


train loss: 2.5660009384155273
 test loss: 2.3441948890686035
train accu: 0.20120632648468018
 test accu: 0.23051242530345917
=== epoch 2 ===


train: 100%|██████████| 200/200 [00:02<00:00, 75.56it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 391.89it/s]


train loss: 2.332127332687378
 test loss: 2.2992665767669678
train accu: 0.23267507553100586
 test accu: 0.2343944013118744
=== epoch 3 ===


train: 100%|██████████| 200/200 [00:02<00:00, 70.30it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 388.24it/s]


train loss: 2.2909586429595947
 test loss: 2.2617979049682617
train accu: 0.2345096468925476
 test accu: 0.24194489419460297
=== epoch 4 ===


train: 100%|██████████| 200/200 [00:02<00:00, 76.87it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 363.07it/s]


train loss: 2.263923406600952
 test loss: 2.2380614280700684
train accu: 0.23623929917812347
 test accu: 0.24468167126178741
=== epoch 5 ===


train: 100%|██████████| 200/200 [00:02<00:00, 73.88it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 297.00it/s]


train loss: 2.244938850402832
 test loss: 2.2206871509552
train accu: 0.23732484877109528
 test accu: 0.2437305897474289
=== epoch 6 ===


train: 100%|██████████| 200/200 [00:03<00:00, 60.38it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 364.18it/s]


train loss: 2.2305305004119873
 test loss: 2.2073850631713867
train accu: 0.23786106705665588
 test accu: 0.24481754004955292
=== epoch 7 ===


train: 100%|██████████| 200/200 [00:03<00:00, 64.02it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 247.23it/s]


train loss: 2.219320058822632
 test loss: 2.1974446773529053
train accu: 0.23842662572860718
 test accu: 0.24623450636863708
=== epoch 8 ===


train: 100%|██████████| 200/200 [00:02<00:00, 66.96it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 321.20it/s]


train loss: 2.209426164627075
 test loss: 2.188187837600708
train accu: 0.23925799131393433
 test accu: 0.24767084419727325
=== epoch 9 ===


train: 100%|██████████| 200/200 [00:02<00:00, 69.94it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 355.02it/s]


train loss: 2.196610927581787
 test loss: 2.173112392425537
train accu: 0.24101123213768005
 test accu: 0.25215449929237366
=== epoch 10 ===


train: 100%|██████████| 200/200 [00:02<00:00, 70.18it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 392.25it/s]


train loss: 2.173067092895508
 test loss: 2.1501963138580322
train accu: 0.24801774322986603
 test accu: 0.25677406787872314
=== epoch 11 ===


train: 100%|██████████| 200/200 [00:02<00:00, 74.07it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 334.80it/s]


train loss: 2.1455719470977783
 test loss: 2.1282620429992676
train accu: 0.2547381520271301
 test accu: 0.2628299593925476
=== epoch 12 ===


train: 100%|██████████| 200/200 [00:02<00:00, 72.76it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 330.38it/s]


train loss: 2.125689744949341
 test loss: 2.113877058029175
train accu: 0.25893041491508484
 test accu: 0.26554736495018005
=== epoch 13 ===


train: 100%|██████████| 200/200 [00:02<00:00, 71.32it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 364.32it/s]


train loss: 2.1121699810028076
 test loss: 2.102344512939453
train accu: 0.26182594895362854
 test accu: 0.2641110122203827
=== epoch 14 ===


train: 100%|██████████| 200/200 [00:02<00:00, 74.59it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 365.72it/s]


train loss: 2.1020190715789795
 test loss: 2.0944716930389404
train accu: 0.262419730424881
 test accu: 0.26335409283638
=== epoch 15 ===


train: 100%|██████████| 200/200 [00:02<00:00, 73.44it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 329.67it/s]


train loss: 2.0934860706329346
 test loss: 2.0868020057678223
train accu: 0.2634822130203247
 test accu: 0.26389750838279724
=== epoch 16 ===


train: 100%|██████████| 200/200 [00:02<00:00, 76.51it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 328.86it/s]


train loss: 2.0859527587890625
 test loss: 2.0807011127471924
train accu: 0.26426514983177185
 test accu: 0.26457688212394714
=== epoch 17 ===


train: 100%|██████████| 200/200 [00:02<00:00, 73.67it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 375.96it/s]


train loss: 2.0786633491516113
 test loss: 2.07452130317688
train accu: 0.2649196982383728
 test accu: 0.26512032747268677
=== epoch 18 ===


train: 100%|██████████| 200/200 [00:02<00:00, 76.57it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 335.41it/s]


train loss: 2.071485757827759
 test loss: 2.0683367252349854
train accu: 0.26549622416496277
 test accu: 0.2660714089870453
=== epoch 19 ===


train: 100%|██████████| 200/200 [00:02<00:00, 75.40it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 357.98it/s]


train loss: 2.0646533966064453
 test loss: 2.0625202655792236
train accu: 0.2662181556224823
 test accu: 0.2684588134288788
=== epoch 20 ===


train: 100%|██████████| 200/200 [00:02<00:00, 73.93it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 385.00it/s]


train loss: 2.058098793029785
 test loss: 2.058354139328003
train accu: 0.2664526104927063
 test accu: 0.26919642090797424
=== epoch 21 ===


train: 100%|██████████| 200/200 [00:02<00:00, 76.31it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 383.42it/s]


train loss: 2.0519065856933594
 test loss: 2.0540802478790283
train accu: 0.2672995626926422
 test accu: 0.2696816623210907
=== epoch 22 ===


train: 100%|██████████| 200/200 [00:02<00:00, 75.94it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 350.86it/s]


train loss: 2.0459368228912354
 test loss: 2.0501859188079834
train accu: 0.2685808539390564
 test accu: 0.2691381871700287
=== epoch 23 ===


train: 100%|██████████| 200/200 [00:02<00:00, 74.24it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 364.63it/s]


train loss: 2.0401549339294434
 test loss: 2.046708106994629
train accu: 0.26897141337394714
 test accu: 0.2685947120189667
=== epoch 24 ===


train: 100%|██████████| 200/200 [00:02<00:00, 72.39it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 308.11it/s]


train loss: 2.0342981815338135
 test loss: 2.042804002761841
train accu: 0.27014654874801636
 test accu: 0.2695458233356476
=== epoch 25 ===


train: 100%|██████████| 200/200 [00:02<00:00, 67.46it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 367.62it/s]


train loss: 2.0285372734069824
 test loss: 2.039987564086914
train accu: 0.270865261554718
 test accu: 0.26962345838546753
=== epoch 26 ===


train: 100%|██████████| 200/200 [00:02<00:00, 71.13it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 397.41it/s]


train loss: 2.0229663848876953
 test loss: 2.0361270904541016
train accu: 0.271335631608963
 test accu: 0.271195650100708
=== epoch 27 ===


train: 100%|██████████| 200/200 [00:02<00:00, 71.63it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 342.87it/s]


train loss: 2.0172502994537354
 test loss: 2.032392740249634
train accu: 0.2722263038158417
 test accu: 0.27364128828048706
=== epoch 28 ===


train: 100%|██████████| 200/200 [00:02<00:00, 76.59it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 327.87it/s]


train loss: 2.0114941596984863
 test loss: 2.0280659198760986
train accu: 0.27344658970832825
 test accu: 0.27459239959716797
=== epoch 29 ===


train: 100%|██████████| 200/200 [00:02<00:00, 76.81it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 397.84it/s]


train loss: 2.005596876144409
 test loss: 2.024137020111084
train accu: 0.2742278277873993
 test accu: 0.27546587586402893
=== epoch 30 ===


train: 100%|██████████| 200/200 [00:02<00:00, 75.92it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 319.88it/s]


train loss: 1.9999995231628418
 test loss: 2.020784616470337
train accu: 0.2749137878417969
 test accu: 0.27560168504714966
=== epoch 31 ===


train: 100%|██████████| 200/200 [00:02<00:00, 76.96it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 356.06it/s]


train loss: 1.994636058807373
 test loss: 2.0174386501312256
train accu: 0.2764310836791992
 test accu: 0.2773680090904236
=== epoch 32 ===


train: 100%|██████████| 200/200 [00:02<00:00, 76.20it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 365.56it/s]


train loss: 1.98944890499115
 test loss: 2.0137839317321777
train accu: 0.27730441093444824
 test accu: 0.2783190608024597
=== epoch 33 ===


train: 100%|██████████| 200/200 [00:02<00:00, 76.97it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 375.79it/s]


train loss: 1.9845402240753174
 test loss: 2.0106406211853027
train accu: 0.2780856788158417
 test accu: 0.2781831920146942
=== epoch 34 ===


train: 100%|██████████| 200/200 [00:02<00:00, 72.53it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 369.66it/s]


train loss: 1.979685664176941
 test loss: 2.007685422897339
train accu: 0.27900758385658264
 test accu: 0.27503880858421326
=== epoch 35 ===


train: 100%|██████████| 200/200 [00:02<00:00, 72.56it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 358.23it/s]


train loss: 1.975180983543396
 test loss: 2.0045664310455322
train accu: 0.2802763879299164
 test accu: 0.27558231353759766
=== epoch 36 ===


train: 100%|██████████| 200/200 [00:02<00:00, 74.11it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 352.23it/s]


train loss: 1.9704474210739136
 test loss: 2.002554416656494
train accu: 0.28108885884284973
 test accu: 0.2772127389907837
=== epoch 37 ===


train: 100%|██████████| 200/200 [00:02<00:00, 73.36it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 350.52it/s]


train loss: 1.966017484664917
 test loss: 1.9993489980697632
train accu: 0.2820263206958771
 test accu: 0.27585408091545105
=== epoch 38 ===


train: 100%|██████████| 200/200 [00:02<00:00, 77.55it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 348.97it/s]


train loss: 1.9617438316345215
 test loss: 1.9972224235534668
train accu: 0.2827155590057373
 test accu: 0.27503886818885803
=== epoch 39 ===


train: 100%|██████████| 200/200 [00:02<00:00, 75.35it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 374.39it/s]


train loss: 1.957587480545044
 test loss: 1.994613766670227
train accu: 0.2836061418056488
 test accu: 0.27571818232536316
=== epoch 40 ===


train: 100%|██████████| 200/200 [00:02<00:00, 74.29it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 382.03it/s]


train loss: 1.953540325164795
 test loss: 1.9922764301300049
train accu: 0.2846199870109558
 test accu: 0.2768051028251648
=== epoch 41 ===


train: 100%|██████████| 200/200 [00:02<00:00, 74.86it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 350.35it/s]


train loss: 1.9495657682418823
 test loss: 1.9895766973495483
train accu: 0.28549182415008545
 test accu: 0.27897903323173523
=== epoch 42 ===


train: 100%|██████████| 200/200 [00:02<00:00, 76.64it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 346.72it/s]


train loss: 1.9456897974014282
 test loss: 1.98690927028656
train accu: 0.2863200008869171
 test accu: 0.2812888026237488
=== epoch 43 ===


train: 100%|██████████| 200/200 [00:02<00:00, 75.48it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 387.69it/s]


train loss: 1.9420208930969238
 test loss: 1.984712839126587
train accu: 0.28689974546432495
 test accu: 0.28033772110939026
=== epoch 44 ===


train: 100%|██████████| 200/200 [00:02<00:00, 75.82it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 329.42it/s]


train loss: 1.9382902383804321
 test loss: 1.9824978113174438
train accu: 0.2873232960700989
 test accu: 0.2829386591911316
=== epoch 45 ===


train: 100%|██████████| 200/200 [00:02<00:00, 69.22it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 355.26it/s]


train loss: 1.9345935583114624
 test loss: 1.9806784391403198
train accu: 0.2877139747142792
 test accu: 0.2829386591911316
=== epoch 46 ===


train: 100%|██████████| 200/200 [00:02<00:00, 75.20it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 306.64it/s]


train loss: 1.9309440851211548
 test loss: 1.978567361831665
train accu: 0.28863751888275146
 test accu: 0.2833462655544281
=== epoch 47 ===


train: 100%|██████████| 200/200 [00:02<00:00, 75.96it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 348.92it/s]


train loss: 1.9272695779800415
 test loss: 1.9765784740447998
train accu: 0.2894498407840729
 test accu: 0.284297376871109
=== epoch 48 ===


train: 100%|██████████| 200/200 [00:02<00:00, 77.39it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 370.96it/s]


train loss: 1.923789381980896
 test loss: 1.9746588468551636
train accu: 0.2900296151638031
 test accu: 0.2837538421154022
=== epoch 49 ===


train: 100%|██████████| 200/200 [00:02<00:00, 77.09it/s]
 test: 100%|██████████| 23/23 [00:00<00:00, 372.56it/s]

train loss: 1.9203715324401855
 test loss: 1.9723210334777832
train accu: 0.29088738560676575
 test accu: 0.2835403382778168





In [12]:
from torch.distributions import Categorical

for test in ('Mosc', 'Rio ', 'Delh', 'Shan'):
    test_tokens = encode(test, 4)
    test_tokens = torch.tensor(test_tokens, dtype=torch.long).unsqueeze(0)

    for i in range(4, SEQ_LENGTH):
        pred = model.forward(F.one_hot(test_tokens, N_VOCAB).to(torch.float32))
        dist = Categorical(logits=pred[:, -1, :])
        next_tokens = dist.sample().unsqueeze(1)

        test_tokens = torch.cat((test_tokens, next_tokens), dim=-1)

    print(decode(test_tokens.squeeze(), decode_special=True))

Moschaurika
Rio Manjunk
Delhadinghe
Shaniharcha


In [13]:
from torch.distributions import Categorical
from torch.distributions.gumbel import Gumbel
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

class AntiRNNGen():
    def __init__(self,
                 inner_net: nn.Module,
                 initial_tokens: int=6,
                 desired_tokens: str='abc',
                 lr: float=1e-2,
                 right_c_coef: float=1.0,
                 wrong_c_coef: float=1.0):
        super().__init__()

        self.n_initial_tokens = initial_tokens
        self.n_desired_tokens = len(desired_tokens)
        self.desired_tokens = [token_to_id[ch] for ch in desired_tokens]
        self.desired_tokens = torch.tensor(self.desired_tokens, dtype=torch.long)

        self.input_shape = (1, initial_tokens, N_VOCAB)
        self.input = nn.Parameter(
            data=torch.normal(0, 1, size=self.input_shape),
            requires_grad=True
        )

        self.inner_net = inner_net
        self.grad_eraser = torch.optim.SGD(self.inner_net.parameters())
        self.anti_optimizer = torch.optim.Adam((self.input,), lr=lr)
        # self.anti_scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

        self.right_c_coef = right_c_coef
        self.wrong_c_coef = wrong_c_coef

        # self.input_min = 0.0
        # self.input_max = 1.0

        self.gumbel = Gumbel(loc=0.0, scale=1.0)
        self.epsilon = 1e-8

    def test(self):
        with torch.no_grad():
            return self.inner_net.forward(self.input)

    def zero_grad(self):
        self.grad_eraser.zero_grad()
        self.anti_optimizer.zero_grad()

    # def crop(self):
    #     self.input.data.copy_(torch.clamp(self.input.detach().clone(), self.input_min, self.input_max))
    
    def step(self):
        prev_input = self.input.detach().clone()

        probs = torch.softmax(self.input, dim=-1)

        extra_tokens = torch.zeros((1, self.n_desired_tokens - 1, N_VOCAB))
        extra_tokens[0, torch.arange(0, self.n_desired_tokens - 1), self.desired_tokens[:-1]] = 1.0
        probs = torch.cat((probs, extra_tokens), dim=1)

        gumb = self.gumbel.sample(probs.shape)

        j_onehot = F.one_hot(torch.argmax(gumb + torch.log(probs + self.epsilon), dim=-1), N_VOCAB)
        j_onehot = j_onehot.to(torch.float32).requires_grad_(True)

        j_continuous = torch.softmax(gumb + torch.log(probs + self.epsilon), dim=-1)

        output = self.inner_net.forward(j_onehot)
        output_soft = torch.softmax(output, dim=-1)
        
        # DEBUG
        # print(j_onehot.shape, j_continuous.shape, output.shape, output_soft.shape)

        loss = output_soft[:, self.n_initial_tokens - 1: self.n_initial_tokens + self.n_desired_tokens - 1, :]\
                          [torch.arange(0, N_VOCAB).unsqueeze(0).unsqueeze(1) != self.desired_tokens.unsqueeze(0).unsqueeze(2)]\
            .sum() * self.wrong_c_coef + \
               output_soft[:, self.n_initial_tokens - 1: self.n_initial_tokens + self.n_desired_tokens - 1, :]\
                          [:, :, self.desired_tokens].sum() * -self.right_c_coef
        loss.backward()

        j_continuous.backward(j_onehot.grad)

        self.anti_optimizer.step()
        self.zero_grad()

        # self.anti_scheduler.step()

        # self.crop()

        return prev_input, output.detach().clone(), loss.detach().clone()

In [19]:
ANTI_NET_STEPS = 3000
TEST_TRIES = 4
TESTS = ('pal', 'pur', 'ana', 'tia', 'ish', 'ika', 'ung', 'cha')

for i in range(len(TESTS)):
    n_desired_tokens = 3
    desired_tokens = TESTS[i]

    print(f'-=-=- step {i} -=-=-')
    print(f'desired tokens:', desired_tokens)

    fin_loss = float('+inf')
    fin_inp, fin_outp = None, None

    for try_i in range(TEST_TRIES):
        anti_net = AntiRNNGen(model, initial_tokens=6, desired_tokens=desired_tokens, lr=5e-2)

        for step_i in tqdm.tqdm(range(ANTI_NET_STEPS), f'attempt #{try_i}'):
            inp, outp, loss = anti_net.step()
        
        if loss.item() < fin_loss:
            fin_loss = loss.item()
            fin_inp = inp
            fin_outp = outp
        
        print(f'attempt #{try_i} loss:', loss.item())
    
    print(f'fin loss:', fin_loss)
    print(f'fin input:', decode(fin_inp.argmax(dim=-1).squeeze()))
    print(f'fin output (from function):', decode(fin_outp.argmax(dim=-1).squeeze(), decode_special=True))
    
    outp_gen = fin_inp.argmax(dim=-1)
    for i in range(n_desired_tokens):
        pred = model.forward(F.one_hot(outp_gen, N_VOCAB).to(torch.float32))
        next_tokens = pred[:, -1, :].argmax(dim=-1).unsqueeze(1)
        outp_gen = torch.cat((outp_gen, next_tokens), dim=-1)

    print(f'fin output (argmax sample):', decode(outp_gen.squeeze(), decode_special=True))

    outp_gen = fin_inp.argmax(dim=-1)
    for i in range(n_desired_tokens):
        pred = model.forward(F.one_hot(outp_gen, N_VOCAB).to(torch.float32))

        dist = Categorical(logits=pred[:, -1, :])
        next_tokens = dist.sample().unsqueeze(1)

        outp_gen = torch.cat((outp_gen, next_tokens), dim=-1)

    print(f'fin output (random sample):', decode(outp_gen.squeeze(), decode_special=True))

-=-=- step 0 -=-=-
desired tokens: pal


attempt #0: 100%|██████████| 3000/3000 [00:27<00:00, 110.29it/s]


attempt #0 loss: 1.3960310220718384


attempt #1: 100%|██████████| 3000/3000 [00:26<00:00, 114.04it/s]


attempt #1 loss: -0.5257922410964966


attempt #2: 100%|██████████| 3000/3000 [00:25<00:00, 119.34it/s]


attempt #2 loss: 1.3415793180465698


attempt #3: 100%|██████████| 3000/3000 [00:25<00:00, 117.92it/s]


attempt #3 loss: 1.4076471328735352
fin loss: -0.5257922410964966
fin input: EZuyAY
fin output (from function): rarampal
fin output (argmax sample): EZuyAYpal
fin output (random sample): EZuyAYali
-=-=- step 1 -=-=-
desired tokens: pur


attempt #0: 100%|██████████| 3000/3000 [00:25<00:00, 115.58it/s]


attempt #0 loss: -0.8369953632354736


attempt #1: 100%|██████████| 3000/3000 [00:25<00:00, 117.25it/s]


attempt #1 loss: -1.128077507019043


attempt #2: 100%|██████████| 3000/3000 [00:25<00:00, 117.23it/s]


attempt #2 loss: -0.5411638021469116


attempt #3: 100%|██████████| 3000/3000 [00:27<00:00, 110.60it/s]


attempt #3 loss: -0.9788793325424194
fin loss: -1.128077507019043
fin input: dURToT
fin output (from function): inaaaaur
fin output (argmax sample): dURToTpur
fin output (random sample): dURToT La
-=-=- step 2 -=-=-
desired tokens: ana


attempt #0: 100%|██████████| 3000/3000 [00:25<00:00, 117.47it/s]


attempt #0 loss: -1.155346155166626


attempt #1: 100%|██████████| 3000/3000 [00:25<00:00, 119.64it/s]


attempt #1 loss: -1.3665924072265625


attempt #2: 100%|██████████| 3000/3000 [00:25<00:00, 115.93it/s]


attempt #2 loss: -1.3541224002838135


attempt #3: 100%|██████████| 3000/3000 [00:25<00:00, 117.03it/s]


attempt #3 loss: -1.2274953126907349
fin loss: -1.3665924072265625
fin input: CncaMR
fin output (from function): hihiaapa
fin output (argmax sample): CncaMRapu
fin output (random sample): CncaMRuru
-=-=- step 3 -=-=-
desired tokens: tia


attempt #0: 100%|██████████| 3000/3000 [00:25<00:00, 116.62it/s]


attempt #0 loss: 0.5438131093978882


attempt #1: 100%|██████████| 3000/3000 [00:25<00:00, 116.11it/s]


attempt #1 loss: 0.032302141189575195


attempt #2: 100%|██████████| 3000/3000 [00:25<00:00, 117.60it/s]


attempt #2 loss: 1.1580132246017456


attempt #3: 100%|██████████| 3000/3000 [00:25<00:00, 116.78it/s]


attempt #3 loss: 0.020190954208374023
fin loss: 0.020190954208374023
fin input: FAcQMR
fin output (from function): aahaaaia
fin output (argmax sample): FAcQMRagh
fin output (random sample): FAcQMRana
-=-=- step 4 -=-=-
desired tokens: ish


attempt #0: 100%|██████████| 3000/3000 [00:25<00:00, 117.46it/s]


attempt #0 loss: 1.6670098304748535


attempt #1: 100%|██████████| 3000/3000 [00:26<00:00, 111.74it/s]


attempt #1 loss: 1.6715195178985596


attempt #2: 100%|██████████| 3000/3000 [00:33<00:00, 89.80it/s] 


attempt #2 loss: 0.5110284090042114


attempt #3: 100%|██████████| 3000/3000 [00:29<00:00, 102.97it/s]


attempt #3 loss: 1.665529727935791
fin loss: 0.5110284090042114
fin input: gwwwvq
fin output (from function): aaaaaiih
fin output (argmax sample): gwwwvqiin
fin output (random sample): gwwwvqied
-=-=- step 5 -=-=-
desired tokens: ika


attempt #0: 100%|██████████| 3000/3000 [00:26<00:00, 111.86it/s]


attempt #0 loss: 1.8541810512542725


attempt #1: 100%|██████████| 3000/3000 [00:28<00:00, 105.91it/s]


attempt #1 loss: 0.695175290107727


attempt #2: 100%|██████████| 3000/3000 [00:33<00:00, 90.18it/s] 


attempt #2 loss: 1.769472360610962


attempt #3: 100%|██████████| 3000/3000 [00:32<00:00, 91.94it/s] 


attempt #3 loss: 1.4867336750030518
fin loss: 0.695175290107727
fin input: IVbH-R
fin output (from function): laaanaaa
fin output (argmax sample): IVbH-Ramp
fin output (random sample): IVbH-Ruri
-=-=- step 6 -=-=-
desired tokens: ung


attempt #0: 100%|██████████| 3000/3000 [00:28<00:00, 104.21it/s]


attempt #0 loss: 1.0059199333190918


attempt #1: 100%|██████████| 3000/3000 [00:25<00:00, 117.01it/s]


attempt #1 loss: 1.4877104759216309


attempt #2: 100%|██████████| 3000/3000 [00:24<00:00, 121.84it/s]


attempt #2 loss: 1.436935544013977


attempt #3: 100%|██████████| 3000/3000 [00:29<00:00, 100.29it/s]


attempt #3 loss: 1.522752285003662
fin loss: 1.0059199333190918
fin input: qLqSEp
fin output (from function): uaiaaurd
fin output (argmax sample): qLqSEppal
fin output (random sample): qLqSEpkhu
-=-=- step 7 -=-=-
desired tokens: cha


attempt #0: 100%|██████████| 3000/3000 [00:40<00:00, 73.49it/s]


attempt #0 loss: 0.1935443878173828


attempt #1: 100%|██████████| 3000/3000 [00:43<00:00, 69.37it/s]


attempt #1 loss: -0.4566993713378906


attempt #2: 100%|██████████| 3000/3000 [00:43<00:00, 69.35it/s]


attempt #2 loss: -0.4763883352279663


attempt #3: 100%|██████████| 3000/3000 [00:45<00:00, 65.76it/s]


attempt #3 loss: -0.5307084321975708
fin loss: -0.5307084321975708
fin input: EnlY w
fin output (from function): rdaaraha
fin output (argmax sample): EnlY wall
fin output (random sample): EnlY wani
