In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from collections import OrderedDict
from network import HDRTransformer

# Define a simple neural network
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 5)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Initialize model and optimizer
model = HDRTransformer(embed_dim=60, depths=[6, 6, 6], num_heads=[6, 6, 6], mlp_ratio=2, in_chans=6)

optimizer = optim.Adam(model.parameters(), lr=0.001)

# Save model state_dict in both .pth and .pkl formats
torch.save(model.state_dict(), "model.pth")
torch.save(model.state_dict(), "model.pkl")

# Load the model weights
pth_data = torch.load("model.pth", map_location="cpu")
pkl_data = torch.load("model.pkl", map_location="cpu")

# Function to remove 'module.' prefix if needed
def remove_module_prefix(state_dict):
    new_state_dict = OrderedDict()
    for key, value in state_dict.items():
        new_key = key.replace("module.", "")  # Remove "module." prefix
        new_state_dict[new_key] = value
    return new_state_dict

# Remove module prefix if necessary
if any(k.startswith("module.") for k in pth_data.keys()):
    pth_data = remove_module_prefix(pth_data)

# Check if the keys match
if set(pth_data.keys()) == set(pkl_data.keys()):
    print("✅ Keys match in both files.")
else:
    print("❌ Key mismatch!")
    print("Keys in .pth but not in .pkl:", set(pth_data.keys()) - set(pkl_data.keys()))
    print("Keys in .pkl but not in .pth:", set(pkl_data.keys()) - set(pth_data.keys()))

# Check if all tensor values match
all_match = True
for key in pth_data.keys():
    if not torch.equal(pth_data[key], pkl_data[key]):
        print(f"❌ Mismatch in {key}")
        all_match = False

if all_match:
    print("✅ The weights are identical in both files!")
else:
    print("⚠️ There are differences in the model weights.")


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


✅ Keys match in both files.
✅ The weights are identical in both files!


In [11]:
import torch

from collections import OrderedDict

# Load .pth file
pth_data = torch.load("./Results/HDRTransformer.pth", map_location="cpu")

# Load .pkl file
pkl_data = torch.load("./Results/HDRTransformer.pkl", map_location="cpu")
import torch


# Check if both are dictionaries
if not isinstance(pth_data, dict) or not isinstance(pkl_data, dict):
    print(f"Unexpected types: .pth = {type(pth_data)}, .pkl = {type(pkl_data)}")
else:
    # Compare keys
    pth_keys = set(pth_data.keys())
    pkl_keys = set(pkl_data.keys())

    if pth_keys == pkl_keys:
        print("Keys in both files match!")
    else:
        print("Key mismatch found!")
        print("Keys in .pth but not in .pkl:", pth_keys - pkl_keys)
        print("Keys in .pkl but not in .pth:", pkl_keys - pth_keys)


