In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import random
import math
import tqdm
import numpy as np
import pandas as pd

torch.manual_seed(192837)
rand = random.Random(192838)

In [2]:
CITY_COLUMN = 'city_ascii'
COUNTRY_COLUMN = 'iso2'

In [3]:
df_all = pd.read_csv('data/worldcities.csv')

'''
df_top_countries = df_all.groupby(COUNTRY_COLUMN).size().sort_values(ascending=False).reset_index()
classes = list(df_top_countries.head(N_CLASSES)[COUNTRY_COLUMN])
'''
classes = ['IN', 'US', 'BR', 'DE', 'CN', 'JP', 'RU', 'MX']
N_CLASSES = len(classes)

print(classes)
class_to_id = {classes[i]: i for i in range(N_CLASSES)}

df_all = df_all[df_all[COUNTRY_COLUMN].isin(classes)]
df_all = df_all[df_all[CITY_COLUMN].notna()]
df_all.head(10)

['IN', 'US', 'BR', 'DE', 'CN', 'JP', 'RU', 'MX']


Unnamed: 0,city,city_ascii,lat,lng,country,iso2,iso3,admin_name,capital,population,id
0,Tokyo,Tokyo,35.687,139.7495,Japan,JP,JPN,Tōkyō,primary,37785000.0,1392685764
2,Delhi,Delhi,28.61,77.23,India,IN,IND,Delhi,admin,32226000.0,1356872604
3,Guangzhou,Guangzhou,23.13,113.26,China,CN,CHN,Guangdong,admin,26940000.0,1156237133
4,Mumbai,Mumbai,19.0761,72.8775,India,IN,IND,Mahārāshtra,admin,24973000.0,1356226629
6,Shanghai,Shanghai,31.2286,121.4747,China,CN,CHN,Shanghai,admin,24073000.0,1156073548
7,São Paulo,Sao Paulo,-23.5504,-46.6339,Brazil,BR,BRA,São Paulo,admin,23086000.0,1076532519
9,Mexico City,Mexico City,19.4333,-99.1333,Mexico,MX,MEX,Ciudad de México,primary,21804000.0,1484247881
10,Kolkāta,Kolkata,22.5675,88.37,India,IN,IND,West Bengal,admin,21747000.0,1356060520
14,New York,New York,40.6943,-73.9249,United States,US,USA,New York,,18832416.0,1840034016
15,Beijing,Beijing,39.9067,116.3975,China,CN,CHN,Beijing,primary,18522000.0,1156228865


In [4]:
character_set = set()
for city in df_all[CITY_COLUMN]:
    for ch in city:
        character_set.add(ch)

PAD_TOKEN = 0
tokens = ['<PAD>'] + sorted(list(character_set))
token_to_id = {tokens[i]: i for i in range(len(tokens))}
N_VOCAB = len(tokens)

''.join(tokens)

"<PAD> '()-./2ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

In [5]:
def encode(city: str, max_length: int):
    encoded = [token_to_id[city[i]] for i in range(min(max_length, len(city)))]
    return encoded + [PAD_TOKEN] * (max_length - len(encoded))

def decode(city_enc: list[int]):
    return ''.join([tokens[tk] for tk in city_enc if tk != PAD_TOKEN])

In [6]:
city_length_90_perc = np.quantile(np.array(list(map(len, df_all[CITY_COLUMN]))), 0.8)

SEQ_LENGTH = int(city_length_90_perc)
SEQ_LENGTH

12

In [7]:
data_list_x = [encode(city, SEQ_LENGTH) for city in df_all[CITY_COLUMN]]
data_list_y = [class_to_id[iso2] for iso2 in df_all[COUNTRY_COLUMN]]
data_list = list(zip(data_list_x, data_list_y))
rand.shuffle(data_list)

data_x = torch.tensor(np.array([x for x, _ in data_list], dtype=np.long))
data_y = torch.tensor(np.array([y for _, y in data_list], dtype=np.long))

for i in range(5):
    print(decode(data_x[i]), classes[data_y[i]])

TRAIN_TEST_SPLIT = 0.9
n_train = round(data_x.shape[0] * TRAIN_TEST_SPLIT)
train_x, train_y = data_x[:n_train, :], data_y[:n_train]
test_x,  test_y  = data_x[n_train:, :], data_y[n_train:]

print(data_x.shape[0], n_train, data_x.shape[0] - n_train)

Gohadi IN
Kandra IN
Chuanliaocun CN
Aracoiaba BR
Miryal IN
22320 20088 2232


In [8]:
class SimpleClassifier(nn.Module):
    def __init__(self,
                 emb_dim: int,
                 hidden_dim: int,
                 hidden_layers: int):
        super().__init__()

        self.emb = nn.Linear(N_VOCAB, emb_dim, bias=False)
        self.foot = nn.Linear(emb_dim * SEQ_LENGTH, hidden_dim)
        self.body = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(hidden_layers)])
        self.head = nn.Linear(hidden_dim, N_CLASSES)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x: torch.Tensor):
        if len(x.shape) == 2:  # unbatched input
            x = x.unsqueeze(0)
        B, *_ = x.shape

        x = self.emb(x).reshape((B, -1))
        x = self.relu(self.foot(x))
        for l in self.body:
            x = self.relu(l(x))
        x = self.head(x)

        return x

