## Notes

Using virchow could be better since it is already pretrained in histology images. We just need to finetune that.

We need 256x256 images.

In this notebook I'll analyze what the checkpoint that they give us includes and what we get out of the trainings.

In [1]:
import torch
from collections import defaultdict
import numpy as np

# ---------------------------------------------------------
# 1. Load checkpoint safely
# ---------------------------------------------------------
def load_checkpoint(path):
    print(f"üìÇ Loading checkpoint: {path}")
    ckpt = torch.load(path, map_location="cpu")
    print(f"‚úî Loaded successfully.\n")
    return ckpt

# ---------------------------------------------------------
# 2. Print top-level structure
# ---------------------------------------------------------
def print_top_level(ckpt):
    print("==== üîç TOP-LEVEL KEYS ====")
    for k in ckpt.keys():
        print(" ‚Ä¢", k)
    print()

# ---------------------------------------------------------
# 3. Get model_state_dict and categorize keys
# ---------------------------------------------------------
def analyze_state_dict(ckpt):
    sd = ckpt.get("model_state_dict", {})
    print(f"==== üì¶ MODEL STATE DICT ({len(sd)} params) ====\n")

    categories = defaultdict(list)

    for k, v in sd.items():
        if k.startswith("encoder."):
            categories["encoder"].append(k)
        elif k.startswith("decoder.") or "head" in k.lower():
            categories["decoder"].append(k)
        else:
            categories["other"].append(k)

    for cat, keys in categories.items():
        print(f"--- {cat.upper()} ({len(keys)} keys) ---")
        for k in keys[:20]:            # print only first 20
            print("  -", k)
        if len(keys) > 20:
            print(f"  ... and {len(keys)-20} more")
        print()

    return categories

# ---------------------------------------------------------
# 4. Print parameter shapes
# ---------------------------------------------------------
def print_shapes(ckpt, filter_prefix=None, max_items=20):
    sd = ckpt.get("model_state_dict", {})
    print("==== üî¢ SHAPES ====\n")
    count = 0

    for k, v in sd.items():
        if filter_prefix and not k.startswith(filter_prefix):
            continue
        print(f"{k:60} {tuple(v.shape)}")
        count += 1
        if count >= max_items:
            print("... truncated ...\n")
            break

# ---------------------------------------------------------
# 5. Count total parameters
# ---------------------------------------------------------
def count_params(ckpt):
    sd = ckpt.get("model_state_dict", {})
    total = sum(np.prod(v.shape) for v in sd.values())
    print(f"==== üßÆ TOTAL PARAMETERS ====")
    print(f"Total: {total:,} parameters\n")
    return total

# ---------------------------------------------------------
# 6. Extract encoder-only weights
# ---------------------------------------------------------
def extract_encoder(ckpt):
    sd = ckpt.get("model_state_dict", {})
    encoder_sd = {k.replace("encoder.", ""): v for k, v in sd.items() if k.startswith("encoder.")}
    print(f"==== üéØ ENCODER EXTRACT ====")
    print(f"Extracted {len(encoder_sd)} encoder weights\n")
    return encoder_sd

# ---------------------------------------------------------
# 7. Compare two checkpoints
# ---------------------------------------------------------
def compare_checkpoints(ckpt1, ckpt2):
    sd1 = ckpt1.get("model_state_dict", {})
    sd2 = ckpt2.get("model_state_dict", {})

    print("==== üîç COMPARISON ====\n")

    keys1 = set(sd1.keys())
    keys2 = set(sd2.keys())

    print("Missing in ckpt2:")
    for k in sorted(keys1 - keys2):
        print("  -", k)
    print()

    print("Missing in ckpt1:")
    for k in sorted(keys2 - keys1):
        print("  -", k)
    print()

    print("==== Shape mismatches ====")
    for k in sorted(keys1 & keys2):
        if sd1[k].shape != sd2[k].shape:
            print(f"  {k}: {tuple(sd1[k].shape)}  vs  {tuple(sd2[k].shape)}")

    print()


In [3]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict


# ---------------------------------------------------------
# 0) Paths
# ---------------------------------------------------------
ckpt1_path = "/projectnb/ec500kb/projects/Fall_2025_Projects/Project_2/AI-guided-whole-slide-imaging-analysis/CellViT-plus-plus/checkpoints/Virchow/CellViT-Virchow-x40-AMP.pth"
ckpt2_path = "/projectnb/ec500kb/projects/Fall_2025_Projects/Project_2/AI-guided-whole-slide-imaging-analysis/ProcessedDataset/v1_40x_area20/patches_cellvit_p128_pannuke/logs_local/2025-11-15T150541_tcga_finetune_128_reheat/checkpoints/model_best.pth"