Key mismatch found!
Keys in .pth but not in .pkl: {'module.layers.2.residual_group.blocks.0.lce.fc.0.weight', 'module.layers.2.residual_group.blocks.5.attn.proj.weight', 'module.layers.1.residual_group.blocks.1.mlp.fc2.weight', 'module.layers.0.residual_group.blocks.5.attn.qkv.weight', 'module.layers.0.residual_group.blocks.1.norm2.bias', 'module.layers.2.residual_group.blocks.3.lce.conv_block.0.weight', 'module.layers.1.residual_group.blocks.0.norm2.weight', 'module.layers.1.residual_group.blocks.2.attn.relative_position_index', 'module.layers.2.residual_group.blocks.1.mlp.fc2.weight', 'module.layers.0.residual_group.blocks.4.lce.conv_block.2.bias', 'module.layers.1.residual_group.blocks.5.mlp.fc1.bias', 'module.layers.1.residual_group.blocks.4.lce.fc.2.weight', 'module.layers.2.residual_group.blocks.5.attn.qkv.bias', 'module.layers.2.residual_group.blocks.3.lce.conv_block.2.bias', 'module.layers.0.residual_group.blocks.2.attn.proj.bias', 'module.layers.1.residual_group.blocks.2.lce.c

In [13]:
def remove_module_prefix(state_dict):
    new_state_dict = OrderedDict()
    for key, value in state_dict.items():
        new_key = key.replace("module.", "")  # Remove "module." prefix
        new_state_dict[new_key] = value
    return new_state_dict



# Remove module prefix if needed
if any(k.startswith("module.") for k in pth_data.keys()):
    pth_data = remove_module_prefix(pth_data)

# Remove module prefix if needed
if any(k.startswith("module.") for k in pth_data.keys()):
    pth_data = remove_module_prefix(pth_data)

# Check if all keys and tensors match
all_match = True
for key in pth_data.keys():
    if key not in pkl_data:
        print(f"❌ Key {key} is missing in .pkl")
        all_match = False
    elif not torch.equal(pth_data[key], pkl_data[key]):
        print(f"❌ Mismatch in {key}")
        all_match = False

if all_match:
    print("✅ Both files are identical!")
else:
    print("⚠️ There are differences between the files.")

❌ Mismatch in conv_f1.weight
❌ Mismatch in conv_f1.bias
❌ Mismatch in conv_f2.weight
❌ Mismatch in conv_f2.bias
❌ Mismatch in conv_f3.weight
❌ Mismatch in conv_f3.bias
❌ Mismatch in att_module_l.att1.weight
❌ Mismatch in att_module_l.att1.bias
❌ Mismatch in att_module_l.att2.weight
❌ Mismatch in att_module_l.att2.bias
❌ Mismatch in att_module_h.att1.weight
❌ Mismatch in att_module_h.att1.bias
❌ Mismatch in att_module_h.att2.weight
❌ Mismatch in att_module_h.att2.bias
❌ Mismatch in conv_first.weight
❌ Mismatch in conv_first.bias
❌ Mismatch in patch_embed.norm.weight
❌ Mismatch in patch_embed.norm.bias
❌ Mismatch in layers.0.residual_group.blocks.0.norm1.weight
❌ Mismatch in layers.0.residual_group.blocks.0.norm1.bias
❌ Mismatch in layers.0.residual_group.blocks.0.attn.relative_position_bias_table
❌ Mismatch in layers.0.residual_group.blocks.0.attn.qkv.weight
❌ Mismatch in layers.0.residual_group.blocks.0.attn.qkv.bias
❌ Mismatch in layers.0.residual_group.blocks.0.attn.proj.weight
❌ Mis

In [3]:
pth_data.keys()

odict_keys(['module.conv_f1.weight', 'module.conv_f1.bias', 'module.conv_f2.weight', 'module.conv_f2.bias', 'module.conv_f3.weight', 'module.conv_f3.bias', 'module.att_module_l.att1.weight', 'module.att_module_l.att1.bias', 'module.att_module_l.att2.weight', 'module.att_module_l.att2.bias', 'module.att_module_h.att1.weight', 'module.att_module_h.att1.bias', 'module.att_module_h.att2.weight', 'module.att_module_h.att2.bias', 'module.conv_first.weight', 'module.conv_first.bias', 'module.patch_embed.norm.weight', 'module.patch_embed.norm.bias', 'module.layers.0.residual_group.blocks.0.norm1.weight', 'module.layers.0.residual_group.blocks.0.norm1.bias', 'module.layers.0.residual_group.blocks.0.attn.relative_position_bias_table', 'module.layers.0.residual_group.blocks.0.attn.relative_position_index', 'module.layers.0.residual_group.blocks.0.attn.qkv.weight', 'module.layers.0.residual_group.blocks.0.attn.qkv.bias', 'module.layers.0.residual_group.blocks.0.attn.proj.weight', 'module.layers.0.

In [4]:
pth_data.keys()

odict_keys(['module.conv_f1.weight', 'module.conv_f1.bias', 'module.conv_f2.weight', 'module.conv_f2.bias', 'module.conv_f3.weight', 'module.conv_f3.bias', 'module.att_module_l.att1.weight', 'module.att_module_l.att1.bias', 'module.att_module_l.att2.weight', 'module.att_module_l.att2.bias', 'module.att_module_h.att1.weight', 'module.att_module_h.att1.bias', 'module.att_module_h.att2.weight', 'module.att_module_h.att2.bias', 'module.conv_first.weight', 'module.conv_first.bias', 'module.patch_embed.norm.weight', 'module.patch_embed.norm.bias', 'module.layers.0.residual_group.blocks.0.norm1.weight', 'module.layers.0.residual_group.blocks.0.norm1.bias', 'module.layers.0.residual_group.blocks.0.attn.relative_position_bias_table', 'module.layers.0.residual_group.blocks.0.attn.relative_position_index', 'module.layers.0.residual_group.blocks.0.attn.qkv.weight', 'module.layers.0.residual_group.blocks.0.attn.qkv.bias', 'module.layers.0.residual_group.blocks.0.attn.proj.weight', 'module.layers.0.