In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import random
import math
import tqdm
import numpy as np
import pandas as pd

torch.manual_seed(192837)
rand = random.Random(192838)

In [2]:
CITY_COLUMN = 'city_ascii'
COUNTRY_COLUMN = 'iso2'

In [3]:
df_all = pd.read_csv('../emb/data/worldcities.csv')

'''
df_top_countries = df_all.groupby(COUNTRY_COLUMN).size().sort_values(ascending=False).reset_index()
classes = list(df_top_countries.head(N_CLASSES)[COUNTRY_COLUMN])
'''
classes = ['IN', 'US', 'BR', 'DE', 'CN', 'JP', 'RU', 'MX']
N_CLASSES = len(classes)

print(classes)
class_to_id = {classes[i]: i for i in range(N_CLASSES)}

df_all = df_all[df_all[COUNTRY_COLUMN].isin(classes)]
df_all = df_all[df_all[CITY_COLUMN].notna()]
df_all.head(10)

['IN', 'US', 'BR', 'DE', 'CN', 'JP', 'RU', 'MX']


Unnamed: 0,city,city_ascii,lat,lng,country,iso2,iso3,admin_name,capital,population,id
0,Tokyo,Tokyo,35.687,139.7495,Japan,JP,JPN,Tōkyō,primary,37785000.0,1392685764
2,Delhi,Delhi,28.61,77.23,India,IN,IND,Delhi,admin,32226000.0,1356872604
3,Guangzhou,Guangzhou,23.13,113.26,China,CN,CHN,Guangdong,admin,26940000.0,1156237133
4,Mumbai,Mumbai,19.0761,72.8775,India,IN,IND,Mahārāshtra,admin,24973000.0,1356226629
6,Shanghai,Shanghai,31.2286,121.4747,China,CN,CHN,Shanghai,admin,24073000.0,1156073548
7,São Paulo,Sao Paulo,-23.5504,-46.6339,Brazil,BR,BRA,São Paulo,admin,23086000.0,1076532519
9,Mexico City,Mexico City,19.4333,-99.1333,Mexico,MX,MEX,Ciudad de México,primary,21804000.0,1484247881
10,Kolkāta,Kolkata,22.5675,88.37,India,IN,IND,West Bengal,admin,21747000.0,1356060520
14,New York,New York,40.6943,-73.9249,United States,US,USA,New York,,18832416.0,1840034016
15,Beijing,Beijing,39.9067,116.3975,China,CN,CHN,Beijing,primary,18522000.0,1156228865


In [4]:
character_set = set()
for city in df_all[CITY_COLUMN]:
    for ch in city:
        character_set.add(ch)

PAD_TOKEN = 0
tokens = ['<PAD>'] + sorted(list(character_set))
token_to_id = {tokens[i]: i for i in range(len(tokens))}
N_VOCAB = len(tokens)

''.join(tokens)

"<PAD> '()-./2ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

In [5]:
def encode(city: str, max_length: int):
    encoded = [token_to_id[city[i]] for i in range(min(max_length, len(city)))]
    return encoded + [PAD_TOKEN] * (max_length - len(encoded))

def decode(city_enc: list[int]):
    return ''.join([tokens[tk] for tk in city_enc if tk != PAD_TOKEN])

In [6]:
city_length_90_perc = np.quantile(np.array(list(map(len, df_all[CITY_COLUMN]))), 0.8)

SEQ_LENGTH = int(city_length_90_perc)
SEQ_LENGTH

12

In [7]:
data_list_x = [encode(city, SEQ_LENGTH) for city in df_all[CITY_COLUMN]]
data_list_y = [class_to_id[iso2] for iso2 in df_all[COUNTRY_COLUMN]]
data_list = list(zip(data_list_x, data_list_y))
rand.shuffle(data_list)

data_x = torch.tensor(np.array([x for x, _ in data_list], dtype=np.long))
data_y = torch.tensor(np.array([y for _, y in data_list], dtype=np.long))

for i in range(5):
    print(decode(data_x[i]), classes[data_y[i]])

TRAIN_TEST_SPLIT = 0.9
n_train = round(data_x.shape[0] * TRAIN_TEST_SPLIT)
train_x, train_y = data_x[:n_train, :], data_y[:n_train]
test_x,  test_y  = data_x[n_train:, :], data_y[n_train:]

print(data_x.shape[0], n_train, data_x.shape[0] - n_train)

Gohadi IN
Kandra IN
Chuanliaocun CN
Aracoiaba BR
Miryal IN
22320 20088 2232


In [8]:
class RNNClassifier(nn.Module):
    def __init__(self,
                 emb_dim: int,
                 hidden_dim: int):
        super().__init__()

        self.emb = nn.Linear(N_VOCAB, emb_dim, bias=False)

        self.rnn = nn.GRU(emb_dim, hidden_dim, num_layers=1, batch_first=True)
        self.head = nn.Linear(hidden_dim, N_CLASSES)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor):
        if len(x.shape) == 2:  # unbatched input
            x = x.unsqueeze(0)
        B, *_ = x.shape

        x = self.emb(x)
        rnn_o, h_n = self.rnn(x)
        o = self.head(h_n[-1])
        return o