In [9]:
def train_epoch(model: nn.Module,
                optimizer: optim.Optimizer,
                dataset_x: torch.Tensor,
                dataset_y: torch.Tensor,
                batch_size: int):
    model.train()
    
    loss_sum = 0
    accu_sum = 0
    n_batches = math.ceil(dataset_x.shape[0] / batch_size)

    for i in tqdm.tqdm(range(n_batches), 'train'):
        x, y = dataset_x[i * batch_size: (i+1) * batch_size, :], dataset_y[i * batch_size: (i+1) * batch_size]
        x = F.one_hot(x, N_VOCAB).to(dtype=torch.float32)
        y_hat = model.forward(x)

        loss = F.cross_entropy(y_hat, y)
        loss_sum += loss.detach().clone()

        with torch.no_grad():
            accu_sum += torch.sum(y_hat.argmax(dim=-1) == y) / x.shape[0]

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    return loss_sum / n_batches, accu_sum / n_batches

def test_epoch(model: nn.Module, dataset_x: torch.Tensor, dataset_y: torch.Tensor, batch_size):
    model.eval()
    
    with torch.no_grad():
        loss_sum = 0
        accu_sum = 0
        n_batches = math.ceil(dataset_x.shape[0] / batch_size)

        for i in tqdm.tqdm(range(n_batches), ' test'):
            x, y = dataset_x[i * batch_size: (i+1) * batch_size, :], dataset_y[i * batch_size: (i+1) * batch_size]
            x = F.one_hot(x, N_VOCAB).to(dtype=torch.float32)
            y_hat = model.forward(x)

            loss = F.cross_entropy(y_hat, y)
            loss_sum += loss
            accu_sum += torch.sum(y_hat.argmax(dim=-1) == y) / x.shape[0]
        
        return loss_sum / n_batches, accu_sum / n_batches

In [10]:
model = SimpleClassifier(emb_dim=16,
                         hidden_dim=32,
                         hidden_layers=6)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
N_EPOCHS = 50
BATCH_SIZE = 32

In [11]:
for epoch_i in range(N_EPOCHS):
    print(f'=== epoch {epoch_i} ===')

    train_loss, train_accu = train_epoch(model, optimizer, train_x, train_y, BATCH_SIZE)
    test_loss, test_accu = test_epoch(model, test_x, test_y, BATCH_SIZE)

    print('train loss:', train_loss.item())
    print(' test loss:', test_loss.item())
    print('train accu:', train_accu.item())
    print(' test accu:', test_accu.item())

=== epoch 0 ===


train: 100%|██████████| 628/628 [00:10<00:00, 59.88it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 358.56it/s]


train loss: 1.6578563451766968
 test loss: 1.5767815113067627
train accu: 0.4182092547416687
 test accu: 0.47083330154418945
=== epoch 1 ===


train: 100%|██████████| 628/628 [00:14<00:00, 42.27it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 227.22it/s]


train loss: 1.4602901935577393
 test loss: 1.4780000448226929
train accu: 0.4997677505016327
 test accu: 0.48020830750465393
=== epoch 2 ===


train: 100%|██████████| 628/628 [00:12<00:00, 51.50it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 427.98it/s]


train loss: 1.3626585006713867
 test loss: 1.3984136581420898
train accu: 0.5138999223709106
 test accu: 0.5166667103767395
=== epoch 3 ===


train: 100%|██████████| 628/628 [00:08<00:00, 73.84it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 509.97it/s]


train loss: 1.248785376548767
 test loss: 1.3033498525619507
train accu: 0.5594313740730286
 test accu: 0.5450893044471741
=== epoch 4 ===


train: 100%|██████████| 628/628 [00:07<00:00, 89.64it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 520.48it/s]


train loss: 1.175220251083374
 test loss: 1.2416893243789673
train accu: 0.5893378257751465
 test accu: 0.5778273940086365
=== epoch 5 ===


train: 100%|██████████| 628/628 [00:11<00:00, 53.23it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 482.34it/s]


train loss: 1.1168216466903687
 test loss: 1.1951146125793457
train accu: 0.6181495189666748
 test accu: 0.588690459728241
=== epoch 6 ===


train: 100%|██████████| 628/628 [00:11<00:00, 53.68it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 429.39it/s]


train loss: 1.0657399892807007
 test loss: 1.1564881801605225
train accu: 0.6365943551063538
 test accu: 0.5995535850524902
=== epoch 7 ===


train: 100%|██████████| 628/628 [00:12<00:00, 50.62it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 802.35it/s]


train loss: 1.020544409751892
 test loss: 1.1350778341293335
train accu: 0.6509255766868591
 test accu: 0.6117559671401978
=== epoch 8 ===


train: 100%|██████████| 628/628 [00:10<00:00, 60.33it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 942.49it/s]


