In [5]:
# Install notebook dependencies
%pip install transformers huggingface_hub torch omegaconf
# Optinal packages for better user experience
%pip install tqdm ipywidgets nbconvert

213.14s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


Note: you may need to restart the kernel to use updated packages.


219.58s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


Note: you may need to restart the kernel to use updated packages.


In [6]:
import torch
import omegaconf
import collections
import typing
from huggingface_hub import hf_hub_download

In [7]:
# Ensure that the necessary types are registered for safe deserialization
torch.serialization.add_safe_globals([
    omegaconf.dictconfig.ContainerMetadata,
    omegaconf.dictconfig.DictConfig,
    omegaconf.base.Metadata,
    omegaconf.nodes.AnyNode,
    omegaconf.listconfig.ListConfig,
    collections.defaultdict,
    typing.Any,
    dict,
    list,
    int])

# Download checkpoint file
checkpoint_path = hf_hub_download(
    repo_id="NTU-NLP-sg/xCodeEval-nl-code-starencoder-ckpt-37",
    filename="dpr_biencoder.37.pt",
    repo_type="model")

# Load the model weights from the specified checkpoint file
model_weights = torch.load(
    checkpoint_path,
    map_location=torch.device("cpu"),
    weights_only=True)

In [9]:
# Print model properties
import pprint

print("=== Model Checkpoint Structure ===")
pprint.pprint(list(model_weights.keys()), indent=2)
print()

print("=== Parameters ===")
print(f"Epoch: {model_weights["epoch"]}")
print(f"Offset: {model_weights["offset"]}")
print()

print("=== Encoder Parameters ===")
if "encoder_params" in model_weights:
    encoder_params = model_weights["encoder_params"]
    print(f"do_lower_case: {encoder_params.get('do_lower_case', 'Not found')}")
    
    if "encoder" in encoder_params:
        print("\nEncoder configuration:")
        pprint.pprint(dict(encoder_params["encoder"]), indent=2, width=80)
    else:
        print("No 'encoder' key found in encoder_params")
        print("Available encoder_params keys:")
        pprint.pprint(list(encoder_params.keys()), indent=2)
else:
    print("No 'encoder_params' found in model weights")

print("\n=== Model Dictionary Keys ===")
if "model_dict" in model_weights:
    model_dict_keys = list(model_weights["model_dict"].keys())
    print(f"Number of model parameters: {len(model_dict_keys)}")
    print("First 10 parameter names:")
    for key in model_dict_keys[:10]:
        print(f"  - {key}")
    if len(model_dict_keys) > 10:
        print(f"  ... and {len(model_dict_keys) - 10} more")
else:
    print("No 'model_dict' found in model weights")
print()

print("=== Optimizer dictionary ===")
if "optimizer_dict" in model_weights:
    optimizer_dict = model_weights["optimizer_dict"]
    print(f"Optimizer keys: {list(optimizer_dict.keys())}")
    
    if "state" in optimizer_dict:
        state_dict = optimizer_dict["state"]
        print(f"Number of parameter groups in optimizer state: {len(state_dict)}")
        
        # Show first few parameter states
        if state_dict:
            first_key = next(iter(state_dict.keys()))
            first_state = state_dict[first_key]
            print(f"Example parameter state keys: {list(first_state.keys())}")
    
    if "param_groups" in optimizer_dict:
        param_groups = optimizer_dict["param_groups"]
        print(f"Number of parameter groups: {len(param_groups)}")
        if param_groups:
            print("First parameter group configuration:")
            first_group = param_groups[0]
            # Remove 'params' key for cleaner output as it contains parameter IDs
            group_config = {k: v for k, v in first_group.items() if k != 'params'}
            pprint.pprint(group_config, indent=2)
            print(f"Number of parameters in first group: {len(first_group.get('params', []))}")
else:
    print("No 'optimizer_dict' found in model weights")
print()

print("=== Scheduler dictionary ===")
if "scheduler_dict" in model_weights:
    scheduler_dict = model_weights["scheduler_dict"]
    print(f"Scheduler keys: {list(scheduler_dict.keys())}")
    
    # Common scheduler state information
    if "last_epoch" in scheduler_dict:
        print(f"Last epoch: {scheduler_dict['last_epoch']}")
    
    if "_step_count" in scheduler_dict:
        print(f"Step count: {scheduler_dict['_step_count']}")
    
    if "base_lrs" in scheduler_dict:
        print(f"Base learning rates: {scheduler_dict['base_lrs']}")
    
    if "_last_lr" in scheduler_dict:
        print(f"Last learning rate: {scheduler_dict['_last_lr']}")
    
    # Print all scheduler state for completeness
    print("\nFull scheduler state:")
    pprint.pprint(scheduler_dict, indent=2, width=80)
else:
    print("No 'scheduler_dict' found in model weights")
print()

=== Model Checkpoint Structure ===
[ 'model_dict',
  'optimizer_dict',
  'scheduler_dict',
  'offset',
  'epoch',
  'encoder_params']

=== Parameters ===
Epoch: 37
Offset: 5158

=== Encoder Parameters ===
do_lower_case: False

Encoder configuration:
{ 'dropout': 0.1,
  'encoder_model_type': 'hf_bert',
  'fix_ctx_encoder': False,
  'pretrained': True,
  'pretrained_file': None,
  'pretrained_model_cfg': 'bigcode/starencoder',
  'projection_dim': 0,
  'sequence_length': 1024}

=== Model Dictionary Keys ===
Number of model parameters: 400
First 10 parameter names:
  - question_model.embeddings.position_ids
  - question_model.embeddings.word_embeddings.weight
  - question_model.embeddings.position_embeddings.weight
  - question_model.embeddings.token_type_embeddings.weight
  - question_model.embeddings.LayerNorm.weight
  - question_model.embeddings.LayerNorm.bias
  - question_model.encoder.layer.0.attention.self.query.weight
  - question_model.encoder.layer.0.attention.self.query.bias
  - 