# ---------------------------------------------------------
# 1) Load checkpoint safely
# ---------------------------------------------------------
def load_checkpoint(path):
    print(f"üìÇ Loading checkpoint: {path}")
    ckpt = torch.load(path, map_location="cpu", weights_only=False)
    print(f"‚úî Loaded successfully.\n")
    return ckpt


# ---------------------------------------------------------
# 2) Print top-level structure
# ---------------------------------------------------------
def print_top_level(ckpt, name="Checkpoint"):
    print(f"==== üîç TOP-LEVEL KEYS ‚Äî {name} ====")
    for k in ckpt.keys():
        print(" ‚Ä¢", k)
    print()


# ---------------------------------------------------------
# 3) Categorize weights into encoder/decoder/other
# ---------------------------------------------------------
def analyze_state_dict(ckpt, name=""):
    sd = ckpt.get("model_state_dict", {})
    print(f"==== üì¶ STATE DICT ANALYSIS ({name}) ‚Äî {len(sd)} tensors ====\n")

    categories = defaultdict(list)

    for k, v in sd.items():
        if k.startswith("encoder."):
            categories["encoder"].append(k)
        elif "head" in k.lower():
            categories["head"].append(k)
        elif "decoder" in k.lower():
            categories["decoder"].append(k)
        else:
            categories["other"].append(k)

    # Print grouped keys
    for cat, keys in categories.items():
        print(f"--- {cat.upper()} ({len(keys)} keys) ---")
        for k in keys[:15]:
            print("  -", k)
        if len(keys) > 15:
            print(f"  ... +{len(keys)-15} more")
        print()

    return categories


# ---------------------------------------------------------
# 4) Print parameter shapes
# ---------------------------------------------------------
def print_shapes(ckpt, prefix=None, max_items=20):
    sd = ckpt.get("model_state_dict", {})
    print(f"==== üî¢ SHAPES (filter: {prefix}) ====\n")
    c = 0
    for k, v in sd.items():
        if prefix and not k.startswith(prefix):
            continue
        print(f"{k:60} {tuple(v.shape)}")
        c += 1
        if c >= max_items:
            print("... truncated ...\n")
            break


# ---------------------------------------------------------
# 5) Count parameters
# ---------------------------------------------------------
def count_params(ckpt, name=""):
    sd = ckpt.get("model_state_dict", {})
    total = sum(np.prod(v.shape) for v in sd.values())
    print(f"==== üßÆ PARAMETER COUNT ‚Äî {name} ====")
    print(f"{total:,} parameters.\n")
    return total


# ---------------------------------------------------------
# 6) Extract encoder-only weights
# ---------------------------------------------------------
def extract_encoder(ckpt):
    sd = ckpt.get("model_state_dict", {})
    enc = {k.replace("encoder.", ""): v for k, v in sd.items() if k.startswith("encoder.")}
    print(f"üéØ Extracted {len(enc)} encoder weights.\n")
    return enc


# ---------------------------------------------------------
# 7) Compare two checkpoints in depth
# ---------------------------------------------------------
def compare_checkpoints(ck1, ck2):
    sd1 = ck1.get("model_state_dict", {})
    sd2 = ck2.get("model_state_dict", {})

    keys1 = set(sd1.keys())
    keys2 = set(sd2.keys())

    print("==== üîç COMPARISON ‚Äî Missing Keys ====\n")

    missing_in_2 = sorted(keys1 - keys2)
    missing_in_1 = sorted(keys2 - keys1)

    print(f"Missing in checkpoint2 ({len(missing_in_2)}):")
    for k in missing_in_2[:15]:
        print("  -", k)
    if len(missing_in_2) > 15:
        print(f"  ... +{len(missing_in_2)-15}\n")

    print(f"Missing in checkpoint1 ({len(missing_in_1)}):")
    for k in missing_in_1[:15]:
        print("  -", k)
    if len(missing_in_1) > 15:
        print(f"  ... +{len(missing_in_1)-15}\n")

    print("\n==== üîÑ SHAPE MISMATCHES ====\n")
    for k in sorted(keys1 & keys2):
        if sd1[k].shape != sd2[k].shape:
            print(f"{k:50}  {tuple(sd1[k].shape)} ‚Üí {tuple(sd2[k].shape)}")

    print()


