In [6]:
# python standard library
import itertools
import json
import os
import threading
from abc import ABC, abstractmethod
from collections.abc import Sequence
from enum import Enum, auto

# yes i like static typing why do you ask?
from typing import Any

# not python standard library
import matplotlib.pyplot as plt
import numpy as np
import timm
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.optim as optim
from baseline_model import get_flowers_dataloaders, set_seed, train_epoch, validate
from connection_test import hugging_face_connectivity_test
from mmcv.ops import DeformConv2d
from safetensors.torch import load_file as safe_load
from tqdm.notebook import tqdm


Dataset found at: ../data\flowers-102
Loading Flowers102 dataset...


  from pkg_resources import packaging  # type: ignore[attr-defined]


In [8]:
def freeze_head(model, model_name):
	if "vit" in model_name:
		for name, param in model.named_parameters():
			if name.startswith("head"):  # ViT classifier head
				param.requires_grad = True
			else:
				param.requires_grad = False
	elif "resnet" in model_name:
		for name, param in model.named_parameters():
			if "fc" in name:  # 'fc' is the final classifier for resnets
				param.requires_grad = False
			else:
				param.requires_grad = False
	else:
		raise NotImplementedError(
			f"Freeze head received unknown model type {model_name}"
		)
	return model

def create_model_with_timeout(model_name, num_classes, device, timeout=10):
	"""
	Create a timm model safely with timeout and fallback logic.
	Tries local weights first, then checks Hugging Face and retries online if accessible.
	"""
	model_container = {}

	def target():
		try:
			model = timm.create_model(
				model_name, pretrained=False, num_classes=num_classes
			)

			# --- Locate local weights ---
			weights_dir = os.path.join("..", "original_weights")
			safe_path = os.path.join(weights_dir, f"{model_name}.safetensors")
			pth_path = os.path.join(weights_dir, f"{model_name}.pth")

			if os.path.exists(safe_path):
				weight_path = safe_path
			elif os.path.exists(pth_path):
				weight_path = pth_path
			else:
				raise FileNotFoundError(
					f"No local weight file found for '{model_name}' in '{weights_dir}'"
				)

			# --- Load local weights safely ---
			try:
				if weight_path.endswith(".safetensors"):
					state_dict = safe_load(weight_path)
				else:
					# PyTorch .pth / .pt file
					state_dict = torch.load(
						weight_path, map_location="cpu", weights_only=False
					)
			except Exception as e:
				raise RuntimeError(
					f"Error loading local weight file '{weight_path}': {e}"
				) from e

			# --- Apply weights ---
			try:
				checkpoint = state_dict.copy()
				for key in list(checkpoint.keys()):
					if key.startswith("fc.") or key.startswith("head."):
						del checkpoint[key]

				# Now load safely
				missing, unexpected = model.load_state_dict(checkpoint, strict=False)
				print(
					f"Ignored missing keys: {missing}"
				)  # should include 'fc.weight', 'fc.bias'
				print(f"Ignored unexpected keys: {unexpected}")
			except Exception as e:
				raise RuntimeError(f"Weight mismatch for '{model_name}': {e}")

			# --- Move to device ---
			model_container["model"] = model.to(device)

		except Exception as e:
			model_container["error"] = str(e)

	# --- Run model creation in separate thread with timeout ---
	thread = threading.Thread(target=target)
	thread.start()
	thread.join(timeout)

	if thread.is_alive():
		raise TimeoutError(f"Creating model '{model_name}' timed out after {timeout}s.")

	# --- Handle failure: fallback path ---
	if "error" in model_container:
		print(f"Local load failed: {model_container['error']}")

		try:
			print("\n")
			hf_status = hugging_face_connectivity_test()
		except Exception as e:
			raise RuntimeError(f"Hugging Face connectivity test failed to run: {e}") from e

		# interpret test output if no connection
		if not hf_status:
			raise RuntimeError(
				"Local Weights not found and could not reach Hugging Face. Check your DNS, VPN, or network access."
			)

		# --- Retry with online download ---
		print("Retrying model creation via TIMM (pretrained=True)...")
		try:
			model = timm.create_model(
				model_name, pretrained=True, num_classes=num_classes
			)
			model_container["model"] = model.to(device)
			print(
				f"âœ… Successfully downloaded and loaded '{model_name}' from Hugging Face."
			)
		except Exception as e:
			raise RuntimeError(
				f"Retried online download but still failed for '{model_name}': {e}"
			) from e

	model_container["model"] = freeze_head(model_container["model"], model_name)
	return model_container["model"]

In [4]:
# === Model Setup ===

# Default list for full experiments
model_names_default = [
	"resnet18",
	"resnet34",
	"resnet50",
	"resnet101",
	"resnet152", # doesn't fit in my machine's gpu
	"vit_base_patch32_224",
	"vit_base_patch16_224", # doesn't fit in my machine's gpu
]

# Smaller list for quick debugging
model_names_debug = [
	"resnet18",
	"resnet34",
	"vit_base_patch32_224",
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
# ViT
model = create_model_with_timeout("vit_base_patch32_224", 102, device = device)

for name, module in model.named_modules():
    print(name, module)

Ignored missing keys: ['head.weight', 'head.bias']
Ignored unexpected keys: []
 VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
     

In [None]:
# Resnet