In [19]:
import torch
import torch.nn as nn
import sys
import os

sys.path.append('/root/projects/PythonProjects/ip-dual-encoder-factorization-machine')
os.chdir('/root/projects/PythonProjects/ip-dual-encoder-factorization-machine')

# Model

In [2]:
my_model_config = dict()
my_model_config['nuser'] = 19672 + 1  # NUM_USERS, 0th token is the padding token
my_model_config['nitem'] = 47844 + 1  # NUM_ITEMS, 0th token is the padding token
my_model_config['d_model'] = 768
my_model_config['nhead'] = 4  # 8  #12
my_model_config['d_hid'] = 1024  # 2048  # dim_feedforward
my_model_config['dropout'] = 0.1  # 0.3  # 0.2  # 0.1
my_model_config['nlayers'] = 2  # 1  # 3  # 6  #12
my_model_config['checkpoint'] = 'bert-base-cased'  # 'jjzha/jobbert-base-cased'
# Use 'bert-base-cased'! 'jjzha/jobbert-base-cased' is problematic, causing BCEloss not decreasing!

my_model_config['using_cross_attention'] = False  # False # Not useful

In [3]:
from model_ddp_trainer.model import DualEncoderAttentionNetwork

model = DualEncoderAttentionNetwork(**my_model_config)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [16]:
device = torch.device('cuda')

In [4]:
checkpoint = torch.load('/root/tmp/tmp/aws-model-training/model_checkpoint_training_steps#26000_run-20230722-224841.pt', map_location=device)

In [9]:
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [11]:
model = model.cuda()
model

DualEncoderAttentionNetwork(
  (job_feature_embedding_layer): JobFeatureEmbeddings(
    (metadata_lookup_table): Embedding(47845, 4, padding_idx=0)
    (metadata_embedding_layers): ModuleList(
      (0): Embedding(84, 768, padding_idx=0)
      (1): Embedding(31, 768, padding_idx=0)
      (2): Embedding(306, 768, padding_idx=0)
      (3): Embedding(5, 768, padding_idx=0)
    )
  )
  (job_embedding_layer): JobEmbedding(
    (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (job_id_embedding_layer): Embedding(47845, 768, padding_idx=0)
    (job_feature_embedding_layer): JobFeatureEmbeddings(
      (metadata_lookup_table): Embedding(47845, 4, padding_idx=0)
      (metadata_embedding_layers): ModuleList(
        (0): Embedding(84, 768, padding_idx=0)
        (1): Embedding(31, 768, padding_idx=0)
        (2): Embedding(306, 768, padding_idx=0)
        (3): Embedding(5, 768, padding_idx=0)
      )
    )
  )
  (item_encoder): 

In [12]:
next(model.parameters()).device

device(type='cuda', index=0)

# Test dataset

In [13]:
# dataset_dict_s3_path = '/opt/project/data/input_processed_output(二分类可行，文本是否有问题不确定)/final_dataset_dict/'
dataset_dict_s3_path = '/root/projects/PythonProjects/ip-dual-encoder-factorization-machine/data/input_processed_output(P1N1)/final_dataset_dict/'

In [14]:
from datasets import DatasetDict

dataset_dict = DatasetDict.load_from_disk(dataset_dict_path=dataset_dict_s3_path)

In [18]:

from model_ddp_trainer.custom_collate_function import CustomCollateFunc
from torch.utils.data import DataLoader

tokenizer_checkpoint = my_model_config['checkpoint']
bert_tokenizer_model_max_length = 512
custom_collate_function = CustomCollateFunc(tokenizer_checkpoint, bert_tokenizer_model_max_length)

batch_size = 64
num_workers = 0

test_dataloader = DataLoader(
    dataset_dict["test"], batch_size=batch_size, collate_fn=custom_collate_function, pin_memory=True,
    num_workers=num_workers,
    shuffle=False,
    # sampler=DistributedSampler(dataset_dict["valid"], shuffle=False, drop_last=True)  # 504114 % (64 * 4) == 50 samples
)

# Evaluation

In [20]:
from torchmetrics.classification import BinaryAUROC, BinaryPrecision, BinaryRecall, BinaryF1Score

criterion = nn.BCEWithLogitsLoss(reduction='mean')
metric_auroc = BinaryAUROC().to(device)
metric_precision = BinaryPrecision().to(device)
metric_recall = BinaryRecall().to(device)
metric_f1score = BinaryF1Score().to(device)

In [22]:
from tqdm.notebook import tqdm
from model_ddp_trainer.ddp_trainer import get_features, get_label

model.eval()
metric_auroc.reset()
metric_precision.reset()
metric_recall.reset()
metric_f1score.reset()

avg_test_loss = 0
test_losses = []

with torch.no_grad():
    for test_batch in tqdm(test_dataloader):
        output = model(**get_features(test_batch, device))
        labels = get_label(test_batch, device)
        test_loss = criterion(output, labels)

        test_losses.append(test_loss)

        metric_auroc.update(preds=output, target=labels.int())
        metric_precision.update(preds=output, target=labels.int())
        metric_recall.update(preds=output, target=labels.int())
        metric_f1score.update(preds=output, target=labels.int())
    avg_test_loss = sum(test_losses) / len(test_losses)

  0%|          | 0/2930 [00:00<?, ?it/s]

  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)


In [23]:
metrics = {
            'test Loss': avg_test_loss.item(),
            'test AUC': metric_auroc.compute().item(),
            'test Precision': metric_precision.compute().item(),
            'test Recall': metric_recall.compute().item(),
            'test F1Score': metric_f1score.compute().item(),
}

In [24]:
metrics

{'test Loss': 0.2982197403907776,
 'test AUC': 0.9672754406929016,
 'test Precision': 0.9273624420166016,
 'test Recall': 0.8761467933654785,
 'test F1Score': 0.9010273814201355}