In [5]:
import sys
import os
sys.path.append(os.path.abspath(".."))
import torch
from models.change_classifier import ChangeClassifier
model = ChangeClassifier(weights=None)
ckpt = torch.load("../pretrained_models/checkpoint_086.pth", map_location="cpu", weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])   # <- key bit
model.eval()
torch.save(model.state_dict(), "../pretrained_models/model_entire.pth")

In [6]:
import torch
from thop import profile, clever_format
from models.change_classifier import ChangeClassifier

# Dummy input
ref = torch.randn(1, 3, 256, 256)
test = torch.randn(1, 3, 256, 256)

# Model init
model = ChangeClassifier(
    bkbn_name="efficientnet_b4",
    weights=None,
    output_layer_bkbn="3",
    freeze_backbone=False
)

model.eval()

# Forward pass test
with torch.no_grad():
    out = model(ref, test)

print("✅ Forward pass OK")
print("Test shape:", test.shape)
print("Ref shape:", ref.shape)

print("Output shape:", out.shape)

# Params count
total_params = sum(p.numel() for p in model.parameters())
print(f"Total Parameters: {total_params:,}")

# FLOPs + MACs
flops, params = profile(model, inputs=(ref, test))
flops, params = clever_format([flops, params], "%.3f")
print(f"FLOPs: {flops}")
print(f"THOP Params: {params}")


✅ Forward pass OK
Test shape: torch.Size([1, 3, 256, 256])
Ref shape: torch.Size([1, 3, 256, 256])
Output shape: torch.Size([1, 1, 256, 256])
Total Parameters: 285,803
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.instancenorm.InstanceNorm2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_prelu() for <class 'torch.nn.modules.activation.PReLU'>.
[INFO] Register count_upsample() for <class 'torch.nn.modules.upsampling.Upsample'>.




FLOPs: 1.633G
THOP Params: 285.803K


In [7]:
import torch
from thop import profile, clever_format
from models.change_classifier import ChangeClassifier

ref = torch.randn(1, 3, 256, 256)
test = torch.randn(1, 3, 256, 256)

model = ChangeClassifier(
    bkbn_name="efficientnet_b4",
    weights=None,
    output_layer_bkbn="3",
    freeze_backbone=False
)

model.eval()
# Forward test
with torch.no_grad():
    out = model(ref, test)

print("✅ Forward pass OK")
print("Test shape:", test.shape)
print("Ref shape:", ref.shape)
print("Output shape:", out.shape)

# Param count
#total_params = sum(p.numel() for p in model.parameters())
#print(f"🧠 Total Parameters: {total_params:,}")

# FLOPs + MACs
flops, params = profile(model, inputs=(ref, test))
flops, params = clever_format([flops, params], "%.3f")
print(f"⚙️ FLOPs: {flops}")
print(f"📦 THOP Params: {params}")

# ✅ Print the full model architecture
print("🔍 Full Model Structure:\n")
print(model)
print("\n" + "="*60 + "\n")

✅ Forward pass OK
Test shape: torch.Size([1, 3, 256, 256])
Ref shape: torch.Size([1, 3, 256, 256])
Output shape: torch.Size([1, 1, 256, 256])
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.instancenorm.InstanceNorm2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_prelu() for <class 'torch.nn.modules.activation.PReLU'>.
[INFO] Register count_upsample() for <class 'torch.nn.modules.upsampling.Upsample'>.
⚙️ FLOPs: 1.633G
📦 THOP Params: 285.803K
🔍 Full Model Structure:

