In [1]:
import os
import json
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast
from pytorch_lightning import seed_everything

## import recformer library
from utils import read_json, AverageMeterSet, Ranker
from optimization import create_optimizer_and_scheduler
from recformer import RecformerModel, RecformerForSeqRec, RecformerTokenizer, RecformerConfig
from collator import FinetuneDataCollatorWithPadding, EvalDataCollatorWithPadding
from dataloader import RecformerTrainDataset, RecformerEvalDataset

seed_everything(42)

Seed set to 42


42

In [2]:
def load_data(args):

    train = read_json(args['train_file'], True)
    val = read_json(args['dev_file'], True)
    test = read_json(args['test_file'], True)
    item_meta_dict = json.load(open( args['meta_file'] ))
    
    item2id = read_json(args['item2id_file'])
    id2item = {v:k for k, v in item2id.items()}

    item_meta_dict_filted = dict()
    for k, v in item_meta_dict.items():
        if k in item2id:
            item_meta_dict_filted[k] = v

    return train, val, test, item_meta_dict_filted, item2id, id2item

In [3]:

def encode_all_items(model: RecformerModel, tokenizer: RecformerTokenizer, tokenized_items, args):

    model.eval()

    items = sorted(list(tokenized_items.items()), key=lambda x: x[0])
    items = [ele[1] for ele in items]

    item_embeddings = []

    with torch.no_grad():
        for i in tqdm(range(0, len(items), args['batch_size']), ncols=100, desc='Encode all items'):

            item_batch = [[item] for item in items[i:i+args['batch_size']]]

            inputs = tokenizer.batch_encode(item_batch, encode_item=False)

            for k, v in inputs.items():
                inputs[k] = torch.LongTensor(v).to(args['device'])

            outputs = model(**inputs)

            item_embeddings.append(outputs.pooler_output.detach())

    item_embeddings = torch.cat(item_embeddings, dim=0)#.cpu()

    return item_embeddings

In [7]:
args = {
  "model_name_or_path": "../longformer-base-4096"
, "longformer_ckpt": '../longformer_ckpt/longformer-base-4096.bin'
, "pretrain_ckpt":"../pretrain_ckpt/recformer_seqrec_ckpt.bin"
, "ckpt":"best_model.bin"
    
, "train_file": "../finetune_data/Scientific/train.json"
, "dev_file": "../finetune_data/Scientific/val.json"
, "test_file": "../finetune_data/Scientific/test.json"
, "item2id_file": "../finetune_data/Scientific/smap.json"
, "meta_file" : "../finetune_data/Scientific/meta_data.json"

, "batch_size" : 2
, "finetune_negative_sample_size":-1
, "metric_ks":[10,50]
, "learning_rate":5e-5
, "weight_decay":0
, "warmup_steps":100
, "verbose":3
    
, "num_train_epochs": 32
, "gradient_accumulation_steps":8

, "preprocessing_num_workers" : 0
, "dataloader_num_workers": 0
, "device":-1
, "fp16":True
, "fix_word_embedding":True
, "temp" :0.05

, "output_dir": "../result/recformer_finetune"
}

args['device'] = torch.device('cuda:{}'.format(args['device'])) if args['device']>=0 else torch.device('cpu')

# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [8]:
## load raw data
train, val, test, item_meta_dict, item2id, id2item = load_data(args)

In [9]:
## load longformer tokenizer

config = RecformerConfig.from_pretrained(args['model_name_or_path'])
config.max_attr_num = 3
config.max_attr_length = 32
config.max_item_embeddings = 51
config.attention_window = [64] * 12
config.max_token_num = 1024
config.item_num = len(item2id)
config.finetune_negative_sample_size = args['finetune_negative_sample_size']

## load longformer tokenizer
tokenizer = RecformerTokenizer.from_pretrained(args['model_name_or_path'], config)  

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LongformerTokenizer'. 
The class this function is called from is 'RecformerTokenizer'.


In [10]:
path_tokenized_items = args['meta_file']+'.tokenized'

if os.path.exists(path_tokenized_items):
    print(f'[Preprocessor] Use cache: {path_tokenized_items}')
    tokenized_items = torch.load(path_tokenized_items)
