In [1]:
import os,sys
import time
import logging
import json
from typing import Optional, Union, List, Dict, Tuple

import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint

seed_everything(42)

Seed set to 42


42

In [2]:
## import recformer library
from recformer import RecformerForPretraining, RecformerTokenizer, RecformerConfig, LitWrapper
from collator import PretrainDataCollatorWithPadding
from lightning_dataloader import ClickDataset

In [3]:
## arguments

args = {
  "model_name_or_path": "../longformer-base-4096"
, "longformer_ckpt": '../longformer_ckpt/longformer-base-4096.bin'

, "train_file": "../pretrain_data/train.json"  
, "dev_file": "../pretrain_data/dev.json"  
, "item_attr_file" : "../pretrain_data/meta_data.json"

, "batch_size" : 2
, "learning_rate":5e-5
, "num_train_epochs": 32
, "mlm_probability": 0.15
, "gradient_accumulation_steps":8

, "valid_step":2000
, "log_step":2000
, "device":-1
, "preprocessing_num_workers" : 0
, "dataloader_num_workers": 0
, "fp16":True
, "fix_word_embedding":True
, "temp" :0.05

, "output_dir": "../result/recformer_pretraining"
}

In [4]:
## Tonenize meta_data
'''
# train.json  3,501,527
# dev.json  112,379

['B000001FLX', 'B000002GIF', 'B000058TD8', '6303234844', '0001393774', 'B0001JXLBK']
['0001393774', 'B0001JXLBK', 'B0016CP2GS', 'B0016CP2GS', 'B0076FJ2R4']

# meta_data.json  1,022,274

item_attrs['B000001FLX']

{'title': 'Best of Bee Gees, Vol. 1',
 'brand': 'Bee Gees',
 'category': 'CDs & Vinyl Pop Oldies Baroque Pop'}


tokenized_item['B000001FLX']

input_ids: [14691, 19183, 9, 15227, 4177, 293, 6, 5896, 4, 112, 11638, 40613, 4177, 293, 42747, 11579, 29, 359, 40236, 7975, 3470, 918, 1731, 139, 3407, 7975]
token_type_ids: [1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]

'''


"\n# train.json  3,501,527\n# dev.json  112,379\n\n['B000001FLX', 'B000002GIF', 'B000058TD8', '6303234844', '0001393774', 'B0001JXLBK']\n['0001393774', 'B0001JXLBK', 'B0016CP2GS', 'B0016CP2GS', 'B0076FJ2R4']\n\n# meta_data.json  1,022,274\n\nitem_attrs['B000001FLX']\n\n{'title': 'Best of Bee Gees, Vol. 1',\n 'brand': 'Bee Gees',\n 'category': 'CDs & Vinyl Pop Oldies Baroque Pop'}\n\n\ntokenized_item['B000001FLX']\n\ninput_ids: [14691, 19183, 9, 15227, 4177, 293, 6, 5896, 4, 112, 11638, 40613, 4177, 293, 42747, 11579, 29, 359, 40236, 7975, 3470, 918, 1731, 139, 3407, 7975]\ntoken_type_ids: [1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]\n\n"

In [5]:
## load longformer tokenizer
config = RecformerConfig.from_pretrained(args['model_name_or_path'])
config.max_attr_num = 3
config.max_attr_length = 32
config.max_item_embeddings = 51  # 50 item and 1 for cls
config.attention_window = [64] * 12
config.max_token_num = 1024

tokenizer = RecformerTokenizer.from_pretrained(args['model_name_or_path'], config)  

# tokenizer._pad_token  = tokenizer.pad_token_id

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LongformerTokenizer'. 
The class this function is called from is 'RecformerTokenizer'.


In [6]:
## tokenize meta_data.json
path_tokenized_items = args['item_attr_file']+'.tokenized'

if os.path.exists(path_tokenized_items):
    print(f'[Preprocessor] Use cache: {path_tokenized_items}')
    tokenized_items = torch.load(path_tokenized_items)

else:
    item_attrs = json.load(open( args['item_attr_file'] ))
    
    tokenized_items = {}
    for item_id, item_attr in item_attrs.items():
        input_ids, token_type_ids = tokenizer.encode_item(item_attr)
        tokenized_items[ item_id ] = [input_ids, token_type_ids]

    torch.save(tokenized_items, path_tokenized_items)

[Preprocessor] Use cache: ../pretrain_data/meta_data.json.tokenized


In [7]:
## load data

data_collator = PretrainDataCollatorWithPadding(tokenizer, tokenized_items, mlm_probability=args['mlm_probability'])
train_data = ClickDataset(json.load(open(args['train_file'])), data_collator)
dev_data = ClickDataset(json.load(open(args['dev_file'])), data_collator)

train_loader = DataLoader(train_data, 
                          batch_size=args['batch_size'], 
                          shuffle=True, 
                          collate_fn=train_data.collate_fn,
                          num_workers=args['dataloader_num_workers'])

dev_loader = DataLoader(dev_data, 
                        batch_size=args['batch_size'], 
                        collate_fn=dev_data.collate_fn,
                        num_workers=args['dataloader_num_workers'])

In [8]:
loop = 0

for step, batch in enumerate(dev_data):
    print(step, batch)
    if loop ==2:
        break
    loop+= 1

0 ['B000001FLX', 'B000002GIF', 'B000058TD8', '6303234844', '0001393774', 'B0001JXLBK']
1 ['0001393774', 'B0001JXLBK', 'B0016CP2GS', 'B0016CP2GS', 'B0076FJ2R4']
2 ['B000008GO6', '0001393774', 'B0001JXLBK', 'B0002EZZMC', 'B0007OQA3A', 'B0000996GP', 'B000K6DPZ6']


In [9]:
for step, batch in enumerate(dev_loader):
    break

In [10]:
## load checkpoint

pytorch_model = RecformerForPretraining(config)
pytorch_model.load_state_dict(torch.load(args['longformer_ckpt']))

if args['fix_word_embedding']:
    print('Fix word embeddings.')
    for param in pytorch_model.longformer.embeddings.word_embeddings.parameters():
        param.requires_grad = False

Fix word embeddings.


In [None]:
pytorch_model(**batch)

In [None]:

model = LitWrapper(pytorch_model, learning_rate=args['learning_rate'])

checkpoint_callback = ModelCheckpoint(save_top_k=5, monitor="accuracy", mode="max", filename="{epoch}-{accuracy:.4f}")

# model(**batch)

In [None]:

trainer = Trainer(accelerator="cpu",
                 max_epochs=args['num_train_epochs'],
                 devices=args['device'],
                 accumulate_grad_batches=args['gradient_accumulation_steps'],
                 val_check_interval=args['valid_step'],
                 default_root_dir=args['output_dir'],
                 gradient_clip_val=1.0,
                 log_every_n_steps=args['log_step'],
                 precision=16 if args['fp16'] else 32,
                 strategy='ddp',
                 callbacks=[checkpoint_callback]
                 )

trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=dev_loader, ckpt_path=args['ckpt'])
