In this notebook, we  worked on improving the model and continue to test metrics to see if our model has improved. 

## Model with Repetition Penalty

We trained a new model where we added penalty for the model whenever the model returns a prediction that already exists in the input tokens. 


### New Loss functions

The major change in the new model is the loss functino. Here is the loss function that we defined for the model to penalize duplicates.

```python
class DedupCrossEntropyLoss(nn.Module):
    """
    Custom loss function that combines cross-entropy loss with a penalty term
    for tokens that are repeated from the input sequence.
    """
    def __init__(self, penalty_weight = 1.0):
        super(DedupCrossEntropyLoss, self).__init__()
        self.penalty_weight = penalty_weight
        self.ce_loss = nn.CrossEntropyLoss()
        logger.info(f"Penalty weight: {penalty_weight}")

    def forward(self, logits: torch.Tensor, labels: torch.Tensor, input_tokens: torch.Tensor):
        # Calculate the standard cross-entropy loss
        ce_loss = self.ce_loss(logits, labels)
        if self.penalty_weight == 0:
            return ce_loss
        token_output = torch.argmax(logits, dim=1)
        duplicated_masks = torch.eq(input_tokens, token_output.unsqueeze(-1)).any(dim=-1).float()
        penalty = duplicated_masks * self.penalty_weight
        loss = ce_loss + penalty.mean()
        return loss
```

### Model training diff

We played around with a few different learning rates and penalty factor and finally landed on the numbers below. We show the diff from the original model trained in previous post below:

```diff
-> % git --no-pager diff training_config.yaml 
diff --git a/movielens-ntp/training_config.yaml b/movielens-ntp/training_config.yaml
index 235e7b5..9dff23a 100644
--- a/movielens-ntp/training_config.yaml
+++ b/movielens-ntp/training_config.yaml
@@ -2,12 +2,13 @@ trainer_config:
   data_dir: ./data/ml-1m
   model_dir: ./models
   batch_size: 512
-  starting_learning_rate: 0.0005
+  starting_learning_rate: 0.0008
   learning_rate_decay: 0.95
   device: cuda
   num_epochs: 1000
   validation_fraction: 0.15
   tensorboard_dir: ./runs
+  penalize_duplicates_factor: 0.2
 
 movie_transformer_config:
   context_window_size: 5

```

Finally, we trained the model:

```shell
python model_train.py --config_file=./training_config.yaml --penalize-duplicates
```

In [7]:
import torch
import torch.nn.functional as F
from model_train import run_model_training, load_config, get_model_config
from data import MovieLensSequenceDataset
from torch.utils.data import DataLoader
from tbparse import SummaryReader
import plotly.express as px
from eval import (
    get_model_predictions,
    get_model_predictions,
    calculate_metrics,
    calculate_relevance,
)
import pandas as pd
import numpy as np
import torch.nn as nn
from movielens_transformer import MovieLensTransformer

In [8]:
def load_model_artifacts(model_file: str, config_file: str):
    config_file = "./training_config.yaml"
    config = load_config(config_file)
    sequence_length = config["movie_transformer_config"]["context_window_size"]
    batch_size = config["trainer_config"]["batch_size"]
    valid_dataset = MovieLensSequenceDataset(
        movies_file="./data/ml-1m/movies.dat",
        users_file="./data/ml-1m/users.dat",
        ratings_file="./data/ml-1m/ratings.dat",
        sequence_length=sequence_length,
        window_size=1,
        is_validation=True,
    )
    valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
    with open(model_file, "rb") as f:
        model_state_dict = torch.load(f, weights_only=True)
        model_config = get_model_config(config, valid_dataset)
    trained_model = MovieLensTransformer(model_config)
    trained_model.load_state_dict(model_state_dict)
    return config, trained_model, valid_dataloader


def predict_next_movie(model, movie_ids, user_ids):
    model.eval()
    logits = model(movie_ids=movie_ids, user_ids=user_ids)
    probabilities = F.softmax(logits, dim=-1)
    predicted_movie_ids = torch.argmax(probabilities, dim=-1)
    return predicted_movie_ids

In [9]:
def read_tensorboard_logs(log_file: str):
    reader = SummaryReader(log_file)
    metrics = reader.scalars
    return metrics

In [10]:
models = {
    "baseline": (
        "./models/model_1000e_512_32_32_4_4.pth",
        "./models/events.out.tfevents.1724703053.big.514232.0",
    ),
    "with_duplication_penalty": (
        "./models/model_1000e_512_32_32_4_4_w_pen.pth",
        "./models/events.out.tfevents.1725245582.big.1079645.0",
    ),
}