else:
    tokenized_items = {}
    loop  = 0 
    
    for item_id, item_attr in item_meta_dict.items():
        input_ids, token_type_ids = tokenizer.encode_item(item_attr)
        tokenized_items[ item2id[item_id] ] = [input_ids, token_type_ids]
        if loop % 1000 == 0:
            print(time.ctime(), loop)
        loop+=1

    torch.save(tokenized_items, path_tokenized_items)

print("Item List:",len(tokenized_items)) # 5327

[Preprocessor] Use cache: ../finetune_data/Scientific/meta_data.json.tokenized
Item List: 5327


In [None]:
## load data

finetune_data_collator = FinetuneDataCollatorWithPadding(tokenizer, tokenized_items)
eval_data_collator = EvalDataCollatorWithPadding(tokenizer, tokenized_items)

train_data = RecformerTrainDataset(train, collator=finetune_data_collator)
val_data = RecformerEvalDataset(train, val, test, mode='val', collator=eval_data_collator)
test_data = RecformerEvalDataset(train, val, test, mode='test', collator=eval_data_collator)


train_loader = DataLoader(train_data, 
                          batch_size=args['batch_size'], 
                          shuffle=True, 
                          collate_fn=train_data.collate_fn)
dev_loader = DataLoader(val_data, 
                        batch_size=args['batch_size'],
                        collate_fn=val_data.collate_fn)
test_loader = DataLoader(test_data, 
                        batch_size=args['batch_size'],
                        collate_fn=test_data.collate_fn)

In [None]:
for step, batch in enumerate(dev_loader):
    print(step, batch)
    break

In [13]:
## load checkpoint

model = RecformerForSeqRec(config)
pretrain_ckpt = torch.load(args['pretrain_ckpt'])
model.load_state_dict(pretrain_ckpt, strict=False)
model.to(args['device'])

if args['fix_word_embedding']:
    print('Fix word embeddings.')
    for param in model.longformer.embeddings.word_embeddings.parameters():
        param.requires_grad = False

Fix word embeddings.


In [18]:
## item embedding

path_item_embeddings =  args['meta_file']+'.embedding'

try:
    print(f'[Item Embeddings] Use cache: {path_tokenized_items}')
    item_embeddings = torch.load(path_item_embeddings)
except:
    print(f'Encoding items.')
    item_embeddings = encode_all_items(model.longformer, tokenizer, tokenized_items, args)
    torch.save(item_embeddings, path_item_embeddings)

model.init_item_embedding(item_embeddings)
model.to(args['device'])

print(len(item_embeddings)) # 5327

[Item Embeddings] Use cache: ../finetune_data/Scientific/meta_data.json.tokenized
Initalize item embeddings from vectors.
5327


In [19]:
def eval(model, dataloader, args):

    model.eval()

    ranker = Ranker(args['metric_ks'])
    average_meter_set = AverageMeterSet()

    for batch, labels in tqdm(dataloader, ncols=100, desc='Evaluate'):

        for k, v in batch.items():
            batch[k] = v.to(args['device'])
            
        labels = labels.to(args['device'])

        with torch.no_grad():
            scores = model(**batch)

        res = ranker(scores, labels)

        metrics = {}
        for i, k in enumerate(args['metric_ks']):
            metrics["NDCG@%d" % k] = res[2*i]
            metrics["Recall@%d" % k] = res[2*i+1]
        metrics["MRR"] = res[-3]
        metrics["AUC"] = res[-2]

        for k, v in metrics.items():
            average_meter_set.update(k, v)

    average_metrics = average_meter_set.averages()

    return average_metrics

def train_one_epoch(model, dataloader, optimizer, scheduler, scaler, args):

    model.train()

    for step, batch in enumerate(tqdm(dataloader, ncols=100, desc='Training')):
        for k, v in batch.items():
            batch[k] = v.to(args['device'])

        if args['fp16']:
            with autocast():
                loss = model(**batch)
        else:
            loss = model(**batch)

        if args['gradient_accumulation_steps'] > 1:
            loss = loss / args['gradient_accumulation_steps']

        if args['fp16']:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        if (step + 1) % args['gradient_accumulation_steps'] == 0:
            if args['fp16']:
                scale_before = scaler.get_scale()
                scaler.step(optimizer)
                scaler.update()
                scale_after = scaler.get_scale()
                optimizer_was_run = scale_before <= scale_after
                optimizer.zero_grad()

                if optimizer_was_run:
                    scheduler.step()

            else:
                scheduler.step()  # Update learning rate schedule
                optimizer.step()
                optimizer.zero_grad()

