In [1]:
import topmost
from topmost.data import download_dataset

device = "cuda" # or "cpu"
dataset_dir = "./datasets/Amazon_Review"
download_dataset('Amazon_Review', cache_path='./datasets')

https://raw.githubusercontent.com/BobXWu/TopMost/master/data/Amazon_Review.zip
Using downloaded and verified file: ./datasets/Amazon_Review.zip


In [2]:
dict_dir = './datasets/dict'
download_dataset('dict', cache_path='./datasets')

https://raw.githubusercontent.com/BobXWu/TopMost/master/data/dict.zip
Using downloaded and verified file: ./datasets/dict.zip


In [3]:
# load a preprocessed dataset
dataset = topmost.data.CrosslingualDatasetHandler(dataset_dir, lang1='en', lang2='cn', dict_path=f'{dict_dir}/ch_en_dict.dat', device=device, batch_size=128, as_tensor=True)

In [4]:
# create a model
model = topmost.models.InfoCTM(
    trans_e2c=dataset.trans_matrix_en,
    pretrain_word_embeddings_en=dataset.pretrained_WE_en,
    pretrain_word_embeddings_cn=dataset.pretrained_WE_cn,
    vocab_size_en=dataset.vocab_size_en,
    vocab_size_cn=dataset.vocab_size_cn,
    weight_MI=50
)
model = model.to(device)

# create a trainer
trainer = topmost.trainers.CrosslingualTrainer(model, lr_scheduler='StepLR', lr_step_size=125, epochs=500)

# train the model
trainer.train(dataset)

# Or directly use fit_transform
# top_words_en, top_words_cn, train_theta_en, train_theta_cn = trainer.fit_transform(dataset)

  1%|          | 5/500 [00:17<28:13,  3.42s/it]

Epoch: 005 loss: 6.729 loss_TAMI: 2.135


  2%|▏         | 10/500 [00:34<27:38,  3.39s/it]

Epoch: 010 loss: 6.556 loss_TAMI: 2.129


  3%|▎         | 15/500 [00:51<27:20,  3.38s/it]

Epoch: 015 loss: 6.493 loss_TAMI: 2.128


  4%|▍         | 20/500 [01:07<27:02,  3.38s/it]

Epoch: 020 loss: 6.466 loss_TAMI: 2.127


  5%|▌         | 25/500 [01:24<26:45,  3.38s/it]

Epoch: 025 loss: 6.455 loss_TAMI: 2.126


  6%|▌         | 30/500 [01:41<26:28,  3.38s/it]

Epoch: 030 loss: 6.441 loss_TAMI: 2.126


  7%|▋         | 35/500 [01:58<26:11,  3.38s/it]

Epoch: 035 loss: 6.432 loss_TAMI: 2.126


  8%|▊         | 40/500 [02:15<25:54,  3.38s/it]

Epoch: 040 loss: 6.422 loss_TAMI: 2.126


  9%|▉         | 45/500 [02:32<25:38,  3.38s/it]

Epoch: 045 loss: 6.416 loss_TAMI: 2.125


 10%|█         | 50/500 [02:49<25:22,  3.38s/it]

Epoch: 050 loss: 6.406 loss_TAMI: 2.125


 11%|█         | 55/500 [03:05<25:04,  3.38s/it]

Epoch: 055 loss: 6.401 loss_TAMI: 2.126


 12%|█▏        | 60/500 [03:22<24:47,  3.38s/it]

Epoch: 060 loss: 6.399 loss_TAMI: 2.126


 13%|█▎        | 65/500 [03:39<24:30,  3.38s/it]

Epoch: 065 loss: 6.389 loss_TAMI: 2.126


 14%|█▍        | 70/500 [03:56<24:13,  3.38s/it]

Epoch: 070 loss: 6.385 loss_TAMI: 2.126


 15%|█▌        | 75/500 [04:13<23:56,  3.38s/it]

Epoch: 075 loss: 6.381 loss_TAMI: 2.126


 16%|█▌        | 80/500 [04:30<23:39,  3.38s/it]

Epoch: 080 loss: 6.380 loss_TAMI: 2.126


 17%|█▋        | 85/500 [04:47<23:23,  3.38s/it]

Epoch: 085 loss: 6.369 loss_TAMI: 2.126


 18%|█▊        | 90/500 [05:04<23:06,  3.38s/it]

Epoch: 090 loss: 6.372 loss_TAMI: 2.126


 19%|█▉        | 95/500 [05:20<22:49,  3.38s/it]

Epoch: 095 loss: 6.369 loss_TAMI: 2.126


 20%|██        | 100/500 [05:37<22:32,  3.38s/it]

Epoch: 100 loss: 6.364 loss_TAMI: 2.126


 21%|██        | 105/500 [05:54<22:15,  3.38s/it]

