In [8]:
# python standard library
import contextlib
import io
import os
import threading
import warnings

import timm
import torch
from safetensors.torch import load_file as safe_load

warnings.filterwarnings("ignore", message=".*_MultiProcessingDataLoaderIter.*")

In [9]:
def freeze_head(model, model_name):
	if "vit" in model_name:
		for name, param in model.named_parameters():
			if name.startswith("head"):  # ViT classifier head
				param.requires_grad = True
			else:
				param.requires_grad = False
	elif "resnet" in model_name:
		for name, param in model.named_parameters():
			if "fc" in name:  # 'fc' is the final classifier for resnets
				param.requires_grad = False
			else:
				param.requires_grad = False
	else:
		raise NotImplementedError(
			f"Freeze head received unknown model type {model_name}"
		)
	return model

def load_masked_state_dict(state_dict, model, verbose = False):
	# --- 1. Clean state_dict if it contains mask-based weights ---
	cleaned_state = {}
	for k, v in state_dict.items():
		if k.endswith(".weight_orig"):
			base = k[:-len(".weight_orig")]
			mask_key = base + ".weight_mask"
			if mask_key in state_dict:
				cleaned_state[base + ".weight"] = (
					state_dict[k] * state_dict[mask_key]
				)
			else:
				cleaned_state[base + ".weight"] = v
		elif k.endswith(".weight_mask"):
			continue  # skip mask tensors themselves
		else:
			cleaned_state[k] = v

	# --- 2. Drop classification head weights (fc/head/classifier) ---
	for bad_key in ["fc.weight", "fc.bias", "head.weight", "head.bias", "classifier.weight", "classifier.bias"]:
		cleaned_state.pop(bad_key, None)

	# --- 3. Load silently (suppress size mismatch noise) ---
	with contextlib.redirect_stderr(io.StringIO()):
		missing, unexpected = model.load_state_dict(cleaned_state, strict=False)

	# --- 4. Optional logging if you want visibility ---
	if verbose:
		ignored = [k for k in missing if any(x in k for x in ["fc", "head", "classifier"])]
		if ignored:
			print(f"Ignored classification head weights: {ignored}")
		else:
			print("Model state restored successfully.")

	return model, missing, unexpected


def create_model_with_timeout(
	model_name: str,
	num_classes: int,
	device: torch.device,
	timeout: int = 10,
	verbose: bool = False,
):
	"""Create a TIMM model safely with timeout, local weight loading, and pruning mask support."""
	result = {}

	def target():
		try:
			model = timm.create_model(model_name, pretrained=False, num_classes=num_classes)

			# --- Find local weight file ---
			weights_dir = "../original_weights"
			weight_path = None
			for ext in (".safetensors", ".pth", ".pt"):
				path = os.path.join(weights_dir, f"{model_name}{ext}")
				if os.path.exists(path):
					weight_path = path
					break
			if weight_path is None:
				raise FileNotFoundError(f"No local weights found for {model_name} in {weights_dir}")

			# --- Load weights (supporting masks) ---
			try:
				state_dict = (
					safe_load(weight_path)
					if weight_path.endswith(".safetensors")
					else torch.load(weight_path, map_location="cpu")
				)
			except Exception as e:
				raise RuntimeError(f"Failed to load {weight_path}: {e}")

			model, missing, unexpected = load_masked_state_dict(state_dict, model, verbose)

			result["model"] = model.to(device)

		except Exception as e:
			result["error"] = str(e)

	# --- Run model creation in a separate thread with timeout ---
	thread = threading.Thread(target=target, daemon=True)
	thread.start()
	thread.join(timeout)

	if thread.is_alive():
		raise TimeoutError(f"Creating model '{model_name}' timed out after {timeout}s.")

	# --- Fallback path ---
	if "error" in result:
		print(f"Local load failed: {result['error']}")
		print("Retrying with pretrained=True via TIMM ...")
		model = timm.create_model(model_name, pretrained=True, num_classes=num_classes)
		result["model"] = model.to(device)

	# --- Optional cleanup ---
	if "model" not in result:
		raise RuntimeError(f"Model creation failed for '{model_name}'.")

	return result["model"]


In [10]:
# === Model Setup ===

# Default list for full experiments
model_names_default = [
	"resnet18",
	"resnet34",
	"resnet50",
	"resnet101",
	"resnet152", # doesn't fit in my machine's gpu
	"vit_base_patch32_224",
	"vit_base_patch16_224", # doesn't fit in my machine's gpu
]

# Smaller list for quick debugging
model_names_debug = [
	"resnet18",
	"resnet34",
	"vit_base_patch32_224",
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
# ViT
model = create_model_with_timeout("vit_base_patch32_224", 102, device = device)

print(model.blocks)
# for name, module in model.named_modules():
#     print(name, module)

The history saving thread hit an unexpected error (OperationalError('database or disk is full')).History will not be written to the database.
Sequential(
  (0): Block(
    (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (attn): Attention(
      (qkv): Linear(in_features=768, out_features=2304, bias=True)
      (q_norm): Identity()
      (k_norm): Identity()
      (attn_drop): Dropout(p=0.0, inplace=False)
      (norm): Identity()
      (proj): Linear(in_features=768, out_features=768, bias=True)
      (proj_drop): Dropout(p=0.0, inplace=False)
    )
    (ls1): Identity()
    (drop_path1): Identity()
    (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (mlp): Mlp(
      (fc1): Linear(in_features=768, out_features=3072, bias=True)
      (act): GELU(approximate='none')
      (drop1): Dropout(p=0.0, inplace=False)
      (norm): Identity()
      (fc2): Linear(in_features=3072, out_features=768, bias=True)
      (drop2): Dropout(p=0.0, inplace=False)
    

In [12]:
# Resnet
model = create_model_with_timeout("resnet18", 102, device = device)

for name, module in model.named_modules():
    print(name, module)

 ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, 