# **Layer 1: Baseline DeiT environment**

DeiT’s baseline training script expects a teacher model name and distillation settings via CLI flags in main.py (e.g., --teacher-model, --teacher-path, --distillation-type).
GitHub
+1

So the “base environment” Layer 1 must include:

DeiT repo (cloned)

PyTorch (Colab default) + GPU

timm installed (for both student and teacher models)

compatibility patches if any (because Colab uses new torch/timm)

Install PyTorch without pinning

In [1]:
!pip -q install --upgrade pip
!pip -q install torch torchvision torchaudio

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m77.0 MB/s[0m eta [36m0:00:00[0m
[?25h

Verify

In [2]:
import torch
print(torch.__version__)
print("CUDA:", torch.cuda.is_available())

2.9.0+cu128
CUDA: True


Clone the baseline repo (official DeiT)

In [3]:
%cd /content

# Check if 'deit' folder exists → delete it
!if [ -d "deit" ]; then rm -rf deit; fi

!git clone https://github.com/facebookresearch/deit.git
%cd /content/deit
!grep -n "torch" -n requirements.txt || true

/content
Cloning into 'deit'...
remote: Enumerating objects: 456, done.[K
remote: Total 456 (delta 0), reused 0 (delta 0), pack-reused 456 (from 1)[K
Receiving objects: 100% (456/456), 5.73 MiB | 6.39 MiB/s, done.
Resolving deltas: 100% (255/255), done.
/content/deit
1:torch==1.13.1
2:torchvision==0.8.1


Colab Compatibility Fixes

1. torch pin removal

2. timm API changes

3. kwargs popping (pretrained_cfg, cache_dir, etc.)



Patch requirements.txt to remove torch pins

In [4]:
%cd /content/deit

!python - << 'PY'
from pathlib import Path
p = Path("requirements.txt")
lines = p.read_text().splitlines()

filtered = []
removed = []
for line in lines:
    s = line.strip()
    if s.startswith("torch==") or s.startswith("torchvision==") or s.startswith("torchaudio=="):
        removed.append(line)
        continue
    filtered.append(line)

p.write_text("\n".join(filtered) + "\n")
print("✅ Removed these pinned lines:")
for r in removed:
    print("  -", r)

/content/deit
✅ Removed these pinned lines:
  - torch==1.13.1
  - torchvision==0.8.1


Verify Pins are gone!i.e torch==1.13.1 pin was removed

In [5]:
!grep -nE "torch|torchvision|torchaudio" requirements.txt || echo "✅ No torch pins remain"

✅ No torch pins remain


Install the baseline dependencies

In [6]:
pip install "jedi>=0.16,<0.19"

Collecting jedi<0.19,>=0.16
  Downloading jedi-0.18.2-py2.py3-none-any.whl.metadata (22 kB)
Downloading jedi-0.18.2-py2.py3-none-any.whl (1.6 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m52.0 MB/s[0m  [33m0:00:00[0m
[?25hInstalling collected packages: jedi
Successfully installed jedi-0.18.2


In [7]:
!pip -q uninstall -y timm
!pip -q install "jedi>=0.16,<0.19"
!pip -q install timm==0.6.13 submitit
#!pip -q install timm==0.4.12 submitit


Verify

In [8]:
!python -c "import timm; print('timm:', timm.__version__)"

timm: 0.6.13


**Restart the Session**

In [9]:
!python - << 'PY'
from pathlib import Path

p = Path("/usr/local/lib/python3.12/dist-packages/timm/data/__init__.py")
txt = p.read_text()

needle = "OPENAI_CLIP_MEAN"
if needle in txt:
    print("✅ timm.data already mentions OPENAI_CLIP_MEAN; no patch needed.")
else:
    patch = """

# --- Colab patch: expose CLIP normalization constants for older exports ---
try:
    from .constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD  # timm versions where defined in constants
except Exception:
    # Standard OpenAI CLIP normalization
    OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
    OPENAI_CLIP_STD  = (0.26862954, 0.26130258, 0.27577711)
# --- end patch ---
"""
    p.write_text(txt + patch)
    print("✅ Patched:", p)

✅ Patched: /usr/local/lib/python3.12/dist-packages/timm/data/__init__.py


In [10]:
%cd /content/deit
from models import deit_tiny_patch16_224
m = deit_tiny_patch16_224()
print("✅ DeiT model instantiated successfully")

/content/deit
✅ DeiT model instantiated successfully


In [11]:
import torch, timm
print(torch.__version__)
print(timm.__version__)
print(torch.cuda.is_available())

2.9.0+cu128
0.6.13
True


Download Tiny-ImageNet

In [12]:
%cd /content
!wget -q http://cs231n.stanford.edu/tiny-imagenet-200.zip
!unzip -q tiny-imagenet-200.zip

/content


Fix Tiny-ImageNet validation folder

In [13]:
!python - << 'EOF'
import shutil
from pathlib import Path

root = Path("/content/tiny-imagenet-200")
val_dir = root/"val"
img_dir = val_dir/"images"
ann = val_dir/"val_annotations.txt"

with ann.open("r") as f:
    for line in f:
        img, cls = line.strip().split("\t")[:2]
        (val_dir/cls).mkdir(parents=True, exist_ok=True)
        src = img_dir/img
        dst = val_dir/cls/img
        if src.exists():
            shutil.move(str(src), str(dst))

if img_dir.exists():
    shutil.rmtree(img_dir)

print("✅ Tiny-ImageNet val reorganized into class subfolders.")

✅ Tiny-ImageNet val reorganized into class subfolders.


In [14]:
!find /content/tiny-imagenet-200/val -maxdepth 1 -type d | head

/content/tiny-imagenet-200/val
/content/tiny-imagenet-200/val/n02002724
/content/tiny-imagenet-200/val/n03447447
/content/tiny-imagenet-200/val/n07715103
/content/tiny-imagenet-200/val/n02486410
/content/tiny-imagenet-200/val/n01770393
/content/tiny-imagenet-200/val/n02666196
/content/tiny-imagenet-200/val/n02892201
/content/tiny-imagenet-200/val/n04356056
/content/tiny-imagenet-200/val/n02364673


In [15]:
ls -lah /content/tiny-imagenet-200 | head

total 2.6M
drwxrwxr-x   5 root root 4.0K Feb  9  2015 [0m[01;34m.[0m/
drwxr-xr-x   1 root root 4.0K Feb 13 07:46 [01;34m..[0m/
drwxrwxr-x   3 root root 4.0K Dec 12  2014 [01;34mtest[0m/
drwxrwxr-x 202 root root 4.0K Dec 12  2014 [01;34mtrain[0m/
drwxrwxr-x 202 root root 4.0K Feb 13 07:47 [01;34mval[0m/
-rw-rw-r--   1 root root 2.0K Feb  9  2015 wnids.txt
-rw-------   1 root root 2.6M Feb  9  2015 words.txt


Handle timm incompatibilities. Although we can instantiate the model directly, the training script uses timm.create_model(), which injects metadata arguments such as pretrained_cfg and cache_dir.
The original DeiT constructors do not support these arguments, so we remove them
YOUR NOTEBOOK CALL
    |
    v
deit_tiny_patch16_224()          ✅ works (no kwargs)

TRAINING PIPELINE
    |
    v
timm.create_model()
    |
    v
deit_tiny_patch16_224(**kwargs)  ❌ injects extra keys


Patch /content/deit/augment.py (safe compatibility fix)

In [16]:
%cd /content/deit
!python - << 'PY'
from pathlib import Path
p = Path("augment.py")
txt = p.read_text()

old = "from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor"
if old in txt:
    txt = txt.replace(
        old,
        "from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensor\n"
        "try:\n"
        "    from timm.data.transforms import _pil_interp  # older timm\n"
        "except Exception:\n"
        "    _pil_interp = None  # newer timm doesn't expose this\n"
    )
    p.write_text(txt)
    print("✅ Patched augment.py for timm compatibility.")
else:
    print("ℹ️ Expected import line not found; augment.py may already be patched or different.")

/content/deit
✅ Patched augment.py for timm compatibility.


In [17]:
%cd /content/deit
!rm -f multiteacher_loss.py
!ls -l multiteacher_loss.py || echo "✅ old file removed"

/content/deit
ls: cannot access 'multiteacher_loss.py': No such file or directory
✅ old file removed


In [18]:
%cd /content/deit

from pathlib import Path

code = r'''
from __future__ import annotations

from typing import Dict, List, Optional
import json
from pathlib import Path as _Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import timm


# -----------------------------
# Utilities
# -----------------------------
def normalize_lambdas(lmb: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """
    Normalize teacher weights so they sum to 1 (supports shape (T,) or (B,T)).
    """
    if lmb.dim() == 1:
        return lmb / lmb.sum().clamp_min(eps)
    return lmb / lmb.sum(dim=-1, keepdim=True).clamp_min(eps)


def fuse_logits(
    teacher_logits: Dict[str, torch.Tensor],
    teacher_order: List[str],
    lambdas: torch.Tensor,
) -> torch.Tensor:
    """
    Weighted sum of teacher logits.
    teacher_logits[k]: (B,C)
    lambdas: (B,T) or (T,)
    returns: (B,C)
    """
    logits_list = [teacher_logits[k] for k in teacher_order]
    stacked = torch.stack(logits_list, dim=1)  # (B,T,C)

    lmb = normalize_lambdas(lambdas).to(stacked.device)
    if lmb.dim() == 1:
        lmb = lmb.unsqueeze(0).expand(stacked.size(0), -1)  # (B,T)

    return (stacked * lmb.unsqueeze(-1)).sum(dim=1)


def kd_soft(student_logits: torch.Tensor, teacher_logits: torch.Tensor, T: float) -> torch.Tensor:
    """
    Standard KL-based soft distillation loss with temperature scaling.
    """
    p_t = F.softmax(teacher_logits / T, dim=-1)
    log_p_s = F.log_softmax(student_logits / T, dim=-1)
    return F.kl_div(log_p_s, p_t, reduction="batchmean") * (T * T)


def kd_hard(student_logits: torch.Tensor, teacher_logits: torch.Tensor) -> torch.Tensor:
    """
    Hard distillation: cross-entropy against teacher argmax.
    """
    return F.cross_entropy(student_logits, teacher_logits.argmax(dim=-1))


# -----------------------------
# Teachers
# -----------------------------
class FrozenTeacherEnsemble(nn.Module):
    """
    Loads a list of timm pretrained teachers and freezes them.
    """
    def __init__(self, teacher_names: List[str], device: torch.device):
        super().__init__()
        self.models = nn.ModuleDict(
            {
                name: timm.create_model(name, pretrained=True, num_classes=1000).eval().to(device)
                for name in teacher_names
            }
        )
        for m in self.models.values():
            for p in m.parameters():
                p.requires_grad_(False)
        self.teacher_order = list(self.models.keys())

    @torch.no_grad()
    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        return {k: m(x) for k, m in self.models.items()}


# -----------------------------
# Teacher logits mapping: ImageNet-1k -> Tiny-ImageNet (wnid-aligned gather)
# -----------------------------
def build_tiny_imagenet_im1k_indices(
    tiny_root: str,
    class_index_json: str = "/content/imagenet_class_index.json",
) -> torch.Tensor:
    """
    Returns a LongTensor of shape (200,) containing the ImageNet-1k class indices
    corresponding to Tiny-ImageNet wnids.txt ordering.

    Requires torchvision's imagenet_class_index.json (wnid->index via JSON).
    """
    tiny_root_p = _Path(tiny_root)
    wnids_path = tiny_root_p / "wnids.txt"
    if not wnids_path.exists():
        raise FileNotFoundError(f"Could not find Tiny-ImageNet wnids.txt at: {wnids_path}")

    wnids = wnids_path.read_text().strip().splitlines()

    class_index_path = _Path(class_index_json)
    if not class_index_path.exists():
        raise FileNotFoundError(
            f"Missing {class_index_json}. Download it before training.\n"
            "Example:\n"
            "  !wget -q https://raw.githubusercontent.com/pytorch/vision/main/torchvision/models/imagenet_class_index.json "
            f"-O {class_index_json}"
        )

    class_index = json.loads(class_index_path.read_text())
    # class_index: {"0": ["n01440764", "tench"], ...}
    wnid_to_idx = {v[0]: int(k) for k, v in class_index.items()}

    indices: List[int] = []
    missing: List[str] = []
    for w in wnids:
        if w in wnid_to_idx:
            indices.append(wnid_to_idx[w])
        else:
            missing.append(w)

    if missing:
        raise ValueError(
            f"{len(missing)} Tiny-ImageNet wnids were not found in ImageNet-1k mapping. "
            f"First few missing: {missing[:10]}"
        )

    return torch.tensor(indices, dtype=torch.long)


class TeacherLogitMapper(nn.Module):
    """
    Maps ImageNet-1k teacher logits (B,1000) -> Tiny-ImageNet logits (B,200)
    by selecting the 200 corresponding ImageNet indices (gather/index_select).
    """
    def __init__(self, teacher_keys: List[str], im1k_indices: torch.Tensor):
        super().__init__()
        self.teacher_keys = list(teacher_keys)
        self.register_buffer("im1k_indices", im1k_indices)  # (200,)

    def forward(self, teacher_logits: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        out: Dict[str, torch.Tensor] = {}
        idx = self.im1k_indices
        for k, v in teacher_logits.items():
            # v: (B,1000) -> (B,200)
            out[k] = v.index_select(dim=-1, index=idx)
        return out


# -----------------------------
# HDTSE confidence weighting
# -----------------------------
class HDTSEConfidence(nn.Module):
    """
    Computes per-sample teacher weights based on each teacher's confidence
    on the (possibly soft) targets.
    """
    def __init__(self, temp: float = 1.0):
        super().__init__()
        self.temp = float(temp)

    @torch.no_grad()
    def forward(
        self,
        teacher_logits: Dict[str, torch.Tensor],
        teacher_order: List[str],
        targets: torch.Tensor,
    ) -> torch.Tensor:
        stacked = torch.stack([teacher_logits[k] for k in teacher_order], dim=1)  # (B,T,C)
        probs = F.softmax(stacked / self.temp, dim=-1)  # (B,T,C)

        # Hard labels: (B,)
        if targets.dim() == 1:
            idx = targets.to(dtype=torch.long, device=probs.device)
            conf = probs.gather(-1, idx[:, None, None]).squeeze(-1)  # (B,T)
            return normalize_lambdas(conf)

        # Soft labels (mixup/cutmix): (B,C)
        tgt = targets.to(dtype=probs.dtype, device=probs.device)
        conf = (probs * tgt[:, None, :]).sum(dim=-1)  # (B,T)
        return normalize_lambdas(conf)


# -----------------------------
# Multi-teacher distillation loss
# -----------------------------
class MultiTeacherDistillationLoss(nn.Module):
    def __init__(
        self,
        base_criterion,
        student_num_classes: int,
        teacher_names: List[str],
        distillation_type: str = "soft",
        alpha: float = 0.5,
        tau: float = 2.0,
        device=None,
        use_adapter: bool = True,
        hdtse_warmup_epochs: int = 0,
        lambda_log: bool = True,
    ):
        """
        base_criterion: supervised loss (CE or soft-target CE when mixup is enabled)
        distillation_type: "soft" or "hard"
        alpha: final KD weight
        tau: KD temperature
        use_adapter: if True, expects Tiny-ImageNet mapping via set_tiny_root() before training
        hdtse_warmup_epochs: use uniform lambdas until this epoch (exclusive)
        """
        super().__init__()
        self.base_criterion = base_criterion
        self.distillation_type = str(distillation_type)
        self.tau = float(tau)
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # teachers (frozen)
        self.teachers = FrozenTeacherEnsemble(teacher_names, self.device)
        self.teacher_order = list(self.teachers.teacher_order)

        # teacher->student class mapping (ImageNet-1k -> dataset classes)
        self.use_adapter = bool(use_adapter)
        self.adapter: Optional[nn.Module] = None  # created by set_tiny_root()

        # HDTSE teacher weighting
        self.hdtse = HDTSEConfidence()

        # epoch state
        self.epoch: int = 0
        self.hdtse_warmup_epochs = int(hdtse_warmup_epochs)

        # alpha schedule (KD weight ramp)
        self.alpha_final = float(alpha)
        self.alpha_start = 0.0
        self.alpha_ramp_epochs = 20  # default ramp duration

        # lambda logging (epoch-level)
        self.lambda_log = bool(lambda_log)
        self._lambda_sum = torch.zeros(len(self.teacher_order), dtype=torch.float32)
        self._lambda_count = 0

    # ---- Public setters ----
    def set_epoch(self, epoch: int):
        self.epoch = int(epoch)

    def set_alpha_schedule(self, alpha_start: float = 0.0, alpha_ramp_epochs: int = 20):
        self.alpha_start = float(alpha_start)
        self.alpha_ramp_epochs = int(alpha_ramp_epochs)

    def set_tiny_root(self, tiny_root: str, class_index_json: str = "/content/imagenet_class_index.json"):
        """
        Call once (from main.py) after constructing this loss, before training starts.
        Creates the gather-based teacher logits mapper: (B,1000)->(B,C).
        """
        im1k_indices = build_tiny_imagenet_im1k_indices(tiny_root, class_index_json=class_index_json).to(self.device)
        self.adapter = TeacherLogitMapper(self.teacher_order, im1k_indices).to(self.device)

    # ---- Logging ----
    def pop_lambda_stats(self) -> Optional[Dict[str, float]]:
        """
        Returns mean λ per teacher over the epoch, then resets accumulators.
        Call once per epoch from main.py.
        """
        if self._lambda_count <= 0:
            return None

        mean_lmb = (self._lambda_sum / float(self._lambda_count)).tolist()
        out = {f"lambda_{name}": float(v) for name, v in zip(self.teacher_order, mean_lmb)}

        self._lambda_sum.zero_()
        self._lambda_count = 0
        return out

    # ---- Internals ----
    def _uniform_lambdas(self, batch_size: int, device: torch.device) -> torch.Tensor:
        t = len(self.teacher_order)
        return torch.full((batch_size, t), 1.0 / t, device=device, dtype=torch.float32)

    def _alpha_effective(self) -> float:
        if self.alpha_ramp_epochs <= 0:
            return self.alpha_final
        t = min(1.0, float(self.epoch) / float(self.alpha_ramp_epochs))
        return self.alpha_start + t * (self.alpha_final - self.alpha_start)

    # ---- Forward ----
    def forward(self, inputs: torch.Tensor, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        inputs: images (B,3,H,W)
        outputs: student logits (B,C)
        targets: hard labels (B,) or soft labels (B,C) when mixup/cutmix is enabled
        """
        base_loss = self.base_criterion(outputs, targets)

        with torch.no_grad():
            t_logits = self.teachers(inputs)  # dict: teacher -> (B,1000)

        student_C = outputs.shape[-1]
        any_teacher = next(iter(t_logits.values()))
        teacher_C = any_teacher.shape[-1]

        if teacher_C != student_C:
            if self.adapter is None:
                raise RuntimeError(
                f"Teacher logits have {teacher_C} classes but student has {student_C}. "
                "Adapter not initialized. Call criterion.set_tiny_root(args.data_path)."
            )
            t_logits = self.adapter(t_logits)  # dict: teacher -> (B,student_C)

        # ---- Teacher weights (λ) ----
        if self.epoch < self.hdtse_warmup_epochs:
            lambdas = self._uniform_lambdas(outputs.size(0), outputs.device)  # (B,T)
        else:
            lambdas = self.hdtse(t_logits, self.teacher_order, targets)  # (B,T)

        # ---- λ logging ----
        if self.lambda_log:
            batch_mean = lambdas.detach().mean(dim=0).cpu()  # (T,)
            self._lambda_sum += batch_mean * outputs.size(0)
            self._lambda_count += outputs.size(0)

        fused = fuse_logits(t_logits, self.teacher_order, lambdas)  # (B,C)

        kd = kd_soft(outputs, fused, self.tau) if self.distillation_type == "soft" else kd_hard(outputs, fused)

        alpha_eff = self._alpha_effective()
        return (1.0 - alpha_eff) * base_loss + alpha_eff * kd
'''

path = Path("multiteacher_loss.py")
path.write_text(code)

print("File written:", path)
print("File size (bytes):", path.stat().st_size)

/content/deit
File written: multiteacher_loss.py
File size (bytes): 11996


In [19]:
from pathlib import Path
import re, py_compile

MAIN = Path("/content/deit/main.py")
txt = MAIN.read_text()

# ------------------------------------------------------------
# Helpers (line-safe insertions to avoid indentation/newline bugs)
# ------------------------------------------------------------
def fix_broken_import_concatenation():
    global txt
    # Fix exact failure mode:
    txt = txt.replace(
        "from multiteacher_loss import MultiTeacherDistillationLossfrom samplers import RASampler",
        "from multiteacher_loss import MultiTeacherDistillationLoss\nfrom samplers import RASampler"
    )

def ensure_line_after(match_line_regex: str, new_line: str):
    """Insert `new_line` as a full line right AFTER the first line matching regex."""
    global txt
    if new_line.strip() in txt:
        return
    lines = txt.splitlines(True)  # keep line endings
    for i, line in enumerate(lines):
        if re.search(match_line_regex, line):
            # insert after this line
            if not new_line.endswith("\n"):
                new_line2 = new_line + "\n"
            else:
                new_line2 = new_line
            lines.insert(i + 1, new_line2)
            txt = "".join(lines)
            return
    raise RuntimeError(f"Could not find line to insert after: {match_line_regex}")

def ensure_block_after_line(match_line_regex: str, block: str):
    """Insert a multi-line block after first line matching regex."""
    global txt
    # Heuristic: if first unique token already exists, don't re-add
    if "--teacher-models" in block and "--teacher-models" in txt and "--hdtse-warmup-epochs" in txt and "--lambda-log" in txt:
        return
    lines = txt.splitlines(True)
    for i, line in enumerate(lines):
        if re.search(match_line_regex, line):
            if not block.endswith("\n"):
                block2 = block + "\n"
            else:
                block2 = block
            lines.insert(i + 1, block2)
            txt = "".join(lines)
            return
    raise RuntimeError(f"Could not find line to insert block after: {match_line_regex}")

def replace_first(pattern: str, repl: str, flags=re.DOTALL):
    global txt
    m = re.search(pattern, txt, flags)
    if not m:
        return False
    txt = txt[:m.start()] + repl + txt[m.end():]
    return True

def remove_first_line_matching(line_regex: str):
    global txt
    lines = txt.splitlines(True)
    for i, line in enumerate(lines):
        if re.search(line_regex, line):
            del lines[i]
            txt = "".join(lines)
            return True
    return False

# ------------------------------------------------------------
# 0) Repair if prior patch created the exact SyntaxError
# ------------------------------------------------------------
fix_broken_import_concatenation()

# ------------------------------------------------------------
# 1) Ensure MultiTeacherDistillationLoss import (safe line insertion)
# Insert after: from losses import DistillationLoss
# ------------------------------------------------------------
ensure_line_after(
    r"^\s*from\s+losses\s+import\s+DistillationLoss\s*$",
    "from multiteacher_loss import MultiTeacherDistillationLoss"
)

# ------------------------------------------------------------
# 2) Ensure CLI args after --teacher-path
# ------------------------------------------------------------
cli_block = """\
    parser.add_argument('--teacher-models', type=str, default='',
                        help='Comma-separated timm model names for multi-teacher distillation')
    parser.add_argument('--hdtse-warmup-epochs', type=int, default=0,
                        help='Use uniform teacher weights for first N epochs, then enable HDTSE weighting')
    parser.add_argument('--lambda-log', action='store_true', default=False,
                        help='Log mean λ (teacher weights) each epoch for multi-teacher distillation')
"""
ensure_block_after_line(r"^\s*parser\.add_argument\('--teacher-path'", cli_block)

# ------------------------------------------------------------
# 3) Allow finetune + distillation ONLY when multi-teacher is used
# Base guard is:
# if args.distillation_type != 'none' and args.finetune and not args.eval:
#     raise NotImplementedError(...)
# ------------------------------------------------------------
replace_first(
    r"^\s*if\s+args\.distillation_type\s*!=\s*'none'\s+and\s+args\.finetune\s+and\s+not\s+args\.eval\s*:\s*\n\s*raise\s+NotImplementedError\([^\n]*\)\s*$",
    "    if args.distillation_type != 'none' and args.finetune and not args.eval and not getattr(args, 'teacher_models', ''):\n"
    "        raise NotImplementedError(\"Finetuning with distillation not yet supported (single-teacher path)\")\n",
    flags=re.MULTILINE
)

# ------------------------------------------------------------
# 4) Move scheduler creation to AFTER adapter param-group add:
# Remove early: lr_scheduler, _ = create_scheduler(args, optimizer)
# ------------------------------------------------------------
remove_first_line_matching(r"^\s*lr_scheduler,\s*_\s*=\s*create_scheduler\(\s*args\s*,\s*optimizer\s*\)\s*$")

# ------------------------------------------------------------
# 5) Unify distillation region (multi-teacher vs single-teacher)
# We'll replace from "teacher_model = None" up to "output_dir = Path(args.output_dir)"
# This avoids indentation mistakes and prevents teacher_path='' crash.
# ------------------------------------------------------------
m_start = re.search(r"^\s*teacher_model\s*=\s*None\s*$", txt, flags=re.MULTILINE)
m_end   = re.search(r"^\s*output_dir\s*=\s*Path\(args\.output_dir\)\s*$", txt, flags=re.MULTILINE)
if not (m_start and m_end and m_start.start() < m_end.start()):
    raise RuntimeError("Could not locate distillation region anchors (teacher_model=None ... output_dir=Path(...))")

unified = """\
    teacher_model = None

    # -------------------------------
    # Unified single + multi-teacher distillation
    # -------------------------------
    teacher_models_str = getattr(args, 'teacher_models', '').strip()

    if args.distillation_type != 'none' and teacher_models_str:
        teacher_names = [t.strip() for t in teacher_models_str.split(',') if t.strip()]
        print(f"✅ Multi-teacher distillation enabled. Teachers: {teacher_names}")

        criterion = MultiTeacherDistillationLoss(
            base_criterion=criterion,
            student_num_classes=args.nb_classes,
            teacher_names=teacher_names,
            distillation_type=args.distillation_type,
            alpha=args.distillation_alpha,
            tau=args.distillation_tau,
            device=device,
            use_adapter=True,
            hdtse_warmup_epochs=getattr(args, 'hdtse_warmup_epochs', 0),
            lambda_log=getattr(args, 'lambda_log', False),
        )

        # Initialize Tiny-ImageNet wnid -> ImageNet-1k index mapping for teacher logits
        if hasattr(criterion, "set_tiny_root"):
            criterion.set_tiny_root(args.data_path)

        # Optional: alpha ramp if you add args later
        if hasattr(criterion, "set_alpha_schedule") and hasattr(args, "alpha_ramp_epochs"):
            criterion.set_alpha_schedule(
                alpha_start=getattr(args, "alpha_start", 0.0),
                alpha_ramp_epochs=getattr(args, "alpha_ramp_epochs", 20),
            )

    else:
        if args.distillation_type != 'none':
            assert args.teacher_path, 'need to specify teacher-path when using single-teacher distillation'
            print(f"Creating teacher model: {args.teacher_model}")
            teacher_model = create_model(
                args.teacher_model,
                pretrained=False,
                num_classes=args.nb_classes,
                global_pool='avg',
            )
            if args.teacher_path.startswith('https'):
                checkpoint = torch.hub.load_state_dict_from_url(
                    args.teacher_path, map_location='cpu', check_hash=True)
            else:
                checkpoint = torch.load(args.teacher_path, map_location='cpu')
            teacher_model.load_state_dict(checkpoint['model'])
            teacher_model.to(device)
            teacher_model.eval()

        criterion = DistillationLoss(
            criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
        )

    # Scheduler must be created AFTER all optimizer param groups are finalized
    lr_scheduler, _ = create_scheduler(args, optimizer)
"""

txt = txt[:m_start.start()] + unified + "\n    output_dir = Path(args.output_dir)\n" + txt[m_end.end():]

# ----------------------------
# 2) Ensure loss call uses (samples, outputs, targets)
# ----------------------------
# Patch ONLY the simple 2-arg form if present.
if "criterion(samples, outputs, targets)" not in txt:
    txt = re.sub(
        r"loss\s*=\s*criterion\(\s*outputs\s*,\s*targets\s*\)",
        r"loss = criterion(samples, outputs, targets)",
        txt
    )

# ------------------------------------------------------------
# 6) Insert criterion.set_epoch(epoch) before train_one_epoch
# We add it inside the epoch loop, after sampler.set_epoch if present.
# ------------------------------------------------------------
if "criterion.set_epoch(epoch)" not in txt:
    # If distributed block exists, insert after it
    if re.search(r"^\s*if\s+args\.distributed\s*:\s*\n\s*data_loader_train\.sampler\.set_epoch\(epoch\)\s*$", txt, flags=re.MULTILINE):
        txt = re.sub(
            r"(^\s*if\s+args\.distributed\s*:\s*\n\s*data_loader_train\.sampler\.set_epoch\(epoch\)\s*$)",
            r"\1\n        if hasattr(criterion, 'set_epoch'):\n            criterion.set_epoch(epoch)",
            txt,
            flags=re.MULTILINE,
            count=1
        )
    else:
        # Otherwise put at top of loop
        txt = re.sub(
            r"(^\s*for\s+epoch\s+in\s+range\(args\.start_epoch,\s*args\.epochs\)\s*:\s*$)",
            r"\1\n        if hasattr(criterion, 'set_epoch'):\n            criterion.set_epoch(epoch)",
            txt,
            flags=re.MULTILINE,
            count=1
        )

# ------------------------------------------------------------
# 7) Per-epoch λ logging after train_one_epoch call
# ------------------------------------------------------------
if "print('λ means:'" not in txt:
    txt = re.sub(
        r"(train_stats\s*=\s*train_one_epoch\([\s\S]*?\)\s*)\n",
        r"\1\n\n"
        r"        # Optional: log mean λ per teacher (multi-teacher only)\n"
        r"        if getattr(args, 'lambda_log', False) and hasattr(criterion, 'pop_lambda_stats'):\n"
        r"            lambda_means = criterion.pop_lambda_stats()\n"
        r"            if lambda_means:\n"
        r"                print('λ means:', lambda_means)\n",
        txt,
        count=1
    )

# ------------------------------------------------------------
# Write + compile check
# ------------------------------------------------------------
MAIN.write_text(txt)
py_compile.compile(str(MAIN), doraise=True)
print("✅ Patched main.py written and compiles:", MAIN)


✅ Patched main.py written and compiles: /content/deit/main.py


In [20]:
# Consolidated patcher for /content/deit/main.py
# Adds: Cosine α schedule CLI args + per-epoch α update (works for DistillationLoss + MultiTeacherDistillationLoss)
#
# Base file reference (as you shared): :contentReference[oaicite:0]{index=0}

from pathlib import Path
import re, py_compile

MAIN = Path("/content/deit/main.py")
txt = MAIN.read_text()

# ------------------------------------------------------------
# Helpers (line-safe insertions to avoid indentation/newline bugs)
# ------------------------------------------------------------
def ensure_import(module_name: str):
    """Ensure `import <module_name>` exists near the top-level imports."""
    global txt
    pat = rf"^\s*import\s+{re.escape(module_name)}\s*$"
    if re.search(pat, txt, flags=re.MULTILINE):
        return
    # Insert after "import time" if present, else after argparse
    lines = txt.splitlines(True)
    for i, line in enumerate(lines):
        if re.match(r"^\s*import\s+time\s*$", line):
            lines.insert(i + 1, f"import {module_name}\n")
            txt = "".join(lines)
            return
    for i, line in enumerate(lines):
        if re.match(r"^\s*import\s+argparse\s*$", line):
            lines.insert(i + 1, f"import {module_name}\n")
            txt = "".join(lines)
            return
    # fallback: very top
    txt = f"import {module_name}\n" + txt

def ensure_line_after(match_line_regex: str, new_line: str):
    """Insert `new_line` as a full line RIGHT AFTER the first line matching regex."""
    global txt
    if new_line.strip() in txt:
        return
    lines = txt.splitlines(True)
    for i, line in enumerate(lines):
        if re.search(match_line_regex, line):
            lines.insert(i + 1, (new_line if new_line.endswith("\n") else new_line + "\n"))
            txt = "".join(lines)
            return
    raise RuntimeError(f"Could not find line to insert after: {match_line_regex}")

def ensure_block_after_line(match_line_regex: str, block: str, unique_guard: str = None):
    """Insert a multi-line block after first line matching regex."""
    global txt
    if unique_guard and unique_guard in txt:
        return
    if (not unique_guard) and block.strip() in txt:
        return
    lines = txt.splitlines(True)
    for i, line in enumerate(lines):
        if re.search(match_line_regex, line):
            lines.insert(i + 1, (block if block.endswith("\n") else block + "\n"))
            txt = "".join(lines)
            return
    raise RuntimeError(f"Could not find line to insert block after: {match_line_regex}")

def replace_first(pattern: str, repl: str, flags=re.DOTALL):
    global txt
    m = re.search(pattern, txt, flags)
    if not m:
        return False
    txt = txt[:m.start()] + repl + txt[m.end():]
    return True

# ------------------------------------------------------------
# 0) Ensure `import math` (needed for cosine schedule)
# ------------------------------------------------------------
ensure_import("math")

# ------------------------------------------------------------
# 1) Add CLI args for alpha scheduling (after distillation params)
#    Insert right after: parser.add_argument('--distillation-tau' ...)
# ------------------------------------------------------------
alpha_cli_block = """\
    # ---- Distillation alpha scheduling (epoch-level) ----
    parser.add_argument('--alpha-schedule', default='none',
                        choices=['none', 'cosine'],
                        type=str, help="Schedule distillation alpha across epochs.")
    parser.add_argument('--alpha-start', default=0.05, type=float,
                        help="Starting alpha for alpha schedule (ignored if alpha-schedule=none).")
    parser.add_argument('--alpha-end', default=0.7, type=float,
                        help="Final alpha for alpha schedule (ignored if alpha-schedule=none).")
"""
ensure_block_after_line(
    r"^\s*parser\.add_argument\('--distillation-tau'.*\)\s*$",
    alpha_cli_block,
    unique_guard="--alpha-schedule"
)

# ------------------------------------------------------------
# 2) Add helper function for cosine schedule (module-level, before main())
#    Insert right after get_args_parser() returns `parser`
# ------------------------------------------------------------
alpha_fn_block = """\

def _cosine_alpha(epoch: int, total_epochs: int, alpha_start: float, alpha_end: float) -> float:
    # Smoothly increases alpha from alpha_start to alpha_end over training.
    total_epochs = int(total_epochs)
    if total_epochs <= 1:
        return alpha_end

    progress = float(epoch) / float(total_epochs-1)
    progress = min(max(progress, 0.0), 1.0)
    cosine = 0.5 * (1.0 - math.cos(math.pi * progress))
    return alpha_start + cosine * (alpha_end - alpha_start)
"""
# Insert after the "return parser" line inside get_args_parser()
ensure_block_after_line(
    r"^\s*return\s+parser\s*$",
    alpha_fn_block,
    unique_guard="def _cosine_alpha"
)

# ------------------------------------------------------------
# 3) Update α each epoch BEFORE train_one_epoch call, and push into criterion
#    Key detail: DistillationLoss() captures alpha at init in this base main.py,
#    so we must update criterion.alpha (or criterion.set_alpha) per epoch.
# ------------------------------------------------------------
epoch_hook_block = """\
        # ---- Cosine α schedule (distillation weight) ----
        if args.distillation_type != 'none' and getattr(args, 'alpha_schedule', 'none') == 'cosine':
            args.distillation_alpha = _cosine_alpha(
                epoch=epoch,
                total_epochs=args.epochs,
                alpha_start=getattr(args, 'alpha_start', 0.05),
                alpha_end=getattr(args, 'alpha_end', 0.7),
            )
            # Propagate into the actual loss wrapper (DistillationLoss / MultiTeacherDistillationLoss)
            if hasattr(criterion, "alpha"):
                criterion.alpha = args.distillation_alpha
            elif hasattr(criterion, "set_alpha"):
                criterion.set_alpha(args.distillation_alpha)

            if (not getattr(args, "distributed", False)) or getattr(args, "rank", 0) == 0:
                print(f"[alpha-schedule=cosine] epoch={epoch} distillation_alpha={args.distillation_alpha:.4f}")
"""

# Insert this block after the distributed sampler epoch set (if present),
# otherwise right after the for-loop line.
if "alpha-schedule=cosine" not in txt:
    # Case A: distributed sampler.set_epoch exists
    if re.search(
        r"^\s*if\s+args\.distributed\s*:\s*\n\s*data_loader_train\.sampler\.set_epoch\(epoch\)\s*$",
        txt,
        flags=re.MULTILINE
    ):
        txt = re.sub(
            r"(^\s*if\s+args\.distributed\s*:\s*\n\s*data_loader_train\.sampler\.set_epoch\(epoch\)\s*$)",
            r"\1\n\n" + epoch_hook_block.rstrip("\n"),
            txt,
            flags=re.MULTILINE,
            count=1
        )
    else:
        # Case B: no distributed stanza → insert at top of loop
        txt = re.sub(
            r"(^\s*for\s+epoch\s+in\s+range\(args\.start_epoch,\s*args\.epochs\)\s*:\s*$)",
            r"\1\n" + epoch_hook_block.rstrip("\n"),
            txt,
            flags=re.MULTILINE,
            count=1
        )

# ------------------------------------------------------------
# Write + compile check
# ------------------------------------------------------------
MAIN.write_text(txt)
py_compile.compile(str(MAIN), doraise=True)
print("✅ Patched main.py written and compiles:", MAIN)


✅ Patched main.py written and compiles: /content/deit/main.py


Before constructing the model, remove those keys from kwargs

In [21]:
from pathlib import Path

p = Path("/content/deit/models.py")
lines = p.read_text().splitlines()

out = []
for line in lines:
    out.append(line)
    if line.strip().startswith("def deit_") and "**kwargs" in line:
        out.append("    # Drop timm-injected kwargs not supported by DeiT")
        out.append("    kwargs.pop('pretrained_cfg', None)")
        out.append("    kwargs.pop('pretrained_cfg_overlay', None)")
        out.append("    kwargs.pop('pretrained_cfg_priority', None)")

p.write_text("\n".join(out) + "\n")
print("✅ models.py patched to drop pretrained_cfg kwargs")


✅ models.py patched to drop pretrained_cfg kwargs


Verify

Fix: Patch /content/deit/models.py to drop pretrained_cfg=...

Patch models.py to also drop cache_dir (and friends)

In [22]:
from pathlib import Path

p = Path("/content/deit/models.py")
lines = p.read_text().splitlines()

# Keys that timm may inject but DeiT constructors don't accept
DROP_KEYS = [
    "cache_dir",
    "hf_hub_id",
    "hf_hub_filename",
    "hf_hub_revision",
]

out = []
for line in lines:
    out.append(line)
    # Right after the comment line we previously inserted, add more pops once per function
    if line.strip() == "# Drop timm-injected kwargs not supported by DeiT":
        for k in DROP_KEYS:
            out.append(f"    kwargs.pop('{k}', None)")

p.write_text("\n".join(out) + "\n")
print("✅ Patched models.py to drop cache_dir/hf_hub* kwargs")


✅ Patched models.py to drop cache_dir/hf_hub* kwargs


In [23]:
!rm -f /content/imagenet_class_index.json
!wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json \
  -O /content/imagenet_class_index.json

!python - <<'PY'
import json
p="/content/imagenet_class_index.json"
with open(p,"r",encoding="utf-8") as f:
    obj=json.load(f)
print("Loaded OK. Entries:", len(obj))
print("Example 0:", obj["0"])

--2026-02-13 07:47:02--  https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json
Resolving s3.amazonaws.com (s3.amazonaws.com)... 3.5.0.81, 52.217.235.72, 52.216.40.224, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|3.5.0.81|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 35363 (35K) [application/octet-stream]
Saving to: ‘/content/imagenet_class_index.json’


2026-02-13 07:47:03 (158 KB/s) - ‘/content/imagenet_class_index.json’ saved [35363/35363]

Loaded OK. Entries: 1000
Example 0: ['n01440764', 'tench']


In [24]:
# %cd /content/deit
# !python main.py \
#   --model deit_tiny_patch16_224 \
#   --data-path /content/tiny-imagenet-200 \
#   --pretrained \
#   --epochs 1 \
#   --batch-size 64 \
#   --num_workers 2 \
#   --output_dir /content/deit_runs/smoke_test
# %cd /content/deit
# !python main.py \
#   --model deit_tiny_patch16_224 \
#   --data-path /content/tiny-imagenet-200 \
#   --epochs 1 \
#   --batch-size 128 \
#   --num_workers 4 \
#   --input-size 224 \
#   --opt adamw \
#   --lr 5e-4 \
#   --weight-decay 0.05 \
#   --sched cosine \
#   --aa rand-m9-mstd0.5 \
#   --reprob 0.25 \
#   --remode pixel \
#   --recount 1 \
#   --output_dir /content/deit_runs/tiny_imagenet
### correct one
# %cd /content/deit
# !python main.py \
#  --model deit_tiny_patch16_224 \
#  --data-path /content/tiny-imagenet-200 \
#  --finetune https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth \
#  --epochs 10 \
#  --batch-size 128 \
#  --num_workers 4 \
#  --input-size 224 \
#  --opt adamw \
#  --lr 3e-4 \
#  --weight-decay 0.05 \
#  --sched cosine \
#  --warmup-epochs 1 \
#  --smoothing 0.1 \
#  --aa rand-m7-mstd0.5 \
#  --reprob 0.1 \
#  --drop-path 0.1 \
#  --output_dir /content/deit_runs/tiny_imagenet_5ep
%cd /content/deit
!python main.py \
 --model deit_tiny_patch16_224 \
 --data-path /content/tiny-imagenet-200 \
 --finetune https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth \
 --epochs 90 \
 --batch-size 128 \
 --num_workers 4 \
 --input-size 224 \
 --opt adamw \
 --lr 2.5e-4 \
 --weight-decay 0.05 \
 --sched cosine \
 --warmup-epochs 4 \
 --smoothing 0.1 \
 --aa rand-m6-mstd0.5 \
 --reprob 0.2 \
 --model-ema \
 --model-ema-decay 0.9999 \
 --drop-path 0.05 \
 --mixup 0.2 \
 --cutmix 0.0 \
 --mixup-prob 0.5 \
 --distillation-type soft \
 --alpha-schedule cosine --alpha-start 0.1 --alpha-end 0.6 \
 --distillation-alpha 0.5 \
 --distillation-tau 3.5 \
 --hdtse-warmup-epochs 8 \
 --lambda-log \
 --output_dir /content/deit_runs/tiny_imagenet \
 --teacher-models "swin_base_patch4_window7_224,convnext_base,tf_efficientnetv2_l"
# %cd /content/deit
# !python main.py \
#  --model deit_tiny_patch16_224 \
#  --data-path /content/tiny-imagenet-200 \
#  --finetune https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth \
#  --epochs 10 \
#  --batch-size 128 \
#  --num_workers 4 \
#  --input-size 224 \
#  --opt adamw \
#  --lr 2.5e-4 \
#  --weight-decay 0.05 \
#  --sched cosine \
#  --warmup-epochs 1 \
#  --smoothing 0.1 \
#  --aa rand-m7-mstd0.5 \
#  --reprob 0.1 \
#  --drop-path 0.1 \
#  --distillation-type hard \
# --teacher-model regnety_160 \
# --teacher-path https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth \
#  --output_dir /content/deit_runs/tiny_imagenet_10ep
# %cd /content/deit
# !python main.py \
#  --model deit_tiny_distilled_patch16_224 \
#  --data-path /content/tiny-imagenet-200 \
#  --epochs 10 \
#  --batch-size 128 \
#  --num_workers 4 \
#  --input-size 224 \
#  --opt adamw \
#  --lr 7e-4 \
#  --weight-decay 0.05 \
#  --sched cosine \
#  --warmup-epochs 1 \
#  --smoothing 0.0 \
#  --aa rand-m7-mstd0.5 \
#  --reprob 0.1 \
#  --drop-path 0.0 \
#  --distillation-type hard \
#  --distillation-alpha 0.7 \
#  --teacher-model regnety_160 \
#  --teacher-path https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth \
#  --output_dir /content/deit_runs/deit_tiny_distilled_10ep





[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch: [31]  [760/781]  eta: 0:00:06  lr: 0.000048  loss: 1.4568 (1.5607)  time: 0.3323  data: 0.0003  max mem: 6459
Epoch: [31]  [770/781]  eta: 0:00:03  lr: 0.000048  loss: 1.5040 (1.5617)  time: 0.3321  data: 0.0003  max mem: 6459
Epoch: [31]  [780/781]  eta: 0:00:00  lr: 0.000048  loss: 1.5088 (1.5629)  time: 0.3320  data: 0.0006  max mem: 6459
Epoch: [31] Total time: 0:04:20 (0.3332 s / it)
Averaged stats: lr: 0.000048  loss: 1.5088 (1.5629)
λ means: {'lambda_swin_base_patch4_window7_224': 0.32951927185058594, 'lambda_convnext_base': 0.2587185204029083, 'lambda_tf_efficientnetv2_l': 0.41176265478134155}
Test:  [ 0/53]  eta: 0:00:44  loss: 0.8254 (0.8254)  acc1: 81.7708 (81.7708)  acc5: 95.3125 (95.3125)  time: 0.8467  data: 0.8161  max mem: 6459
Test:  [10/53]  eta: 0:00:07  loss: 1.0174 (1.0066)  acc1: 80.2083 (79.6875)  acc5: 94.2708 (92.8504)  time: 0.1783  data: 0.1479  max mem: 6459
Test:  [20/53]  eta: 0:00:05 

In [25]:
# Show lines first (sanity check)
!grep -n "torch.load(args.resume" /content/deit/main.py
!grep -n "torch.load(args.finetune" /content/deit/main.py

# Patch resume loader
!sed -i "s/torch.load(args.resume, map_location='cpu')/torch.load(args.resume, map_location='cpu', weights_only=False)/" /content/deit/main.py

# Patch finetune loader
!sed -i "s/torch.load(args.finetune, map_location='cpu')/torch.load(args.finetune, map_location='cpu', weights_only=False)/" /content/deit/main.py

462:            checkpoint = torch.load(args.resume, map_location='cpu')
305:            checkpoint = torch.load(args.finetune, map_location='cpu')


In [None]:
%cd /content/deit

!python main.py \
  --model deit_tiny_patch16_224 \
  --data-path /content/tiny-imagenet-200 \
  --finetune /content/deit_runs/tiny_imagenet/best_checkpoint.pth \
  --epochs 15 \
  --batch-size 128 \
  --num_workers 4 \
  --input-size 224 \
  --opt adamw \
  --lr 2.5e-5 \
  --weight-decay 0.02 \
  --sched cosine \
  --warmup-epochs 0 \
  --smoothing 0.05 \
  --aa rand-m3-mstd0.5 \
  --reprob 0.0 \
  --model-ema \
  --model-ema-decay 0.9999 \
  --drop-path 0.0 \
  --mixup 0.0 \
  --cutmix 0.0 \
  --mixup-prob 0.0 \
  --distillation-type soft \
  --distillation-alpha 0.25 \
  --distillation-tau 2.0 \
  --hdtse-warmup-epochs 0 \
  --lambda-log \
  --output_dir /content/deit_runs/tiny_imagenet_tail20_best \
  --teacher-models "swin_base_patch4_window7_224,convnext_base,tf_efficientnetv2_l"

/content/deit
Not using distributed mode
Namespace(batch_size=128, epochs=20, bce_loss=False, unscale_lr=False, model='deit_tiny_patch16_224', input_size=224, drop=0.0, drop_path=0.0, model_ema=True, model_ema_decay=0.9999, model_ema_force_cpu=False, opt='adamw', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.02, sched='cosine', lr=2.5e-05, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-06, min_lr=1e-05, decay_epochs=30, warmup_epochs=0, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m3-mstd0.5', smoothing=0.05, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=False, src=False, reprob=0.0, remode='pixel', recount=1, resplit=False, mixup=0.0, cutmix=0.0, cutmix_minmax=None, mixup_prob=0.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', teacher_models='swin_base_patch4_window7_224,convnext_base,tf_efficientnetv2_l', hdtse_warmup_epoch

# **Layer 2: Base Environment — Teacher Models & Multi-Teacher Adaptations**

Layer 2 extends the baseline DeiT environment to support knowledge distillation from one or more teacher models. This layer is additive: it does not modify the baseline DeiT training loop unless explicitly stated.
It includes
1. Teacher Model Support (Single & Multiple)
2. Teacher Registry / Configuration
3. Multi-Teacher Fusion Mechanism (Adaptation Layer)
4. Distillation Loss Integration