In [11]:
def populate_ranking_metrics(
    model_name, model, valid_dataloader, sequence_length, k_values
):
    metrics = {}

    for k in k_values:
        # we would torch.cat these later
        model_relevances = []
        model_scores = []

        for batch in valid_dataloader:
            (
                movie_id_tokens,
                rating_id_tokens,
                user_id_tokens,
                output_movie_id_tokens,
                output_rating_id_tokens,
            ) = batch
            model_predictions = get_model_predictions(
                model, movie_id_tokens, user_id_tokens, n=k
            )
            model_relevance = calculate_relevance(
                model_predictions.predictions, output_movie_id_tokens
            )
            model_relevances.append(model_relevance)
            model_scores.append(model_predictions.scores)

        model_relevances_tensor = torch.cat(model_relevances)
        model_scores_tensor = torch.cat(model_scores)

        # Calculate the metrics
        model_metrics = calculate_metrics(model_relevances_tensor, model_scores_tensor)

        metrics[k] = model_metrics

    metrics_df = pd.DataFrame(
        [
            {
                "k": k,
                f"MRR_{model_name}": v.MRR,
                f"MAP_{model_name}": v.MAP,
                f"NDCG_{model_name}": v.NDCG,
            }
            for k, v in metrics.items()
        ]
    )
    return metrics_df

In [12]:
k_values = [3, 5, 10]

for model_name, (model_file, tb_file) in models.items():
    print(f"======= Model: {model_name} =======")
    config, model, valid_dataloader = load_model_artifacts(
        model_file, "./training_config.yaml"
    )
    metrics = read_tensorboard_logs(tb_file)
    metrics_by_run = metrics.groupby(["step", "tag"]).mean().unstack()["value"]
    metrics_by_run.columns.name = None
    fig = px.line(metrics_by_run, title=f"Model: {model_name}")
    fig.show()

    print("======= Ranking Metrics =======")
    ranking_metrics = populate_ranking_metrics(
        model_name,
        model,
        valid_dataloader,
        config["movie_transformer_config"]["context_window_size"],
        k_values,
    )
    print(ranking_metrics.to_markdown())

2025-07-10 12:59:33.697 | INFO     | data:__init__:89 - Creating MovieLensSequenceDataset with validation set: %s
2025-07-10 12:59:33.698 | INFO     | data:read_movielens_data:12 - Reading data from files




2025-07-10 12:59:35.967 | INFO     | data:_add_tokens:140 - Adding tokens to data
2025-07-10 12:59:36.017 | INFO     | data:_generate_sequences:159 - Generating sequences
2025-07-10 12:59:36.968 | INFO     | data:__init__:110 - Train data length: 884191
2025-07-10 12:59:36.969 | INFO     | data:__init__:111 - Validation data length: 97898
2025-07-10 12:59:37.038 | INFO     | model_train:get_model_config:103 - Model config:
MovieLensTransformerConfig(movie_transformer_config=TransformerConfig(vocab_size=3885, context_window_size=5, embedding_dimension=32, num_layers=4, num_heads=4, dropout_embeddings=0.1, dropout_attention=0.1, dropout_residual=0.1, layer_norm_epsilon=1e-05), user_embedding_dimension=32, num_users=6040, interaction_mlp_hidden_sizes=[16]) 




2025-07-10 13:00:00.016 | INFO     | data:__init__:89 - Creating MovieLensSequenceDataset with validation set: %s
2025-07-10 13:00:00.017 | INFO     | data:read_movielens_data:12 - Reading data from files


|    |   k |   MRR_baseline |   MAP_baseline |   NDCG_baseline |
|---:|----:|---------------:|---------------:|----------------:|
|  0 |   3 |      0.075405  |       0.115396 |       0.0856271 |
|  1 |   5 |      0.0874063 |       0.168155 |       0.107293  |
|  2 |  10 |      0.101339  |       0.268381 |       0.140093  |


2025-07-10 13:00:02.160 | INFO     | data:_add_tokens:140 - Adding tokens to data
2025-07-10 13:00:02.208 | INFO     | data:_generate_sequences:159 - Generating sequences
2025-07-10 13:00:03.108 | INFO     | data:__init__:110 - Train data length: 883611
2025-07-10 13:00:03.108 | INFO     | data:__init__:111 - Validation data length: 98478
2025-07-10 13:00:03.135 | INFO     | model_train:get_model_config:103 - Model config:
MovieLensTransformerConfig(movie_transformer_config=TransformerConfig(vocab_size=3885, context_window_size=5, embedding_dimension=32, num_layers=4, num_heads=4, dropout_embeddings=0.1, dropout_attention=0.1, dropout_residual=0.1, layer_norm_epsilon=1e-05), user_embedding_dimension=32, num_users=6040, interaction_mlp_hidden_sizes=[16]) 


