In [1]:
import torch
import os
import numpy as np
import pandas as pd
import math
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as Fun
from utils import price_to_log_cat, get_date_to_month_buckets, quantity_to_log_cat

In [2]:
# total_tasks = 3460
total_tasks = 3460
per_task = 100
embeddings = np.empty(0)
embedding_indices = set()

for i in tqdm(range(total_tasks)):
    emb_f = f'training_embeddings/training_embeddings_{i}.npy'
    if os.path.isfile(emb_f):
        embedding = np.load(emb_f)
        if embeddings.size == 0:
            embeddings = embedding[:,1:]
        else:
            embeddings = np.vstack((embeddings, embedding[:,1:]))
    else:
        print(f'Missing embedding file: {emb_f}')

print(f'Found {len(embedding_indices)} embeddings: {embeddings.shape}')
train_df = pd.read_csv('train.csv', index_col=0)
train_df.shape

  0%|          | 0/3460 [00:00<?, ?it/s]

Found 0 embeddings: (345018, 1536)


(345018, 31)

In [3]:
train_df['Total Price'] = train_df['Total Price'].apply(price_to_log_cat).to_numpy()
train_df['Unit Price'] = train_df['Unit Price'].apply(price_to_log_cat).to_numpy()

In [4]:
# Improve date sanitization with GPT cleaning
train_df['Purchase Date'] = pd.to_datetime([date[:-4]+'20'+date[-2:] if type(date) == str else date for date in train_df['Purchase Date'].values], format='%m/%d/%Y')
train_df['Creation Date'] = pd.to_datetime(train_df['Creation Date'], format='%m/%d/%Y')
date_to_month_buckets = get_date_to_month_buckets(train_df['Creation Date'].min())
train_df['Creation Date'] = train_df['Creation Date'].apply(date_to_month_buckets).to_numpy()
train_df['Purchase Date'] = train_df['Purchase Date'].apply(date_to_month_buckets).to_numpy()

In [5]:
train_df['Quantity'] = train_df['Quantity'].apply(quantity_to_log_cat).to_numpy()

In [6]:
class Task_Dataset(Dataset):
    def __init__(self, X : np.ndarray, 
                       y : np.ndarray):
        self.X = X
        self.y = Fun.one_hot(torch.from_numpy(y).to(torch.int64))
        assert self.X.shape[0] == self.y.shape[0]
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):    
        X = torch.from_numpy(self.X[idx,:]).float().squeeze()
        y = self.y[idx]
        return X, y

In [7]:
total_price_ds = Task_Dataset(embeddings, train_df.iloc[:total_tasks*per_task]['Total Price'].to_numpy())
total_price_dl = DataLoader(total_price_ds, batch_size = 64, shuffle = True)

unit_price_ds = Task_Dataset(embeddings, train_df.iloc[:total_tasks*per_task]['Unit Price'].to_numpy())
unit_price_dl = DataLoader(unit_price_ds, batch_size = 64, shuffle = True)

creation_date_ds = Task_Dataset(embeddings, train_df.iloc[:total_tasks*per_task]['Creation Date'].to_numpy())
creation_date_dl = DataLoader(creation_date_ds, batch_size = 64, shuffle = True)

purchase_date_ds = Task_Dataset(embeddings, train_df.iloc[:total_tasks*per_task]['Purchase Date'].to_numpy())
purchase_date_dl = DataLoader(purchase_date_ds, batch_size = 64, shuffle = True)

quantity_ds = Task_Dataset(embeddings, train_df.iloc[:total_tasks*per_task]['Quantity'].to_numpy())
quantity_dl = DataLoader(quantity_ds, batch_size = 64, shuffle = True)