Epoch: 105 loss: 6.366 loss_TAMI: 2.126


 22%|██▏       | 110/500 [06:11<21:58,  3.38s/it]

Epoch: 110 loss: 6.361 loss_TAMI: 2.126


 23%|██▎       | 115/500 [06:28<21:41,  3.38s/it]

Epoch: 115 loss: 6.358 loss_TAMI: 2.126


 24%|██▍       | 120/500 [06:45<21:24,  3.38s/it]

Epoch: 120 loss: 6.354 loss_TAMI: 2.126


 25%|██▌       | 125/500 [07:02<21:07,  3.38s/it]

Epoch: 125 loss: 6.352 loss_TAMI: 2.126


 26%|██▌       | 130/500 [07:18<20:50,  3.38s/it]

Epoch: 130 loss: 6.337 loss_TAMI: 2.125


 27%|██▋       | 135/500 [07:35<20:33,  3.38s/it]

Epoch: 135 loss: 6.336 loss_TAMI: 2.125


 28%|██▊       | 140/500 [07:52<20:17,  3.38s/it]

Epoch: 140 loss: 6.336 loss_TAMI: 2.125


 29%|██▉       | 145/500 [08:09<19:59,  3.38s/it]

Epoch: 145 loss: 6.337 loss_TAMI: 2.125


 30%|███       | 150/500 [08:26<19:43,  3.38s/it]

Epoch: 150 loss: 6.334 loss_TAMI: 2.125


 31%|███       | 155/500 [08:43<19:26,  3.38s/it]

Epoch: 155 loss: 6.334 loss_TAMI: 2.125


 32%|███▏      | 160/500 [09:00<19:09,  3.38s/it]

Epoch: 160 loss: 6.334 loss_TAMI: 2.125


 33%|███▎      | 165/500 [09:16<18:52,  3.38s/it]

Epoch: 165 loss: 6.332 loss_TAMI: 2.125


 34%|███▍      | 170/500 [09:33<18:35,  3.38s/it]

Epoch: 170 loss: 6.331 loss_TAMI: 2.125


 35%|███▌      | 175/500 [09:50<18:18,  3.38s/it]

Epoch: 175 loss: 6.331 loss_TAMI: 2.125


 36%|███▌      | 180/500 [10:07<18:03,  3.39s/it]

Epoch: 180 loss: 6.333 loss_TAMI: 2.125


 37%|███▋      | 185/500 [10:24<17:45,  3.38s/it]

Epoch: 185 loss: 6.332 loss_TAMI: 2.125


 38%|███▊      | 190/500 [10:41<17:28,  3.38s/it]

Epoch: 190 loss: 6.326 loss_TAMI: 2.125


 39%|███▉      | 195/500 [10:58<17:10,  3.38s/it]

Epoch: 195 loss: 6.326 loss_TAMI: 2.125


 40%|████      | 200/500 [11:15<16:54,  3.38s/it]

Epoch: 200 loss: 6.327 loss_TAMI: 2.125


 41%|████      | 205/500 [11:31<16:37,  3.38s/it]

Epoch: 205 loss: 6.326 loss_TAMI: 2.125


 42%|████▏     | 210/500 [11:48<16:20,  3.38s/it]

Epoch: 210 loss: 6.325 loss_TAMI: 2.125


 43%|████▎     | 215/500 [12:05<16:03,  3.38s/it]

Epoch: 215 loss: 6.329 loss_TAMI: 2.125


 44%|████▍     | 220/500 [12:22<15:46,  3.38s/it]

Epoch: 220 loss: 6.324 loss_TAMI: 2.125


 45%|████▌     | 225/500 [12:39<15:29,  3.38s/it]

Epoch: 225 loss: 6.325 loss_TAMI: 2.125


 46%|████▌     | 230/500 [12:56<15:12,  3.38s/it]

Epoch: 230 loss: 6.328 loss_TAMI: 2.125


 47%|████▋     | 235/500 [13:13<14:56,  3.38s/it]

Epoch: 235 loss: 6.324 loss_TAMI: 2.125


 48%|████▊     | 240/500 [13:30<14:39,  3.38s/it]

Epoch: 240 loss: 6.325 loss_TAMI: 2.125


 49%|████▉     | 245/500 [13:46<14:22,  3.38s/it]

Epoch: 245 loss: 6.323 loss_TAMI: 2.125


 50%|█████     | 250/500 [14:03<14:05,  3.38s/it]

Epoch: 250 loss: 6.324 loss_TAMI: 2.125


 51%|█████     | 255/500 [14:20<13:48,  3.38s/it]