|    |   k |   MRR_with_duplication_penalty |   MAP_with_duplication_penalty |   NDCG_with_duplication_penalty |
|---:|----:|-------------------------------:|-------------------------------:|--------------------------------:|
|  0 |   3 |                      0.0726017 |                       0.110938 |                       0.0823983 |
|  1 |   5 |                      0.084306  |                       0.162635 |                       0.103595  |
|  2 |  10 |                      0.0971976 |                       0.260454 |                       0.135058  |


check if the dedup logic is actually working. 
check how many predicted values are duplicated in baseline vs duplication_penalty model to root-cause.

In [13]:
def get_duplicated_movies(k: int, model: nn.Module, valid_dataloader: DataLoader):
    duplicated_movies = []
    for i, batch in enumerate(valid_dataloader):
        movie_id_tokens, rating_ids, user_id_tokens, movie_targets, rating_targets = (
            batch
        )
        with torch.no_grad():
            # batch x num_tokens
            output = model(movie_id_tokens, user_id_tokens)
        output_probabilites = F.softmax(output, dim=-1)
        _, top_tokens = output_probabilites.topk(k, dim=-1)

        for i in range(movie_id_tokens.shape[0]):
            input_tokens = movie_id_tokens[i]
            output_tokens = top_tokens[i]
            concat_tensor, counts = torch.cat([input_tokens, output_tokens]).unique(
                return_counts=True
            )
            intersection = concat_tensor[torch.where(counts.gt(1))]
            if intersection.shape[0] > 0:
                duplicated_movies.append(intersection)

    return torch.cat(duplicated_movies)

In [14]:
k_values = [3, 5, 10]

for model_name, (model_file, tb_file) in models.items():
    config, model, valid_dataloader = load_model_artifacts(
        model_file, "./training_config.yaml"
    )

    for k in k_values:
        duplicated_movies = get_duplicated_movies(k, model, valid_dataloader)
        average_duplications = duplicated_movies.shape[0] / (
            k * len(valid_dataloader.dataset)
        )
        print(f"[{model_name}] Average Duplications @ {k}: ", average_duplications)

2025-07-10 13:00:26.181 | INFO     | data:__init__:89 - Creating MovieLensSequenceDataset with validation set: %s
2025-07-10 13:00:26.185 | INFO     | data:read_movielens_data:12 - Reading data from files


2025-07-10 13:00:28.296 | INFO     | data:_add_tokens:140 - Adding tokens to data
2025-07-10 13:00:28.343 | INFO     | data:_generate_sequences:159 - Generating sequences
2025-07-10 13:00:29.323 | INFO     | data:__init__:110 - Train data length: 884412
2025-07-10 13:00:29.323 | INFO     | data:__init__:111 - Validation data length: 97677
2025-07-10 13:00:29.338 | INFO     | model_train:get_model_config:103 - Model config:
MovieLensTransformerConfig(movie_transformer_config=TransformerConfig(vocab_size=3885, context_window_size=5, embedding_dimension=32, num_layers=4, num_heads=4, dropout_embeddings=0.1, dropout_attention=0.1, dropout_residual=0.1, layer_norm_epsilon=1e-05), user_embedding_dimension=32, num_users=6040, interaction_mlp_hidden_sizes=[16]) 


[baseline] Average Duplications @ 3:  0.14770109647102184
[baseline] Average Duplications @ 5:  0.1377601687193505


2025-07-10 13:00:44.730 | INFO     | data:__init__:89 - Creating MovieLensSequenceDataset with validation set: %s
2025-07-10 13:00:44.730 | INFO     | data:read_movielens_data:12 - Reading data from files


[baseline] Average Duplications @ 10:  0.11828373107282164


2025-07-10 13:00:47.098 | INFO     | data:_add_tokens:140 - Adding tokens to data
2025-07-10 13:00:47.149 | INFO     | data:_generate_sequences:159 - Generating sequences
2025-07-10 13:00:48.130 | INFO     | data:__init__:110 - Train data length: 883635
2025-07-10 13:00:48.130 | INFO     | data:__init__:111 - Validation data length: 98454
2025-07-10 13:00:48.147 | INFO     | model_train:get_model_config:103 - Model config:
MovieLensTransformerConfig(movie_transformer_config=TransformerConfig(vocab_size=3885, context_window_size=5, embedding_dimension=32, num_layers=4, num_heads=4, dropout_embeddings=0.1, dropout_attention=0.1, dropout_residual=0.1, layer_norm_epsilon=1e-05), user_embedding_dimension=32, num_users=6040, interaction_mlp_hidden_sizes=[16]) 


[with_duplication_penalty] Average Duplications @ 3:  0.14398263825407467
[with_duplication_penalty] Average Duplications @ 5:  0.13379649379405611
[with_duplication_penalty] Average Duplications @ 10:  0.1154559489710931