In [8]:
class MultiTask_Network(torch.nn.Module):
    def __init__(self, input_dim, 
                 output_dim_0 : int = 1,
                 output_dim_1 : int = 1,
                 output_dim_2 : int = 1,
                 output_dim_3 : int = 1,
                 output_dim_4 : int = 1,
                 hidden_dim : int = 2048):
        
        super(MultiTask_Network, self).__init__()
        self.input_dim = input_dim
        self.output_dim_0 = output_dim_0
        self.output_dim_1 = output_dim_1
        self.output_dim_2 = output_dim_2
        self.output_dim_3 = output_dim_3
        self.output_dim_4 = output_dim_4
        self.hidden_dim = hidden_dim
        
        self.hidden0 = torch.nn.Linear(self.input_dim, self.hidden_dim)
        self.hidden1 = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
        self.hidden2 = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
        self.final_0 = torch.nn.Linear(self.hidden_dim, self.output_dim_0)
        self.final_1 = torch.nn.Linear(self.hidden_dim, self.output_dim_1)
        self.final_2 = torch.nn.Linear(self.hidden_dim, self.output_dim_2)
        self.final_3 = torch.nn.Linear(self.hidden_dim, self.output_dim_3)
        self.final_4 = torch.nn.Linear(self.hidden_dim, self.output_dim_4)
        
    def forward(self, x : torch.Tensor, task_id : int):
        x = self.hidden0(x)
        x = torch.relu(x)
        x = self.hidden1(x)
        x = torch.relu(x)
        x = self.hidden2(x)
        x = torch.relu(x)
        if task_id == 0:
            x = self.final_0(x)
        elif task_id == 1:
            x = self.final_1(x)
        elif task_id == 2:
            x = self.final_2(x)
        elif task_id == 3:
            x = self.final_3(x)
        elif task_id == 4:
            x = self.final_4(x)
        else:
            assert False, 'Bad Task ID passed'
        
        return x

In [9]:
model = MultiTask_Network(total_price_ds.X.shape[1],
                          output_dim_0 = total_price_ds.y.shape[1],
                          output_dim_1 = unit_price_ds.y.shape[1],
                          output_dim_2 = creation_date_ds.y.shape[1],
                          output_dim_3 = purchase_date_ds.y.shape[1],
                          output_dim_4 = quantity_ds.y.shape[1])
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
loss_fn_0 = torch.nn.CrossEntropyLoss()
loss_fn_1 = torch.nn.CrossEntropyLoss()
loss_fn_2 = torch.nn.CrossEntropyLoss()
loss_fn_3 = torch.nn.CrossEntropyLoss()
loss_fn_4 = torch.nn.CrossEntropyLoss()

In [21]:
print(total_price_ds.X.shape[1],
                          total_price_ds.y.shape[1],
                          unit_price_ds.y.shape[1],
                          creation_date_ds.y.shape[1],
                          purchase_date_ds.y.shape[1],
                          quantity_ds.y.shape[1])

1536 13 13 39 75 19


