In [None]:
import sys

sys.path.append("..")
from common_utils import set_data_home

set_data_home("~/datasets")
from common_utils import DATA_HOME, join
from lstm.sales_data import Sales_Dataset
import torch
import random

device = "cuda" if torch.cuda.is_available() else "cpu"
SALE_HOME = join(DATA_HOME, "sales_data")
MODEL_HOME = join(DATA_HOME, "sale_model")

torch.manual_seed(42)
random.seed(42)

I, H, B = 71, 36, 4
TARGET_DIM = 33
HEAD = 4
SEQ_LEN = 516
INFER_DAYS = 16

In [None]:
from torch.nn import LSTM, Transformer, Linear
from torch.nn import MSELoss
from torch import optim
import matplotlib.pyplot as plt
from IPython.display import clear_output


class Predictor(torch.nn.Module):

    def __init__(self):
        super(Predictor, self).__init__()
        self.lstm = LSTM(I, H, num_layers=2, batch_first=True).cuda()
        self.trans = Transformer(
            d_model=H,
            nhead=HEAD,
            num_encoder_layers=3,
            num_decoder_layers=3,
            batch_first=True,
        )
        self.relu = torch.nn.ReLU()
        self.linear = Linear(H, TARGET_DIM)

    def forward(self, x1, x2):
        h, (_, _) = self.lstm(x1)
        h = self.trans(h, x2)
        h = self.relu(h)
        return self.linear(h)


model = Predictor().cuda()

In [None]:
sd = Sales_Dataset(SALE_HOME, seq_len=SEQ_LEN, is_train=False)

In [None]:
print(sd[0][0].shape, sd[1][0].shape, sd[2][0].shape)

### Perform Inference

In [None]:
import pandas as pd
from datetime import timedelta

model.load_state_dict(torch.load(f"sales_model_1000_{B}.pth"))

sales = pd.read_csv(join(SALE_HOME, "test.csv"), index_col=0)
sales["sales"] = 0.0
# base_sales = sd.base_sales.set_index(["store_nbr", "date"])

for X1, X2, base_sales, store_id in sd:
    X1 = X1.cuda()
    X2 = X2.cuda()
    base_sales = base_sales.cuda()
    next_x = torch.empty((1, I)).cuda()
    yhat = model(X1, X2)

    # infer and update input for each store
    for i in range(INFER_DAYS):
        # compute actual sales
        yhat_rets = yhat[-INFER_DAYS + i]
        tmp_base_sales = base_sales[-INFER_DAYS + i]
        curr_sales = sd.ret_2_sale(yhat_rets, tmp_base_sales)
        ts = (sd.train_max_date + timedelta(days=i + 1)).strftime("%Y-%m-%d")

        # write to each family in the answer dataframe
        for j, f in enumerate(sd.families):
            sales.loc[
                (sales.date == ts)
                & (sales.store_nbr == store_id)
                & (sales.family == f),
                "sales",
            ] = (
                curr_sales[j].cpu().item()
            )

### output the answer

In [None]:
sales

In [None]:
sales.drop(columns=["store_nbr", "date", "family", "onpromotion"]).sort_values(
    "id"
).sort_index().to_csv("answer.csv", index=True)