Epoch: 255 loss: 6.318 loss_TAMI: 2.125


 52%|█████▏    | 260/500 [14:37<13:31,  3.38s/it]

Epoch: 260 loss: 6.313 loss_TAMI: 2.125


 53%|█████▎    | 265/500 [14:54<13:14,  3.38s/it]

Epoch: 265 loss: 6.315 loss_TAMI: 2.125


 54%|█████▍    | 270/500 [15:11<12:57,  3.38s/it]

Epoch: 270 loss: 6.317 loss_TAMI: 2.125


 55%|█████▌    | 275/500 [15:28<12:40,  3.38s/it]

Epoch: 275 loss: 6.314 loss_TAMI: 2.125


 56%|█████▌    | 280/500 [15:44<12:23,  3.38s/it]

Epoch: 280 loss: 6.314 loss_TAMI: 2.125


 57%|█████▋    | 285/500 [16:01<12:06,  3.38s/it]

Epoch: 285 loss: 6.315 loss_TAMI: 2.125


 58%|█████▊    | 290/500 [16:18<11:50,  3.38s/it]

Epoch: 290 loss: 6.319 loss_TAMI: 2.125


 59%|█████▉    | 295/500 [16:35<11:33,  3.38s/it]

Epoch: 295 loss: 6.318 loss_TAMI: 2.125


 60%|██████    | 300/500 [16:52<11:16,  3.38s/it]

Epoch: 300 loss: 6.310 loss_TAMI: 2.125


 61%|██████    | 305/500 [17:09<10:59,  3.38s/it]

Epoch: 305 loss: 6.311 loss_TAMI: 2.125


 62%|██████▏   | 310/500 [17:26<10:42,  3.38s/it]

Epoch: 310 loss: 6.318 loss_TAMI: 2.125


 63%|██████▎   | 315/500 [17:43<10:25,  3.38s/it]

Epoch: 315 loss: 6.314 loss_TAMI: 2.125


 64%|██████▍   | 320/500 [17:59<10:08,  3.38s/it]

Epoch: 320 loss: 6.316 loss_TAMI: 2.125


 65%|██████▌   | 325/500 [18:16<09:51,  3.38s/it]

Epoch: 325 loss: 6.312 loss_TAMI: 2.125


 66%|██████▌   | 330/500 [18:33<09:34,  3.38s/it]

Epoch: 330 loss: 6.310 loss_TAMI: 2.125


 67%|██████▋   | 335/500 [18:50<09:18,  3.38s/it]

Epoch: 335 loss: 6.312 loss_TAMI: 2.125


 68%|██████▊   | 340/500 [19:07<09:01,  3.38s/it]

Epoch: 340 loss: 6.314 loss_TAMI: 2.125


 69%|██████▉   | 345/500 [19:24<08:44,  3.38s/it]

Epoch: 345 loss: 6.315 loss_TAMI: 2.125


 70%|███████   | 350/500 [19:41<08:28,  3.39s/it]

Epoch: 350 loss: 6.316 loss_TAMI: 2.125


 71%|███████   | 355/500 [19:58<08:10,  3.38s/it]

Epoch: 355 loss: 6.313 loss_TAMI: 2.125


 72%|███████▏  | 360/500 [20:14<07:53,  3.38s/it]

Epoch: 360 loss: 6.313 loss_TAMI: 2.125


 73%|███████▎  | 365/500 [20:31<07:36,  3.38s/it]

Epoch: 365 loss: 6.313 loss_TAMI: 2.125


 74%|███████▍  | 370/500 [20:48<07:19,  3.38s/it]

Epoch: 370 loss: 6.312 loss_TAMI: 2.125


 75%|███████▌  | 375/500 [21:05<07:02,  3.38s/it]

Epoch: 375 loss: 6.312 loss_TAMI: 2.125


 76%|███████▌  | 380/500 [21:22<06:45,  3.38s/it]

Epoch: 380 loss: 6.310 loss_TAMI: 2.125


 77%|███████▋  | 385/500 [21:39<06:28,  3.38s/it]

Epoch: 385 loss: 6.308 loss_TAMI: 2.125


 78%|███████▊  | 390/500 [21:56<06:11,  3.38s/it]

Epoch: 390 loss: 6.307 loss_TAMI: 2.125


 79%|███████▉  | 395/500 [22:12<05:55,  3.38s/it]

Epoch: 395 loss: 6.307 loss_TAMI: 2.125


 80%|████████  | 400/500 [22:29<05:38,  3.38s/it]

Epoch: 400 loss: 6.309 loss_TAMI: 2.125


 81%|████████  | 405/500 [22:46<05:21,  3.38s/it]

Epoch: 405 loss: 6.309 loss_TAMI: 2.125


 82%|████████▏ | 410/500 [23:03<05:04,  3.38s/it]