In [9]:
def train_epoch(model: nn.Module,
                optimizer: optim.Optimizer,
                dataset_x: torch.Tensor,
                dataset_y: torch.Tensor,
                batch_size: int):
    model.train()
    
    loss_sum = 0
    accu_sum = 0
    n_batches = math.ceil(dataset_x.shape[0] / batch_size)

    for i in tqdm.tqdm(range(n_batches), 'train'):
        x, y = dataset_x[i * batch_size: (i+1) * batch_size, :], dataset_y[i * batch_size: (i+1) * batch_size]
        x = F.one_hot(x, N_VOCAB).to(dtype=torch.float32)
        y_hat = model.forward(x)

        loss = F.cross_entropy(y_hat, y)
        loss_sum += loss.detach().clone()

        with torch.no_grad():
            accu_sum += torch.sum(y_hat.argmax(dim=-1) == y) / x.shape[0]

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    return loss_sum / n_batches, accu_sum / n_batches

def test_epoch(model: nn.Module, dataset_x: torch.Tensor, dataset_y: torch.Tensor, batch_size):
    model.eval()
    
    with torch.no_grad():
        loss_sum = 0
        accu_sum = 0
        n_batches = math.ceil(dataset_x.shape[0] / batch_size)

        for i in tqdm.tqdm(range(n_batches), ' test'):
            x, y = dataset_x[i * batch_size: (i+1) * batch_size, :], dataset_y[i * batch_size: (i+1) * batch_size]
            x = F.one_hot(x, N_VOCAB).to(dtype=torch.float32)
            y_hat = model.forward(x)

            loss = F.cross_entropy(y_hat, y)
            loss_sum += loss
            accu_sum += torch.sum(y_hat.argmax(dim=-1) == y) / x.shape[0]
        
        return loss_sum / n_batches, accu_sum / n_batches

In [10]:
model = RNNClassifier(emb_dim=16,
                      hidden_dim=32)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
N_EPOCHS = 50
BATCH_SIZE = 32

In [11]:
for epoch_i in range(N_EPOCHS):
    print(f'=== epoch {epoch_i} ===')

    train_loss, train_accu = train_epoch(model, optimizer, train_x, train_y, BATCH_SIZE)
    test_loss, test_accu = test_epoch(model, test_x, test_y, BATCH_SIZE)

    print('train loss:', train_loss.item())
    print(' test loss:', test_loss.item())
    print('train accu:', train_accu.item())
    print(' test accu:', test_accu.item())

=== epoch 0 ===


train: 100%|██████████| 628/628 [00:14<00:00, 42.19it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 336.29it/s]


train loss: 1.5864789485931396
 test loss: 1.3765908479690552
train accu: 0.4347963333129883
 test accu: 0.5248511433601379
=== epoch 1 ===


train: 100%|██████████| 628/628 [00:16<00:00, 38.63it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 244.46it/s]


train loss: 1.2607035636901855
 test loss: 1.2409735918045044
train accu: 0.5563130378723145
 test accu: 0.5565475821495056
=== epoch 2 ===


train: 100%|██████████| 628/628 [00:11<00:00, 52.54it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 377.39it/s]


train loss: 1.1126981973648071
 test loss: 1.1395307779312134
train accu: 0.5985270738601685
 test accu: 0.586904764175415
=== epoch 3 ===


train: 100%|██████████| 628/628 [00:10<00:00, 62.04it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 440.43it/s]


train loss: 1.0536291599273682
 test loss: 1.1016048192977905
train accu: 0.6179339289665222
 test accu: 0.6038690209388733
=== epoch 4 ===


train: 100%|██████████| 628/628 [00:14<00:00, 43.76it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 378.98it/s]


train loss: 1.020969033241272
 test loss: 1.0700860023498535
train accu: 0.631667971611023
 test accu: 0.6239582896232605
=== epoch 5 ===


train: 100%|██████████| 628/628 [00:12<00:00, 49.97it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 199.49it/s]


train loss: 0.9873481392860413
 test loss: 1.0331345796585083
train accu: 0.6450205445289612
 test accu: 0.6328868865966797
=== epoch 6 ===


train: 100%|██████████| 628/628 [00:12<00:00, 49.70it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 347.39it/s]


train loss: 0.9386317729949951
 test loss: 0.9812482595443726
train accu: 0.6632828712463379
 test accu: 0.6547619104385376
=== epoch 7 ===


train: 100%|██████████| 628/628 [00:09<00:00, 65.35it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 323.25it/s]


train loss: 0.8842469453811646
 test loss: 0.9354801774024963
train accu: 0.6875
 test accu: 0.6691964268684387
=== epoch 8 ===


train: 100%|██████████| 628/628 [00:09<00:00, 62.93it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 386.58it/s]


train loss: 0.8441709280014038
 test loss: 0.9021945595741272