ChangeClassifier(
  (_retina): RetinaSimBlock(
    (dog): Conv2d(3, 3, kernel_size=(15, 15), stride=(1, 1), padding=(7, 7), groups=3, bias=False)
    (adapt): InstanceNorm2d(3, eps=1e-05, 

In [None]:
from torchviz import make_dot
# example with dummy input (make sure inputs match your model)
t1_input = torch.randn(1, 3, 256, 256)
t2_input = torch.randn(1, 3, 256, 256)
out = model(t1_input, t2_input)
# Create graph
dot = make_dot(out, params=dict(model.named_parameters()))
# Save to file
dot.format = 'png'
dot.render('../pretrained_models/model_graph')


'model_graph.png'

In [None]:
t1_input = torch.randn(1, 3, 256, 256)
t2_input = torch.randn(1, 3, 256, 256)
model.eval()
torch.onnx.export(
    model,
    (t1_input, t2_input),  # <--- pack as a tuple
    "../pretrained_models/model.onnx",
    opset_version=11,
    input_names=["img1", "img2"],
    output_names=["pred"]
)

In [8]:
for name, module in model.named_modules():
    print(name)


_retina
_retina.dog
_retina.adapt
_retina.act
_backbone
_backbone.0
_backbone.0.0
_backbone.0.1
_backbone.0.2
_backbone.1
_backbone.1.0
_backbone.1.0.block
_backbone.1.0.block.0
_backbone.1.0.block.0.0
_backbone.1.0.block.0.1
_backbone.1.0.block.0.2
_backbone.1.0.block.1
_backbone.1.0.block.1.avgpool
_backbone.1.0.block.1.fc1
_backbone.1.0.block.1.fc2
_backbone.1.0.block.1.activation
_backbone.1.0.block.1.scale_activation
_backbone.1.0.block.2
_backbone.1.0.block.2.0
_backbone.1.0.block.2.1
_backbone.1.0.stochastic_depth
_backbone.1.1
_backbone.1.1.block
_backbone.1.1.block.0
_backbone.1.1.block.0.0
_backbone.1.1.block.0.1
_backbone.1.1.block.0.2
_backbone.1.1.block.1
_backbone.1.1.block.1.avgpool
_backbone.1.1.block.1.fc1
_backbone.1.1.block.1.fc2
_backbone.1.1.block.1.activation
_backbone.1.1.block.1.scale_activation
_backbone.1.1.block.2
_backbone.1.1.block.2.0
_backbone.1.1.block.2.1
_backbone.1.1.stochastic_depth
_backbone.2
_backbone.2.0
_backbone.2.0.block
_backbone.2.0.block.0

In [13]:
# state_dict = torch.load("../pretrained_models/checkpoint_086.pth", map_location="cpu", weights_only=False)
# for k, v in state_dict.items():
#    if len(v.shape) == 4:  # Conv weight tensors
#        print(f"{k}: {v.shape}")
#print("all")

ckpt = torch.load("../pretrained_models/checkpoint_086.pth", map_location="cpu", weights_only=False)
state_dict = model.load_state_dict(ckpt["model_state_dict"])   # <- key bit

for k, v in state_dict.items():
    print(f"{k}: {v.shape}")

RuntimeError: Error(s) in loading state_dict for ChangeClassifier:
	Missing key(s) in state_dict: "total_ops", "total_params", "_retina.total_ops", "_retina.total_params", "_retina.act.total_ops", "_retina.act.total_params", "_backbone.total_ops", "_backbone.total_params", "_backbone.0.total_ops", "_backbone.0.total_params", "_backbone.0.2.total_ops", "_backbone.0.2.total_params", "_backbone.1.0.total_ops", "_backbone.1.0.total_params", "_backbone.1.0.block.0.total_ops", "_backbone.1.0.block.0.total_params", "_backbone.1.0.block.0.2.total_ops", "_backbone.1.0.block.0.2.total_params", "_backbone.1.0.block.1.total_ops", "_backbone.1.0.block.1.total_params", "_backbone.1.0.block.1.activation.total_ops", "_backbone.1.0.block.1.activation.total_params", "_backbone.1.0.block.1.scale_activation.total_ops", "_backbone.1.0.block.1.scale_activation.total_params", "_backbone.1.0.block.2.total_ops", "_backbone.1.0.block.2.total_params", "_backbone.1.0.stochastic_depth.total_ops", "_backbone.1.0.stochastic_depth.total_params", "_backbone.1.1.total_ops", "_backbone.1.1.total_params", "_backbone.1.1.block.0.total_ops", "_backbone.1.1.block.0.total_params", "_backbone.1.1.block.0.2.total_ops", "_backbone.1.1.block.0.2.total_params", "_backbone.1.1.block.1.total_ops", "_backbone.1.1.block.1.total_params", "_backbone.1.1.block.1.activation.total_ops", "_backbone.1.1.block.1.activation.total_params", "_backbone.1.1.block.1.scale_activation.total_ops", "_backbone.1.1.block.1.scale_activation.total_params", "_backbone.1.1.block.2.total_ops", "_backbone.1.1.block.2.total_params", "_backbone.1.1.stochastic_depth.total_ops", "_backbone.1.1.stochastic_depth.total_params", "_backbone.2.0.total_ops", "_backbone.2.0.total_params", "_backbone.2.0.block.0.total_ops", "_backbone.2.0.block.0.total_params", "_backbone.2.0.block.0.2.total_ops", "_backbone.2.0.block.0.2.total_params", "_backbone.2.0.block.1.total_ops", "_backbone.2.0.block.1.total_params", "_backbone.2.0.block.1.2.total_ops", "_backbone.2.0.block.1.2.total_params", "_backbone.2.0.block.2.total_ops", "_backbone.2.0.block.2.total_params", "_backbone.2.0.block.2.activation.total_ops", "_backbone.2.0.block.2.activation.total_params", "_backbone.2.0.block.2.scale_activation.total_ops", "_backbone.2.0.block.2.scale_activation.total_params", "_backbone.2.0.block.3.total_ops", "_backbone.2.0.block.3.total_params", "_backbone.2.0.stochastic_depth.total_ops", "_backbone.2.0.stochastic_depth.total_params", "_backbone.2.1.total_ops", "_backbone.2.1.total_params", "_backbone.2.1.block.0.total_ops", "_backbone.2.1.block.0.total_params", "_backbone.2.1.block.0.2.total_ops", "_backbone.2.1.block.0.2.total_params", "_backbone.2.1.block.1.total_ops", "_backbone.2.1.block.1.total_params", "_backbone.2.1.block.1.2.total_ops", "_backbone.2.1.block.1.2.total_params", "_backbone.2.1.block.2.total_ops", "_backbone.2.1.block.2.total_params", "_backbone.2.1.block.2.activation.total_ops", "_backbone.2.1.block.2.activation.total_params", "_backbone.2.1.block.2.scale_activation.total_ops", "_backbone.2.1.block.2.scale_activation.total_params", "_backbone.2.1.block.3.total_ops", "_backbone.2.1.block.3.total_params", "_backbone.2.1.stochastic_depth.total_ops", "_backbone.2.1.stochastic_depth.total_params", "_backbone.2.2.total_ops", "_backbone.2.2.total_params", "_backbone.2.2.block.0.total_ops", "_backbone.2.2.block.0.total_params", "_backbone.2.2.block.0.2.total_ops", "_backbone.2.2.block.0.2.total_params", "_backbone.2.2.block.1.total_ops", "_backbone.2.2.block.1.total_params", "_backbone.2.2.block.1.2.total_ops", "_backbone.2.2.block.1.2.total_params", "_backbone.2.2.block.2.total_ops", "_backbone.2.2.block.2.total_params", "_backbone.2.2.block.2.activation.total_ops", "_backbone.2.2.block.2.activation.total_params", "_backbone.2.2.block.2.scale_activation.total_ops", "_backbone.2.2.block.2.scale_activation.total_params", "_backbone.2.2.block.3.total_ops", "_backbone.2.2.block.3.total_params", "_backbone.2.2.stochastic_depth.total_ops", "_backbone.2.2.stochastic_depth.total_params", "_backbone.2.3.total_ops", "_backbone.2.3.total_params", "_backbone.2.3.block.0.total_ops", "_backbone.2.3.block.0.total_params", "_backbone.2.3.block.0.2.total_ops", "_backbone.2.3.block.0.2.total_params", "_backbone.2.3.block.1.total_ops", "_backbone.2.3.block.1.total_params", "_backbone.2.3.block.1.2.total_ops", "_backbone.2.3.block.1.2.total_params", "_backbone.2.3.block.2.total_ops", "_backbone.2.3.block.2.total_params", "_backbone.2.3.block.2.activation.total_ops", "_backbone.2.3.block.2.activation.total_params", "_backbone.2.3.block.2.scale_activation.total_ops", "_backbone.2.3.block.2.scale_activation.total_params", "_backbone.2.3.block.3.total_ops", "_backbone.2.3.block.3.total_params", "_backbone.2.3.stochastic_depth.total_ops", "_backbone.2.3.stochastic_depth.total_params", "_backbone.3.0.total_ops", "_backbone.3.0.total_params", "_backbone.3.0.block.0.total_ops", "_backbone.3.0.block.0.total_params", "_backbone.3.0.block.0.2.total_ops", "_backbone.3.0.block.0.2.total_params", "_backbone.3.0.block.1.total_ops", "_backbone.3.0.block.1.total_params", "_backbone.3.0.block.1.2.total_ops", "_backbone.3.0.block.1.2.total_params", "_backbone.3.0.block.2.total_ops", "_backbone.3.0.block.2.total_params", "_backbone.3.0.block.2.activation.total_ops", "_backbone.3.0.block.2.activation.total_params", "_backbone.3.0.block.2.scale_activation.total_ops", "_backbone.3.0.block.2.scale_activation.total_params", "_backbone.3.0.block.3.total_ops", "_backbone.3.0.block.3.total_params", "_backbone.3.0.stochastic_depth.total_ops", "_backbone.3.0.stochastic_depth.total_params", "_backbone.3.1.total_ops", "_backbone.3.1.total_params", "_backbone.3.1.block.0.total_ops", "_backbone.3.1.block.0.total_params", "_backbone.3.1.block.0.2.total_ops", "_backbone.3.1.block.0.2.total_params", "_backbone.3.1.block.1.total_ops", "_backbone.3.1.block.1.total_params", "_backbone.3.1.block.1.2.total_ops", "_backbone.3.1.block.1.2.total_params", "_backbone.3.1.block.2.total_ops", "_backbone.3.1.block.2.total_params", "_backbone.3.1.block.2.activation.total_ops", "_backbone.3.1.block.2.activation.total_params", "_backbone.3.1.block.2.scale_activation.total_ops", "_backbone.3.1.block.2.scale_activation.total_params", "_backbone.3.1.block.3.total_ops", "_backbone.3.1.block.3.total_params", "_backbone.3.1.stochastic_depth.total_ops", "_backbone.3.1.stochastic_depth.total_params", "_backbone.3.2.total_ops", "_backbone.3.2.total_params", "_backbone.3.2.block.0.total_ops", "_backbone.3.2.block.0.total_params", "_backbone.3.2.block.0.2.total_ops", "_backbone.3.2.block.0.2.total_params", "_backbone.3.2.block.1.total_ops", "_backbone.3.2.block.1.total_params", "_backbone.3.2.block.1.2.total_ops", "_backbone.3.2.block.1.2.total_params", "_backbone.3.2.block.2.total_ops", "_backbone.3.2.block.2.total_params", "_backbone.3.2.block.2.activation.total_ops", "_backbone.3.2.block.2.activation.total_params", "_backbone.3.2.block.2.scale_activation.total_ops", "_backbone.3.2.block.2.scale_activation.total_params", "_backbone.3.2.block.3.total_ops", "_backbone.3.2.block.3.total_params", "_backbone.3.2.stochastic_depth.total_ops", "_backbone.3.2.stochastic_depth.total_params", "_backbone.3.3.total_ops", "_backbone.3.3.total_params", "_backbone.3.3.block.0.total_ops", "_backbone.3.3.block.0.total_params", "_backbone.3.3.block.0.2.total_ops", "_backbone.3.3.block.0.2.total_params", "_backbone.3.3.block.1.total_ops", "_backbone.3.3.block.1.total_params", "_backbone.3.3.block.1.2.total_ops", "_backbone.3.3.block.1.2.total_params", "_backbone.3.3.block.2.total_ops", "_backbone.3.3.block.2.total_params", "_backbone.3.3.block.2.activation.total_ops", "_backbone.3.3.block.2.activation.total_params", "_backbone.3.3.block.2.scale_activation.total_ops", "_backbone.3.3.block.2.scale_activation.total_params", "_backbone.3.3.block.3.total_ops", "_backbone.3.3.block.3.total_params", "_backbone.3.3.stochastic_depth.total_ops", "_backbone.3.3.stochastic_depth.total_params", "_first_mix.total_ops", "_first_mix.total_params", "_first_mix._mixing.total_ops", "_first_mix._mixing.total_params", "_first_mix._linear.total_ops", "_first_mix._linear.total_params", "_mixing_mask.total_ops", "_mixing_mask.total_params", "_mixing_mask.0.total_ops", "_mixing_mask.0.total_params", "_mixing_mask.0._mixing.total_ops", "_mixing_mask.0._mixing.total_params", "_mixing_mask.0._linear.total_ops", "_mixing_mask.0._linear.total_params", "_mixing_mask.1.total_ops", "_mixing_mask.1.total_params", "_mixing_mask.1._mixing.total_ops", "_mixing_mask.1._mixing.total_params", "_mixing_mask.1._linear.total_ops", "_mixing_mask.1._linear.total_params", "_mixing_mask.2.total_ops", "_mixing_mask.2.total_params", "_up.total_ops", "_up.total_params", "_up.0.total_ops", "_up.0.total_params", "_up.1.total_ops", "_up.1.total_params", "_up.2.total_ops", "_up.2.total_params", "_classify.total_ops", "_classify.total_params", "_classify._linears.2.1.total_ops", "_classify._linears.2.1.total_params". 