In [None]:
num_train_optimization_steps = int(len(train_loader) / args['gradient_accumulation_steps']) * args['num_train_epochs']
optimizer, scheduler = create_optimizer_and_scheduler(model, num_train_optimization_steps, args)

if args['fp16']:
    scaler = torch.cuda.amp.GradScaler()
else:
    scaler = None

test_metrics = eval(model, test_loader, args)
print(f'Test set: {test_metrics}')

best_target = float('-inf')
patient = 5

In [None]:

test_data = RecformerEvalDataset(train, val, test, mode='test', collator=eval_data_collator)


train_loader = DataLoader(train_data, 
                          batch_size=args['batch_size'], 
                          shuffle=True, 
                          collate_fn=train_data.collate_fn)
dev_loader = DataLoader(val_data, 
                        batch_size=args['batch_size'],
                        collate_fn=val_data.collate_fn)
test_loader = DataLoader(test_data, 
                        batch_size=args['batch_size'],
                        collate_fn=test_data.collate_fn)

In [28]:
loop = 0

for batch, labels in test_data:
    print(batch), print(labels)
    if loop ==1:
        break
    else:
        loop+=1

[193, 346, 353, 1278, 186, 61, 237, 0, 1268]
[4173]
[2936, 4125, 4345, 751]
[0]


In [66]:
model.eval()

ranker = Ranker(args['metric_ks'])
average_meter_set = AverageMeterSet()

loop = 0

for batch, labels in test_loader:

    with torch.no_grad():
        scores = model(**batch) # [batch_size, 5327]

    res = ranker(scores, labels)

    metrics = {}
    for i, k in enumerate(args['metric_ks']):
        metrics["NDCG@%d" % k] = res[2*i]
        metrics["Recall@%d" % k] = res[2*i+1]
    metrics["MRR"] = res[-3]
    metrics["AUC"] = res[-2]

    for k, v in metrics.items():
        average_meter_set.update(k, v)

    break

# average_metrics = average_meter_set.averages()

labels: tensor([4173,    0])
predicts: tensor([[18.4771],
        [15.0274]])
tensor([   0., 1528.])
k: 10
indicator: tensor([1., 0.])
ncdg: tensor([1., 0.])
hr: tensor([1., 0.])
MRR: tensor([1.0000e+00, 6.5402e-04])
AUC: tensor([1.0000, 0.7132])
k: 50
indicator: tensor([1., 0.])
ncdg: tensor([1., 0.])
hr: tensor([1., 0.])
MRR: tensor([1.0000e+00, 6.5402e-04])
AUC: tensor([1.0000, 0.7132])


In [46]:
labels

tensor([[4173],
        [   0]])

In [57]:
predicts = scores[torch.arange(scores.size(0)), labels.squeeze()].unsqueeze(-1) 
print(predicts)

tensor([[18.4771],
        [15.0274]])


In [58]:
 (predicts < scores).sum(-1)

tensor([   0, 1528])

In [67]:
import torch.nn as nn

MAX_VAL = 1e4


class Ranker(nn.Module):
    def __init__(self, metrics_ks):
        super().__init__()
        self.ks = metrics_ks
        self.ce = nn.CrossEntropyLoss()
        
    def forward(self, scores, labels):
        labels = labels.squeeze()
        print("labels:", labels)
        
        try:
            loss = self.ce(scores, labels).item()
        except:
            print(scores.size())
            print(labels.size())
            loss = 0.0
        
        predicts = scores[torch.arange(scores.size(0)), labels].unsqueeze(-1) # gather perdicted values
        print("predicts:", predicts)
        
        valid_length = (scores > -MAX_VAL).sum(-1).float()
        rank = (predicts < scores).sum(-1).float()
        print("rank:", rank)
        res = []
        for k in self.ks:
            print("k:", k)
            indicator = (rank < k).float()
            print("indicator:", indicator)
            print("ncdg:", ((1 / torch.log2(rank+2)) * indicator))
            print("hr:", indicator)
            print("MRR:", (1 / (rank+1)))
            print("AUC:", (1 - (rank/valid_length)) )
            res.append(
                ((1 / torch.log2(rank+2)) * indicator).mean().item() # ndcg@k
            ) 
            res.append(
                indicator.mean().item() # hr@k
            )
        res.append((1 / (rank+1)).mean().item()) # MRR
        res.append((1 - (rank/valid_length)).mean().item()) # AUC

        return res + [loss]