In [10]:
for i in range(10):
    zipped_dls = zip(total_price_dl, unit_price_dl, creation_date_dl, purchase_date_dl, quantity_dl)
    for j, ((batch_X_0, batch_y_0), (batch_X_1, batch_y_1), (batch_X_2, batch_y_2), (batch_X_3, batch_y_3), (batch_X_4, batch_y_4)) in (pbar := tqdm(enumerate(zipped_dls), unit="batch", total=len(total_price_dl))):
        preds_0 = model(batch_X_0, task_id=0)
        loss_0 = loss_fn_0(preds_0, batch_y_0.float())
        preds_1 = model(batch_X_1, task_id=1)
        loss_1 = loss_fn_1(preds_1, batch_y_1.float())
        preds_2 = model(batch_X_2, task_id=2)
        loss_2 = loss_fn_2(preds_2, batch_y_2.float())
        preds_3 = model(batch_X_3, task_id=3)
        loss_3 = loss_fn_3(preds_3, batch_y_3.float())
        preds_4 = model(batch_X_4, task_id=4)
        loss_4 = loss_fn_4(preds_4, batch_y_4.float())
        loss = loss_0 + loss_1 + loss_2 + loss_3 + loss_4
        pbar.set_description("Loss = " + f"{loss:.4f}"[:6])
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    n_correct = [0]*5
    n_test = [0]*5
    zipped_dls = zip(total_price_dl, unit_price_dl, creation_date_dl, purchase_date_dl, quantity_dl)
    for j, ((batch_X_0, batch_y_0), (batch_X_1, batch_y_1), (batch_X_2, batch_y_2), (batch_X_3, batch_y_3), (batch_X_4, batch_y_4)) in (pbar := tqdm(enumerate(zipped_dls), unit="batch", total=len(total_price_dl))):
        preds_0 = model(batch_X_0, task_id=0)
        preds_1 = model(batch_X_1, task_id=1)
        preds_2 = model(batch_X_2, task_id=2)
        preds_3 = model(batch_X_3, task_id=3)
        preds_4 = model(batch_X_4, task_id=4)
        n_correct[0] += (preds_0.detach().numpy().argmax(axis=1) == batch_y_0.numpy().argmax(axis=1)).sum()
        n_correct[1] += (preds_1.detach().numpy().argmax(axis=1) == batch_y_1.numpy().argmax(axis=1)).sum()
        n_correct[2] += (preds_2.detach().numpy().argmax(axis=1) == batch_y_2.numpy().argmax(axis=1)).sum()
        n_correct[3] += (preds_3.detach().numpy().argmax(axis=1) == batch_y_3.numpy().argmax(axis=1)).sum()
        n_correct[4] += (preds_4.detach().numpy().argmax(axis=1) == batch_y_4.numpy().argmax(axis=1)).sum()
        n_test[0] += preds_0.shape[0]
        n_test[1] += preds_1.shape[0]
        n_test[2] += preds_2.shape[0]
        n_test[3] += preds_3.shape[0]
        n_test[4] += preds_4.shape[0]
    print(f'Epoch {i} accuracy: total_price={n_correct[0]/n_test[0]*100} unit_price={n_correct[1]/n_test[1]*100} creation_date={n_correct[2]/n_test[2]*100} purchase_date={n_correct[3]/n_test[3]*100} quantity_dl={n_correct[4]/n_test[4]*100}')


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 0 accuracy: total_price=61.706345755873606 unit_price=58.213484513851455 creation_date=19.404784677900864 purchase_date=23.401967433583177 quantity_dl=70.27082644963451


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 1 accuracy: total_price=67.05708107982773 unit_price=65.15283260583506 creation_date=34.28255917082587 purchase_date=37.233419705638546 quantity_dl=73.801366885206


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 2 accuracy: total_price=70.74587412830635 unit_price=70.68529757867705 creation_date=46.162519056976734 purchase_date=46.863931736894884 quantity_dl=76.2650064634309


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 3 accuracy: total_price=73.0724194100018 unit_price=73.35182512216754 creation_date=55.674776388478286 purchase_date=54.33861421722924 quantity_dl=77.93158617811244


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 4 accuracy: total_price=72.87011112463698 unit_price=74.3851045452701 creation_date=59.40008927070471 purchase_date=56.67269533763456 quantity_dl=80.22248114591123


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 5 accuracy: total_price=76.9559849051354 unit_price=78.03998631955434 creation_date=63.42451698172269 purchase_date=60.96754372235651 quantity_dl=81.21895089531561


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 6 accuracy: total_price=79.01095015332533 unit_price=80.43841190894389 creation_date=67.3585146282223 purchase_date=63.69667669512895 quantity_dl=82.85335837550505


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 7 accuracy: total_price=78.11041742749654 unit_price=79.71265267319386 creation_date=70.28010132804665 purchase_date=66.7069544197694 quantity_dl=83.5927400889229


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 8 accuracy: total_price=81.64501562237334 unit_price=83.48549930728252 creation_date=72.06986302163945 purchase_date=68.1845584868036 quantity_dl=84.76050524900151


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 9 accuracy: total_price=82.62583401445721 unit_price=85.1770052576967 creation_date=73.283422893878 purchase_date=69.9418581059539 quantity_dl=85.9453709661525