train accu: 0.7032245397567749
 test accu: 0.6764881014823914
=== epoch 9 ===


train: 100%|██████████| 628/628 [00:09<00:00, 66.56it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 387.95it/s]


train loss: 0.816473126411438
 test loss: 0.8752551078796387
train accu: 0.7142714858055115
 test accu: 0.692559540271759
=== epoch 10 ===


train: 100%|██████████| 628/628 [00:10<00:00, 61.92it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 316.64it/s]


train loss: 0.7941454648971558
 test loss: 0.852144181728363
train accu: 0.7224323153495789
 test accu: 0.6947916746139526
=== epoch 11 ===


train: 100%|██████████| 628/628 [00:23<00:00, 26.26it/s]
 test: 100%|██████████| 70/70 [00:01<00:00, 48.69it/s]


train loss: 0.774700939655304
 test loss: 0.8321852087974548
train accu: 0.7294652462005615
 test accu: 0.7041667103767395
=== epoch 12 ===


train: 100%|██████████| 628/628 [00:15<00:00, 41.32it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 463.30it/s]


train loss: 0.7570507526397705
 test loss: 0.8144172430038452
train accu: 0.735884428024292
 test accu: 0.7122023701667786
=== epoch 13 ===


train: 100%|██████████| 628/628 [00:18<00:00, 33.41it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 398.09it/s]


train loss: 0.740339457988739
 test loss: 0.7978391051292419
train accu: 0.7417064309120178
 test accu: 0.7206845283508301
=== epoch 14 ===


train: 100%|██████████| 628/628 [00:14<00:00, 43.02it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 386.45it/s]


train loss: 0.7239969968795776
 test loss: 0.7818132042884827
train accu: 0.7493199110031128
 test accu: 0.729315459728241
=== epoch 15 ===


train: 100%|██████████| 628/628 [00:08<00:00, 70.89it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 365.84it/s]


train loss: 0.7078695893287659
 test loss: 0.7662736773490906
train accu: 0.7549595236778259
 test accu: 0.7360118627548218
=== epoch 16 ===


train: 100%|██████████| 628/628 [00:21<00:00, 29.02it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 385.74it/s]


train loss: 0.692112147808075
 test loss: 0.7516171336174011
train accu: 0.7606820464134216
 test accu: 0.7406250238418579
=== epoch 17 ===


train: 100%|██████████| 628/628 [00:15<00:00, 41.42it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 562.92it/s]


train loss: 0.6768902540206909
 test loss: 0.7381412386894226
train accu: 0.767267107963562
 test accu: 0.7455357313156128
=== epoch 18 ===


train: 100%|██████████| 628/628 [00:09<00:00, 65.08it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 201.76it/s]


train loss: 0.6623246073722839
 test loss: 0.7257729172706604
train accu: 0.7733379602432251
 test accu: 0.7508928775787354
=== epoch 19 ===


train: 100%|██████████| 628/628 [00:12<00:00, 50.26it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 446.67it/s]


train loss: 0.6483950614929199
 test loss: 0.7144452929496765
train accu: 0.7791102528572083
 test accu: 0.7571428418159485
=== epoch 20 ===


train: 100%|██████████| 628/628 [00:05<00:00, 109.12it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 499.09it/s]


train loss: 0.635031521320343
 test loss: 0.7041754722595215
train accu: 0.7835887670516968
 test accu: 0.7611607313156128
=== epoch 21 ===


train: 100%|██████████| 628/628 [00:10<00:00, 59.41it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 478.00it/s]


train loss: 0.6222336888313293
 test loss: 0.6949217319488525
train accu: 0.7878682613372803
 test accu: 0.7666667103767395
=== epoch 22 ===


train: 100%|██████████| 628/628 [00:09<00:00, 63.67it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 475.33it/s]


train loss: 0.6100534796714783
 test loss: 0.6865898966789246
train accu: 0.7925955653190613
 test accu: 0.7688988447189331
=== epoch 23 ===


train: 100%|██████████| 628/628 [00:06<00:00, 102.74it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 529.17it/s]


train loss: 0.5985538363456726
 test loss: 0.6790516972541809
train accu: 0.7964271306991577
 test accu: 0.7715774178504944
=== epoch 24 ===


train: 100%|██████████| 628/628 [00:06<00:00, 100.45it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 505.16it/s]


train loss: 0.587769091129303
 test loss: 0.6721553206443787
train accu: 0.800457775592804
 test accu: 0.7760416865348816
=== epoch 25 ===


train: 100%|██████████| 628/628 [00:06<00:00, 97.75it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 315.49it/s]


train loss: 0.5776784420013428
 test loss: 0.6657480597496033
train accu: 0.8033439517021179
 test accu: 0.7782738208770752
=== epoch 26 ===


train: 100%|██████████| 628/628 [00:11<00:00, 54.02it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 530.47it/s]


train loss: 0.5682169198989868
 test loss: 0.6597204208374023
train accu: 0.8072750568389893
 test accu: 0.7800595164299011
=== epoch 27 ===