train loss: 0.9810142517089844
 test loss: 1.1255720853805542
train accu: 0.6611431837081909
 test accu: 0.6200892925262451
=== epoch 9 ===


train: 100%|██████████| 628/628 [00:10<00:00, 58.75it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 715.13it/s]


train loss: 0.9475372433662415
 test loss: 1.1095243692398071
train accu: 0.6706973314285278
 test accu: 0.6255952715873718
=== epoch 10 ===


train: 100%|██████████| 628/628 [00:10<00:00, 62.48it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 359.95it/s]


train loss: 0.9188485145568848
 test loss: 1.087493896484375
train accu: 0.6784102916717529
 test accu: 0.6276785731315613
=== epoch 11 ===


train: 100%|██████████| 628/628 [00:10<00:00, 59.08it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 796.19it/s]


train loss: 0.8925533294677734
 test loss: 1.0873417854309082
train accu: 0.6862227916717529
 test accu: 0.6316964030265808
=== epoch 12 ===


train: 100%|██████████| 628/628 [00:16<00:00, 37.17it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 1134.98it/s]


train loss: 0.8697267770767212
 test loss: 1.0799897909164429
train accu: 0.6960754990577698
 test accu: 0.6407738327980042
=== epoch 13 ===


train: 100%|██████████| 628/628 [00:18<00:00, 34.04it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 1061.44it/s]


train loss: 0.8506645560264587
 test loss: 1.0792841911315918
train accu: 0.7009189128875732
 test accu: 0.6434524059295654
=== epoch 14 ===


train: 100%|██████████| 628/628 [00:12<00:00, 51.53it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 801.44it/s]


train loss: 0.82981938123703
 test loss: 1.0751893520355225
train accu: 0.7073215246200562
 test accu: 0.6461309790611267
=== epoch 15 ===


train: 100%|██████████| 628/628 [00:21<00:00, 29.43it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 601.82it/s]


train loss: 0.8076635003089905
 test loss: 1.0731360912322998
train accu: 0.7172240018844604
 test accu: 0.6468750238418579
=== epoch 16 ===


train: 100%|██████████| 628/628 [00:10<00:00, 59.99it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 977.33it/s]


train loss: 0.7915974259376526
 test loss: 1.0685449838638306
train accu: 0.7204086780548096
 test accu: 0.6495535969734192
=== epoch 17 ===


train: 100%|██████████| 628/628 [00:15<00:00, 40.49it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 376.24it/s]


train loss: 0.7766727209091187
 test loss: 1.058601975440979
train accu: 0.7252355217933655
 test accu: 0.653124988079071
=== epoch 18 ===


train: 100%|██████████| 628/628 [00:13<00:00, 46.28it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 852.32it/s]


train loss: 0.7663139700889587
 test loss: 1.0641287565231323
train accu: 0.7284865975379944
 test accu: 0.6529761552810669
=== epoch 19 ===


train: 100%|██████████| 628/628 [00:13<00:00, 46.02it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 826.45it/s]


train loss: 0.7500232458114624
 test loss: 1.0619685649871826
train accu: 0.7326996922492981
 test accu: 0.6571428775787354
=== epoch 20 ===


train: 100%|██████████| 628/628 [00:04<00:00, 134.47it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 262.79it/s]


train loss: 0.7362562417984009
 test loss: 1.0648199319839478
train accu: 0.736282467842102
 test accu: 0.664434552192688
=== epoch 21 ===


train: 100%|██████████| 628/628 [00:08<00:00, 75.04it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 1150.41it/s]


train loss: 0.7251154184341431
 test loss: 1.0691219568252563
train accu: 0.7422704696655273
 test accu: 0.6581845283508301
=== epoch 22 ===


train: 100%|██████████| 628/628 [00:09<00:00, 65.28it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 1026.36it/s]


train loss: 0.7097309231758118
 test loss: 1.0774987936019897
train accu: 0.7478602528572083
 test accu: 0.6595238447189331
=== epoch 23 ===


train: 100%|██████████| 628/628 [00:03<00:00, 168.02it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 1119.47it/s]


train loss: 0.6976060271263123
 test loss: 1.0894118547439575
train accu: 0.7533174157142639
 test accu: 0.6627976298332214
=== epoch 24 ===


train: 100%|██████████| 628/628 [00:05<00:00, 123.28it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 1140.11it/s]


train loss: 0.6855646371841431
 test loss: 1.097861647605896
train accu: 0.7583930492401123
 test accu: 0.6619047522544861
=== epoch 25 ===


train: 100%|██████████| 628/628 [00:11<00:00, 55.24it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 710.78it/s]


train loss: 0.6738675236701965
 test loss: 1.0894370079040527
train accu: 0.7633857727050781
 test accu: 0.666815459728241
=== epoch 26 ===


train: 100%|██████████| 628/628 [00:04<00:00, 138.86it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 1253.02it/s]


train loss: 0.6593242287635803
 test loss: 1.1012741327285767
train accu: 0.7665041089057922
 test accu: 0.6659225821495056
=== epoch 27 ===