In [12]:
for i in range(5):
    zipped_dls = zip(total_price_dl, unit_price_dl, creation_date_dl, purchase_date_dl, quantity_dl)
    for j, ((batch_X_0, batch_y_0), (batch_X_1, batch_y_1), (batch_X_2, batch_y_2), (batch_X_3, batch_y_3), (batch_X_4, batch_y_4)) in (pbar := tqdm(enumerate(zipped_dls), unit="batch", total=len(total_price_dl))):
        preds_0 = model(batch_X_0, task_id=0)
        loss_0 = loss_fn_0(preds_0, batch_y_0.float())
        preds_1 = model(batch_X_1, task_id=1)
        loss_1 = loss_fn_1(preds_1, batch_y_1.float())
        preds_2 = model(batch_X_2, task_id=2)
        loss_2 = loss_fn_2(preds_2, batch_y_2.float())
        preds_3 = model(batch_X_3, task_id=3)
        loss_3 = loss_fn_3(preds_3, batch_y_3.float())
        preds_4 = model(batch_X_4, task_id=4)
        loss_4 = loss_fn_4(preds_4, batch_y_4.float())
        loss = loss_0 + loss_1 + loss_2 + loss_3 + loss_4
        pbar.set_description("Loss = " + f"{loss:.4f}"[:6])
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    n_correct = [0]*5
    n_test = [0]*5
    zipped_dls = zip(total_price_dl, unit_price_dl, creation_date_dl, purchase_date_dl, quantity_dl)
    for j, ((batch_X_0, batch_y_0), (batch_X_1, batch_y_1), (batch_X_2, batch_y_2), (batch_X_3, batch_y_3), (batch_X_4, batch_y_4)) in (pbar := tqdm(enumerate(zipped_dls), unit="batch", total=len(total_price_dl))):
        preds_0 = model(batch_X_0, task_id=0)
        preds_1 = model(batch_X_1, task_id=1)
        preds_2 = model(batch_X_2, task_id=2)
        preds_3 = model(batch_X_3, task_id=3)
        preds_4 = model(batch_X_4, task_id=4)
        n_correct[0] += (preds_0.detach().numpy().argmax(axis=1) == batch_y_0.numpy().argmax(axis=1)).sum()
        n_correct[1] += (preds_1.detach().numpy().argmax(axis=1) == batch_y_1.numpy().argmax(axis=1)).sum()
        n_correct[2] += (preds_2.detach().numpy().argmax(axis=1) == batch_y_2.numpy().argmax(axis=1)).sum()
        n_correct[3] += (preds_3.detach().numpy().argmax(axis=1) == batch_y_3.numpy().argmax(axis=1)).sum()
        n_correct[4] += (preds_4.detach().numpy().argmax(axis=1) == batch_y_4.numpy().argmax(axis=1)).sum()
        n_test[0] += preds_0.shape[0]
        n_test[1] += preds_1.shape[0]
        n_test[2] += preds_2.shape[0]
        n_test[3] += preds_3.shape[0]
        n_test[4] += preds_4.shape[0]
    print(f'Epoch {i} accuracy: total_price={n_correct[0]/n_test[0]*100} unit_price={n_correct[1]/n_test[1]*100} creation_date={n_correct[2]/n_test[2]*100} purchase_date={n_correct[3]/n_test[3]*100} quantity_dl={n_correct[4]/n_test[4]*100}')


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 0 accuracy: total_price=83.2182668730327 unit_price=85.8424777837678 creation_date=76.51513834060832 purchase_date=72.47332023256757 quantity_dl=87.11226660637996


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 1 accuracy: total_price=84.58370287927006 unit_price=86.97691134955278 creation_date=75.0126080378415 purchase_date=70.77109020398936 quantity_dl=87.5661559686741


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 2 accuracy: total_price=85.35438730732889 unit_price=87.4803633433618 creation_date=78.44315369053209 purchase_date=73.99961741126549 quantity_dl=88.37104151087769


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 3 accuracy: total_price=85.58712878748355 unit_price=88.78638215977136 creation_date=80.49637989901977 purchase_date=76.18385127732466 quantity_dl=87.81918624535533


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 4 accuracy: total_price=86.94328991530877 unit_price=89.32316574787403 creation_date=79.60599157145424 purchase_date=75.30708542742698 quantity_dl=88.77710728135924


