In [1]:
import os
import argparse
import wandb
import torch
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pathlib import Path
from datetime import datetime
from models.GTM import GTM
from models.FCN import FCN
from utils.data_multitrends import ZeroShotDataset


In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
# General arguments
class Args(argparse.Namespace):
    def __init__(self):
        # General arguments
        self.data_folder = 'dataset/'
        self.img_root = 'dataset/images/'
        self.pop_path = 'signals/pop.pt'
    
        self.log_dir = 'log'
        self.ckpt_dir = 'ckpt'
        self.seed = 21
        self.epochs = 1
        self.gpu_num = 0
    
        # Model specific arguments
        self.use_trends = 1
        self.use_img =1
        self.use_text = 1
        self.num_trends = 3
        self.trend_len = 52
        self.decoder_input_type = 3
        self.batch_size = 128
        self.embedding_dim = 32
        self.hidden_dim = 64
        self.output_dim = 12
        self.use_encoder_mask = 1
        self.autoregressive = 0
        self.num_attn_heads = 4
        self.num_hidden_layers = 1
        self.model_type = 'GTM'
        # wandb arguments
        self.wandb_entity = 'irshadgirachirshu'
        self.wandb_proj = 'Apparel.Ai'
        self.wandb_run = 'Run1'


In [4]:
args = Args()

In [5]:
print(args)
# Seeds for reproducibility (By default we use the number 21)
pl.seed_everything(args.seed)

# Load sales data
train_df = pd.read_csv(Path(args.data_folder + 'train.csv'), parse_dates=['release_date'])
test_df = pd.read_csv(Path(args.data_folder + 'test.csv'), parse_dates=['release_date'])

# Load category and color encodings
cat_dict = torch.load(Path(args.data_folder + 'category_labels.pt'))
col_dict = torch.load(Path(args.data_folder + 'color_labels.pt'))
fab_dict = torch.load(Path(args.data_folder + 'fabric_labels.pt'))

# Load Google trends
gtrends = pd.read_csv(Path(args.data_folder + 'gtrends.csv'), index_col=[0], parse_dates=True)

train_loader = ZeroShotDataset(train_df, Path(args.data_folder + '/images'), gtrends, cat_dict, col_dict,
                               fab_dict, args.trend_len).get_loader(batch_size=args.batch_size, train=True)
test_loader = ZeroShotDataset(test_df, Path(args.data_folder + '/images'), gtrends, cat_dict, col_dict,
                              fab_dict, args.trend_len).get_loader(batch_size=1, train=False)



Global seed set to 21


Args(autoregressive=0, batch_size=128, ckpt_dir='ckpt', data_folder='dataset/', decoder_input_type=3, embedding_dim=32, epochs=1, gpu_num=0, hidden_dim=64, img_root='dataset/images/', log_dir='log', model_type='GTM', num_attn_heads=4, num_hidden_layers=1, num_trends=3, output_dim=12, pop_path='signals/pop.pt', seed=21, trend_len=52, use_encoder_mask=1, use_img=1, use_text=1, use_trends=1, wandb_entity='irshadgirachirshu', wandb_proj='Apparel.Ai', wandb_run='Run1')
Starting dataset creation process...


100%|#######################################| 5080/5080 [01:14<00:00, 68.34it/s]


5080
Done.
Starting dataset creation process...


100%|#########################################| 497/497 [00:09<00:00, 51.43it/s]

497
Done.





In [6]:
# Create model
if args.model_type == 'FCN':
    model = FCN(
        embedding_dim=args.embedding_dim,
        hidden_dim=args.hidden_dim,
        output_dim=args.output_dim,
        cat_dict=cat_dict,
        col_dict=col_dict,
        fab_dict=fab_dict,
        use_trends=args.use_trends,
        use_text=args.use_text,
        use_img=args.use_img,
        trend_len=args.trend_len,
        num_trends=args.num_trends,
        use_encoder_mask=args.use_encoder_mask,
        gpu_num=args.gpu_num
    )
else:
    model = GTM(
        embedding_dim=args.embedding_dim,
        hidden_dim=args.hidden_dim,
        output_dim=args.output_dim,
        num_heads=args.num_attn_heads,
        num_layers=args.num_hidden_layers,
        cat_dict=cat_dict,
        col_dict=col_dict,
        fab_dict=fab_dict,
        use_text=args.use_text,
        use_img=args.use_img,
        trend_len=args.trend_len,
        num_trends=args.num_trends,
        use_encoder_mask=args.use_encoder_mask,
        autoregressive=args.autoregressive,
        gpu_num=args.gpu_num
    )




In [7]:
# Model Training
# Define model saving procedure
dt_string = datetime.now().strftime("%d-%m-%Y-%H-%M-%S")

model_savename = args.model_type + '_' + args.wandb_run

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath=args.log_dir + '/'+args.model_type,
    filename=model_savename+'---{epoch}---'+dt_string,
    monitor='val_mae',
    mode='min',
    save_top_k=1
)

wandb.init(entity=args.wandb_entity, project=args.wandb_proj, name=args.wandb_run)
wandb_logger = pl_loggers.WandbLogger()
wandb_logger.watch(model)


[34m[1mwandb[0m: Currently logged in as: [33mirshadgirachirshu[0m. Use [1m`wandb login --relogin`[0m to force relogin


  rank_zero_warn(
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


In [8]:
trainer = pl.Trainer(accelerator='cpu',max_epochs=args.epochs, check_val_every_n_epoch=1,
                         logger=wandb_logger, callbacks=[checkpoint_callback])

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


In [9]:
# Fit model
trainer.fit(model, train_dataloaders=train_loader,
            val_dataloaders=test_loader)

# Print out path of best model
print(checkpoint_callback.best_model_path)


  | Name                   | Type               | Params
--------------------------------------------------------------
0 | dummy_encoder          | DummyEmbedder      | 4.4 K 
1 | image_encoder          | ImageEmbedder      | 23.5 M
2 | text_encoder           | TextEmbedder       | 24.6 K
3 | gtrend_encoder         | GTrendEmbedder     | 562 K 
4 | static_feature_encoder | FusionNetwork      | 81.2 K
5 | decoder_linear         | TimeDistributed    | 128   
6 | decoder                | TransformerDecoder | 50.0 K
7 | decoder_fc             | Sequential         | 780   
--------------------------------------------------------------
723 K     Trainable params
23.5 M    Non-trainable params
24.2 M    Total params
96.927    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Validation MAE: 364.158 LR: None


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


Validation MAE: 319.17078 LR: tensor(3.4259e-06)
/Users/irshad/Dev/PyEnv/GTM-Transformer/log/GTM/GTM_Run1---epoch=0---28-09-2023-23-07-37.ckpt