# ---------------------------------------------------------
# 8) Weight similarity (L2 difference)
# ---------------------------------------------------------
def compute_layer_similarity(ck1, ck2):
    sd1 = ck1.get("model_state_dict", {})
    sd2 = ck2.get("model_state_dict", {})

    print("==== üìâ LAYER SIMILARITY (L2 norm) ====\n")
    overlaps = sorted(set(sd1.keys()) & set(sd2.keys()))

    diffs = []
    for k in overlaps:
        v1 = sd1[k].float().view(-1)
        v2 = sd2[k].float().view(-1)
        diff = torch.norm(v1 - v2).item()
        diffs.append((k, diff))

    diffs_sorted = sorted(diffs, key=lambda x: -x[1])[:30]

    for k, d in diffs_sorted:
        print(f"{k:50}  Œî={d:.4f}")

    print()


# ---------------------------------------------------------
# 9) Plot weight histograms of encoder
# ---------------------------------------------------------
def plot_histograms(ck1, ck2, layer="encoder.pos_embed"):
    sd1 = ck1["model_state_dict"]
    sd2 = ck2["model_state_dict"]

    if layer not in sd1 or layer not in sd2:
        print(f"Layer {layer} not found in both checkpoints.")
        return

    w1 = sd1[layer].cpu().numpy().flatten()
    w2 = sd2[layer].cpu().numpy().flatten()

    plt.figure(figsize=(12,5))
    plt.hist(w1, bins=100, alpha=0.5, label="ckpt1")
    plt.hist(w2, bins=100, alpha=0.5, label="ckpt2")
    plt.title(f"Weight Histogram ‚Äî {layer}")
    plt.legend()
    plt.show()



# =========================================================
#               ANALYSIS STARTS HERE
# =========================================================

ck1 = load_checkpoint(ckpt1_path)
ck2 = load_checkpoint(ckpt2_path)

print_top_level(ck1, "Virchow Official")
print_top_level(ck2, "Your Fine-tuned Model")

cats1 = analyze_state_dict(ck1, "Virchow Official")
cats2 = analyze_state_dict(ck2, "Your Model")

count_params(ck1, "Virchow Official")
count_params(ck2, "Your Model")

print_shapes(ck1, "encoder.")
print_shapes(ck2, "encoder.")

compare_checkpoints(ck1, ck2)

compute_layer_similarity(ck1, ck2)

# Try comparing pos_embed histograms (often very telling)
plot_histograms(ck1, ck2, "encoder.pos_embed")


üìÇ Loading checkpoint: /projectnb/ec500kb/projects/Fall_2025_Projects/Project_2/AI-guided-whole-slide-imaging-analysis/CellViT-plus-plus/checkpoints/Virchow/CellViT-Virchow-x40-AMP.pth
‚úî Loaded successfully.

üìÇ Loading checkpoint: /projectnb/ec500kb/projects/Fall_2025_Projects/Project_2/AI-guided-whole-slide-imaging-analysis/ProcessedDataset/v1_40x_area20/patches_cellvit_p128_pannuke/logs_local/2025-11-15T150541_tcga_finetune_128_reheat/checkpoints/model_best.pth
‚úî Loaded successfully.

==== üîç TOP-LEVEL KEYS ‚Äî Virchow Official ====
 ‚Ä¢ arch
 ‚Ä¢ epoch
 ‚Ä¢ model_state_dict
 ‚Ä¢ config
 ‚Ä¢ scaler_state_dict

==== üîç TOP-LEVEL KEYS ‚Äî Your Fine-tuned Model ====
 ‚Ä¢ arch
 ‚Ä¢ epoch
 ‚Ä¢ model_state_dict
 ‚Ä¢ optimizer_state_dict
 ‚Ä¢ scheduler_state_dict
 ‚Ä¢ best_metric
 ‚Ä¢ best_epoch
 ‚Ä¢ config
 ‚Ä¢ wandb_id
 ‚Ä¢ logdir
 ‚Ä¢ run_name
 ‚Ä¢ scaler_state_dict

==== üì¶ STATE DICT ANALYSIS (Virchow Official) ‚Äî 743 tensors ====

--- ENCODER (456 keys) ---
  - encoder

RuntimeError: The size of tensor a (752640) must match the size of tensor b (983040) at non-singleton dimension 0

In [4]:
ckpt1_path = "/projectnb/ec500kb/projects/Fall_2025_Projects/Project_2/AI-guided-whole-slide-imaging-analysis/CellViT-plus-plus/checkpoints/SAM/encoder_only_CellViT-SAM-H-x40-AMP.pth"
# =========================================================
#               ANALYSIS STARTS HERE
# =========================================================

ck1 = load_checkpoint(ckpt1_path)
ck2 = load_checkpoint(ckpt2_path)

print_top_level(ck1, "Virchow Official")
print_top_level(ck2, "Your Fine-tuned Model")

cats1 = analyze_state_dict(ck1, "Virchow Official")
cats2 = analyze_state_dict(ck2, "Your Model")