In [13]:
for i in range(5):
    zipped_dls = zip(total_price_dl, unit_price_dl, creation_date_dl, purchase_date_dl, quantity_dl)
    for j, ((batch_X_0, batch_y_0), (batch_X_1, batch_y_1), (batch_X_2, batch_y_2), (batch_X_3, batch_y_3), (batch_X_4, batch_y_4)) in (pbar := tqdm(enumerate(zipped_dls), unit="batch", total=len(total_price_dl))):
        preds_0 = model(batch_X_0, task_id=0)
        loss_0 = loss_fn_0(preds_0, batch_y_0.float())
        preds_1 = model(batch_X_1, task_id=1)
        loss_1 = loss_fn_1(preds_1, batch_y_1.float())
        preds_2 = model(batch_X_2, task_id=2)
        loss_2 = loss_fn_2(preds_2, batch_y_2.float())
        preds_3 = model(batch_X_3, task_id=3)
        loss_3 = loss_fn_3(preds_3, batch_y_3.float())
        preds_4 = model(batch_X_4, task_id=4)
        loss_4 = loss_fn_4(preds_4, batch_y_4.float())
        loss = loss_0 + loss_1 + loss_2 + loss_3 + loss_4
        pbar.set_description("Loss = " + f"{loss:.4f}"[:6])
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    n_correct = [0]*5
    n_test = [0]*5
    zipped_dls = zip(total_price_dl, unit_price_dl, creation_date_dl, purchase_date_dl, quantity_dl)
    for j, ((batch_X_0, batch_y_0), (batch_X_1, batch_y_1), (batch_X_2, batch_y_2), (batch_X_3, batch_y_3), (batch_X_4, batch_y_4)) in (pbar := tqdm(enumerate(zipped_dls), unit="batch", total=len(total_price_dl))):
        preds_0 = model(batch_X_0, task_id=0)
        preds_1 = model(batch_X_1, task_id=1)
        preds_2 = model(batch_X_2, task_id=2)
        preds_3 = model(batch_X_3, task_id=3)
        preds_4 = model(batch_X_4, task_id=4)
        n_correct[0] += (preds_0.detach().numpy().argmax(axis=1) == batch_y_0.numpy().argmax(axis=1)).sum()
        n_correct[1] += (preds_1.detach().numpy().argmax(axis=1) == batch_y_1.numpy().argmax(axis=1)).sum()
        n_correct[2] += (preds_2.detach().numpy().argmax(axis=1) == batch_y_2.numpy().argmax(axis=1)).sum()
        n_correct[3] += (preds_3.detach().numpy().argmax(axis=1) == batch_y_3.numpy().argmax(axis=1)).sum()
        n_correct[4] += (preds_4.detach().numpy().argmax(axis=1) == batch_y_4.numpy().argmax(axis=1)).sum()
        n_test[0] += preds_0.shape[0]
        n_test[1] += preds_1.shape[0]
        n_test[2] += preds_2.shape[0]
        n_test[3] += preds_3.shape[0]
        n_test[4] += preds_4.shape[0]
    print(f'Epoch {i} accuracy: total_price={n_correct[0]/n_test[0]*100} unit_price={n_correct[1]/n_test[1]*100} creation_date={n_correct[2]/n_test[2]*100} purchase_date={n_correct[3]/n_test[3]*100} quantity_dl={n_correct[4]/n_test[4]*100}')


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 0 accuracy: total_price=86.19637236318106 unit_price=88.48002133222035 creation_date=82.46381348219514 purchase_date=78.29475563593785 quantity_dl=89.9019761287817


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 1 accuracy: total_price=87.62412395874998 unit_price=89.7034357627718 creation_date=83.09015761496502 purchase_date=78.79125147093775 quantity_dl=90.61034496750894


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 2 accuracy: total_price=85.27207276142114 unit_price=87.87599487562969 creation_date=84.5677616819992 purchase_date=80.32595400819667 quantity_dl=90.86018700473598


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 3 accuracy: total_price=86.71721475401284 unit_price=89.91936652580446 creation_date=83.93764962987437 purchase_date=79.44686943869596 quantity_dl=91.31610524668278


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 4 accuracy: total_price=88.17047226521515 unit_price=91.02539577645224 creation_date=85.99667263736964 purchase_date=80.98562973526019 quantity_dl=91.54044136827643