train: 100%|██████████| 628/628 [00:04<00:00, 148.92it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 1241.98it/s]


train loss: 0.6483054757118225
 test loss: 1.085457682609558
train accu: 0.7747644782066345
 test accu: 0.6760416626930237
=== epoch 28 ===


train: 100%|██████████| 628/628 [00:03<00:00, 157.58it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 1365.97it/s]


train loss: 0.6360803842544556
 test loss: 1.1364537477493286
train accu: 0.7796410918235779
 test accu: 0.6605654358863831
=== epoch 29 ===


train: 100%|██████████| 628/628 [00:04<00:00, 150.04it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 850.34it/s]


train loss: 0.6252405047416687
 test loss: 1.088197946548462
train accu: 0.7833731770515442
 test accu: 0.6840773820877075
=== epoch 30 ===


train: 100%|██████████| 628/628 [00:04<00:00, 151.86it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 1340.20it/s]


train loss: 0.6116631031036377
 test loss: 1.1000868082046509
train accu: 0.7901738286018372
 test accu: 0.6788690090179443
=== epoch 31 ===


train: 100%|██████████| 628/628 [00:06<00:00, 99.13it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 751.49it/s]


train loss: 0.6007989048957825
 test loss: 1.088274598121643
train accu: 0.7950670123100281
 test accu: 0.6839285492897034
=== epoch 32 ===


train: 100%|██████████| 628/628 [00:08<00:00, 71.68it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 1362.89it/s]


train loss: 0.5870400667190552
 test loss: 1.1159673929214478
train accu: 0.8004246354103088
 test accu: 0.676934540271759
=== epoch 33 ===


train: 100%|██████████| 628/628 [00:10<00:00, 59.39it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 364.44it/s]


train loss: 0.5811120867729187
 test loss: 1.1028326749801636
train accu: 0.8055997490882874
 test accu: 0.6854166984558105
=== epoch 34 ===


train: 100%|██████████| 628/628 [00:08<00:00, 75.88it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 199.24it/s]


train loss: 0.5718633532524109
 test loss: 1.1411614418029785
train accu: 0.8081044554710388
 test accu: 0.6772321462631226
=== epoch 35 ===


train: 100%|██████████| 628/628 [00:05<00:00, 112.31it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 957.56it/s]


train loss: 0.5621746182441711
 test loss: 1.1047070026397705
train accu: 0.812699019908905
 test accu: 0.6845238208770752
=== epoch 36 ===


train: 100%|██████████| 628/628 [00:03<00:00, 174.07it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 1174.77it/s]


train loss: 0.5557459592819214
 test loss: 1.1406359672546387
train accu: 0.8158173561096191
 test accu: 0.6802083253860474
=== epoch 37 ===


train: 100%|██████████| 628/628 [00:04<00:00, 156.39it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 1196.85it/s]


train loss: 0.5535765886306763
 test loss: 1.0960819721221924
train accu: 0.8151705265045166
 test accu: 0.6918154358863831
=== epoch 38 ===


train: 100%|██████████| 628/628 [00:06<00:00, 92.56it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 458.09it/s]


train loss: 0.5442521572113037
 test loss: 1.0674926042556763
train accu: 0.8178077936172485
 test accu: 0.6979166865348816
=== epoch 39 ===


train: 100%|██████████| 628/628 [00:06<00:00, 96.51it/s] 
 test: 100%|██████████| 70/70 [00:00<00:00, 783.27it/s]


train loss: 0.5368861556053162
 test loss: 1.1059092283248901
train accu: 0.8224688768386841
 test accu: 0.6888392567634583
=== epoch 40 ===


train: 100%|██████████| 628/628 [00:03<00:00, 171.88it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 1319.94it/s]


train loss: 0.531315803527832
 test loss: 1.0758333206176758
train accu: 0.8225517272949219
 test accu: 0.6882440447807312
=== epoch 41 ===


train: 100%|██████████| 628/628 [00:04<00:00, 152.99it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 1614.18it/s]


train loss: 0.5235191583633423
 test loss: 1.0838522911071777
train accu: 0.8268643617630005
 test accu: 0.7025297284126282
=== epoch 42 ===


train: 100%|██████████| 628/628 [00:03<00:00, 170.07it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 1140.85it/s]


train loss: 0.5195631980895996
 test loss: 1.082788348197937
train accu: 0.8290704488754272
 test accu: 0.6869047284126282
=== epoch 43 ===


train: 100%|██████████| 628/628 [00:04<00:00, 139.30it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 1337.45it/s]


train loss: 0.5157086849212646
 test loss: 1.0552030801773071
train accu: 0.8313594460487366
 test accu: 0.7040178775787354
=== epoch 44 ===


train: 100%|██████████| 628/628 [00:04<00:00, 150.60it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 732.88it/s]


train loss: 0.5099453330039978
 test loss: 1.0832167863845825
train accu: 0.8303808569908142
 test accu: 0.6937500238418579
=== epoch 45 ===


train: 100%|██████████| 628/628 [00:04<00:00, 146.14it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 1381.30it/s]