Epoch: 410 loss: 6.309 loss_TAMI: 2.125


 83%|████████▎ | 415/500 [23:20<04:47,  3.38s/it]

Epoch: 415 loss: 6.308 loss_TAMI: 2.125


 84%|████████▍ | 420/500 [23:37<04:30,  3.38s/it]

Epoch: 420 loss: 6.306 loss_TAMI: 2.125


 85%|████████▌ | 425/500 [23:54<04:13,  3.38s/it]

Epoch: 425 loss: 6.309 loss_TAMI: 2.125


 86%|████████▌ | 430/500 [24:11<03:56,  3.38s/it]

Epoch: 430 loss: 6.306 loss_TAMI: 2.125


 87%|████████▋ | 435/500 [24:27<03:39,  3.38s/it]

Epoch: 435 loss: 6.309 loss_TAMI: 2.125


 88%|████████▊ | 440/500 [24:44<03:22,  3.38s/it]

Epoch: 440 loss: 6.311 loss_TAMI: 2.125


 89%|████████▉ | 445/500 [25:01<03:06,  3.38s/it]

Epoch: 445 loss: 6.306 loss_TAMI: 2.125


 90%|█████████ | 450/500 [25:18<02:49,  3.38s/it]

Epoch: 450 loss: 6.307 loss_TAMI: 2.125


 91%|█████████ | 455/500 [25:35<02:32,  3.38s/it]

Epoch: 455 loss: 6.304 loss_TAMI: 2.125


 92%|█████████▏| 460/500 [25:52<02:15,  3.38s/it]

Epoch: 460 loss: 6.309 loss_TAMI: 2.125


 93%|█████████▎| 465/500 [26:09<01:58,  3.38s/it]

Epoch: 465 loss: 6.306 loss_TAMI: 2.125


 94%|█████████▍| 470/500 [26:25<01:41,  3.38s/it]

Epoch: 470 loss: 6.308 loss_TAMI: 2.125


 95%|█████████▌| 475/500 [26:42<01:24,  3.38s/it]

Epoch: 475 loss: 6.307 loss_TAMI: 2.125


 96%|█████████▌| 480/500 [26:59<01:07,  3.38s/it]

Epoch: 480 loss: 6.305 loss_TAMI: 2.125


 97%|█████████▋| 485/500 [27:16<00:50,  3.38s/it]

Epoch: 485 loss: 6.310 loss_TAMI: 2.125


 98%|█████████▊| 490/500 [27:33<00:33,  3.38s/it]

Epoch: 490 loss: 6.305 loss_TAMI: 2.125


 99%|█████████▉| 495/500 [27:50<00:16,  3.38s/it]

Epoch: 495 loss: 6.310 loss_TAMI: 2.125


100%|██████████| 500/500 [28:07<00:00,  3.37s/it]

Epoch: 500 loss: 6.306 loss_TAMI: 2.125





In [5]:
########################### Evaluate ####################################
from topmost import evaluations

# get theta (doc-topic distributions)
train_theta_en, train_theta_cn, test_theta_en, test_theta_cn = trainer.export_theta(dataset)

# get top words of topics
top_words_en, top_words_cn = trainer.export_top_words(dataset.vocab_en, dataset.vocab_cn)

# compute topic coherence (CNPMI)
# refer to https://github.com/BobXWu/CNPMI

# compute topic diversity
TD = evaluations.multiaspect_topic_diversity([top_words_en, top_words_cn])
print(f"TD: {TD:.5f}")

# evaluate classification
results = evaluations.crosslingual_classification(
    train_theta_en,
    train_theta_cn,
    test_theta_en,
    test_theta_cn,
    dataset.train_labels_en,
    dataset.train_labels_cn,
    dataset.test_labels_en,
    dataset.test_labels_cn,
    classifier="SVM",
    gamma="auto"
)

print(results)

Topic 0: stress ink layout paper paperback print prints silver printer thick printing slight layer colors legal
Topic 1: worn crew tire gun address cash trash staff complaint church boot door agent broke wrist
Topic 2: sunday r press ship slow respond rest dan yesterday night morning afternoon re moon day
Topic 3: adult baby parents children width dad parenting draw prefer mother plain house adults greek everyday
Topic 4: yesterday afternoon blow files p cardboard file shock compact speed moment gift download load birthday
Topic 5: knife knives shoe socks sharp comfortable round size dan pants blade cutting sells feet tax
Topic 6: loves toy daughter son toys baby doll children grandson sister smile girl kids parents year
Topic 7: potter season magic episodes seasons clever spare rob harry series final tip caught devil priced
Topic 8: doctor household parent maintain health patient collar train child method dog control cure human ill
Topic 9: matrix risk invest bank logic factor aid dou