In [14]:
for i in range(10):
    zipped_dls = zip(total_price_dl, unit_price_dl, creation_date_dl, purchase_date_dl, quantity_dl)
    for j, ((batch_X_0, batch_y_0), (batch_X_1, batch_y_1), (batch_X_2, batch_y_2), (batch_X_3, batch_y_3), (batch_X_4, batch_y_4)) in (pbar := tqdm(enumerate(zipped_dls), unit="batch", total=len(total_price_dl))):
        preds_0 = model(batch_X_0, task_id=0)
        loss_0 = loss_fn_0(preds_0, batch_y_0.float())
        preds_1 = model(batch_X_1, task_id=1)
        loss_1 = loss_fn_1(preds_1, batch_y_1.float())
        preds_2 = model(batch_X_2, task_id=2)
        loss_2 = loss_fn_2(preds_2, batch_y_2.float())
        preds_3 = model(batch_X_3, task_id=3)
        loss_3 = loss_fn_3(preds_3, batch_y_3.float())
        preds_4 = model(batch_X_4, task_id=4)
        loss_4 = loss_fn_4(preds_4, batch_y_4.float())
        loss = loss_0 + loss_1 + loss_2 + loss_3 + loss_4
        pbar.set_description("Loss = " + f"{loss:.4f}"[:6])
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    n_correct = [0]*5
    n_test = [0]*5
    zipped_dls = zip(total_price_dl, unit_price_dl, creation_date_dl, purchase_date_dl, quantity_dl)
    for j, ((batch_X_0, batch_y_0), (batch_X_1, batch_y_1), (batch_X_2, batch_y_2), (batch_X_3, batch_y_3), (batch_X_4, batch_y_4)) in (pbar := tqdm(enumerate(zipped_dls), unit="batch", total=len(total_price_dl))):
        preds_0 = model(batch_X_0, task_id=0)
        preds_1 = model(batch_X_1, task_id=1)
        preds_2 = model(batch_X_2, task_id=2)
        preds_3 = model(batch_X_3, task_id=3)
        preds_4 = model(batch_X_4, task_id=4)
        n_correct[0] += (preds_0.detach().numpy().argmax(axis=1) == batch_y_0.numpy().argmax(axis=1)).sum()
        n_correct[1] += (preds_1.detach().numpy().argmax(axis=1) == batch_y_1.numpy().argmax(axis=1)).sum()
        n_correct[2] += (preds_2.detach().numpy().argmax(axis=1) == batch_y_2.numpy().argmax(axis=1)).sum()
        n_correct[3] += (preds_3.detach().numpy().argmax(axis=1) == batch_y_3.numpy().argmax(axis=1)).sum()
        n_correct[4] += (preds_4.detach().numpy().argmax(axis=1) == batch_y_4.numpy().argmax(axis=1)).sum()
        n_test[0] += preds_0.shape[0]
        n_test[1] += preds_1.shape[0]
        n_test[2] += preds_2.shape[0]
        n_test[3] += preds_3.shape[0]
        n_test[4] += preds_4.shape[0]
    print(f'Epoch {i} accuracy: total_price={n_correct[0]/n_test[0]*100} unit_price={n_correct[1]/n_test[1]*100} creation_date={n_correct[2]/n_test[2]*100} purchase_date={n_correct[3]/n_test[3]*100} quantity_dl={n_correct[4]/n_test[4]*100}')


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 0 accuracy: total_price=90.52745074170043 unit_price=92.92964425044491 creation_date=87.35457280489713 purchase_date=82.15078633578537 quantity_dl=92.39865746134984


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 1 accuracy: total_price=89.36461286077828 unit_price=91.74564805314506 creation_date=86.76185010637127 purchase_date=82.29802503057812 quantity_dl=91.81868772064065


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 2 accuracy: total_price=85.39467506043164 unit_price=89.35417862256462 creation_date=87.38587551953812 purchase_date=82.84321397724177 quantity_dl=92.28851828020566


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 3 accuracy: total_price=86.20361836194054 unit_price=85.73900492148235 creation_date=86.96183967213304 purchase_date=82.47772579981334 quantity_dl=91.17118527149309


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 4 accuracy: total_price=91.16451895263435 unit_price=93.78959938322059 creation_date=89.08143922925761 purchase_date=84.68224846239907 quantity_dl=92.56618495266913


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 5 accuracy: total_price=90.02950570694863 unit_price=93.23107779883948 creation_date=89.73097055805785 purchase_date=83.9837341819847 quantity_dl=93.78872986336944


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 6 accuracy: total_price=92.54792503579523 unit_price=93.88553640679616 creation_date=90.11037105310447 purchase_date=85.76103275771119 quantity_dl=91.97607081369668


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 7 accuracy: total_price=92.30619851717881 unit_price=94.66433635346561 creation_date=89.90922212754117 purchase_date=85.0926618321363 quantity_dl=93.7936571425259


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 8 accuracy: total_price=92.61951550353893 unit_price=94.59477476537455 creation_date=91.01583105808973 purchase_date=86.14304181231124 quantity_dl=94.08030885345113


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 9 accuracy: total_price=92.00447512883386 unit_price=94.37971352219304 creation_date=91.89955306679651 purchase_date=87.03806757908283 quantity_dl=94.39043760035707