train loss: 0.5060484409332275
 test loss: 1.0756645202636719
train accu: 0.8335323333740234
 test accu: 0.6921131014823914
=== epoch 46 ===


train: 100%|██████████| 628/628 [00:04<00:00, 155.52it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 306.62it/s]


train loss: 0.5022204518318176
 test loss: 1.0893667936325073
train accu: 0.8341294527053833
 test accu: 0.6895833015441895
=== epoch 47 ===


train: 100%|██████████| 628/628 [00:04<00:00, 139.47it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 1662.88it/s]


train loss: 0.4990904927253723
 test loss: 1.0830504894256592
train accu: 0.8353403806686401
 test accu: 0.6889880895614624
=== epoch 48 ===


train: 100%|██████████| 628/628 [00:05<00:00, 122.23it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 1021.13it/s]


train loss: 0.4950273334980011
 test loss: 1.0912714004516602
train accu: 0.8378118276596069
 test accu: 0.6913690567016602
=== epoch 49 ===


train: 100%|██████████| 628/628 [00:03<00:00, 175.59it/s]
 test: 100%|██████████| 70/70 [00:00<00:00, 923.34it/s]

train loss: 0.4871978759765625
 test loss: 1.0959736108779907
train accu: 0.8391056060791016
 test accu: 0.6895833015441895





In [12]:
for i in range(20):
    print(decode(test_x[i]), classes[test_y[i]], end=' ')
    y_hat = torch.softmax(model.forward(F.one_hot(test_x[i], N_VOCAB).to(dtype=torch.float32)), dim=-1).squeeze()
    print(classes[y_hat.argmax(dim=-1)], f'({y_hat.max().item() * 100:.2f}%)')

Leme BR BR (38.76%)
Minatitlan MX MX (63.12%)
Valley US US (92.83%)
Garden City US US (99.52%)
Hampton US US (77.23%)
Dicholi IN IN (72.40%)
Pau Brasil BR BR (97.84%)
Taloda IN US (60.19%)
Mutum BR US (25.32%)
Degana IN IN (58.53%)
Sao Goncalo  BR BR (99.01%)
Pileru IN IN (66.45%)
Hunsur IN DE (27.55%)
Vincennes US DE (81.73%)
Lawrence US US (96.06%)
Bad Iburg DE US (46.03%)
Morbi IN IN (46.55%)
Othello US DE (47.44%)
Amari IN IN (77.16%)
Lobau DE US (40.57%)


In [13]:
from torch.distributions import Categorical
from torch.distributions.gumbel import Gumbel
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

class AntiEmb():
    def __init__(self,
               inner_net: nn.Module,
               desired_c: int,
               lr: float=1e-2,
               right_c_coef: float=1.0,
               wrong_c_coef: float=1.0):
        super().__init__()

        self.input_shape = (1, SEQ_LENGTH, N_VOCAB)
        self.input = nn.Parameter(
            data=torch.normal(0, 1, size=self.input_shape),
            requires_grad=True
        )

        self.inner_net = inner_net
        self.grad_eraser = torch.optim.SGD(self.inner_net.parameters())
        self.anti_optimizer = torch.optim.Adam((self.input,), lr=lr)
        # self.anti_scheduler = CosineAnnealingWarmRestarts(self.anti_optimizer, T_0=50, T_mult=2)

        self.desired_c = desired_c
        self.right_c_coef = right_c_coef
        self.wrong_c_coef = wrong_c_coef

        # self.input_min = 0.0
        # self.input_max = 1.0

        self.gumbel = Gumbel(loc=0.0, scale=1.0)
        self.epsilon = 1e-8

    def test(self):
        with torch.no_grad():
            return self.inner_net.forward(self.input)

    def zero_grad(self):
        self.grad_eraser.zero_grad()
        self.anti_optimizer.zero_grad()

    def crop(self):
        self.input.data.copy_(torch.clamp(self.input.detach().clone(), self.input_min, self.input_max))
    
    def step(self):
        prev_input = self.input.detach().clone()

        probs = torch.softmax(self.input, dim=-1)
        gumb = self.gumbel.sample(self.input_shape)

        j_onehot = F.one_hot(torch.argmax(gumb + torch.log(probs + self.epsilon), dim=-1), N_VOCAB)
        j_onehot = j_onehot.to(torch.float32).requires_grad_(True)

        j_continuous = torch.softmax(gumb + torch.log(probs + self.epsilon), dim=-1)

        output = self.inner_net.forward(j_onehot).squeeze()
        output_soft = torch.softmax(output, dim=-1)
        
        loss = output_soft[torch.arange(0, N_CLASSES) != self.desired_c].max() * self.wrong_c_coef + \
               output_soft[self.desired_c] * -self.right_c_coef
        loss.backward()

        j_continuous.backward(j_onehot.grad)

        self.anti_optimizer.step()
        self.zero_grad()

        # self.anti_scheduler.step()

        # self.crop()

        return prev_input, output.detach().clone(), loss.detach().clone()

In [14]:
ANTI_NET_STEPS = 3000
TEST_TRIES = 5

