### Importing Dependencies

In [1]:
import os
import argparse
import wandb
import torch
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pathlib import Path
from datetime import datetime
from models.GTM_POP import GTM
from utils.data import POPDataset

os.environ["TOKENIZERS_PARALLELISM"] = "false"

### Loading datasets

In [2]:
class Args(argparse.Namespace):
    def __init__(self):
        # General arguments
        self.data_folder = 'dataset/'
        self.img_root = 'dataset/images/'
        self.pop_path = 'signals/pop.pt'

        self.log_dir = 'log'
        self.ckpt_dir = 'ckpt'
        self.seed = 21
        self.epochs = 1#200
        self.gpu_num = 0

        # Model specific arguments

        self.use_trends = 1
        self.num_trends = 1
        self.trend_len = 52
        self.decoder_input_type = 3
        self.batch_size = 128
        self.embedding_dim = 32
        self.hidden_dim = 64
        self.output_dim = 12
        self.use_encoder_mask = 1
        self.autoregressive = 0
        self.num_attn_heads = 4
        self.num_hidden_layers = 1
        
        # wandb arguments
        self.wandb_entity = 'irshadgirachirshu'
        self.wandb_proj = 'Apparel.Ai'
        self.wandb_run = 'POP1'


In [3]:
args = Args()

In [4]:
# Seeds for reproducibility (By default we use the number 21)
pl.seed_everything(args.seed)

# Load sales data
train_df = pd.read_csv(Path(args.data_folder + 'train.csv'), parse_dates=['release_date'])
test_df = pd.read_csv(Path(args.data_folder + 'test.csv'), parse_dates=['release_date'])

# Load category and color encodings
cat_dict = torch.load(Path(args.data_folder + 'category_labels.pt'))
col_dict = torch.load(Path(args.data_folder + 'color_labels.pt'))
fab_dict = torch.load(Path(args.data_folder + 'fabric_labels.pt'))

pop_signal = torch.load(args.pop_path)

train_loader = POPDataset(train_df, args.img_root, pop_signal, cat_dict, col_dict, \
        fab_dict, args.trend_len).get_loader(batch_size=args.batch_size, train=True)

test_loader = POPDataset(test_df, args.img_root, pop_signal, cat_dict, col_dict, \
        fab_dict, args.trend_len).get_loader(batch_size=1, train=False)


Global seed set to 21


Starting dataset creation process...


100%|#######################################| 5080/5080 [01:10<00:00, 72.16it/s]


Done.
Starting dataset creation process...


100%|#########################################| 497/497 [00:09<00:00, 54.23it/s]

Done.





In [5]:
# Create model
model = GTM(
        embedding_dim=args.embedding_dim,
        hidden_dim=args.hidden_dim,
        output_dim=args.output_dim,
        num_heads=args.num_attn_heads,
        num_layers=args.num_hidden_layers,
        cat_dict=cat_dict,
        col_dict=col_dict,
        fab_dict=fab_dict,
        trend_len=args.trend_len, 
        num_trends= args.num_trends,
        decoder_input_type=args.decoder_input_type,
        use_encoder_mask=args.use_encoder_mask,
        autoregressive=args.autoregressive,
        gpu_num=args.gpu_num
    )


# Model Training
# Define model saving procedure
dt_string = datetime.now().strftime("%d-%m-%Y-%H-%M-%S")

model_savename = 'POP_' + args.wandb_run



In [6]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath=args.ckpt_dir,
    filename=model_savename+'---{epoch}---'+dt_string,
    monitor='val_mae',
    mode='min',
    save_top_k=1
)

wandb.init(entity=args.wandb_entity, project=args.wandb_proj, name=args.wandb_run)
wandb_logger = pl_loggers.WandbLogger()
wandb_logger.watch(model)

[34m[1mwandb[0m: Currently logged in as: [33mirshadgirachirshu[0m. Use [1m`wandb login --relogin`[0m to force relogin


  rank_zero_warn(
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


In [7]:
# If you wish to use Tensorboard you can change the logger to:
# tb_logger = pl_loggers.TensorBoardLogger(args.log_dir+'/', name=model_savename)
trainer = pl.Trainer(accelerator='cpu',max_epochs=args.epochs, check_val_every_n_epoch=1,
                     logger=wandb_logger, callbacks=[checkpoint_callback])

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


In [8]:
# Fit model
trainer.fit(model, train_dataloaders=train_loader,
            val_dataloaders=test_loader)

# Print out path of best model
print(checkpoint_callback.best_model_path)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name                   | Type               | Params
--------------------------------------------------------------
0 | dummy_encoder          | DummyEmbedder      | 4.4 K 
1 | image_encoder          | ImageEmbedder      | 23.5 M
2 | text_encoder           | TextEmbedder       | 24.6 K
3 | POP_encoder            | POPEmbedder        | 562 K 
4 | static_feature_encoder | FusionNetwork      | 81.2 K
5 | decoder_linear         | TimeDistributed    | 128   
6 | decoder                | TransformerDecoder | 50.0 K
7 | decoder_fc             | Sequential         | 780   
--------------------------------------------------------------
723 K     Trainable params
23.5 M    Non-trainable params
24.2 M    Total params
96.926    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Validation MAE: 618.0101 LR: None


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


Validation MAE: 534.8379 LR: tensor(2.4075e-06)
/Users/irshad/Dev/PyEnv/GTM-Transformer/ckpt/GTM_POP1---epoch=0---02-10-2023-18-03-34.ckpt