In [15]:
model_file = "./models/model_500e_4098_32_32_4_4_compiled_cos_ann_lr.pth"
tb_file = "./models/events.out.tfevents.1725310210.big.1156831.0"
model_name = "compiled_cos_ann_lr_with_duplication_penalty"
config, model, valid_dataloader = load_model_artifacts(
    model_file, "./training_config.yaml"
)
metrics = read_tensorboard_logs(tb_file)
metrics_by_run = metrics.groupby(["step", "tag"]).mean().unstack()["value"]
metrics_by_run.columns.name = None
fig = px.line(metrics_by_run, title=f"Model: {model_name}")
fig.show()

print("======= Ranking Metrics =======")
ranking_metrics = populate_ranking_metrics(
    model_name,
    model,
    valid_dataloader,
    config["movie_transformer_config"]["context_window_size"],
    k_values,
)
print(ranking_metrics.to_markdown())

2025-07-10 13:01:06.210 | INFO     | data:__init__:89 - Creating MovieLensSequenceDataset with validation set: %s
2025-07-10 13:01:06.211 | INFO     | data:read_movielens_data:12 - Reading data from files
2025-07-10 13:01:08.409 | INFO     | data:_add_tokens:140 - Adding tokens to data
2025-07-10 13:01:08.457 | INFO     | data:_generate_sequences:159 - Generating sequences
2025-07-10 13:01:09.577 | INFO     | data:__init__:110 - Train data length: 883769
2025-07-10 13:01:09.578 | INFO     | data:__init__:111 - Validation data length: 98320
2025-07-10 13:01:09.600 | INFO     | model_train:get_model_config:103 - Model config:
MovieLensTransformerConfig(movie_transformer_config=TransformerConfig(vocab_size=3885, context_window_size=5, embedding_dimension=32, num_layers=4, num_heads=4, dropout_embeddings=0.1, dropout_attention=0.1, dropout_residual=0.1, layer_norm_epsilon=1e-05), user_embedding_dimension=32, num_users=6040, interaction_mlp_hidden_sizes=[16]) 


|    |   k |   MRR_compiled_cos_ann_lr_with_duplication_penalty |   MAP_compiled_cos_ann_lr_with_duplication_penalty |   NDCG_compiled_cos_ann_lr_with_duplication_penalty |
|---:|----:|---------------------------------------------------:|---------------------------------------------------:|----------------------------------------------------:|
|  0 |   3 |                                          0.0737524 |                                           0.112764 |                                           0.0837217 |
|  1 |   5 |                                          0.0863046 |                                           0.166009 |                                           0.105919  |
|  2 |  10 |                                          0.0992404 |                                           0.264138 |                                           0.137483  |


It seems there isn't any improvement in the model after these changes. Although training run finished in half the time 

In [16]:
for k in k_values:
    duplicated_movies = get_duplicated_movies(k, model, valid_dataloader)
    average_duplications = duplicated_movies.shape[0] / (
        k * len(valid_dataloader.dataset)
    )
    print(f"[{model_name}] Average Duplications @ {k}: ", average_duplications)

[compiled_cos_ann_lr_with_duplication_penalty] Average Duplications @ 3:  0.15278003797125034
[compiled_cos_ann_lr_with_duplication_penalty] Average Duplications @ 5:  0.1414279902359642
[compiled_cos_ann_lr_with_duplication_penalty] Average Duplications @ 10:  0.12085638730675345


In [17]:
model_file = (
    "./models/model_500e_4098_32_32_4_4_compiled_cos_ann_lr_from_checkpoint.pth"
)
tb_file = "./models/events.out.tfevents.1725330223.big.1178851.0"
model_name = "compiled_cos_ann_lr_with_duplication_penalty_from_checkpoint"
metrics = read_tensorboard_logs(tb_file)
metrics_by_run = metrics.groupby(["step", "tag"]).mean().unstack()["value"]
metrics_by_run.columns.name = None
fig = px.line(metrics_by_run, title=f"Model: {model_name}")
fig.show()