for i in range(N_CLASSES):
    desired_c = i

    print(f'-=-=- step {i} -=-=-')
    print(f'des class:', classes[desired_c], f'({desired_c})')

    fin_loss = float('+inf')
    fin_inp, fin_outp = None, None

    for try_i in range(TEST_TRIES):
        anti_net = AntiEmb(model, desired_c, lr=5e-2)

        for step_i in tqdm.tqdm(range(ANTI_NET_STEPS), f'attempt #{try_i}'):
            inp, outp, loss = anti_net.step()
        
        if loss.item() < fin_loss:
            fin_loss = loss.item()
            fin_inp = inp
            fin_outp = outp
        
        print(f'attempt #{try_i} loss:', loss.item())
    
    print(f'fin loss:', fin_loss)
    print(f'fin output (from function):', fin_outp)

    inp_argmax = fin_inp.argmax(dim=-1)

    print(f'fin input  (argmax):', decode(inp_argmax.squeeze()))
    print(f'fin output (argmax):', model.forward(F.one_hot(inp_argmax, N_VOCAB).to(torch.float32)))

    dist = torch.distributions.Categorical(logits=fin_inp)
    sample = dist.sample()
    entropy = dist.entropy()

    print(f'fin input  (sample):', decode(sample.squeeze()))
    print(f'fin output (sample):', model.forward(F.one_hot(sample, N_VOCAB).to(torch.float32)))
    print(f'fin input entropy:', entropy)

-=-=- step 0 -=-=-
des class: IN (0)


attempt #0:   0%|          | 0/3000 [00:00<?, ?it/s]

attempt #0: 100%|██████████| 3000/3000 [00:28<00:00, 105.26it/s]


attempt #0 loss: -1.0


attempt #1: 100%|██████████| 3000/3000 [00:21<00:00, 140.94it/s]


attempt #1 loss: -1.0


attempt #2: 100%|██████████| 3000/3000 [00:21<00:00, 140.57it/s]


attempt #2 loss: -1.0


attempt #3: 100%|██████████| 3000/3000 [00:19<00:00, 150.75it/s]


attempt #3 loss: -1.0


attempt #4: 100%|██████████| 3000/3000 [00:12<00:00, 243.59it/s]


attempt #4 loss: -1.0
fin loss: -1.0
fin output (from function): tensor([ 33.0438,  -7.7766, -19.0549, -37.6644, -20.2585, -13.5219,   9.1719,
        -15.1635])
fin input  (argmax): PNjaaKpjjBhh
fin output (argmax): tensor([[ 46.6128, -10.8342, -18.9394, -50.0701, -28.9233, -25.5937,   9.6162,
         -20.0990]], grad_fn=<AddmmBackward0>)
fin input  (sample): PjUaMTpajBhh
fin output (sample): tensor([[ 30.2518,  -6.7675, -10.3639, -30.6660, -19.6235, -21.5355,   5.2117,
         -13.3198]], grad_fn=<AddmmBackward0>)
fin input entropy: tensor([[2.3988, 2.5359, 2.5615, 1.6119, 2.7087, 1.3533, 1.5548, 2.3987, 0.6532,
         0.6330, 0.5502, 0.6463]])
-=-=- step 1 -=-=-
des class: US (1)


attempt #0: 100%|██████████| 3000/3000 [00:11<00:00, 255.93it/s]


attempt #0 loss: -0.9955480694770813


attempt #1: 100%|██████████| 3000/3000 [00:11<00:00, 270.47it/s]


attempt #1 loss: -0.9980511665344238


attempt #2: 100%|██████████| 3000/3000 [00:10<00:00, 277.11it/s]


attempt #2 loss: -0.9990313649177551


attempt #3: 100%|██████████| 3000/3000 [00:14<00:00, 205.84it/s]


attempt #3 loss: -0.9974507093429565


attempt #4: 100%|██████████| 3000/3000 [00:10<00:00, 276.96it/s]


attempt #4 loss: -0.9992530941963196
fin loss: -0.9992530941963196
fin output (from function): tensor([  0.6444,  22.8224, -21.7648,  14.9299, -34.1634, -73.3837, -17.1440,
        -17.1059])
fin input  (argmax): BLwtHCH.wWlw
fin output (argmax): tensor([[  0.5226,  21.8606, -20.6603,  14.1541, -32.6519, -69.9312, -16.5659,
         -16.0854]], grad_fn=<AddmmBackward0>)
fin input  (sample): BcwtHCHxZWlH
fin output (sample): tensor([[  0.4048,  20.8629, -19.5230,  13.3520, -31.0935, -66.3527, -15.9618,
         -15.0477]], grad_fn=<AddmmBackward0>)
fin input entropy: tensor([[0.1034, 1.7189, 0.1233, 1.9884, 2.1215, 0.3926, 0.5872, 0.4841, 0.8569,
         0.1909, 0.6080, 1.2833]])
-=-=- step 2 -=-=-
des class: BR (2)


attempt #0: 100%|██████████| 3000/3000 [00:14<00:00, 209.28it/s]