count_params(ck1, "Virchow Official")
count_params(ck2, "Your Model")

print_shapes(ck1, "encoder.")
print_shapes(ck2, "encoder.")

compare_checkpoints(ck1, ck2)

compute_layer_similarity(ck1, ck2)

# Try comparing pos_embed histograms (often very telling)
plot_histograms(ck1, ck2, "encoder.pos_embed")

üìÇ Loading checkpoint: /projectnb/ec500kb/projects/Fall_2025_Projects/Project_2/AI-guided-whole-slide-imaging-analysis/CellViT-plus-plus/checkpoints/SAM/encoder_only_CellViT-SAM-H-x40-AMP.pth
‚úî Loaded successfully.

üìÇ Loading checkpoint: /projectnb/ec500kb/projects/Fall_2025_Projects/Project_2/AI-guided-whole-slide-imaging-analysis/ProcessedDataset/v1_40x_area20/patches_cellvit_p128_pannuke/logs_local/2025-11-15T150541_tcga_finetune_128_reheat/checkpoints/model_best.pth
‚úî Loaded successfully.

==== üîç TOP-LEVEL KEYS ‚Äî Virchow Official ====
 ‚Ä¢ pos_embed
 ‚Ä¢ patch_embed.proj.weight
 ‚Ä¢ patch_embed.proj.bias
 ‚Ä¢ blocks.0.norm1.weight
 ‚Ä¢ blocks.0.norm1.bias
 ‚Ä¢ blocks.0.attn.rel_pos_h
 ‚Ä¢ blocks.0.attn.rel_pos_w
 ‚Ä¢ blocks.0.attn.qkv.weight
 ‚Ä¢ blocks.0.attn.qkv.bias
 ‚Ä¢ blocks.0.attn.proj.weight
 ‚Ä¢ blocks.0.attn.proj.bias
 ‚Ä¢ blocks.0.norm2.weight
 ‚Ä¢ blocks.0.norm2.bias
 ‚Ä¢ blocks.0.mlp.lin1.weight
 ‚Ä¢ blocks.0.mlp.lin1.bias
 ‚Ä¢ blocks.0.mlp.lin2.weight
 ‚

KeyError: 'model_state_dict'

## Extract encoder keys from checkpoint!

In [6]:
import torch
from pathlib import Path

def extract_encoder_only(root_dir, checkpoint_name):
    root = Path(root_dir)
    input_path = root / checkpoint_name

    if not input_path.exists():
        raise FileNotFoundError(f"‚ùå Checkpoint not found:\n{input_path}")

    ckpt = torch.load(input_path, map_location="cpu")

    if "model_state_dict" not in ckpt:
        raise ValueError("‚ùå Checkpoint does not contain model_state_dict!")

    full_sd = ckpt["model_state_dict"]

    # ---- Extract encoder.* keys ----
    encoder_only = {}
    for k, v in full_sd.items():
        if k.startswith("encoder."):
            new_k = k.replace("encoder.", "")   # SAME FORMAT CellViT expects
            encoder_only[new_k] = v

    output_path = root / f"encoder_only_{checkpoint_name}"

    # Save ONLY the pure encoder state dict
    torch.save(encoder_only, output_path)

    print("====================================")
    print(f"üì¶ Input checkpoint : {input_path}")
    print(f"üîë Extracted keys   : {len(encoder_only)}")
    print(f"üíæ Saved encoder to : {output_path}")
    print("====================================")

    return output_path


In [7]:
root_dir = "/projectnb/ec500kb/projects/Fall_2025_Projects/Project_2/AI-guided-whole-slide-imaging-analysis/CellViT-plus-plus/checkpoints/SAM"
checkpoint_name= "CellViT-SAM-H-x40-AMP.pth"
extract_encoder_only(root_dir, checkpoint_name)


üì¶ Input checkpoint : /projectnb/ec500kb/projects/Fall_2025_Projects/Project_2/AI-guided-whole-slide-imaging-analysis/CellViT-plus-plus/checkpoints/SAM/CellViT-SAM-H-x40-AMP.pth
üîë Extracted keys   : 457
üíæ Saved encoder to : /projectnb/ec500kb/projects/Fall_2025_Projects/Project_2/AI-guided-whole-slide-imaging-analysis/CellViT-plus-plus/checkpoints/SAM/encoder_only_CellViT-SAM-H-x40-AMP.pth


PosixPath('/projectnb/ec500kb/projects/Fall_2025_Projects/Project_2/AI-guided-whole-slide-imaging-analysis/CellViT-plus-plus/checkpoints/SAM/encoder_only_CellViT-SAM-H-x40-AMP.pth')