In [15]:
for i in range(10):
    zipped_dls = zip(total_price_dl, unit_price_dl, creation_date_dl, purchase_date_dl, quantity_dl)
    for j, ((batch_X_0, batch_y_0), (batch_X_1, batch_y_1), (batch_X_2, batch_y_2), (batch_X_3, batch_y_3), (batch_X_4, batch_y_4)) in (pbar := tqdm(enumerate(zipped_dls), unit="batch", total=len(total_price_dl))):
        preds_0 = model(batch_X_0, task_id=0)
        loss_0 = loss_fn_0(preds_0, batch_y_0.float())
        preds_1 = model(batch_X_1, task_id=1)
        loss_1 = loss_fn_1(preds_1, batch_y_1.float())
        preds_2 = model(batch_X_2, task_id=2)
        loss_2 = loss_fn_2(preds_2, batch_y_2.float())
        preds_3 = model(batch_X_3, task_id=3)
        loss_3 = loss_fn_3(preds_3, batch_y_3.float())
        preds_4 = model(batch_X_4, task_id=4)
        loss_4 = loss_fn_4(preds_4, batch_y_4.float())
        loss = loss_0 + loss_1 + loss_2 + loss_3 + loss_4
        pbar.set_description("Loss = " + f"{loss:.4f}"[:6])
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    n_correct = [0]*5
    n_test = [0]*5
    zipped_dls = zip(total_price_dl, unit_price_dl, creation_date_dl, purchase_date_dl, quantity_dl)
    for j, ((batch_X_0, batch_y_0), (batch_X_1, batch_y_1), (batch_X_2, batch_y_2), (batch_X_3, batch_y_3), (batch_X_4, batch_y_4)) in (pbar := tqdm(enumerate(zipped_dls), unit="batch", total=len(total_price_dl))):
        preds_0 = model(batch_X_0, task_id=0)
        preds_1 = model(batch_X_1, task_id=1)
        preds_2 = model(batch_X_2, task_id=2)
        preds_3 = model(batch_X_3, task_id=3)
        preds_4 = model(batch_X_4, task_id=4)
        n_correct[0] += (preds_0.detach().numpy().argmax(axis=1) == batch_y_0.numpy().argmax(axis=1)).sum()
        n_correct[1] += (preds_1.detach().numpy().argmax(axis=1) == batch_y_1.numpy().argmax(axis=1)).sum()
        n_correct[2] += (preds_2.detach().numpy().argmax(axis=1) == batch_y_2.numpy().argmax(axis=1)).sum()
        n_correct[3] += (preds_3.detach().numpy().argmax(axis=1) == batch_y_3.numpy().argmax(axis=1)).sum()
        n_correct[4] += (preds_4.detach().numpy().argmax(axis=1) == batch_y_4.numpy().argmax(axis=1)).sum()
        n_test[0] += preds_0.shape[0]
        n_test[1] += preds_1.shape[0]
        n_test[2] += preds_2.shape[0]
        n_test[3] += preds_3.shape[0]
        n_test[4] += preds_4.shape[0]
    print(f'Epoch {i} accuracy: total_price={n_correct[0]/n_test[0]*100} unit_price={n_correct[1]/n_test[1]*100} creation_date={n_correct[2]/n_test[2]*100} purchase_date={n_correct[3]/n_test[3]*100} quantity_dl={n_correct[4]/n_test[4]*100}')


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 0 accuracy: total_price=93.06297062761942 unit_price=95.7663078448081 creation_date=91.79608020451107 purchase_date=87.76092841532905 quantity_dl=94.52985061648957


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 1 accuracy: total_price=91.4578369824183 unit_price=93.86901552962455 creation_date=91.67579662510363 purchase_date=87.58209716594496 quantity_dl=93.6023627752755


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 2 accuracy: total_price=91.73463413503063 unit_price=95.5031331698636 creation_date=91.95636169707088 purchase_date=87.840344561733 quantity_dl=93.73568915245002


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 3 accuracy: total_price=92.89109553704445 unit_price=95.3428516773038 creation_date=93.23194731869063 purchase_date=88.98550220568202 quantity_dl=95.38574798995995


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 4 accuracy: total_price=94.38232208174647 unit_price=96.37989901976128 creation_date=93.45802247998655 purchase_date=89.51098203571988 quantity_dl=95.57124555820276


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 5 accuracy: total_price=93.72032763507991 unit_price=95.40690630633763 creation_date=93.79713522193045 purchase_date=89.66343784961944 quantity_dl=95.11329843660332


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 6 accuracy: total_price=94.40956703708213 unit_price=96.8975531711389 creation_date=93.49425247378397 purchase_date=89.55764626773096 quantity_dl=95.60370763264525


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 7 accuracy: total_price=95.03504165000088 unit_price=96.62365441803037 creation_date=94.42579807430337 purchase_date=89.89646916972448 quantity_dl=95.63414082743509


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 8 accuracy: total_price=94.91620727034532 unit_price=96.2657020793118 creation_date=93.1490530928821 purchase_date=89.44055092777768 quantity_dl=96.06571251355001


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 9 accuracy: total_price=94.38319160159759 unit_price=96.49409596021077 creation_date=94.4394205519712 purchase_date=90.00776771067017 quantity_dl=95.30720136340712


In [16]:
torch.save(model.state_dict(), './ckpt')