attempt #0 loss: -0.9946171045303345


attempt #1: 100%|██████████| 3000/3000 [00:11<00:00, 264.82it/s]


attempt #1 loss: -0.9958387613296509


attempt #2: 100%|██████████| 3000/3000 [00:14<00:00, 207.15it/s]


attempt #2 loss: -0.9971375465393066


attempt #3: 100%|██████████| 3000/3000 [00:11<00:00, 272.25it/s]


attempt #3 loss: -0.9966521859169006


attempt #4: 100%|██████████| 3000/3000 [00:22<00:00, 133.38it/s]


attempt #4 loss: -0.9971781969070435
fin loss: -0.9971781969070435
fin output (from function): tensor([  2.6685,   1.8251,   9.4617,  -7.9852,  -3.8837,  -3.6653, -12.3450,
          0.2301])
fin input  (argmax): I xiqAoaCoOA
fin output (argmax): tensor([[  3.1549,   1.9031,   9.9987,  -8.7853,  -3.9970,  -3.7589, -12.8650,
          -0.1669]], grad_fn=<AddmmBackward0>)
fin input  (sample): I QiqXoCCoOA
fin output (sample): tensor([[  3.1862,   1.9966,  10.0456,  -8.7883,  -4.3169,  -4.1472, -13.0729,
           0.0676]], grad_fn=<AddmmBackward0>)
fin input entropy: tensor([[1.0608, 0.1727, 2.3082, 0.3539, 0.5114, 0.6343, 0.2311, 3.0917, 1.6122,
         0.2896, 1.6428, 1.6460]])
-=-=- step 3 -=-=-
des class: DE (3)


attempt #0: 100%|██████████| 3000/3000 [00:10<00:00, 273.34it/s]


attempt #0 loss: -0.9984422922134399


attempt #1: 100%|██████████| 3000/3000 [00:21<00:00, 141.17it/s]


attempt #1 loss: -0.9999470710754395


attempt #2: 100%|██████████| 3000/3000 [00:23<00:00, 126.40it/s]


attempt #2 loss: -0.9998937249183655


attempt #3: 100%|██████████| 3000/3000 [00:21<00:00, 142.03it/s]


attempt #3 loss: -0.9991379976272583


attempt #4: 100%|██████████| 3000/3000 [00:13<00:00, 229.76it/s]


attempt #4 loss: -0.9998248815536499
fin loss: -0.9999470710754395
fin output (from function): tensor([  2.6247,   6.6649, -19.1119,  17.5444, -14.3742, -48.7725,   6.4302,
        -25.9031])
fin input  (argmax): WtrRWWWBfyfL
fin output (argmax): tensor([[  3.0196,   9.3819, -22.7777,  20.6435, -18.7063, -59.9129,   5.3309,
         -30.1048]], grad_fn=<AddmmBackward0>)
fin input  (sample): WtrRWWNBfyfW
fin output (sample): tensor([[  2.3757,   8.4066, -19.7279,  18.1674, -16.4356, -52.4573,   4.2610,
         -26.2113]], grad_fn=<AddmmBackward0>)
fin input entropy: tensor([[0.3848, 0.6816, 1.0163, 0.3989, 0.4146, 0.0628, 0.4794, 1.3249, 0.3184,
         2.0621, 1.7584, 1.5260]])
-=-=- step 4 -=-=-
des class: CN (4)


attempt #0: 100%|██████████| 3000/3000 [00:10<00:00, 278.02it/s]


attempt #0 loss: -0.9999775290489197


attempt #1: 100%|██████████| 3000/3000 [00:10<00:00, 284.15it/s]


attempt #1 loss: -0.9999811053276062


attempt #2: 100%|██████████| 3000/3000 [00:10<00:00, 290.36it/s]


attempt #2 loss: -0.9999977350234985


attempt #3: 100%|██████████| 3000/3000 [00:10<00:00, 281.98it/s]


attempt #3 loss: -0.9999986290931702


attempt #4: 100%|██████████| 3000/3000 [00:10<00:00, 282.50it/s]


attempt #4 loss: -0.9985603094100952
fin loss: -0.9999986290931702
fin output (from function): tensor([  1.4530,  -0.4674,  -9.5489,  -8.5752,  15.8480,   0.1466, -11.3359,
         -2.8385])
fin input  (argmax): ki'iuo'gggeV
fin output (argmax): tensor([[  0.8226,  -0.2983,  -8.3794,  -7.1657,  14.7365,  -0.4255, -10.7098,
          -2.3388]], grad_fn=<AddmmBackward0>)
fin input  (sample): kt'iuaaggDCz
fin output (sample): tensor([[-1.0502,  0.2363, -4.2104, -2.8806, 10.9335, -2.2181, -8.8071, -0.7746]],
       grad_fn=<AddmmBackward0>)
fin input entropy: tensor([[0.6841, 0.9543, 0.0738, 2.2771, 2.2248, 2.9769, 1.2987, 1.0489, 1.3200,
         0.4684, 0.8853, 2.2021]])
-=-=- step 5 -=-=-
des class: JP (5)


