In [1]:
from google.colab import drive

drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


In [2]:
%cd gdrive/MyDrive/csci567

/content/gdrive/MyDrive/csci567


In [3]:
!ls

csci567  csci567.egg-info  data  experiments  README.md  setup.py  venv


In [4]:
from csci567.models.two_tower_model import *

In [9]:
# Init
experiment_name = "two_tower_all_no_nn"
print(f"Initializing experiment: {experiment_name}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 10000
val_size = 100
learning_rate = 2e-3

Initializing experiment: two_tower_all_no_nn


In [10]:
# Load in training data
print("Loading in training data")
transactions_df = get_train_data(cutoff_date=None)
purchases_df, customers_index_dict, articles_index_dict = get_customer_purchases(
    transactions_df, after="2020-08-23")
training_purchases_df = purchases_df
testing_purchases_df = purchases_df[:val_size]
del transactions_df

Loading in training data


In [11]:
# Make dataset and dataloader
print("Making purchases dataloader")
training_purchases_dataset = PurchasesDataset(
    training_purchases_df, customers_index_dict, articles_index_dict, device=device)
training_purchases_dataloader = DataLoader(
    training_purchases_dataset, batch_size=64, shuffle=True)
testing_purchases_dataset = PurchasesDataset(
    testing_purchases_df, customers_index_dict, articles_index_dict, device=device)
testing_purchases_dataloader = DataLoader(
    testing_purchases_dataset, batch_size=val_size, shuffle=True)

Making purchases dataloader


In [12]:
# Make model and trainer
print("Making model and trainer")
two_tower_model = TwoTowerModel(
    len(customers_index_dict), len(articles_index_dict), device=device)
count_parameters(two_tower_model)
two_tower_trainer = TwoTowerTrainer(
    two_tower_model, training_purchases_dataloader, testing_purchases_dataloader, epochs, learning_rate, experiment_name)

Making model and trainer
+---------------------------+------------+
|          Modules          | Parameters |
+---------------------------+------------+
| queries_embeddings.weight |  3007428   |
| objects_embeddings.weight |   350844   |
|  query_net.model.0.weight |    144     |
|   query_net.model.0.bias  |     12     |
|  query_net.model.2.weight |    144     |
|   query_net.model.2.bias  |     12     |
| object_net.model.0.weight |    144     |
|  object_net.model.0.bias  |     12     |
| object_net.model.2.weight |    144     |
|  object_net.model.2.bias  |     12     |
+---------------------------+------------+
Total Trainable Params: 3358896


In [None]:
# Start training
print("Starting training")
two_tower_trainer.train()

 19%|█▉        | 752/3916 [00:05<00:22, 138.75it/s][A
 20%|█▉        | 766/3916 [00:05<00:23, 134.65it/s][A

It 991500: Total Loss: 4.195425987243652
It 991500: Val Acc: 0.02



 20%|█▉        | 780/3916 [00:05<00:23, 133.62it/s][A
 20%|██        | 795/3916 [00:05<00:22, 137.56it/s][A
 21%|██        | 811/3916 [00:05<00:21, 141.46it/s][A
 21%|██        | 826/3916 [00:05<00:21, 143.83it/s][A
 21%|██▏       | 841/3916 [00:05<00:21, 144.28it/s][A
 22%|██▏       | 856/3916 [00:06<00:21, 144.04it/s][A
 22%|██▏       | 871/3916 [00:06<00:21, 143.25it/s][A
 23%|██▎       | 886/3916 [00:06<00:21, 142.87it/s][A
 23%|██▎       | 901/3916 [00:06<00:20, 144.41it/s][A
 23%|██▎       | 916/3916 [00:06<00:22, 135.41it/s][A
 24%|██▎       | 930/3916 [00:06<00:23, 128.06it/s][A
 24%|██▍       | 945/3916 [00:06<00:22, 133.87it/s][A
 25%|██▍       | 961/3916 [00:06<00:21, 138.82it/s][A
 25%|██▍       | 976/3916 [00:06<00:20, 140.25it/s][A
 25%|██▌       | 991/3916 [00:07<00:20, 141.91it/s][A
 26%|██▌       | 1006/3916 [00:07<00:20, 143.33it/s][A
 26%|██▌       | 1021/3916 [00:07<00:20, 143.51it/s][A
 26%|██▋       | 1036/3916 [00:07<00:19, 144.33it/s][A
 27%|█

It 992000: Total Loss: 4.179941654205322
It 992000: Val Acc: 0.03



 33%|███▎      | 1295/3916 [00:09<00:18, 144.42it/s][A
 33%|███▎      | 1311/3916 [00:09<00:17, 146.30it/s][A
 34%|███▍      | 1326/3916 [00:09<00:17, 147.31it/s][A
 34%|███▍      | 1341/3916 [00:09<00:17, 147.17it/s][A
 35%|███▍      | 1356/3916 [00:09<00:17, 146.03it/s][A
 35%|███▌      | 1371/3916 [00:09<00:17, 145.97it/s][A
 35%|███▌      | 1386/3916 [00:09<00:17, 142.99it/s][A
 36%|███▌      | 1401/3916 [00:09<00:17, 141.34it/s][A
 36%|███▌      | 1416/3916 [00:09<00:17, 139.48it/s][A
 37%|███▋      | 1431/3916 [00:10<00:17, 141.97it/s][A
 37%|███▋      | 1446/3916 [00:10<00:17, 142.39it/s][A
 37%|███▋      | 1461/3916 [00:10<00:18, 131.23it/s][A
 38%|███▊      | 1476/3916 [00:10<00:18, 134.19it/s][A
 38%|███▊      | 1491/3916 [00:10<00:17, 138.19it/s][A
 38%|███▊      | 1506/3916 [00:10<00:17, 139.92it/s][A
 39%|███▉      | 1521/3916 [00:10<00:17, 140.75it/s][A
 39%|███▉      | 1536/3916 [00:10<00:16, 141.62it/s][A
 40%|███▉      | 1551/3916 [00:10<00:16, 142.91

It 992500: Total Loss: 4.160948276519775
It 992500: Val Acc: 0.01



 46%|████▌     | 1793/3916 [00:12<00:15, 138.12it/s][A
 46%|████▌     | 1809/3916 [00:12<00:14, 142.18it/s][A
 47%|████▋     | 1824/3916 [00:12<00:14, 144.18it/s][A
 47%|████▋     | 1839/3916 [00:12<00:14, 145.39it/s][A
 47%|████▋     | 1855/3916 [00:13<00:13, 147.47it/s][A
 48%|████▊     | 1871/3916 [00:13<00:13, 148.43it/s][A
 48%|████▊     | 1886/3916 [00:13<00:13, 148.30it/s][A
 49%|████▊     | 1902/3916 [00:13<00:13, 149.33it/s][A
 49%|████▉     | 1917/3916 [00:13<00:13, 147.87it/s][A
 49%|████▉     | 1932/3916 [00:13<00:13, 147.32it/s][A
 50%|████▉     | 1947/3916 [00:13<00:13, 145.02it/s][A
 50%|█████     | 1962/3916 [00:13<00:13, 141.94it/s][A
 50%|█████     | 1977/3916 [00:13<00:13, 141.24it/s][A
 51%|█████     | 1992/3916 [00:14<00:14, 131.28it/s][A
 51%|█████     | 2006/3916 [00:14<00:14, 127.61it/s][A
 52%|█████▏    | 2022/3916 [00:14<00:14, 134.17it/s][A
 52%|█████▏    | 2037/3916 [00:14<00:13, 137.75it/s][A
 52%|█████▏    | 2052/3916 [00:14<00:13, 141.18

It 993000: Total Loss: 4.192275047302246
It 993000: Val Acc: 0.02



 58%|█████▊    | 2282/3916 [00:16<00:12, 129.27it/s][A
 59%|█████▊    | 2298/3916 [00:16<00:11, 135.20it/s][A
 59%|█████▉    | 2314/3916 [00:16<00:11, 139.73it/s][A
 59%|█████▉    | 2329/3916 [00:16<00:11, 140.31it/s][A
 60%|█████▉    | 2345/3916 [00:16<00:10, 144.09it/s][A
 60%|██████    | 2361/3916 [00:16<00:10, 146.04it/s][A
 61%|██████    | 2376/3916 [00:16<00:10, 143.37it/s][A
 61%|██████    | 2391/3916 [00:16<00:10, 145.21it/s][A
 61%|██████▏   | 2406/3916 [00:16<00:10, 146.43it/s][A
 62%|██████▏   | 2422/3916 [00:17<00:10, 147.71it/s][A
 62%|██████▏   | 2437/3916 [00:17<00:09, 148.11it/s][A
 63%|██████▎   | 2452/3916 [00:17<00:09, 147.99it/s][A
 63%|██████▎   | 2467/3916 [00:17<00:09, 146.32it/s][A
 63%|██████▎   | 2482/3916 [00:17<00:10, 142.31it/s][A
 64%|██████▍   | 2497/3916 [00:17<00:10, 141.19it/s][A
 64%|██████▍   | 2512/3916 [00:17<00:09, 141.63it/s][A
 65%|██████▍   | 2527/3916 [00:17<00:09, 141.09it/s][A
 65%|██████▍   | 2543/3916 [00:17<00:09, 145.12

It 993500: Total Loss: 4.1506571769714355
It 993500: Val Acc: 0.04



 71%|███████   | 2785/3916 [00:19<00:08, 141.14it/s][A
 72%|███████▏  | 2800/3916 [00:19<00:07, 142.50it/s][A
 72%|███████▏  | 2815/3916 [00:19<00:07, 142.89it/s][A
 72%|███████▏  | 2830/3916 [00:19<00:07, 144.15it/s][A
 73%|███████▎  | 2845/3916 [00:20<00:07, 141.70it/s][A
 73%|███████▎  | 2860/3916 [00:20<00:07, 133.47it/s][A
 73%|███████▎  | 2875/3916 [00:20<00:07, 136.47it/s][A
 74%|███████▍  | 2890/3916 [00:20<00:07, 139.73it/s][A
 74%|███████▍  | 2905/3916 [00:20<00:07, 142.21it/s][A
 75%|███████▍  | 2920/3916 [00:20<00:06, 143.70it/s][A
 75%|███████▍  | 2936/3916 [00:20<00:06, 145.67it/s][A
 75%|███████▌  | 2951/3916 [00:20<00:06, 146.72it/s][A
 76%|███████▌  | 2966/3916 [00:20<00:06, 144.48it/s][A
 76%|███████▌  | 2981/3916 [00:20<00:06, 143.05it/s][A
 77%|███████▋  | 2996/3916 [00:21<00:06, 144.18it/s][A
 77%|███████▋  | 3011/3916 [00:21<00:06, 142.00it/s][A
 77%|███████▋  | 3026/3916 [00:21<00:06, 142.30it/s][A
 78%|███████▊  | 3042/3916 [00:21<00:06, 144.43

It 994000: Total Loss: 4.20861291885376
It 994000: Val Acc: 0.0



 84%|████████▍ | 3285/3916 [00:23<00:04, 144.45it/s][A
 84%|████████▍ | 3300/3916 [00:23<00:04, 143.01it/s][A
 85%|████████▍ | 3316/3916 [00:23<00:04, 146.26it/s][A
 85%|████████▌ | 3331/3916 [00:23<00:04, 145.91it/s][A
 85%|████████▌ | 3347/3916 [00:23<00:03, 147.44it/s][A
 86%|████████▌ | 3363/3916 [00:23<00:03, 149.04it/s][A
 86%|████████▋ | 3378/3916 [00:23<00:03, 148.89it/s][A
 87%|████████▋ | 3393/3916 [00:23<00:03, 147.66it/s][A
 87%|████████▋ | 3408/3916 [00:23<00:03, 145.73it/s][A
 87%|████████▋ | 3423/3916 [00:24<00:03, 138.60it/s][A
 88%|████████▊ | 3437/3916 [00:24<00:03, 128.72it/s][A
 88%|████████▊ | 3451/3916 [00:24<00:03, 129.35it/s][A
 89%|████████▊ | 3466/3916 [00:24<00:03, 133.67it/s][A
 89%|████████▉ | 3481/3916 [00:24<00:03, 137.34it/s][A
 89%|████████▉ | 3496/3916 [00:24<00:03, 139.82it/s][A
 90%|████████▉ | 3511/3916 [00:24<00:02, 142.52it/s][A
 90%|█████████ | 3527/3916 [00:24<00:02, 145.18it/s][A
 90%|█████████ | 3542/3916 [00:24<00:02, 143.41

It 994500: Total Loss: 4.1463093757629395
It 994500: Val Acc: 0.0



 97%|█████████▋| 3785/3916 [00:26<00:00, 141.79it/s][A
 97%|█████████▋| 3801/3916 [00:26<00:00, 144.33it/s][A
 97%|█████████▋| 3816/3916 [00:26<00:00, 144.75it/s][A
 98%|█████████▊| 3831/3916 [00:26<00:00, 144.36it/s][A
 98%|█████████▊| 3847/3916 [00:27<00:00, 146.55it/s][A
 99%|█████████▊| 3862/3916 [00:27<00:00, 146.22it/s][A
 99%|█████████▉| 3877/3916 [00:27<00:00, 145.96it/s][A
 99%|█████████▉| 3893/3916 [00:27<00:00, 147.71it/s][A
100%|█████████▉| 3908/3916 [00:27<00:00, 142.02it/s][A