In [17]:
for i in range(10):
    zipped_dls = zip(total_price_dl, unit_price_dl, creation_date_dl, purchase_date_dl, quantity_dl)
    for j, ((batch_X_0, batch_y_0), (batch_X_1, batch_y_1), (batch_X_2, batch_y_2), (batch_X_3, batch_y_3), (batch_X_4, batch_y_4)) in (pbar := tqdm(enumerate(zipped_dls), unit="batch", total=len(total_price_dl))):
        preds_0 = model(batch_X_0, task_id=0)
        loss_0 = loss_fn_0(preds_0, batch_y_0.float())
        preds_1 = model(batch_X_1, task_id=1)
        loss_1 = loss_fn_1(preds_1, batch_y_1.float())
        preds_2 = model(batch_X_2, task_id=2)
        loss_2 = loss_fn_2(preds_2, batch_y_2.float())
        preds_3 = model(batch_X_3, task_id=3)
        loss_3 = loss_fn_3(preds_3, batch_y_3.float())
        preds_4 = model(batch_X_4, task_id=4)
        loss_4 = loss_fn_4(preds_4, batch_y_4.float())
        loss = loss_0 + loss_1 + loss_2 + loss_3 + loss_4
        pbar.set_description("Loss = " + f"{loss:.4f}"[:6])
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    n_correct = [0]*5
    n_test = [0]*5
    zipped_dls = zip(total_price_dl, unit_price_dl, creation_date_dl, purchase_date_dl, quantity_dl)
    for j, ((batch_X_0, batch_y_0), (batch_X_1, batch_y_1), (batch_X_2, batch_y_2), (batch_X_3, batch_y_3), (batch_X_4, batch_y_4)) in (pbar := tqdm(enumerate(zipped_dls), unit="batch", total=len(total_price_dl))):
        preds_0 = model(batch_X_0, task_id=0)
        preds_1 = model(batch_X_1, task_id=1)
        preds_2 = model(batch_X_2, task_id=2)
        preds_3 = model(batch_X_3, task_id=3)
        preds_4 = model(batch_X_4, task_id=4)
        n_correct[0] += (preds_0.detach().numpy().argmax(axis=1) == batch_y_0.numpy().argmax(axis=1)).sum()
        n_correct[1] += (preds_1.detach().numpy().argmax(axis=1) == batch_y_1.numpy().argmax(axis=1)).sum()
        n_correct[2] += (preds_2.detach().numpy().argmax(axis=1) == batch_y_2.numpy().argmax(axis=1)).sum()
        n_correct[3] += (preds_3.detach().numpy().argmax(axis=1) == batch_y_3.numpy().argmax(axis=1)).sum()
        n_correct[4] += (preds_4.detach().numpy().argmax(axis=1) == batch_y_4.numpy().argmax(axis=1)).sum()
        n_test[0] += preds_0.shape[0]
        n_test[1] += preds_1.shape[0]
        n_test[2] += preds_2.shape[0]
        n_test[3] += preds_3.shape[0]
        n_test[4] += preds_4.shape[0]
    print(f'Epoch {i} accuracy: total_price={n_correct[0]/n_test[0]*100} unit_price={n_correct[1]/n_test[1]*100} creation_date={n_correct[2]/n_test[2]*100} purchase_date={n_correct[3]/n_test[3]*100} quantity_dl={n_correct[4]/n_test[4]*100}')


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 0 accuracy: total_price=95.83905767235332 unit_price=97.56302569721001 creation_date=94.5350677355964 purchase_date=91.48595145760511 quantity_dl=96.76596583366664


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 1 accuracy: total_price=94.24725666486967 unit_price=96.28946895524292 creation_date=94.706652986221 purchase_date=91.34856732112527 quantity_dl=95.91528557930312


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 2 accuracy: total_price=92.58647374919569 unit_price=94.30899257430048 creation_date=94.9515677442916 purchase_date=91.50189265487599 quantity_dl=96.41844773316176


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 3 accuracy: total_price=94.73679634106047 unit_price=97.10073097635485 creation_date=93.96350335344823 purchase_date=91.07930600722281 quantity_dl=96.18338753340406


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 4 accuracy: total_price=94.95736454329919 unit_price=97.39955596519601 creation_date=95.05504060657705 purchase_date=92.26330220452267 quantity_dl=96.82654238329594


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 5 accuracy: total_price=96.27149887831939 unit_price=98.0879258473471 creation_date=95.62573546887408 purchase_date=92.63893478021437 quantity_dl=97.44245227785217


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 6 accuracy: total_price=95.4509619787953 unit_price=96.60075706195039 creation_date=95.84746303091433 purchase_date=92.71110492785883 quantity_dl=95.9871658869972


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 7 accuracy: total_price=93.87973960778858 unit_price=96.29149783489557 creation_date=95.55153644157696 purchase_date=92.71400332736263 quantity_dl=97.17319096394971


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 8 accuracy: total_price=96.48424140189788 unit_price=97.82822925180716 creation_date=95.61906915001536 purchase_date=92.60792190552377 quantity_dl=97.44708971705825


  0%|          | 0/5391 [00:00<?, ?batch/s]

  0%|          | 0/5391 [00:00<?, ?batch/s]

Epoch 9 accuracy: total_price=96.79958726791065 unit_price=97.95894706942826 creation_date=95.08576364131727 purchase_date=92.85544522314778 quantity_dl=97.53433154212244


In [18]:
torch.save(model.state_dict(), './ckpt2')