train: 100%|██████████| 628/628 [00:13<00:00, 46.02it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 463.33it/s]


train loss: 0.5593007206916809
 test loss: 0.6540277600288391
train accu: 0.809779703617096
 test accu: 0.7805059552192688
=== epoch 28 ===


train: 100%|██████████| 628/628 [00:10<00:00, 61.33it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 558.88it/s]


train loss: 0.5508477687835693
 test loss: 0.6486741304397583
train accu: 0.8132132291793823
 test accu: 0.7845238447189331
=== epoch 29 ===


train: 100%|██████████| 628/628 [00:05<00:00, 110.17it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 329.34it/s]


train loss: 0.5427957773208618
 test loss: 0.6436883807182312
train accu: 0.8157510757446289
 test accu: 0.7818452715873718
=== epoch 30 ===


train: 100%|██████████| 628/628 [00:06<00:00, 100.96it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 542.03it/s]


train loss: 0.5351032018661499
 test loss: 0.6391057372093201
train accu: 0.8184381723403931
 test accu: 0.7831845283508301
=== epoch 31 ===


train: 100%|██████████| 628/628 [00:10<00:00, 57.53it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 587.12it/s]


train loss: 0.5277458429336548
 test loss: 0.6349490880966187
train accu: 0.8200305700302124
 test accu: 0.7840774059295654
=== epoch 32 ===


train: 100%|██████████| 628/628 [00:05<00:00, 105.68it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 609.45it/s]


train loss: 0.520708441734314
 test loss: 0.6312112212181091
train accu: 0.8234640955924988
 test accu: 0.7831845283508301
=== epoch 33 ===


train: 100%|██████████| 628/628 [00:05<00:00, 107.00it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 580.19it/s]


train loss: 0.513977587223053
 test loss: 0.6278491020202637
train accu: 0.8248076438903809
 test accu: 0.786309540271759
=== epoch 34 ===


train: 100%|██████████| 628/628 [00:06<00:00, 95.59it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 595.25it/s]


train loss: 0.5075368285179138
 test loss: 0.6247977018356323
train accu: 0.8273454308509827
 test accu: 0.786309540271759
=== epoch 35 ===


train: 100%|██████████| 628/628 [00:06<00:00, 99.04it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 483.82it/s]


train loss: 0.5013708472251892
 test loss: 0.6219882965087891
train accu: 0.8297837376594543
 test accu: 0.7867559790611267
=== epoch 36 ===


train: 100%|██████████| 628/628 [00:06<00:00, 95.75it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 279.40it/s]


train loss: 0.495474249124527
 test loss: 0.6193605661392212
train accu: 0.8322718143463135
 test accu: 0.7867559790611267
=== epoch 37 ===


train: 100%|██████████| 628/628 [00:06<00:00, 94.02it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 621.36it/s]


train loss: 0.48984280228614807
 test loss: 0.6168720126152039
train accu: 0.8346603512763977
 test accu: 0.7889881134033203
=== epoch 38 ===


train: 100%|██████████| 628/628 [00:07<00:00, 84.69it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 174.46it/s]


train loss: 0.4844695031642914
 test loss: 0.6145021319389343
train accu: 0.8373972177505493
 test accu: 0.7903273701667786
=== epoch 39 ===


train: 100%|██████████| 628/628 [00:07<00:00, 83.27it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 494.70it/s]


train loss: 0.4793420433998108
 test loss: 0.6122503280639648
train accu: 0.8393378853797913
 test accu: 0.7921131253242493
=== epoch 40 ===


train: 100%|██████████| 628/628 [00:12<00:00, 48.56it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 389.41it/s]


train loss: 0.47444823384284973
 test loss: 0.6101255416870117
train accu: 0.8417264223098755
 test accu: 0.7921131253242493
=== epoch 41 ===


train: 100%|██████████| 628/628 [00:06<00:00, 103.11it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 523.29it/s]


train loss: 0.4697708189487457
 test loss: 0.6081411838531494
train accu: 0.8432689905166626
 test accu: 0.7921131253242493
=== epoch 42 ===


train: 100%|██████████| 628/628 [00:06<00:00, 89.97it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 553.98it/s]


train loss: 0.4652925729751587
 test loss: 0.6063047051429749
train accu: 0.8457570672035217
 test accu: 0.7934523820877075
=== epoch 43 ===


train: 100%|██████████| 628/628 [00:08<00:00, 74.89it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 516.46it/s]


train loss: 0.4609960913658142
 test loss: 0.6046156287193298
train accu: 0.8471006155014038
 test accu: 0.7938988208770752
=== epoch 44 ===


train: 100%|██████████| 628/628 [00:06<00:00, 97.98it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 536.73it/s]


train loss: 0.4568655788898468
 test loss: 0.6030648350715637
train accu: 0.848344624042511
 test accu: 0.7938988208770752
=== epoch 45 ===


train: 100%|██████████| 628/628 [00:05<00:00, 104.78it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 522.98it/s]


train loss: 0.4528876543045044
 test loss: 0.6016384363174438
