In [1]:
import argparse
import torch
import pandas as pd
import numpy as np
import pytorch_lightning as pl
from tqdm import tqdm
from models.GTM import GTM
from models.FCN import FCN
from utils.data_multitrends import ZeroShotDataset
from pathlib import Path
from sklearn.metrics import mean_absolute_error
from pathlib import Path

In [2]:
def cal_error_metrics(gt, forecasts):
    # Absolute errors
    mae = mean_absolute_error(gt, forecasts)
    wape = 100 * np.sum(np.sum(np.abs(gt - forecasts), axis=-1)) / np.sum(gt)

    return round(mae, 3), round(wape, 3)

In [3]:
def print_error_metrics(y_test, y_hat, rescaled_y_test, rescaled_y_hat):
    mae, wape = cal_error_metrics(y_test, y_hat)
    rescaled_mae, rescaled_wape = cal_error_metrics(rescaled_y_test, rescaled_y_hat)
    print(mae, wape, rescaled_mae, rescaled_wape)

In [4]:
def run(args):    
    # Set up device
    device = 'cpu'

    # Seeds for reproducibility
    pl.seed_everything(args.seed)

    # Load sales data    
    print(Path(args.data_folder + '/test.csv'))
    test_df = pd.read_csv(Path(args.data_folder + '/test.csv'), parse_dates=['release_date'])
    item_codes = test_df['external_code'].values

     # Load category and color encodings
    cat_dict = torch.load(Path(args.data_folder + '/category_labels.pt'))
    col_dict = torch.load(Path(args.data_folder + '/color_labels.pt'))
    fab_dict = torch.load(Path(args.data_folder + '/fabric_labels.pt'))

    # Load Google trends
    gtrends = pd.read_csv(Path(args.data_folder + '/gtrends.csv'), index_col=[0], parse_dates=True)
    
    test_loader = ZeroShotDataset(test_df, Path(args.data_folder + '/images'), gtrends, cat_dict, col_dict, \
            fab_dict, args.trend_len).get_loader(batch_size=1, train=False)


    model_savename = f'{args.wandb_run}_{args.output_dim}'
    print("Model save name:  ",model_savename)
    
    # Create model
    model = None
    if args.model_type == 'FCN':
        model = FCN(
            embedding_dim=args.embedding_dim,
            hidden_dim=args.hidden_dim,
            output_dim=args.output_dim,
            cat_dict=cat_dict,
            col_dict=col_dict,
            fab_dict=fab_dict,
            use_trends=args.use_trends,
            use_text=args.use_text,
            use_img=args.use_img,
            trend_len=args.trend_len,
            num_trends=args.num_trends,
            use_encoder_mask=args.use_encoder_mask,
            gpu_num=args.gpu_num
        )
    else:
        model = GTM(
            embedding_dim=args.embedding_dim,
            hidden_dim=args.hidden_dim,
            output_dim=args.output_dim,
            num_heads=args.num_attn_heads,
            num_layers=args.num_hidden_layers,
            cat_dict=cat_dict,
            col_dict=col_dict,
            fab_dict=fab_dict,
            use_text=args.use_text,
            use_img=args.use_img,
            trend_len=args.trend_len,
            num_trends=args.num_trends,
            use_encoder_mask=args.use_encoder_mask,
            autoregressive=args.autoregressive,
            gpu_num=args.gpu_num
        )
    print('Loading: ', args.ckpt_path)
    #model.load_state_dict(torch.load(args.ckpt_path)['state_dict'], strict=False)

    # Forecast the testing set
    model.to(device)
    model.eval()
    gt, forecasts, attns = [], [],[]
    for test_data in tqdm(test_loader, total=len(test_loader), ascii=True):
        with torch.no_grad():
            test_data = [tensor.to(device) for tensor in test_data]
            item_sales, category, color, textures, temporal_features, gtrends, images =  test_data
            y_pred, att = model(category, color,textures, temporal_features, gtrends, images)
            forecasts.append(y_pred.detach().cpu().numpy().flatten()[:args.output_dim])
            gt.append(item_sales.detach().cpu().numpy().flatten()[:args.output_dim])
            attns.append(att.detach().cpu().numpy())

    attns = np.stack(attns)
    forecasts = np.array(forecasts)
    gt = np.array(gt)

    rescale_vals = np.load(args.data_folder + 'normalization_scale.npy')
    rescaled_forecasts = forecasts * rescale_vals
    rescaled_gt = gt * rescale_vals
    print_error_metrics(gt, forecasts, rescaled_gt, rescaled_forecasts)

    print("Saving results:   ", Path('results/' + model_savename+'.pth'))
    torch.save({'results': forecasts* rescale_vals, 'gts': gt* rescale_vals, 'codes': item_codes.tolist()}, Path('results/' + model_savename+'.pth'))

In [5]:
class Args():
    def __init__(self):
        self.data_folder = 'dataset/'
        self.ckpt_path = 'log/path-to-model.ckpt'
        self.gpu_num = 0
        self.seed = 21
    
        self.model_type = 'GTM'
        self.use_trends = 1
        self.use_img = 1
        self.use_text = 1
        self.trend_len = 52
        self.num_trends = 2
        self.embedding_dim =32 
        self.hidden_dim = 64
        self.output_dim = 12
        self.use_encoder_mask =1
        self.autoregressive = 0
        self.num_attn_heads = 4
        self.num_hidden_layers = 1
    
        self.wandb_run = 'Run1'

In [6]:
args = Args()

In [7]:
run(args)

Global seed set to 21


dataset/test.csv
Starting dataset creation process...


100%|###################################################| 497/497 [00:08<00:00, 55.57it/s]


Done.
Model save name:   Run1_12
Loading:  log/path-to-model.ckpt


  0%|                                                             | 0/497 [00:01<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (52x3 and 2x64)