# **Layer 1: Baseline DeiT environment**

DeiT’s baseline training script expects a teacher model name and distillation settings via CLI flags in main.py (e.g., --teacher-model, --teacher-path, --distillation-type).
GitHub
+1

So the “base environment” Layer 1 must include:

DeiT repo (cloned)

PyTorch (Colab default) + GPU

timm installed (for both student and teacher models)

compatibility patches if any (because Colab uses new torch/timm)

Install PyTorch without pinning

In [1]:
!pip -q install --upgrade pip
!pip -q install torch torchvision torchaudio

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m69.4 MB/s[0m eta [36m0:00:00[0m
[?25h

Verify

In [2]:
import torch
print(torch.__version__)
print("CUDA:", torch.cuda.is_available())

2.9.0+cu126
CUDA: True


Clone the baseline repo (official DeiT)

In [3]:
%cd /content
!git clone https://github.com/facebookresearch/deit.git
%cd /content/deit
!grep -n "torch" -n requirements.txt || true

/content
Cloning into 'deit'...
remote: Enumerating objects: 456, done.[K
remote: Total 456 (delta 0), reused 0 (delta 0), pack-reused 456 (from 1)[K
Receiving objects: 100% (456/456), 5.73 MiB | 23.20 MiB/s, done.
Resolving deltas: 100% (255/255), done.
/content/deit
1:torch==1.13.1
2:torchvision==0.8.1


Colab Compatibility Fixes

1. torch pin removal

2. timm API changes

3. kwargs popping (pretrained_cfg, cache_dir, etc.)



Patch requirements.txt to remove torch pins

In [4]:
%cd /content/deit

!python - << 'PY'
from pathlib import Path
p = Path("requirements.txt")
lines = p.read_text().splitlines()

filtered = []
removed = []
for line in lines:
    s = line.strip()
    if s.startswith("torch==") or s.startswith("torchvision==") or s.startswith("torchaudio=="):
        removed.append(line)
        continue
    filtered.append(line)

p.write_text("\n".join(filtered) + "\n")
print("✅ Removed these pinned lines:")
for r in removed:
    print("  -", r)

/content/deit
✅ Removed these pinned lines:
  - torch==1.13.1
  - torchvision==0.8.1


Verify Pins are gone!i.e torch==1.13.1 pin was removed

In [5]:
!grep -nE "torch|torchvision|torchaudio" requirements.txt || echo "✅ No torch pins remain"

✅ No torch pins remain


Install the baseline dependencies

In [6]:
pip install "jedi>=0.16,<0.19"

Collecting jedi<0.19,>=0.16
  Downloading jedi-0.18.2-py2.py3-none-any.whl.metadata (22 kB)
Downloading jedi-0.18.2-py2.py3-none-any.whl (1.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m59.3 MB/s[0m  [33m0:00:00[0m
[?25hInstalling collected packages: jedi
Successfully installed jedi-0.18.2


In [7]:
!pip -q uninstall -y timm
!pip -q install "jedi>=0.16,<0.19"
# !pip -q install timm==0.6.13 submitit
!pip -q install timm==0.4.12 submitit


Verify

In [8]:
!python -c "import timm; print('timm:', timm.__version__)"
#0.4.12

timm: 0.4.12


**Restart the Session**

In [1]:
!python - << 'PY'
from pathlib import Path

p = Path("/usr/local/lib/python3.12/dist-packages/timm/data/__init__.py")
txt = p.read_text()

needle = "OPENAI_CLIP_MEAN"
if needle in txt:
    print("✅ timm.data already mentions OPENAI_CLIP_MEAN; no patch needed.")
else:
    patch = """

# --- Colab patch: expose CLIP normalization constants for older exports ---
try:
    from .constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD  # timm versions where defined in constants
except Exception:
    # Standard OpenAI CLIP normalization
    OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
    OPENAI_CLIP_STD  = (0.26862954, 0.26130258, 0.27577711)
# --- end patch ---
"""
    p.write_text(txt + patch)
    print("✅ Patched:", p)

✅ Patched: /usr/local/lib/python3.12/dist-packages/timm/data/__init__.py


In [2]:
%cd /content/deit
from models import deit_tiny_patch16_224
m = deit_tiny_patch16_224()
print("✅ DeiT model instantiated successfully")

/content/deit
✅ DeiT model instantiated successfully


In [3]:
import torch, timm
print(torch.__version__)
print(timm.__version__)
print(torch.cuda.is_available())

2.9.0+cu126
0.4.12
True


Download Tiny-ImageNet

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
%cd /content
!wget -q http://cs231n.stanford.edu/tiny-imagenet-200.zip
!unzip -q tiny-imagenet-200.zip

/content


Fix Tiny-ImageNet validation folder

In [6]:
!python - << 'EOF'
import shutil
from pathlib import Path

root = Path("/content/tiny-imagenet-200")
val_dir = root/"val"
img_dir = val_dir/"images"
ann = val_dir/"val_annotations.txt"

with ann.open("r") as f:
    for line in f:
        img, cls = line.strip().split("\t")[:2]
        (val_dir/cls).mkdir(parents=True, exist_ok=True)
        src = img_dir/img
        dst = val_dir/cls/img
        if src.exists():
            shutil.move(str(src), str(dst))

if img_dir.exists():
    shutil.rmtree(img_dir)

print("✅ Tiny-ImageNet val reorganized into class subfolders.")

✅ Tiny-ImageNet val reorganized into class subfolders.


In [7]:
!find /content/tiny-imagenet-200/val -maxdepth 1 -type d | head

/content/tiny-imagenet-200/val
/content/tiny-imagenet-200/val/n02950826
/content/tiny-imagenet-200/val/n02085620
/content/tiny-imagenet-200/val/n01641577
/content/tiny-imagenet-200/val/n04254777
/content/tiny-imagenet-200/val/n02917067
/content/tiny-imagenet-200/val/n03404251
/content/tiny-imagenet-200/val/n03085013
/content/tiny-imagenet-200/val/n02504458
/content/tiny-imagenet-200/val/n03424325


In [8]:
ls -lah /content/tiny-imagenet-200 | head

total 2.6M
drwxrwxr-x   5 root root 4.0K Feb  9  2015 [0m[01;34m.[0m/
drwxr-xr-x   1 root root 4.0K Jan 30 19:31 [01;34m..[0m/
drwxrwxr-x   3 root root 4.0K Dec 12  2014 [01;34mtest[0m/
drwxrwxr-x 202 root root 4.0K Dec 12  2014 [01;34mtrain[0m/
drwxrwxr-x 202 root root 4.0K Jan 30 19:32 [01;34mval[0m/
-rw-rw-r--   1 root root 2.0K Feb  9  2015 wnids.txt
-rw-------   1 root root 2.6M Feb  9  2015 words.txt


Handle timm incompatibilities. Although we can instantiate the model directly, the training script uses timm.create_model(), which injects metadata arguments such as pretrained_cfg and cache_dir.
The original DeiT constructors do not support these arguments, so we remove them
YOUR NOTEBOOK CALL
    |
    v
deit_tiny_patch16_224()          ✅ works (no kwargs)

TRAINING PIPELINE
    |
    v
timm.create_model()
    |
    v
deit_tiny_patch16_224(**kwargs)  ❌ injects extra keys


Patch /content/deit/augment.py (safe compatibility fix)

In [9]:
%cd /content/deit
!python - << 'PY'
from pathlib import Path
p = Path("augment.py")
txt = p.read_text()

old = "from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor"
if old in txt:
    txt = txt.replace(
        old,
        "from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensor\n"
        "try:\n"
        "    from timm.data.transforms import _pil_interp  # older timm\n"
        "except Exception:\n"
        "    _pil_interp = None  # newer timm doesn't expose this\n"
    )
    p.write_text(txt)
    print("✅ Patched augment.py for timm compatibility.")
else:
    print("ℹ️ Expected import line not found; augment.py may already be patched or different.")

/content/deit
✅ Patched augment.py for timm compatibility.


In [10]:
%cd /content/deit
!sed -n '1,200p' models.py

/content/deit
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import torch
import torch.nn as nn
from functools import partial

from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_


__all__ = [
    'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
    'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
    'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
    'deit_base_distilled_patch16_384',
]


class DistilledVisionTransformer(VisionTransformer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        num_patches = self.patch_embed.num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
        self.head_dist = nn.Linear(self.embed_dim, self.num_classes)

In [11]:
%cd /content/deit
!rm -f multiteacher_loss.py
!ls -l multiteacher_loss.py || echo "✅ old file removed"

/content/deit
ls: cannot access 'multiteacher_loss.py': No such file or directory
✅ old file removed


In [12]:
%cd /content/deit

from pathlib import Path

code = r'''
from __future__ import annotations
from typing import Dict, List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import timm


def normalize_lambdas(lmb: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    if lmb.dim() == 1:
        return lmb / lmb.sum().clamp_min(eps)
    return lmb / lmb.sum(dim=-1, keepdim=True).clamp_min(eps)


def fuse_logits(
    teacher_logits: Dict[str, torch.Tensor],
    teacher_order: List[str],
    lambdas: torch.Tensor
) -> torch.Tensor:
    logits_list = [teacher_logits[k] for k in teacher_order]
    stacked = torch.stack(logits_list, dim=1)  # (B, T, C)

    lambdas = normalize_lambdas(lambdas).to(stacked.device)
    if lambdas.dim() == 1:
        lambdas = lambdas.unsqueeze(0).expand(stacked.size(0), -1)  # (B, T)

    return (stacked * lambdas.unsqueeze(-1)).sum(dim=1)  # (B, C)


def kd_soft(student_logits: torch.Tensor, teacher_logits: torch.Tensor, T: float) -> torch.Tensor:
    p_t = F.softmax(teacher_logits / T, dim=-1)
    log_p_s = F.log_softmax(student_logits / T, dim=-1)
    return F.kl_div(log_p_s, p_t, reduction="batchmean") * (T * T)


def kd_hard(student_logits: torch.Tensor, teacher_logits: torch.Tensor) -> torch.Tensor:
    return F.cross_entropy(student_logits, teacher_logits.argmax(dim=-1))


class FrozenTeacherEnsemble(nn.Module):
    def __init__(self, teacher_names: List[str], device: torch.device):
        super().__init__()
        self.models = nn.ModuleDict({
            name: timm.create_model(name, pretrained=True, num_classes=1000).eval().to(device)
            for name in teacher_names
        })
        for m in self.models.values():
            for p in m.parameters():
                p.requires_grad_(False)
        self.teacher_order = list(self.models.keys())

    @torch.no_grad()
    def forward(self, x):
        return {k: m(x) for k, m in self.models.items()}


class TeacherLogitAdapter(nn.Module):
    def __init__(self, teacher_keys: List[str], student_num_classes: int):
        super().__init__()
        self.adapters = nn.ModuleDict({
            k: nn.Linear(1000, student_num_classes, bias=False) for k in teacher_keys
        })

    def forward(self, teacher_logits: Dict[str, torch.Tensor]):
        return {k: self.adapters[k](v) for k, v in teacher_logits.items()}


class HDTSEConfidence(nn.Module):
    def __init__(self, temp: float = 1.0):
        super().__init__()
        self.temp = temp

    @torch.no_grad()
    def forward(self, student_logits, teacher_logits, teacher_order, targets):
        stacked = torch.stack([teacher_logits[k] for k in teacher_order], dim=1)  # (B,T,C)
        probs = F.softmax(stacked / self.temp, dim=-1)  # (B,T,C)

        # Hard labels: (B,)
        if targets.dim() == 1:
            idx = targets.to(dtype=torch.long, device=probs.device)
            conf = probs.gather(-1, idx[:, None, None]).squeeze(-1)  # (B,T)
            return normalize_lambdas(conf)

        # Soft labels (mixup/cutmix): (B,C)
        tgt = targets.to(dtype=probs.dtype, device=probs.device)
        conf = (probs * tgt[:, None, :]).sum(dim=-1)  # (B,T)
        return normalize_lambdas(conf)


class MultiTeacherDistillationLoss(nn.Module):
    def __init__(
        self,
        base_criterion,
        student_num_classes: int,
        teacher_names: List[str],
        distillation_type: str = "soft",
        alpha: float = 0.5,
        tau: float = 2.0,
        device=None,
        use_adapter: bool = True,
        hdtse_warmup_epochs: int = 0,
        lambda_log: bool = True,
    ):
        super().__init__()
        self.base_criterion = base_criterion
        self.distillation_type = distillation_type
        self.alpha = alpha
        self.tau = tau
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.teachers = FrozenTeacherEnsemble(teacher_names, self.device)
        self.teacher_order = list(self.teachers.teacher_order)

        self.adapter = TeacherLogitAdapter(self.teachers.teacher_order, student_num_classes).to(self.device) if use_adapter else None
        self.hdtse = HDTSEConfidence()

        # ---- New controls ----
        self.epoch: int = 0
        self.hdtse_warmup_epochs = int(hdtse_warmup_epochs)
        self.lambda_log = bool(lambda_log)

        # ---- Logging state (epoch-level) ----
        self._lambda_sum = torch.zeros(len(self.teacher_order), dtype=torch.float32)
        self._lambda_count = 0
        self.last_lambdas: Optional[torch.Tensor] = None  # (B,T) from last forward

    def set_epoch(self, epoch: int):
        self.epoch = int(epoch)

    def _uniform_lambdas(self, batch_size: int, device: torch.device) -> torch.Tensor:
        t = len(self.teacher_order)
        return torch.full((batch_size, t), 1.0 / t, device=device, dtype=torch.float32)

    def pop_lambda_stats(self) -> Optional[Dict[str, float]]:
        """
        Returns mean λ per teacher over the epoch, then resets accumulators.
        Call this once per epoch from main.py.
        """
        if self._lambda_count <= 0:
            return None

        mean_lmb = (self._lambda_sum / float(self._lambda_count)).tolist()
        out = {f"lambda_{name}": float(v) for name, v in zip(self.teacher_order, mean_lmb)}

        # reset
        self._lambda_sum.zero_()
        self._lambda_count = 0
        return out

    def forward(self, inputs, outputs, targets):
        base_loss = self.base_criterion(outputs, targets)

        with torch.no_grad():
            t_logits = self.teachers(inputs)
        if self.adapter is not None:
            t_logits = self.adapter(t_logits)

        # ---- HDTSE delay ----
        if self.epoch < self.hdtse_warmup_epochs:
            lambdas = self._uniform_lambdas(outputs.size(0), outputs.device)  # (B,T)
        else:
            lambdas = self.hdtse(outputs, t_logits, list(t_logits.keys()), targets)  # (B,T)

        self.last_lambdas = lambdas.detach()

        # ---- λ logging ----
        if self.lambda_log:
            # accumulate batch mean λ, weighted by batch size
            batch_mean = lambdas.detach().mean(dim=0).cpu()  # (T,)
            self._lambda_sum += batch_mean * outputs.size(0)
            self._lambda_count += outputs.size(0)

        fused = fuse_logits(t_logits, self.teacher_order, lambdas)

        kd = kd_soft(outputs, fused, self.tau) if self.distillation_type == "soft" else kd_hard(outputs, fused)
        return (1 - self.alpha) * base_loss + self.alpha * kd
'''

path = Path("multiteacher_loss.py")
path.write_text(code)

print("File written:", path)
print("File size (bytes):", path.stat().st_size)

/content/deit
File written: multiteacher_loss.py
File size (bytes): 6517


In [14]:
from pathlib import Path
import re
import py_compile

MAIN = Path("/content/deit/main.py")
assert MAIN.exists(), f"Not found: {MAIN}"

txt = MAIN.read_text()

# ------------------------------------------------------------
# 1) Add import for MultiTeacherDistillationLoss
# ------------------------------------------------------------
if "from multiteacher_loss import MultiTeacherDistillationLoss" not in txt:
    if "from losses import DistillationLoss" in txt:
        txt = txt.replace(
            "from losses import DistillationLoss",
            "from losses import DistillationLoss\nfrom multiteacher_loss import MultiTeacherDistillationLoss",
            1
        )
    else:
        raise RuntimeError("Could not find 'from losses import DistillationLoss' to insert MultiTeacher import.")

# -----------------------------
# 2) Add CLI args after --teacher-path
# -----------------------------
if "--teacher-models" not in txt:
    anchor = "    parser.add_argument('--teacher-path', type=str, default='')"
    if anchor not in txt:
        raise RuntimeError("Couldn't find --teacher-path argument to insert after.")
    insert = (
        "    parser.add_argument('--teacher-models', type=str, default='',\n"
        "                        help='Comma-separated timm model names for multi-teacher distillation')\n"
        "    parser.add_argument('--hdtse-warmup-epochs', default=3, type=int,\n"
        "                        help='Use uniform lambdas for first N epochs, then enable HDTSE weighting')\n"
        "    parser.add_argument('--lambda-log', action='store_true',\n"
        "                        help='Log mean lambda per teacher each epoch (only for multi-teacher)')\n"
    )
    txt = txt.replace(anchor, anchor + "\n" + insert, 1)
    print("✅ Added --teacher-models, --hdtse-warmup-epochs, --lambda-log")

# ------------------------------------------------------------
# 3) Allow finetune + multi-teacher distillation (block only single-teacher)
# ------------------------------------------------------------
txt = re.sub(
    r"if args\.distillation_type != 'none' and args\.finetune and not args\.eval:\s*\n\s*raise NotImplementedError\(\"Finetuning with distillation not yet supported\"\)",
    "if args.distillation_type != 'none' and args.finetune and not args.eval and not args.teacher_models:\n"
    "        raise NotImplementedError(\"Finetuning with distillation not yet supported (single-teacher path)\")",
    txt
)

# ------------------------------------------------------------
# 4) Ensure SoftTargetCrossEntropy is instantiated (it is already () in your base, keep safe)
# ------------------------------------------------------------
txt = txt.replace("criterion = SoftTargetCrossEntropy\n", "criterion = SoftTargetCrossEntropy()\n")

# ------------------------------------------------------------
# 5) Remove early scheduler creation (base does it before distillation + adapter)
# ------------------------------------------------------------
txt = re.sub(
    r"^\s*lr_scheduler\s*,\s*_\s*=\s*create_scheduler\(\s*args\s*,\s*optimizer\s*\)\s*$",
    "",
    txt,
    flags=re.MULTILINE
)

# ============================================================
# 6) Replace distillation section with unified multi + single teacher logic
#    Anchors: "teacher_model = None" up to just before "output_dir = Path(args.output_dir)"
# ============================================================
start_key = "    teacher_model = None"
out_key = "    output_dir = Path(args.output_dir)"

s = txt.find(start_key)
o = txt.find(out_key)

if s == -1 or o == -1 or o <= s:
    raise RuntimeError("❌ Could not locate distillation anchors ('teacher_model = None' / 'output_dir = Path(...)').")


# The base file has the DistillationLoss call ending before output_dir.
# We'll replace everything from teacher_model=None up to just before output_dir.
replacement_block = """    teacher_model = None

    if args.distillation_type != 'none':
        # Allow either teacher-path (single teacher) OR teacher-models (multi teacher)
        assert (args.teacher_path or args.teacher_models), 'need to specify teacher-path OR teacher-models when using distillation'

        # -----------------------
        # Multi-teacher distillation
        # -----------------------
        if args.teacher_models:
            teacher_names = [t.strip() for t in args.teacher_models.split(',') if t.strip()]
            print("✅ Multi-teacher distillation enabled. Teachers:", teacher_names)

            criterion = MultiTeacherDistillationLoss(
                base_criterion=criterion,
                student_num_classes=args.nb_classes,
                teacher_names=teacher_names,
                distillation_type=args.distillation_type,
                alpha=args.distillation_alpha,
                tau=args.distillation_tau,
                device=device,
                use_adapter=True,
                hdtse_warmup_epochs=args.hdtse_warmup_epochs,
                lambda_log=args.lambda_log,
            )

            # IMPORTANT: adapter must be trained
            if hasattr(criterion, "adapter") and criterion.adapter is not None:
                optimizer.add_param_group({
                    "params": criterion.adapter.parameters(),
                    "lr": args.lr,
                    "weight_decay": 0.0
                })
                print("✅ Added adapter parameters to optimizer")

        # -----------------------
        # Single-teacher distillation (original DeiT)
        # -----------------------
        else:
            print(f"Creating teacher model: {args.teacher_model}")
            teacher_model = create_model(
                args.teacher_model,
                pretrained=False,
                num_classes=args.nb_classes,
                global_pool='avg',
            )
            if args.teacher_path.startswith('https'):
                checkpoint = torch.hub.load_state_dict_from_url(
                    args.teacher_path, map_location='cpu', check_hash=True)
            else:
                checkpoint = torch.load(args.teacher_path, map_location='cpu')
            teacher_model.load_state_dict(checkpoint['model'])
            teacher_model.to(device)
            teacher_model.eval()

            criterion = DistillationLoss(
                criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
            )

    else:
        # No distillation
        pass

    # Create scheduler AFTER optimizer has all param groups (incl adapter)
    lr_scheduler, _ = create_scheduler(args, optimizer)

"""

txt = txt[:s] + replacement_block + "\n" + txt[o:]  # keep the 'output_dir...' line onward

# ============================================================
# 6) Add epoch setter + lambda logging in training loop
#    We insert two snippets:
#    (a) before train_one_epoch call
#    (b) after train_one_epoch returns
# ============================================================

# (a) before train_one_epoch(...) call
if "criterion.set_epoch(epoch)" not in txt:
    # find the call site "train_stats = train_one_epoch(" and insert just above it
    marker = "        train_stats = train_one_epoch("
    idx = txt.find(marker)
    if idx == -1:
        print("⚠️ Could not find train_one_epoch call to insert set_epoch(). Skipping.")
    else:
        insert = (
            "        # ---- Multi-teacher: inform loss about current epoch (for HDTSE warmup) ----\n"
            "        if hasattr(criterion, \"set_epoch\"):\n"
            "            criterion.set_epoch(epoch)\n\n"
        )
        txt = txt[:idx] + insert + txt[idx:]
        print("✅ Inserted criterion.set_epoch(epoch) before train_one_epoch")

# (b) after train_one_epoch(...) returns, log lambdas
if "pop_lambda_stats" not in txt:
    # Insert after the line: train_stats = train_one_epoch(...)
    # We'll look for the first occurrence of "train_stats = train_one_epoch(" then find the next blank line
    m = re.search(r"train_stats\s*=\s*train_one_epoch\([\s\S]*?\)\n", txt)
    if not m:
        print("⚠️ Could not locate end of train_one_epoch(...) call to insert lambda logging. Skipping.")
    else:
        insert_after = (
            "\n        # ---- Multi-teacher: lambda (λ) logging for HDTSE transparency ----\n"
            "        if getattr(args, \"lambda_log\", False) and hasattr(criterion, \"pop_lambda_stats\"):\n"
            "            lmb_stats = criterion.pop_lambda_stats()\n"
            "            if lmb_stats:\n"
            "                print(\"λ(mean over epoch):\", lmb_stats)\n"
            "                train_stats.update(lmb_stats)\n"
        )
        txt = txt[:m.end()] + insert_after + txt[m.end():]
        print("✅ Inserted per-epoch λ logging after train_one_epoch")

# ---------------------------------------------------------
# (F) In training loop: set_epoch + lambda stats logging
# ---------------------------------------------------------
loop_anchor = re.search(r"for epoch in range\(args\.start_epoch, args\.epochs\):\n", txt)
if not loop_anchor:
    raise RuntimeError("Could not find training loop to add set_epoch hook.")

set_epoch_hook = (
    "        # Multi-teacher: drive HDTSE warmup + lambda logging\n"
    "        if hasattr(criterion, 'set_epoch'):\n"
    "            criterion.set_epoch(epoch)\n"
)
txt = txt[:loop_anchor.end()] + set_epoch_hook + txt[loop_anchor.end():]

# After train_one_epoch, print λ stats if available
after_train_anchor = re.search(r"train_stats\s*=\s*train_one_epoch\([\s\S]*?\)\n", txt)
if not after_train_anchor:
    raise RuntimeError("Could not find train_one_epoch(...) call block to attach lambda stats reporting.")

lambda_report = (
    "        if hasattr(criterion, 'pop_lambda_stats'):\n"
    "            lmb = criterion.pop_lambda_stats()\n"
    "            if lmb is not None:\n"
    "                print('λ means:', lmb)\n"
)
txt = txt[:after_train_anchor.end()] + lambda_report + txt[after_train_anchor.end():]

path.write_text(txt)
print("✅ Patched main.py written to:", path)

py_compile.compile(str(path), doraise=True)
print("✅ main.py compiles successfully")


✅ Added --teacher-models, --hdtse-warmup-epochs, --lambda-log
✅ Inserted criterion.set_epoch(epoch) before train_one_epoch
✅ Inserted per-epoch λ logging after train_one_epoch
✅ Patched main.py written to: multiteacher_loss.py
✅ main.py compiles successfully


Before constructing the model, remove those keys from kwargs

In [15]:
from pathlib import Path

p = Path("/content/deit/models.py")
lines = p.read_text().splitlines()

out = []
for line in lines:
    out.append(line)
    if line.strip().startswith("def deit_") and "**kwargs" in line:
        out.append("    # Drop timm-injected kwargs not supported by DeiT")
        out.append("    kwargs.pop('pretrained_cfg', None)")
        out.append("    kwargs.pop('pretrained_cfg_overlay', None)")
        out.append("    kwargs.pop('pretrained_cfg_priority', None)")

p.write_text("\n".join(out) + "\n")
print("✅ models.py patched to drop pretrained_cfg kwargs")


✅ models.py patched to drop pretrained_cfg kwargs


Verify

In [16]:
!grep -n "pretrained_cfg" /content/deit/models.py

65:    kwargs.pop('pretrained_cfg', None)
66:    kwargs.pop('pretrained_cfg_overlay', None)
67:    kwargs.pop('pretrained_cfg_priority', None)
84:    kwargs.pop('pretrained_cfg', None)
85:    kwargs.pop('pretrained_cfg_overlay', None)
86:    kwargs.pop('pretrained_cfg_priority', None)
103:    kwargs.pop('pretrained_cfg', None)
104:    kwargs.pop('pretrained_cfg_overlay', None)
105:    kwargs.pop('pretrained_cfg_priority', None)
122:    kwargs.pop('pretrained_cfg', None)
123:    kwargs.pop('pretrained_cfg_overlay', None)
124:    kwargs.pop('pretrained_cfg_priority', None)
141:    kwargs.pop('pretrained_cfg', None)
142:    kwargs.pop('pretrained_cfg_overlay', None)
143:    kwargs.pop('pretrained_cfg_priority', None)
160:    kwargs.pop('pretrained_cfg', None)
161:    kwargs.pop('pretrained_cfg_overlay', None)
162:    kwargs.pop('pretrained_cfg_priority', None)
179:    kwargs.pop('pretrained_cfg', None)
180:    kwargs.pop('pretrained_cfg_overlay', None)
181:    kwargs.pop('pretrained_cfg_p

Fix: Patch /content/deit/models.py to drop pretrained_cfg=...

Patch models.py to also drop cache_dir (and friends)

In [17]:
from pathlib import Path

p = Path("/content/deit/models.py")
lines = p.read_text().splitlines()

# Keys that timm may inject but DeiT constructors don't accept
DROP_KEYS = [
    "cache_dir",
    "hf_hub_id",
    "hf_hub_filename",
    "hf_hub_revision",
]

out = []
for line in lines:
    out.append(line)
    # Right after the comment line we previously inserted, add more pops once per function
    if line.strip() == "# Drop timm-injected kwargs not supported by DeiT":
        for k in DROP_KEYS:
            out.append(f"    kwargs.pop('{k}', None)")

p.write_text("\n".join(out) + "\n")
print("✅ Patched models.py to drop cache_dir/hf_hub* kwargs")


✅ Patched models.py to drop cache_dir/hf_hub* kwargs


Verify

In [18]:
!grep -n "cache_dir" /content/deit/models.py

65:    kwargs.pop('cache_dir', None)
88:    kwargs.pop('cache_dir', None)
111:    kwargs.pop('cache_dir', None)
134:    kwargs.pop('cache_dir', None)
157:    kwargs.pop('cache_dir', None)
180:    kwargs.pop('cache_dir', None)
203:    kwargs.pop('cache_dir', None)
226:    kwargs.pop('cache_dir', None)


In [19]:
# %cd /content/deit
# !python main.py \
#   --model deit_tiny_patch16_224 \
#   --data-path /content/tiny-imagenet-200 \
#   --pretrained \
#   --epochs 1 \
#   --batch-size 64 \
#   --num_workers 2 \
#   --output_dir /content/deit_runs/smoke_test
# %cd /content/deit
# !python main.py \
#   --model deit_tiny_patch16_224 \
#   --data-path /content/tiny-imagenet-200 \
#   --epochs 1 \
#   --batch-size 128 \
#   --num_workers 4 \
#   --input-size 224 \
#   --opt adamw \
#   --lr 5e-4 \
#   --weight-decay 0.05 \
#   --sched cosine \
#   --aa rand-m9-mstd0.5 \
#   --reprob 0.25 \
#   --remode pixel \
#   --recount 1 \
#   --output_dir /content/deit_runs/tiny_imagenet
### correct one
# %cd /content/deit
# !python main.py \
#  --model deit_tiny_patch16_224 \
#  --data-path /content/tiny-imagenet-200 \
#  --finetune https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth \
#  --epochs 10 \
#  --batch-size 128 \
#  --num_workers 4 \
#  --input-size 224 \
#  --opt adamw \
#  --lr 3e-4 \
#  --weight-decay 0.05 \
#  --sched cosine \
#  --warmup-epochs 1 \
#  --smoothing 0.1 \
#  --aa rand-m7-mstd0.5 \
#  --reprob 0.1 \
#  --drop-path 0.1 \
#  --output_dir /content/deit_runs/tiny_imagenet_5ep
%cd /content/deit
!python main.py \
 --model deit_tiny_patch16_224 \
 --data-path /content/tiny-imagenet-200 \
 --finetune https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth \
 --epochs 10 \
 --batch-size 128 \
 --num_workers 4 \
 --input-size 224 \
 --opt adamw \
 --lr 2.5e-4 \
 --weight-decay 0.05 \
 --sched cosine \
 --warmup-epochs 0 \
 --smoothing 0.1 \
 --aa rand-m6-mstd0.5 \
 --reprob 0.1 \
 --drop-path 0.05 \
 --mixup 0.2 \
 --cutmix 0.0 \
 --mixup-prob 0.5 \
 --distillation-type soft \
 --distillation-alpha 0.2 \
 --distillation-tau 2.0 \
 --hdtse-warmup-epochs 3 \
 --lambda-log \
 --output_dir /content/deit_runs/tiny_imagenet \
 --teacher-models "tf_efficientnet_b2,mobilenetv3_large_100,regnety_040"
# %cd /content/deit
# !python main.py \
#  --model deit_tiny_patch16_224 \
#  --data-path /content/tiny-imagenet-200 \
#  --finetune https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth \
#  --epochs 10 \
#  --batch-size 128 \
#  --num_workers 4 \
#  --input-size 224 \
#  --opt adamw \
#  --lr 2.5e-4 \
#  --weight-decay 0.05 \
#  --sched cosine \
#  --warmup-epochs 1 \
#  --smoothing 0.1 \
#  --aa rand-m7-mstd0.5 \
#  --reprob 0.1 \
#  --drop-path 0.1 \
#  --distillation-type hard \
# --teacher-model regnety_160 \
# --teacher-path https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth \
#  --output_dir /content/deit_runs/tiny_imagenet_10ep
# %cd /content/deit
# !python main.py \
#  --model deit_tiny_distilled_patch16_224 \
#  --data-path /content/tiny-imagenet-200 \
#  --epochs 10 \
#  --batch-size 128 \
#  --num_workers 4 \
#  --input-size 224 \
#  --opt adamw \
#  --lr 7e-4 \
#  --weight-decay 0.05 \
#  --sched cosine \
#  --warmup-epochs 1 \
#  --smoothing 0.0 \
#  --aa rand-m7-mstd0.5 \
#  --reprob 0.1 \
#  --drop-path 0.0 \
#  --distillation-type hard \
#  --distillation-alpha 0.7 \
#  --teacher-model regnety_160 \
#  --teacher-path https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth \
#  --output_dir /content/deit_runs/deit_tiny_distilled_10ep





/content/deit
usage: DeiT training and evaluation script [-h] [--batch-size BATCH_SIZE]
                                           [--epochs EPOCHS] [--bce-loss]
                                           [--unscale-lr] [--model MODEL]
                                           [--input-size INPUT_SIZE]
                                           [--drop PCT] [--drop-path PCT]
                                           [--model-ema] [--no-model-ema]
                                           [--model-ema-decay MODEL_EMA_DECAY]
                                           [--model-ema-force-cpu]
                                           [--opt OPTIMIZER]
                                           [--opt-eps EPSILON]
                                           [--opt-betas BETA [BETA ...]]
                                           [--clip-grad NORM] [--momentum M]
                                           [--weight-decay WEIGHT_DECAY]
                                           [--sched SC

# **Layer 2: Base Environment — Teacher Models & Multi-Teacher Adaptations**

Layer 2 extends the baseline DeiT environment to support knowledge distillation from one or more teacher models. This layer is additive: it does not modify the baseline DeiT training loop unless explicitly stated.
It includes
1. Teacher Model Support (Single & Multiple)
2. Teacher Registry / Configuration
3. Multi-Teacher Fusion Mechanism (Adaptation Layer)
4. Distillation Loss Integration