In [None]:
labels: tensor([4173,    0])
predicts: tensor([[18.4771],
        [15.0274]])
rank in 5327 items: tensor([   0., 1528.])


k: 10
indicator: tensor([1., 0.])
ncdg: tensor([1., 0.])
hr: tensor([1., 0.])
MRR: tensor([1.0000e+00, 6.5402e-04])
AUC: tensor([1.0000, 0.7132])

k: 50
indicator: tensor([1., 0.])
ncdg: tensor([1., 0.])
hr: tensor([1., 0.])
MRR: tensor([1.0000e+00, 6.5402e-04])
AUC: tensor([1.0000, 0.7132])

In [None]:

for epoch in range(args['num_train_epochs']):

    item_embeddings = encode_all_items(model.longformer, tokenizer, tokenized_items, args)
    model.init_item_embedding(item_embeddings)

    train_one_epoch(model, train_loader, optimizer, scheduler, scaler, args)
    
    if (epoch + 1) % args['verbose'] == 0:
        dev_metrics = eval(model, dev_loader, args)
        print(f'Epoch: {epoch}. Dev set: {dev_metrics}')

        if dev_metrics['NDCG@10'] > best_target:
            print('Save the best model.')
            best_target = dev_metrics['NDCG@10']
            patient = 5
            torch.save(model.state_dict(), path_ckpt)
        
        else:
            patient -= 1
            if patient == 0:
                break

print('Load best model in stage 1.')
model.load_state_dict(torch.load(path_ckpt))

patient = 3

for epoch in range(args['num_train_epochs']):

    train_one_epoch(model, train_loader, optimizer, scheduler, scaler, args)
    
    if (epoch + 1) % args['verbose'] == 0:
        dev_metrics = eval(model, dev_loader, args)
        print(f'Epoch: {epoch}. Dev set: {dev_metrics}')

        if dev_metrics['NDCG@10'] > best_target:
            print('Save the best model.')
            best_target = dev_metrics['NDCG@10']
            patient = 3
            torch.save(model.state_dict(), path_ckpt)
        
        else:
            patient -= 1
            if patient == 0:
                break

print('Test with the best checkpoint.')  
model.load_state_dict(torch.load(path_ckpt))
test_metrics = eval(model, test_loader, args)
print(f'Test set: {test_metrics}')


In [None]:
import torch
from recformer import RecformerModel, RecformerConfig, RecformerForSeqRec

config = RecformerConfig.from_pretrained('longformer-base-4096')
config.max_attr_num = 3  # max number of attributes for each item
config.max_attr_length = 32 # max number of tokens for each attribute
config.max_item_embeddings = 51 # max number of items in a sequence +1 for cls token
config.attention_window = [64] * 12 # attention window for each layer

model = RecformerModel(config) 
model.load_state_dict(torch.load('recformer_ckpt.bin')) # RecformerModel = recformer_ckpt.bin

rec_model = RecformerForSeqRec(config)
rec_model.load_state_dict(torch.load('recformer_seqrec_ckpt.bin'), strict=False)
# strict=False because RecformerForSeqRec doesn't have lm_head

# missing_keys=[]
# unexpected_keys=['lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.bias'])
# lm_head.decoder.bias torch.Size([50265]) is missing

In [None]:
_ = torch.load('recformer_seqrec_ckpt.bin')
for k, v in _.items():
    print(k, v.size())

In [None]:
for name, param in model.state_dict().items():
    if not torch.all(param == rec_model.state_dict()[ 'longformer.'+name]):
        print(name)
    else:
        print("Match:", name)