train accu: 0.849290132522583
 test accu: 0.7952381372451782
=== epoch 46 ===


train: 100%|██████████| 628/628 [00:06<00:00, 91.38it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 270.18it/s]


train loss: 0.44904935359954834
 test loss: 0.6003215312957764
train accu: 0.8506336808204651
 test accu: 0.7943452596664429
=== epoch 47 ===


train: 100%|██████████| 628/628 [00:08<00:00, 71.01it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 522.89it/s]


train loss: 0.445340096950531
 test loss: 0.5990995168685913
train accu: 0.8516786694526672
 test accu: 0.7934523820877075
=== epoch 48 ===


train: 100%|██████████| 628/628 [00:06<00:00, 99.16it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 593.14it/s]


train loss: 0.44174909591674805
 test loss: 0.5979602336883545
train accu: 0.8536193370819092
 test accu: 0.7965773940086365
=== epoch 49 ===


train: 100%|██████████| 628/628 [00:07<00:00, 88.50it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 272.74it/s]

train loss: 0.4382658004760742
 test loss: 0.5968939661979675
train accu: 0.855659544467926
 test accu: 0.7970238327980042





In [12]:
for i in range(20):
    print(decode(test_x[i]), classes[test_y[i]], end=' ')
    y_hat = torch.softmax(model.forward(F.one_hot(test_x[i], N_VOCAB).to(dtype=torch.float32)), dim=-1).squeeze()
    print(classes[y_hat.argmax(dim=-1)], f'({y_hat.max().item() * 100:.2f}%)')

Leme BR US (46.62%)
Minatitlan MX MX (59.94%)
Valley US US (86.04%)
Garden City US US (99.88%)
Hampton US US (95.88%)
Dicholi IN IN (91.98%)
Pau Brasil BR IN (76.90%)
Taloda IN IN (79.22%)
Mutum BR BR (38.58%)
Degana IN IN (58.15%)
Sao Goncalo  BR BR (87.28%)
Pileru IN IN (52.74%)
Hunsur IN IN (87.33%)
Vincennes US US (62.03%)
Lawrence US US (99.57%)
Bad Iburg DE DE (81.99%)
Morbi IN IN (61.50%)
Othello US US (85.89%)
Amari IN IN (53.31%)
Lobau DE US (40.05%)


In [15]:
from torch.distributions import Categorical
from torch.distributions.gumbel import Gumbel
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

class AntiEmb():
    def __init__(self,
               inner_net: nn.Module,
               desired_c: int,
               lr: float=1e-2,
               right_c_coef: float=1.0,
               wrong_c_coef: float=1.0):
        super().__init__()

        self.input_shape = (1, SEQ_LENGTH, N_VOCAB)
        self.input = nn.Parameter(
            data=torch.normal(0, 1, size=self.input_shape),
            requires_grad=True
        )

        self.inner_net = inner_net
        self.grad_eraser = torch.optim.SGD(self.inner_net.parameters())
        self.anti_optimizer = torch.optim.Adam((self.input,), lr=lr)
        # self.anti_scheduler = CosineAnnealingWarmRestarts(self.anti_optimizer, T_0=50, T_mult=2)

        self.desired_c = desired_c
        self.right_c_coef = right_c_coef
        self.wrong_c_coef = wrong_c_coef

        # self.input_min = 0.0
        # self.input_max = 1.0

        self.gumbel = Gumbel(loc=0.0, scale=1.0)
        self.epsilon = 1e-8

    def test(self):
        with torch.no_grad():
            return self.inner_net.forward(self.input)

    def zero_grad(self):
        self.grad_eraser.zero_grad()
        self.anti_optimizer.zero_grad()

    def crop(self):
        self.input.data.copy_(torch.clamp(self.input.detach().clone(), self.input_min, self.input_max))
    
    def step(self):
        prev_input = self.input.detach().clone()

        probs = torch.softmax(self.input, dim=-1)
        gumb = self.gumbel.sample(self.input_shape)

        j_onehot = F.one_hot(torch.argmax(gumb + torch.log(probs + self.epsilon), dim=-1), N_VOCAB)
        j_onehot = j_onehot.to(torch.float32).requires_grad_(True)

        j_continuous = torch.softmax(gumb + torch.log(probs + self.epsilon), dim=-1)

        output = self.inner_net.forward(j_onehot).squeeze()
        output_soft = torch.softmax(output, dim=-1)
        
        loss = output_soft[torch.arange(0, N_CLASSES) != self.desired_c].sum() * self.wrong_c_coef + \
               output_soft[self.desired_c] * -self.right_c_coef
        loss.backward()

        j_continuous.backward(j_onehot.grad)

        self.anti_optimizer.step()
        self.zero_grad()

        # self.anti_scheduler.step()

        # self.crop()

        return prev_input, output.detach().clone(), loss.detach().clone()

In [16]:
ANTI_NET_STEPS = 3000
TEST_TRIES = 5