We see that the model loss improved further slightly but not a lot. we wanted to test the duplication metric as well, however, it seems currently it is not possible to save/load compiled_model without some effort. See discussion [forum](https://discuss.pytorch.org/t/how-to-save-load-a-model-with-torch-compile/179739) and [on github issue](https://github.com/pytorch/pytorch/issues/101107#issuecomment-1542688089).

In this Step we got some errors which we couldn't resolve before deadline 

In [20]:
config, model, valid_dataloader = load_model_artifacts(
    model_file, "./training_config.yaml"
)
metrics = read_tensorboard_logs(tb_file)
metrics_by_run = metrics.groupby(["step", "tag"]).mean().unstack()["value"]
metrics_by_run.columns.name = None
fig = px.line(metrics_by_run, title=f"Model: {model_name}")
fig.show()

print("======= Ranking Metrics =======")
ranking_metrics = populate_ranking_metrics(
    model_name,
    model,
    valid_dataloader,
    config["movie_transformer_config"]["context_window_size"],
    k_values,
)
print(ranking_metrics.to_markdown())

2025-07-10 13:03:13.330 | INFO     | data:__init__:89 - Creating MovieLensSequenceDataset with validation set: %s
2025-07-10 13:03:13.331 | INFO     | data:read_movielens_data:12 - Reading data from files
2025-07-10 13:03:15.479 | INFO     | data:_add_tokens:140 - Adding tokens to data
2025-07-10 13:03:15.527 | INFO     | data:_generate_sequences:159 - Generating sequences
2025-07-10 13:03:16.432 | INFO     | data:__init__:110 - Train data length: 884075
2025-07-10 13:03:16.433 | INFO     | data:__init__:111 - Validation data length: 98014
2025-07-10 13:03:16.450 | INFO     | model_train:get_model_config:103 - Model config:
MovieLensTransformerConfig(movie_transformer_config=TransformerConfig(vocab_size=3885, context_window_size=5, embedding_dimension=32, num_layers=4, num_heads=4, dropout_embeddings=0.1, dropout_attention=0.1, dropout_residual=0.1, layer_norm_epsilon=1e-05), user_embedding_dimension=32, num_users=6040, interaction_mlp_hidden_sizes=[16]) 


RuntimeError: Error(s) in loading state_dict for MovieLensTransformer:
	Missing key(s) in state_dict: "movie_transformer.positional_ids", "movie_transformer.token_embedding.weight", "movie_transformer.positional_embedding.weight", "movie_transformer.layers.0.layer_norm1.weight", "movie_transformer.layers.0.layer_norm1.bias", "movie_transformer.layers.0.attention.qkv.weight", "movie_transformer.layers.0.attention.qkv.bias", "movie_transformer.layers.0.attention.out.weight", "movie_transformer.layers.0.attention.out.bias", "movie_transformer.layers.0.layer_norm2.weight", "movie_transformer.layers.0.layer_norm2.bias", "movie_transformer.layers.0.mlp.fc1.weight", "movie_transformer.layers.0.mlp.fc1.bias", "movie_transformer.layers.0.mlp.fc2.weight", "movie_transformer.layers.0.mlp.fc2.bias", "movie_transformer.layers.1.layer_norm1.weight", "movie_transformer.layers.1.layer_norm1.bias", "movie_transformer.layers.1.attention.qkv.weight", "movie_transformer.layers.1.attention.qkv.bias", "movie_transformer.layers.1.attention.out.weight", "movie_transformer.layers.1.attention.out.bias", "movie_transformer.layers.1.layer_norm2.weight", "movie_transformer.layers.1.layer_norm2.bias", "movie_transformer.layers.1.mlp.fc1.weight", "movie_transformer.layers.1.mlp.fc1.bias", "movie_transformer.layers.1.mlp.fc2.weight", "movie_transformer.layers.1.mlp.fc2.bias", "movie_transformer.layers.2.layer_norm1.weight", "movie_transformer.layers.2.layer_norm1.bias", "movie_transformer.layers.2.attention.qkv.weight", "movie_transformer.layers.2.attention.qkv.bias", "movie_transformer.layers.2.attention.out.weight", "movie_transformer.layers.2.attention.out.bias", "movie_transformer.layers.2.layer_norm2.weight", "movie_transformer.layers.2.layer_norm2.bias", "movie_transformer.layers.2.mlp.fc1.weight", "movie_transformer.layers.2.mlp.fc1.bias", "movie_transformer.layers.2.mlp.fc2.weight", "movie_transformer.layers.2.mlp.fc2.bias", "movie_transformer.layers.3.layer_norm1.weight", "movie_transformer.layers.3.layer_norm1.bias", "movie_transformer.layers.3.attention.qkv.weight", "movie_transformer.layers.3.attention.qkv.bias", "movie_transformer.layers.3.attention.out.weight", "movie_transformer.layers.3.attention.out.bias", "movie_transformer.layers.3.layer_norm2.weight", "movie_transformer.layers.3.layer_norm2.bias", "movie_transformer.layers.3.mlp.fc1.weight", "movie_transformer.layers.3.mlp.fc1.bias", "movie_transformer.layers.3.mlp.fc2.weight", "movie_transformer.layers.3.mlp.fc2.bias", "movie_transformer.layer_norm.weight", "movie_transformer.layer_norm.bias", "user_embedding.weight", "output_layer.fc_layers.0.weight", "output_layer.fc_layers.0.bias", "output_layer.fc_layers.2.weight". 
	Unexpected key(s) in state_dict: "_orig_mod.movie_transformer.positional_ids", "_orig_mod.movie_transformer.token_embedding.weight", "_orig_mod.movie_transformer.positional_embedding.weight", "_orig_mod.movie_transformer.layers.0.layer_norm1.weight", "_orig_mod.movie_transformer.layers.0.layer_norm1.bias", "_orig_mod.movie_transformer.layers.0.attention.qkv.weight", "_orig_mod.movie_transformer.layers.0.attention.qkv.bias", "_orig_mod.movie_transformer.layers.0.attention.out.weight", "_orig_mod.movie_transformer.layers.0.attention.out.bias", "_orig_mod.movie_transformer.layers.0.layer_norm2.weight", "_orig_mod.movie_transformer.layers.0.layer_norm2.bias", "_orig_mod.movie_transformer.layers.0.mlp.fc1.weight", "_orig_mod.movie_transformer.layers.0.mlp.fc1.bias", "_orig_mod.movie_transformer.layers.0.mlp.fc2.weight", "_orig_mod.movie_transformer.layers.0.mlp.fc2.bias", "_orig_mod.movie_transformer.layers.1.layer_norm1.weight", "_orig_mod.movie_transformer.layers.1.layer_norm1.bias", "_orig_mod.movie_transformer.layers.1.attention.qkv.weight", "_orig_mod.movie_transformer.layers.1.attention.qkv.bias", "_orig_mod.movie_transformer.layers.1.attention.out.weight", "_orig_mod.movie_transformer.layers.1.attention.out.bias", "_orig_mod.movie_transformer.layers.1.layer_norm2.weight", "_orig_mod.movie_transformer.layers.1.layer_norm2.bias", "_orig_mod.movie_transformer.layers.1.mlp.fc1.weight", "_orig_mod.movie_transformer.layers.1.mlp.fc1.bias", "_orig_mod.movie_transformer.layers.1.mlp.fc2.weight", "_orig_mod.movie_transformer.layers.1.mlp.fc2.bias", "_orig_mod.movie_transformer.layers.2.layer_norm1.weight", "_orig_mod.movie_transformer.layers.2.layer_norm1.bias", "_orig_mod.movie_transformer.layers.2.attention.qkv.weight", "_orig_mod.movie_transformer.layers.2.attention.qkv.bias", "_orig_mod.movie_transformer.layers.2.attention.out.weight", "_orig_mod.movie_transformer.layers.2.attention.out.bias", "_orig_mod.movie_transformer.layers.2.layer_norm2.weight", "_orig_mod.movie_transformer.layers.2.layer_norm2.bias", "_orig_mod.movie_transformer.layers.2.mlp.fc1.weight", "_orig_mod.movie_transformer.layers.2.mlp.fc1.bias", "_orig_mod.movie_transformer.layers.2.mlp.fc2.weight", "_orig_mod.movie_transformer.layers.2.mlp.fc2.bias", "_orig_mod.movie_transformer.layers.3.layer_norm1.weight", "_orig_mod.movie_transformer.layers.3.layer_norm1.bias", "_orig_mod.movie_transformer.layers.3.attention.qkv.weight", "_orig_mod.movie_transformer.layers.3.attention.qkv.bias", "_orig_mod.movie_transformer.layers.3.attention.out.weight", "_orig_mod.movie_transformer.layers.3.attention.out.bias", "_orig_mod.movie_transformer.layers.3.layer_norm2.weight", "_orig_mod.movie_transformer.layers.3.layer_norm2.bias", "_orig_mod.movie_transformer.layers.3.mlp.fc1.weight", "_orig_mod.movie_transformer.layers.3.mlp.fc1.bias", "_orig_mod.movie_transformer.layers.3.mlp.fc2.weight", "_orig_mod.movie_transformer.layers.3.mlp.fc2.bias", "_orig_mod.movie_transformer.layer_norm.weight", "_orig_mod.movie_transformer.layer_norm.bias", "_orig_mod.user_embedding.weight", "_orig_mod.output_layer.fc_layers.0.weight", "_orig_mod.output_layer.fc_layers.0.bias", "_orig_mod.output_layer.fc_layers.2.weight". 

In [None]:
model_file = "./models/model_500e_4098_32_32_4_4_compiled_cos_ann_lr_from_checkpoint_regularized.pth"
tb_file = "./models/events.out.tfevents.1733057344.big.3280480.0"
model_name = "compiled_cos_ann_lr_with_duplication_penalty_regularized"
metrics = read_tensorboard_logs(tb_file)
metrics_by_run = metrics.groupby(["step", "tag"]).mean().unstack()["value"]
metrics_by_run.columns.name = None
fig = px.line(metrics_by_run, title=f"Model: {model_name}")
fig.show()

In [19]:
config, model, valid_dataloader = load_model_artifacts(
    model_file, "./training_config.yaml"
)
metrics = read_tensorboard_logs(tb_file)
metrics_by_run = metrics.groupby(["step", "tag"]).mean().unstack()["value"]
metrics_by_run.columns.name = None
fig = px.line(metrics_by_run, title=f"Model: {model_name}")
fig.show()

print("======= Ranking Metrics =======")
ranking_metrics = populate_ranking_metrics(
    model_name,
    model,
    valid_dataloader,
    config["movie_transformer_config"]["context_window_size"],
    k_values,
)
print(ranking_metrics.to_markdown())

2025-07-10 13:02:50.368 | INFO     | data:__init__:89 - Creating MovieLensSequenceDataset with validation set: %s
2025-07-10 13:02:50.369 | INFO     | data:read_movielens_data:12 - Reading data from files
2025-07-10 13:02:52.386 | INFO     | data:_add_tokens:140 - Adding tokens to data
2025-07-10 13:02:52.432 | INFO     | data:_generate_sequences:159 - Generating sequences
2025-07-10 13:02:53.426 | INFO     | data:__init__:110 - Train data length: 883931
2025-07-10 13:02:53.426 | INFO     | data:__init__:111 - Validation data length: 98158
2025-07-10 13:02:53.441 | INFO     | model_train:get_model_config:103 - Model config:
MovieLensTransformerConfig(movie_transformer_config=TransformerConfig(vocab_size=3885, context_window_size=5, embedding_dimension=32, num_layers=4, num_heads=4, dropout_embeddings=0.1, dropout_attention=0.1, dropout_residual=0.1, layer_norm_epsilon=1e-05), user_embedding_dimension=32, num_users=6040, interaction_mlp_hidden_sizes=[16]) 


RuntimeError: Error(s) in loading state_dict for MovieLensTransformer:
	Missing key(s) in state_dict: "movie_transformer.positional_ids", "movie_transformer.token_embedding.weight", "movie_transformer.positional_embedding.weight", "movie_transformer.layers.0.layer_norm1.weight", "movie_transformer.layers.0.layer_norm1.bias", "movie_transformer.layers.0.attention.qkv.weight", "movie_transformer.layers.0.attention.qkv.bias", "movie_transformer.layers.0.attention.out.weight", "movie_transformer.layers.0.attention.out.bias", "movie_transformer.layers.0.layer_norm2.weight", "movie_transformer.layers.0.layer_norm2.bias", "movie_transformer.layers.0.mlp.fc1.weight", "movie_transformer.layers.0.mlp.fc1.bias", "movie_transformer.layers.0.mlp.fc2.weight", "movie_transformer.layers.0.mlp.fc2.bias", "movie_transformer.layers.1.layer_norm1.weight", "movie_transformer.layers.1.layer_norm1.bias", "movie_transformer.layers.1.attention.qkv.weight", "movie_transformer.layers.1.attention.qkv.bias", "movie_transformer.layers.1.attention.out.weight", "movie_transformer.layers.1.attention.out.bias", "movie_transformer.layers.1.layer_norm2.weight", "movie_transformer.layers.1.layer_norm2.bias", "movie_transformer.layers.1.mlp.fc1.weight", "movie_transformer.layers.1.mlp.fc1.bias", "movie_transformer.layers.1.mlp.fc2.weight", "movie_transformer.layers.1.mlp.fc2.bias", "movie_transformer.layers.2.layer_norm1.weight", "movie_transformer.layers.2.layer_norm1.bias", "movie_transformer.layers.2.attention.qkv.weight", "movie_transformer.layers.2.attention.qkv.bias", "movie_transformer.layers.2.attention.out.weight", "movie_transformer.layers.2.attention.out.bias", "movie_transformer.layers.2.layer_norm2.weight", "movie_transformer.layers.2.layer_norm2.bias", "movie_transformer.layers.2.mlp.fc1.weight", "movie_transformer.layers.2.mlp.fc1.bias", "movie_transformer.layers.2.mlp.fc2.weight", "movie_transformer.layers.2.mlp.fc2.bias", "movie_transformer.layers.3.layer_norm1.weight", "movie_transformer.layers.3.layer_norm1.bias", "movie_transformer.layers.3.attention.qkv.weight", "movie_transformer.layers.3.attention.qkv.bias", "movie_transformer.layers.3.attention.out.weight", "movie_transformer.layers.3.attention.out.bias", "movie_transformer.layers.3.layer_norm2.weight", "movie_transformer.layers.3.layer_norm2.bias", "movie_transformer.layers.3.mlp.fc1.weight", "movie_transformer.layers.3.mlp.fc1.bias", "movie_transformer.layers.3.mlp.fc2.weight", "movie_transformer.layers.3.mlp.fc2.bias", "movie_transformer.layer_norm.weight", "movie_transformer.layer_norm.bias", "user_embedding.weight", "output_layer.fc_layers.0.weight", "output_layer.fc_layers.0.bias", "output_layer.fc_layers.2.weight". 
	Unexpected key(s) in state_dict: "_orig_mod.movie_transformer.positional_ids", "_orig_mod.movie_transformer.token_embedding.weight", "_orig_mod.movie_transformer.positional_embedding.weight", "_orig_mod.movie_transformer.layers.0.layer_norm1.weight", "_orig_mod.movie_transformer.layers.0.layer_norm1.bias", "_orig_mod.movie_transformer.layers.0.attention.qkv.weight", "_orig_mod.movie_transformer.layers.0.attention.qkv.bias", "_orig_mod.movie_transformer.layers.0.attention.out.weight", "_orig_mod.movie_transformer.layers.0.attention.out.bias", "_orig_mod.movie_transformer.layers.0.layer_norm2.weight", "_orig_mod.movie_transformer.layers.0.layer_norm2.bias", "_orig_mod.movie_transformer.layers.0.mlp.fc1.weight", "_orig_mod.movie_transformer.layers.0.mlp.fc1.bias", "_orig_mod.movie_transformer.layers.0.mlp.fc2.weight", "_orig_mod.movie_transformer.layers.0.mlp.fc2.bias", "_orig_mod.movie_transformer.layers.1.layer_norm1.weight", "_orig_mod.movie_transformer.layers.1.layer_norm1.bias", "_orig_mod.movie_transformer.layers.1.attention.qkv.weight", "_orig_mod.movie_transformer.layers.1.attention.qkv.bias", "_orig_mod.movie_transformer.layers.1.attention.out.weight", "_orig_mod.movie_transformer.layers.1.attention.out.bias", "_orig_mod.movie_transformer.layers.1.layer_norm2.weight", "_orig_mod.movie_transformer.layers.1.layer_norm2.bias", "_orig_mod.movie_transformer.layers.1.mlp.fc1.weight", "_orig_mod.movie_transformer.layers.1.mlp.fc1.bias", "_orig_mod.movie_transformer.layers.1.mlp.fc2.weight", "_orig_mod.movie_transformer.layers.1.mlp.fc2.bias", "_orig_mod.movie_transformer.layers.2.layer_norm1.weight", "_orig_mod.movie_transformer.layers.2.layer_norm1.bias", "_orig_mod.movie_transformer.layers.2.attention.qkv.weight", "_orig_mod.movie_transformer.layers.2.attention.qkv.bias", "_orig_mod.movie_transformer.layers.2.attention.out.weight", "_orig_mod.movie_transformer.layers.2.attention.out.bias", "_orig_mod.movie_transformer.layers.2.layer_norm2.weight", "_orig_mod.movie_transformer.layers.2.layer_norm2.bias", "_orig_mod.movie_transformer.layers.2.mlp.fc1.weight", "_orig_mod.movie_transformer.layers.2.mlp.fc1.bias", "_orig_mod.movie_transformer.layers.2.mlp.fc2.weight", "_orig_mod.movie_transformer.layers.2.mlp.fc2.bias", "_orig_mod.movie_transformer.layers.3.layer_norm1.weight", "_orig_mod.movie_transformer.layers.3.layer_norm1.bias", "_orig_mod.movie_transformer.layers.3.attention.qkv.weight", "_orig_mod.movie_transformer.layers.3.attention.qkv.bias", "_orig_mod.movie_transformer.layers.3.attention.out.weight", "_orig_mod.movie_transformer.layers.3.attention.out.bias", "_orig_mod.movie_transformer.layers.3.layer_norm2.weight", "_orig_mod.movie_transformer.layers.3.layer_norm2.bias", "_orig_mod.movie_transformer.layers.3.mlp.fc1.weight", "_orig_mod.movie_transformer.layers.3.mlp.fc1.bias", "_orig_mod.movie_transformer.layers.3.mlp.fc2.weight", "_orig_mod.movie_transformer.layers.3.mlp.fc2.bias", "_orig_mod.movie_transformer.layer_norm.weight", "_orig_mod.movie_transformer.layer_norm.bias", "_orig_mod.user_embedding.weight", "_orig_mod.output_layer.fc_layers.0.weight", "_orig_mod.output_layer.fc_layers.0.bias", "_orig_mod.output_layer.fc_layers.2.weight". 