attempt #0: 100%|██████████| 3000/3000 [00:10<00:00, 287.42it/s]


attempt #0 loss: -0.9993529915809631


attempt #1: 100%|██████████| 3000/3000 [00:10<00:00, 285.13it/s]


attempt #1 loss: -0.9997848272323608


attempt #2: 100%|██████████| 3000/3000 [00:10<00:00, 279.83it/s]


attempt #2 loss: -0.9995297789573669


attempt #3: 100%|██████████| 3000/3000 [00:10<00:00, 289.85it/s]


attempt #3 loss: -0.9999352693557739


attempt #4: 100%|██████████| 3000/3000 [00:10<00:00, 285.67it/s]


attempt #4 loss: -0.999035120010376
fin loss: -0.9999352693557739
fin output (from function): tensor([  9.4077,  -2.2033, -23.7475, -22.2308,  -2.7551,  19.8812,   8.1771,
         -8.5535])
fin input  (argmax): uOOOT'hTiao'
fin output (argmax): tensor([[  7.1054,  -1.6663, -20.3315, -17.5537,  -2.1678,  16.4518,   7.1227,
          -7.0917]], grad_fn=<AddmmBackward0>)
fin input  (sample): uOSOj'hTiao'
fin output (sample): tensor([[  6.8919,  -1.5111, -18.7631, -17.3031,  -1.8705,  16.1752,   6.1061,
          -6.9613]], grad_fn=<AddmmBackward0>)
fin input entropy: tensor([[0.2601, 0.1830, 1.8949, 2.0207, 0.9865, 0.7062, 0.0169, 0.1933, 0.0367,
         0.2528, 0.6719, 0.0428]])
-=-=- step 6 -=-=-
des class: RU (6)


attempt #0: 100%|██████████| 3000/3000 [00:10<00:00, 290.18it/s]


attempt #0 loss: -0.9999980926513672


attempt #1: 100%|██████████| 3000/3000 [00:10<00:00, 289.61it/s]


attempt #1 loss: -0.9991814494132996


attempt #2: 100%|██████████| 3000/3000 [00:10<00:00, 282.42it/s]


attempt #2 loss: -0.9999995231628418


attempt #3: 100%|██████████| 3000/3000 [00:10<00:00, 275.69it/s]


attempt #3 loss: -1.0


attempt #4: 100%|██████████| 3000/3000 [00:11<00:00, 251.71it/s]


attempt #4 loss: -1.0
fin loss: -1.0
fin output (from function): tensor([ 16.1617,  -6.5635, -33.0676,   1.7112, -20.4552, -32.7809,  34.1356,
        -27.9527])
fin input  (argmax): p-vm'vv'kkXv
fin output (argmax): tensor([[ 16.6827,  -7.1821, -30.4896,  -7.2995, -19.4865, -18.2040,  30.2312,
         -19.5458]], grad_fn=<AddmmBackward0>)
fin input  (sample): Kgvs'vvhksXv
fin output (sample): tensor([[ 12.0222,  -5.0522, -23.0480,  -1.1248, -15.2734, -20.6257,  24.0229,
         -18.4207]], grad_fn=<AddmmBackward0>)
fin input entropy: tensor([[1.2813, 1.5263, 0.6045, 2.3623, 2.0119, 2.0138, 1.5451, 2.3079, 0.1849,
         2.7784, 0.4477, 1.0542]])
-=-=- step 7 -=-=-
des class: MX (7)


attempt #0: 100%|██████████| 3000/3000 [00:18<00:00, 165.58it/s]


attempt #0 loss: -0.8864269852638245


attempt #1: 100%|██████████| 3000/3000 [00:10<00:00, 280.93it/s]


attempt #1 loss: -0.878266453742981


attempt #2: 100%|██████████| 3000/3000 [00:11<00:00, 260.23it/s]


attempt #2 loss: -0.8908573389053345


attempt #3: 100%|██████████| 3000/3000 [00:11<00:00, 268.05it/s]


attempt #3 loss: -0.8882911801338196


attempt #4: 100%|██████████| 3000/3000 [00:14<00:00, 202.28it/s]


attempt #4 loss: -0.87801593542099
fin loss: -0.8908573389053345
fin output (from function): tensor([-1.0934,  2.6656,  2.4943, -2.4017, -5.1932, -5.7134, -8.7659,  5.8622])
fin input  (argmax): sqcFaYAIqi F
fin output (argmax): tensor([[-1.0934,  2.6656,  2.4943, -2.4017, -5.1932, -5.7134, -8.7659,  5.8622]],
       grad_fn=<AddmmBackward0>)
fin input  (sample): sqcFuYAIqi F
fin output (sample): tensor([[-1.0831,  2.6334,  2.5211, -2.4175, -5.1583, -5.6532, -8.7498,  5.8585]],
       grad_fn=<AddmmBackward0>)
fin input entropy: tensor([[0.9058, 0.0627, 0.2960, 0.4487, 0.7257, 0.1034, 0.5485, 0.0823, 0.2642,
         0.1631, 0.2633, 0.0858]])