for i in range(N_CLASSES):
    desired_c = i

    print(f'-=-=- step {i} -=-=-')
    print(f'des class:', classes[desired_c], f'({desired_c})')

    fin_loss = float('+inf')
    fin_inp, fin_outp = None, None

    for try_i in range(TEST_TRIES):
        anti_net = AntiEmb(model, desired_c, lr=5e-2)

        for step_i in tqdm.tqdm(range(ANTI_NET_STEPS), f'attempt #{try_i}'):
            inp, outp, loss = anti_net.step()
        
        if loss.item() < fin_loss:
            fin_loss = loss.item()
            fin_inp = inp
            fin_outp = outp
        
        print(f'attempt #{try_i} loss:', loss.item())
    
    print(f'fin loss:', fin_loss)
    print(f'fin output (from function):', fin_outp)

    inp_argmax = fin_inp.argmax(dim=-1)

    print(f'fin input  (argmax):', decode(inp_argmax.squeeze()))
    print(f'fin output (argmax):', model.forward(F.one_hot(inp_argmax, N_VOCAB).to(torch.float32)))

    dist = torch.distributions.Categorical(logits=fin_inp)
    sample = dist.sample()
    entropy = dist.entropy()

    print(f'fin input  (sample):', decode(sample.squeeze()))
    print(f'fin output (sample):', model.forward(F.one_hot(sample, N_VOCAB).to(torch.float32)))
    print(f'fin input entropy:', entropy)

-=-=- step 0 -=-=-
des class: IN (0)


attempt #0: 100%|██████████| 3000/3000 [00:16<00:00, 178.44it/s]


attempt #0 loss: -0.9995288848876953


attempt #1: 100%|██████████| 3000/3000 [00:15<00:00, 195.00it/s]


attempt #1 loss: -0.9999167919158936


attempt #2: 100%|██████████| 3000/3000 [00:17<00:00, 169.42it/s]


attempt #2 loss: -0.9999181628227234


attempt #3: 100%|██████████| 3000/3000 [00:16<00:00, 179.56it/s]


attempt #3 loss: -0.9998440742492676


attempt #4: 100%|██████████| 3000/3000 [00:16<00:00, 181.01it/s]


attempt #4 loss: -0.9996631145477295
fin loss: -0.9999181628227234
fin output (from function): tensor([10.4733, -1.0560, -2.5797, -2.1453, -1.0501, -4.2405, -0.9023, -1.9820])
fin input  (argmax): jKhPJhJPjJrI
fin output (argmax): tensor([[10.0322, -1.0514, -2.4324, -2.5660, -2.0233, -4.1931,  0.5492, -1.6034]],
       grad_fn=<AddmmBackward0>)
fin input  (sample): aQjPJaJPjhhI
fin output (sample): tensor([[ 9.3463, -0.3912, -1.8614, -1.5222, -1.0442, -4.3258, -1.0329, -2.7199]],
       grad_fn=<AddmmBackward0>)
fin input entropy: tensor([[0.8291, 1.0870, 1.1360, 1.4094, 1.5970, 1.8089, 1.5921, 0.2800, 0.4279,
         1.4387, 1.4600, 0.4594]])
-=-=- step 1 -=-=-
des class: US (1)


attempt #0: 100%|██████████| 3000/3000 [00:16<00:00, 182.64it/s]


attempt #0 loss: -0.9997773170471191


attempt #1: 100%|██████████| 3000/3000 [00:16<00:00, 182.61it/s]


attempt #1 loss: -0.999244213104248


attempt #2: 100%|██████████| 3000/3000 [00:20<00:00, 149.92it/s]


attempt #2 loss: -0.9994984865188599


attempt #3: 100%|██████████| 3000/3000 [00:22<00:00, 130.67it/s]


attempt #3 loss: -0.9992470145225525


attempt #4: 100%|██████████| 3000/3000 [00:17<00:00, 172.55it/s]


attempt #4 loss: -0.9998098015785217
fin loss: -0.9998098015785217
fin output (from function): tensor([-1.3387, 10.4837, -3.1515,  0.4188, -3.4715, -6.1400, -1.3712,  0.2512])
fin input  (argmax): WwGwWWCrprw
fin output (argmax): tensor([[-0.9928, 10.5835, -3.1138,  0.5028, -3.4937, -6.5151, -1.7641,  0.1379]],
       grad_fn=<AddmmBackward0>)
fin input  (sample): twLLWlSpprFF
fin output (sample): tensor([[-1.2539,  9.7284, -1.6509,  1.1403, -5.5939, -6.3190, -0.8994,  0.4930]],
       grad_fn=<AddmmBackward0>)
fin input entropy: tensor([[2.2293, 1.7470, 2.4404, 0.8199, 1.8741, 2.7545, 1.5936, 2.0165, 1.3701,
         1.1650, 2.2871, 1.6949]])
-=-=- step 2 -=-=-
des class: BR (2)


attempt #0: 100%|██████████| 3000/3000 [00:17<00:00, 172.85it/s]


attempt #0 loss: -0.9954688549041748


attempt #1: 100%|██████████| 3000/3000 [00:17<00:00, 171.83it/s]


attempt #1 loss: -0.9992583394050598


attempt #2: 100%|██████████| 3000/3000 [00:15<00:00, 192.30it/s]


attempt #2 loss: -0.9992420673370361


attempt #3: 100%|██████████| 3000/3000 [00:15<00:00, 188.08it/s]


attempt #3 loss: -0.9995051622390747


attempt #4: 100%|██████████| 3000/3000 [00:20<00:00, 149.36it/s]


attempt #4 loss: -0.9992285370826721
fin loss: -0.9995051622390747
fin output (from function): tensor([ 0.5276, -0.4830, 10.6987, -2.0878, -8.1769, -2.0951, -2.7906,  2.1214])
fin input  (argmax): IIIIUIIIUAAA
fin output (argmax): tensor([[-0.2248, -0.6680, 10.2627, -2.6923, -6.9744, -0.0745, -3.6271,  2.3920]],
       grad_fn=<AddmmBackward0>)
fin input  (sample): IIIIOIIIIAOO
fin output (sample): tensor([[-2.1792, -1.0520,  9.8637, -2.6510, -5.7125,  1.6895, -2.3658,  2.3584]],
       grad_fn=<AddmmBackward0>)
fin input entropy: tensor([[0.6862, 0.5750, 0.7694, 0.1913, 2.0524, 0.3337, 0.8064, 0.8046, 1.6109,
         2.1435, 1.2905, 1.2316]])
-=-=- step 3 -=-=-
des class: DE (3)


attempt #0: 100%|██████████| 3000/3000 [00:19<00:00, 155.38it/s]


attempt #0 loss: -0.9986081719398499


attempt #1: 100%|██████████| 3000/3000 [00:18<00:00, 160.45it/s]


attempt #1 loss: -0.9990404844284058


attempt #2: 100%|██████████| 3000/3000 [00:22<00:00, 136.11it/s]


attempt #2 loss: -0.9987223148345947


attempt #3: 100%|██████████| 3000/3000 [00:20<00:00, 145.26it/s]


attempt #3 loss: -0.9973835349082947


attempt #4: 100%|██████████| 3000/3000 [00:20<00:00, 149.28it/s]


attempt #4 loss: -0.9800529479980469
fin loss: -0.9990404844284058
fin output (from function): tensor([-3.4690,  1.3772, -2.5431, 10.3155, -0.8053, -2.7635,  2.2917, -4.0674])
fin input  (argmax): bbbbeubblgbZ
fin output (argmax): tensor([[-3.3880,  1.4144, -2.4541, 10.3080, -0.9940, -2.8041,  2.2735, -4.0898]],
       grad_fn=<AddmmBackward0>)
fin input  (sample): bbbbeebblgbZ
fin output (sample): tensor([[-3.5018,  1.3824, -2.3078, 10.2506, -0.9564, -2.7443,  2.2006, -4.0227]],
       grad_fn=<AddmmBackward0>)
fin input entropy: tensor([[0.6048, 0.2857, 0.9976, 0.4507, 1.5848, 2.4089, 1.1834, 0.2953, 0.2607,
         0.4647, 0.9449, 0.5791]])
-=-=- step 4 -=-=-
des class: CN (4)


attempt #0: 100%|██████████| 3000/3000 [00:19<00:00, 152.78it/s]


attempt #0 loss: -0.9978901147842407


attempt #1: 100%|██████████| 3000/3000 [00:20<00:00, 149.05it/s]


attempt #1 loss: -0.9988809823989868


attempt #2: 100%|██████████| 3000/3000 [00:16<00:00, 184.34it/s]


attempt #2 loss: -0.9990564584732056


attempt #3: 100%|██████████| 3000/3000 [00:15<00:00, 189.38it/s]


attempt #3 loss: -0.9984514713287354


attempt #4: 100%|██████████| 3000/3000 [00:17<00:00, 171.53it/s]


attempt #4 loss: -0.9994179010391235
fin loss: -0.9994179010391235
fin output (from function): tensor([ 0.3742,  1.0522, -8.6425,  0.8195, 10.2873,  0.4455, -2.1112, -1.2828])
fin input  (argmax): XQgwXiiYQYYw
fin output (argmax): tensor([[ 0.0846,  0.9909, -8.1058,  0.4177, 10.3619,  0.6887, -2.6004, -0.8179]],
       grad_fn=<AddmmBackward0>)
fin input  (sample): iQgwXiQXQYYw
fin output (sample): tensor([[ 0.9445,  1.1064, -8.2214,  0.7922, 10.0543,  0.1949, -2.7997, -1.3981]],
       grad_fn=<AddmmBackward0>)
fin input entropy: tensor([[1.3054, 0.4360, 2.3853, 0.7826, 0.4866, 0.5872, 1.2036, 1.2403, 1.2501,
         0.3110, 0.3257, 1.8892]])
-=-=- step 5 -=-=-
des class: JP (5)


attempt #0: 100%|██████████| 3000/3000 [00:18<00:00, 160.19it/s]


attempt #0 loss: -0.9929953217506409


attempt #1: 100%|██████████| 3000/3000 [00:18<00:00, 162.82it/s]


attempt #1 loss: -0.9952360391616821


attempt #2: 100%|██████████| 3000/3000 [00:17<00:00, 167.86it/s]


attempt #2 loss: -0.9898184537887573


attempt #3: 100%|██████████| 3000/3000 [00:18<00:00, 165.01it/s]


attempt #3 loss: -0.9792752861976624


attempt #4: 100%|██████████| 3000/3000 [00:17<00:00, 169.80it/s]


attempt #4 loss: -0.9662270545959473
fin loss: -0.9952360391616821
fin output (from function): tensor([-0.2252, -2.4073,  0.5624, -1.6663,  0.7484,  7.9218,  0.4685, -3.2506])
fin input  (argmax): okoOe-ii'YuK
fin output (argmax): tensor([[-0.2252, -2.4073,  0.5624, -1.6663,  0.7484,  7.9218,  0.4685, -3.2506]],
       grad_fn=<AddmmBackward0>)
fin input  (sample): oOoee-i''YuK
fin output (sample): tensor([[-1.0284, -2.1288,  1.0187, -1.6193,  0.6812,  7.6448,  0.2284, -2.5987]],
       grad_fn=<AddmmBackward0>)
fin input entropy: tensor([[0.7193, 0.1217, 0.1255, 1.2787, 0.7753, 0.1393, 1.4559, 0.7558, 0.1803,
         0.0262, 0.0406, 0.1777]])
-=-=- step 6 -=-=-
des class: RU (6)


attempt #0: 100%|██████████| 3000/3000 [00:15<00:00, 198.22it/s]


attempt #0 loss: -0.9994046092033386


attempt #1: 100%|██████████| 3000/3000 [00:15<00:00, 196.39it/s]


attempt #1 loss: -0.9991394281387329


attempt #2: 100%|██████████| 3000/3000 [00:18<00:00, 166.61it/s]


attempt #2 loss: -0.9993845224380493


attempt #3: 100%|██████████| 3000/3000 [00:16<00:00, 185.66it/s]


attempt #3 loss: -0.9993496537208557


attempt #4: 100%|██████████| 3000/3000 [00:17<00:00, 175.48it/s]


attempt #4 loss: -0.9995476603507996
fin loss: -0.9995476603507996
fin output (from function): tensor([ 0.3014, -1.1302, -2.1875,  0.9285, -4.7087, -0.0523, 10.0750, -2.3627])
fin input  (argmax): ykvVvz-vVvVR
fin output (argmax): tensor([[ 0.0447, -0.8562, -1.9420,  0.9086, -4.8968, -0.0721,  9.9002, -2.3093]],
       grad_fn=<AddmmBackward0>)
fin input  (sample): yUkInVVvVvVR
fin output (sample): tensor([[ 0.4198, -0.5406, -2.1704,  0.3477, -4.8682, -0.3831,  9.7866, -2.1415]],
       grad_fn=<AddmmBackward0>)
fin input entropy: tensor([[2.3115, 1.9536, 2.2852, 2.5473, 0.9685, 2.3031, 1.9851, 0.4072, 0.7650,
         1.1792, 0.0874, 0.5176]])
-=-=- step 7 -=-=-
des class: MX (7)


attempt #0: 100%|██████████| 3000/3000 [00:17<00:00, 174.05it/s]


attempt #0 loss: -0.9961458444595337


attempt #1: 100%|██████████| 3000/3000 [00:17<00:00, 172.51it/s]


attempt #1 loss: -0.9953432083129883


attempt #2: 100%|██████████| 3000/3000 [00:15<00:00, 188.50it/s]


attempt #2 loss: -0.9951199889183044


attempt #3: 100%|██████████| 3000/3000 [00:14<00:00, 208.87it/s]


attempt #3 loss: -0.9959519505500793


attempt #4: 100%|██████████| 3000/3000 [00:15<00:00, 199.25it/s]


attempt #4 loss: -0.9958794713020325
fin loss: -0.9961458444595337
fin output (from function): tensor([-0.8930,  2.9028,  0.8547, -4.5569,  1.3687, -5.6742, -4.1996,  9.4665])
fin input  (argmax):    q qqqqqqq
fin output (argmax): tensor([[-0.7942,  3.1887,  1.0007, -4.7220,  0.8489, -5.7898, -4.1859,  9.5906]],
       grad_fn=<AddmmBackward0>)
fin input  (sample):  qqqZqUqqqqq
fin output (sample): tensor([[-0.8638,  2.8850,  0.7848, -4.4925,  1.5249, -5.7696, -4.2774,  9.4701]],
       grad_fn=<AddmmBackward0>)
fin input entropy: tensor([[1.0411, 0.9175, 1.1635, 1.2277, 1.0803, 0.3053, 1.6267, 0.3843, 0.6132,
         0.8116, 0.4747, 1.1735]])
