In [None]:
# Cell 2: Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import timm
import os
import numpy as np
import math # For sinusoidal embedding, sqrt
from PIL import Image
from tqdm.notebook import tqdm
import copy
from functools import partial # For default args in transforms
from einops import rearrange, repeat # If using einops

# For AMP
from torch.cuda.amp import GradScaler, autocast

print("Imports complete.")

In [None]:
# Cell 3: Global Hyperparameters & Configuration

# --- System ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATA_DIR = "/kaggle/input/officehome/OfficeHome" 
MODEL_SAVE_DIR_BASE = "/kaggle/working/models/" 
os.makedirs(MODEL_SAVE_DIR_BASE, exist_ok=True)
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

# --- Dataset & DataLoader ---
IMAGE_SIZE = 224
BATCH_SIZE = 32 
NUM_WORKERS = 2
NUM_CLASSES = 65 

# --- ViT Backbone ---
VIT_MODEL_NAME = 'vit_base_patch16_224'
VIT_EMBED_DIM = 768 

# --- LoRA ---
LORA_RANK = 16
LORA_ALPHA = 32.0
LORA_DROPOUT = 0.05

# --- DAD Module ---
DAD_K_STEPS = 200       
DAD_P_THETA_TIMESTEP_EMBED_DIM = 128
DAD_P_THETA_HIDDEN_DIM_MULT = 4 
DAD_P_THETA_MLP_HIDDEN_DIM = 1024
DAD_BETA_START = 1e-4
DAD_BETA_END = 2e-2

# --- Training (General) ---
SOURCE_DOMAIN_NAME = 'Art'
TARGET_DOMAIN_NAMES_ORDERED = ['Clipart', 'Product', 'RealWorld']
ALL_TRAINABLE_DOMAIN_NAMES = [SOURCE_DOMAIN_NAME] + TARGET_DOMAIN_NAMES_ORDERED

# --- Source Domain Training (Art) ---
ART_EPOCHS = 10
ART_LR_LORA_HEAD_LN = 5e-4 

# --- Continual Adaptation (Per Target Domain) ---
ADAPT_MLS_R_ITER = 10   
ADAPT_LR_LORA_HEAD_LN = 1e-4
ADAPT_LR_P_THETA = 1e-4 
ADAPT_LTR_EPOCHS = 5 
EARLY_STOPPING_PATIENCE_ADAPT = 3 # << NEW: Patience for early stopping

# --- EMA Teacher & Pseudo-Labeling ---
EMA_DECAY = 0.999
PSEUDO_LABEL_THRESHOLD_START = 0.7 
PSEUDO_LABEL_THRESHOLD_END = 0.9   

# --- FixMatch ---
FIXMATCH_CONF_THRESHOLD = 0.95
FIXMATCH_LAMBDA = 1.0 

# --- SHOT ---
SHOT_LAMBDA_COND_ENT = 0.05
SHOT_LAMBDA_ENT_MAX = 0.05

# --- Experience Replay (Optional) ---
REPLAY_BUFFER_SIZE = 2000
REPLAY_BATCH_SIZE_RATIO = 0.25 
REPLAY_LAMBDA = 0.1 

# --- Domain Classifier (for Robust Inference) ---
DC_HEAD_LR = 1e-3
DC_HEAD_EPOCHS = 10

# --- Robust Inference Pipeline ---
INFER_DOMAIN_CONF_THRESH = 0.7
INFER_EXPERT_CONF_THRESH = 0.6
INFER_STAGE2_EXPERT_CONF_THRESH = 0.5
INFER_K_EXPERTS_FOR_AVG = 3

# --- Autocast Context for AMP ---
# Corrected autocast_ctx definition
if DEVICE.type == 'cuda':
    autocast_ctx = partial(torch.amp.autocast, device_type='cuda', dtype=torch.float16, enabled=True)
else:
    # For CPU, autocast might not be beneficial or might need bfloat16 if supported
    # For simplicity, creating a no-op context manager for CPU
    import contextlib
    autocast_ctx = contextlib.nullcontext

print(f"Using device: {DEVICE}")
if torch.cuda.is_available(): print(f"CUDA available. GPU: {torch.cuda.get_device_name(0)}")
print(f"All configurations set. Base model save dir: {MODEL_SAVE_DIR_BASE}")

In [None]:
# Cell 4: Transforms Definition

# --- Base Transforms (from original notebook) ---
# For training student models (strong augmentation for FixMatch)
train_transform_strong = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), # Moderate
    transforms.RandomAffine(degrees=10, translate=(0.05, 0.05), scale=(0.95, 1.05)),
    # For FixMatch Strong Aug: RandAugment
    # transforms.RandAugment(num_ops=2, magnitude=10), # N=2, M=10 from blueprint
    # transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0), # Or Cutout
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# For validation, LTR target features, EMA teacher input (weak augmentation for FixMatch)
val_test_transform_weak = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5), # Standard weak aug
    # transforms.RandomCrop(IMAGE_SIZE, padding=int(IMAGE_SIZE*0.125), padding_mode='reflect'), # Optional crop
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Transform for loading PIL images without ToTensor or Normalize (for CombinedValDataset)
pil_load_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)) # Just resize, keep as PIL
])

print("Transforms defined: train_transform_strong, val_test_transform_weak, pil_load_transform.")
# Note: RandAugment and RandomErasing/Cutout for strong FixMatch augs need to be added if torchvision version supports them easily,
# or implemented manually/imported from timm.data.auto_augment.
# For now, using the original notebook's train_transform as a proxy for strong.

In [None]:
# Cell 5 (or a new cell for FixMatch Dataset)
class FixMatchOfficeHomeDataset(Dataset):
    def __init__(self, root_dir, domain_name, transform_weak, transform_strong,
                 split_ratios=(0.8, 0.1, 0.1), split_type='train',
                 random_seed=RANDOM_SEED, class_to_idx_mapping=None):
        # Use OfficeHomeDomainDataset to get image paths and labels internally
        self.base_dataset = OfficeHomeDomainDataset(
            root_dir, domain_name, transform=None, # No transform initially
            split_ratios=split_ratios, split_type=split_type,
            random_seed=random_seed, class_to_idx_mapping=class_to_idx_mapping,
            load_pil=True # Important: get PIL images
        )
        self.transform_weak = transform_weak
        self.transform_strong = transform_strong

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        pil_image, label = self.base_dataset[idx] # Gets PIL image and label

        img_weak = self.transform_weak(pil_image)
        img_strong = self.transform_strong(pil_image)
        
        # Label is not strictly needed for unlabeled target, but dataset provides it
        return img_weak, img_strong #, label (optional)

class OfficeHomeDomainDataset(Dataset):
    def __init__(self, root_dir, domain_name, transform=None, # General transform
                 split_ratios=(0.8, 0.1, 0.1), split_type='train', 
                 random_seed=RANDOM_SEED, class_to_idx_mapping=None, load_pil=False):
        self.domain_path = os.path.join(root_dir, domain_name)
        self.transform = transform
        self.load_pil = load_pil # If true, returns PIL image and label
        
        self.images_paths = [] # Store paths
        self.labels = []
        
        if class_to_idx_mapping is None:
            self.class_to_idx = {}
            self.idx_to_class = {}
            current_idx = 0
            for class_name_iter in sorted(os.listdir(self.domain_path)):
                if os.path.isdir(os.path.join(self.domain_path, class_name_iter)):
                    if class_name_iter not in self.class_to_idx:
                        self.class_to_idx[class_name_iter] = current_idx
                        self.idx_to_class[current_idx] = class_name_iter
                        current_idx += 1
        else:
            self.class_to_idx = class_to_idx_mapping
            self.idx_to_class = {v: k for k, v in class_to_idx_mapping.items()}

        for class_name in sorted(os.listdir(self.domain_path)):
            class_label_idx = self.class_to_idx.get(class_name)
            if class_label_idx is None: continue # Skip if class not in provided mapping

            class_path = os.path.join(self.domain_path, class_name)
            if not os.path.isdir(class_path): continue
            
            domain_class_images_paths = []
            for img_name in sorted(os.listdir(class_path)):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    domain_class_images_paths.append(os.path.join(class_path, img_name))
            
            np.random.seed(random_seed)
            np.random.shuffle(domain_class_images_paths)
            
            n_total = len(domain_class_images_paths)
            n_train = int(n_total * split_ratios[0])
            n_val = int(n_total * split_ratios[1])

            if split_type == 'train': selected_paths = domain_class_images_paths[:n_train]
            elif split_type == 'val': selected_paths = domain_class_images_paths[n_train : n_train + n_val]
            elif split_type == 'test': selected_paths = domain_class_images_paths[n_train + n_val:]
            elif split_type == 'all': selected_paths = domain_class_images_paths # For unlabeled target data
            else: raise ValueError("split_type must be 'train', 'val', 'test', or 'all'")

            self.images_paths.extend(selected_paths)
            self.labels.extend([class_label_idx] * len(selected_paths))

    def __len__(self):
        return len(self.images_paths)

    def __getitem__(self, idx):
        img_path = self.images_paths[idx]
        image_pil = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.load_pil:
            return image_pil, label # Return PIL image and label

        if self.transform:
            image_tensor = self.transform(image_pil)
        else: # If no transform, attempt to convert to tensor (basic)
            image_tensor = transforms.ToTensor()(image_pil) 
        return image_tensor, label

print("OfficeHomeDomainDataset class defined.")

In [None]:
# Cell 6: Global Class Mapping & Initial Data Check

# Create a global class_to_idx mapping from one domain (e.g., Art)
# This ensures all domains use the same integer labels for the same class names.
print("Attempting to create GLOBAL_CLASS_TO_IDX from Art domain...")
try:
    temp_art_dataset_for_map = OfficeHomeDomainDataset(DATA_DIR, 'Art', split_type='all')
    GLOBAL_CLASS_TO_IDX = temp_art_dataset_for_map.class_to_idx
    GLOBAL_IDX_TO_CLASS = temp_art_dataset_for_map.idx_to_class
    assert len(GLOBAL_CLASS_TO_IDX) == NUM_CLASSES, \
        f"Mismatch: Expected {NUM_CLASSES} classes, found {len(GLOBAL_CLASS_TO_IDX)} in Art domain."
    print(f"GLOBAL_CLASS_TO_IDX created successfully with {len(GLOBAL_CLASS_TO_IDX)} classes.")
    del temp_art_dataset_for_map
except Exception as e:
    print(f"ERROR: Could not create GLOBAL_CLASS_TO_IDX from Art domain: {e}")
    print("Please ensure DATA_DIR is correct and Art domain exists. Cannot proceed without class mapping.")
    GLOBAL_CLASS_TO_IDX = None # Critical failure

# Test loading Art dataset (example)
if GLOBAL_CLASS_TO_IDX:
    try:
        art_train_dataset_check = OfficeHomeDomainDataset(DATA_DIR, SOURCE_DOMAIN_NAME, 
                                                          transform=train_transform_strong, 
                                                          split_type='train',
                                                          class_to_idx_mapping=GLOBAL_CLASS_TO_IDX)
        art_val_dataset_check = OfficeHomeDomainDataset(DATA_DIR, SOURCE_DOMAIN_NAME, 
                                                        transform=val_test_transform_weak, 
                                                        split_type='val',
                                                        class_to_idx_mapping=GLOBAL_CLASS_TO_IDX)
        if len(art_train_dataset_check) > 0 and len(art_val_dataset_check) > 0:
            print(f"Successfully loaded check datasets for '{SOURCE_DOMAIN_NAME}': "
                  f"Train size {len(art_train_dataset_check)}, Val size {len(art_val_dataset_check)}")
        else:
            print(f"Warning: Check datasets for '{SOURCE_DOMAIN_NAME}' are empty.")
    except Exception as e:
        print(f"Error loading check datasets for '{SOURCE_DOMAIN_NAME}': {e}")

In [None]:
# Cell 7: ViT Backbone Loading Function
def load_frozen_vit_backbone(model_name=VIT_MODEL_NAME, device=DEVICE):
    vit_backbone = timm.create_model(model_name, pretrained=True, num_classes=0) # num_classes=0 removes head
    for param in vit_backbone.parameters():
        param.requires_grad = False
    vit_backbone = vit_backbone.to(device)
    vit_backbone.eval()
    print(f"Loaded and froze ViT backbone: {model_name}")
    return vit_backbone

# Load it once globally
# base_vit_frozen_global = load_frozen_vit_backbone() # Will be used for DC head and as base for LoRA experts
print("ViT backbone loading function defined.")

In [None]:
# Cell 8: LoRALinear and LoRA Injection Function (Updated for alpha, dropout)

class LoRALinear(nn.Module):
    def __init__(self, linear_layer, rank, alpha, lora_dropout_p=0.0):
        super().__init__()
        self.in_features = linear_layer.in_features
        self.out_features = linear_layer.out_features
        self.rank = rank
        self.alpha = alpha
        self.lora_dropout = nn.Dropout(p=lora_dropout_p)

        # Original weight and bias (frozen)
        self.weight = nn.Parameter(linear_layer.weight.detach().clone(), requires_grad=False)
        if linear_layer.bias is not None:
            self.bias = nn.Parameter(linear_layer.bias.detach().clone(), requires_grad=False)
        else:
            self.register_parameter('bias', None)

        # LoRA matrices A and B
        self.lora_A = nn.Parameter(torch.zeros(self.rank, self.in_features)) 
        self.lora_B = nn.Parameter(torch.zeros(self.out_features, self.rank))
        
        # Initialization
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) 
        nn.init.zeros_(self.lora_B)
        
        self.scaling = self.alpha / self.rank

    def forward(self, x):
        out_original = F.linear(x, self.weight, self.bias)
        lora_effect = self.lora_B @ self.lora_dropout(self.lora_A @ x.transpose(-2, -1)) # Adapting for (B, N, D) input
        lora_effect = lora_effect.transpose(-2, -1)
        return out_original + lora_effect * self.scaling
        # Original notebook had: lora_adaptation = (x @ self.lora_A.t()) @ self.lora_B.t()
        # This assumes x is (B, D_in). For ViT attention, x is (B, Num_Patches, D_in)
        # If linear_layer is applied to last dim: (B, N, D_in) @ A.T (D_in, R) -> (B, N, R)
        # Then (B, N, R) @ B.T (R, D_out) -> (B, N, D_out)
        # So, (x @ self.lora_A.t()) @ self.lora_B.t() should be correct if x is (B, ..., D_in)
        # Let's stick to the original notebook's LoRA forward for now, assuming it handles ViT's shapes.
        # Re-checking ViT structure: qkv is Linear(dim, dim * 3). Input to qkv is (B, N, D).
        # So x is (B, N, D_in). Output should be (B, N, D_out).
        # x @ self.lora_A.t() -> (B, N, D_in) @ (D_in, R) -> (B, N, R)
        # (B, N, R) @ self.lora_B.t() -> (B, N, R) @ (R, D_out) -> (B, N, D_out). This is correct.
        # The dropout should be applied to the output of lora_A multiplication.
        # lora_A_x = x @ self.lora_A.t()
        # lora_A_x_dropped = self.lora_dropout(lora_A_x)
        # lora_adaptation = lora_A_x_dropped @ self.lora_B.t()
        # return out_original + lora_adaptation * self.scaling


    def extra_repr(self):
        return f'in_features={self.in_features}, out_features={self.out_features}, rank={self.rank}, alpha={self.alpha}'

def inject_lora_to_vit_attention(vit_model, rank, alpha, lora_dropout_p=0.0):
    # Freezes base model params implicitly by not making them requires_grad=True later
    # The base model should already be frozen before calling this.
    injected_count = 0
    for block_idx, block in enumerate(vit_model.blocks):
        # Inject into QKV
        original_qkv = block.attn.qkv
        if isinstance(original_qkv, nn.Linear):
            block.attn.qkv = LoRALinear(original_qkv, rank, alpha, lora_dropout_p)
            injected_count += 1
        
        # Optionally, inject into attention projection output
        # original_proj = block.attn.proj
        # if isinstance(original_proj, nn.Linear):
        #    block.attn.proj = LoRALinear(original_proj, rank, alpha, lora_dropout_p)
        #    injected_count += 1 # Count as one layer if qkv is one, or more if separate q,k,v
    
    if injected_count == 0:
        print("WARNING: No QKV layers found or replaced with LoRA.")
    else:
        print(f"Injected LoRA (rank={rank}, alpha={alpha}) into {injected_count} attention layers (QKV combined).")
    return vit_model

print("LoRALinear class and injection function defined.")
# Test LoRA forward logic correction
x_test = torch.randn(2, 197, 768) # B, N, D_in
linear_test = nn.Linear(768, 768*3)
lora_linear_test = LoRALinear(linear_test, rank=LORA_RANK, alpha=LORA_ALPHA, lora_dropout_p=LORA_DROPOUT)
out_test = lora_linear_test(x_test)
print(f"Test LoRA output shape: {out_test.shape}") # Expected (2, 197, 768*3)

In [None]:
# Cell 9: DomainSpecificHead and LayerNormTrainable function

class DomainSpecificHead(nn.Module):
    def __init__(self, in_features=VIT_EMBED_DIM, num_classes=NUM_CLASSES):
        super().__init__()
        # The ViT's own norm layer (vit_model.norm) will be applied to features BEFORE this head.
        self.fc = nn.Linear(in_features, num_classes)

    def forward(self, x_normed_cls_token): # Expects already normed CLS token
        return self.fc(x_normed_cls_token)

def set_layernorm_affine_trainable(model, trainable=True):
    ln_param_count = 0
    for name, mod in model.named_modules():
        if isinstance(mod, nn.LayerNorm):
            if hasattr(mod, 'weight') and mod.weight is not None:
                mod.weight.requires_grad = trainable
                if trainable: ln_param_count += mod.weight.numel()
            if hasattr(mod, 'bias') and mod.bias is not None:
                mod.bias.requires_grad = trainable
                if trainable: ln_param_count += mod.bias.numel()
    status = "trainable" if trainable else "frozen"
    print(f"Set LayerNorm affine parameters {status}. Total LN params affected: {ln_param_count:,}")
    return model
    
print("DomainSpecificHead class and set_layernorm_affine_trainable function defined.")

In [None]:
# Cell 10: DAD Module - Timestep Embedding & p_theta Denoiser MLP

class SinusoidalTimestepEmbedding(nn.Module):
    def __init__(self, dim, max_period=10000):
        super().__init__()
        self.dim = dim
        self.max_period = max_period

    def forward(self, timesteps): # timesteps: (B,) or scalar, dtype=long
        if timesteps.ndim == 0: timesteps = timesteps.unsqueeze(0)
        device = timesteps.device
        half_dim = self.dim // 2
        freqs = torch.exp(
            -math.log(self.max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32) / half_dim
        ).to(device)
        args = timesteps.float().unsqueeze(1) * freqs.unsqueeze(0)
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if self.dim % 2: # Odd dimension
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding # Shape: (B, dim)

class DAD_P_Theta_MLP(nn.Module):
    def __init__(self, feature_dim=VIT_EMBED_DIM, 
                 timestep_embed_dim=DAD_P_THETA_TIMESTEP_EMBED_DIM,
                 ts_embed_hidden_mult=DAD_P_THETA_HIDDEN_DIM_MULT,
                 mlp_hidden_dim=DAD_P_THETA_MLP_HIDDEN_DIM):
        super().__init__()
        self.timestep_embed_dim_actual = timestep_embed_dim
        
        self.timestep_encoder = nn.Sequential(
            SinusoidalTimestepEmbedding(timestep_embed_dim),
            nn.Linear(timestep_embed_dim, timestep_embed_dim * ts_embed_hidden_mult),
            nn.GELU(),
            nn.Linear(timestep_embed_dim * ts_embed_hidden_mult, timestep_embed_dim * ts_embed_hidden_mult) # Projected dim
        )
        
        combined_dim = feature_dim + (timestep_embed_dim * ts_embed_hidden_mult)
        
        self.mlp = nn.Sequential(
            nn.Linear(combined_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, feature_dim) # Predicts epsilon (noise)
        )

    def forward(self, noisy_features, timesteps_long): # timesteps_long is (B,) or scalar, dtype=long
        # Ensure timesteps_long is on the same device as noisy_features for SinusoidalTimestepEmbedding
        timesteps_long = timesteps_long.to(noisy_features.device)
        
        ts_embedding = self.timestep_encoder(timesteps_long) # (B, ts_embed_dim_actual * ts_embed_hidden_mult)
        
        # Ensure ts_embedding is broadcastable if noisy_features has more batch dims (unlikely for CLS token)
        if ts_embedding.ndim == 1 and noisy_features.ndim > 1 : # scalar timestep input, batched features
             ts_embedding = ts_embedding.unsqueeze(0).expand(noisy_features.shape[0], -1)

        combined_input = torch.cat([noisy_features, ts_embedding], dim=-1)
        predicted_noise = self.mlp(combined_input)
        return predicted_noise

print("DAD p_theta MLP denoiser and Timestep Embedding defined.")
# Test p_theta
p_theta_test = DAD_P_Theta_MLP().to(DEVICE)
test_noisy_feat = torch.randn(BATCH_SIZE, VIT_EMBED_DIM).to(DEVICE)
test_timesteps = torch.randint(0, DAD_K_STEPS, (BATCH_SIZE,), dtype=torch.long).to(DEVICE)
pred_noise_test = p_theta_test(test_noisy_feat, test_timesteps)
print(f"p_theta test output shape: {pred_noise_test.shape}") # Expected (BATCH_SIZE, VIT_EMBED_DIM)

In [None]:
# Cell 11: DAD Diffusion Utilities (q_sample, q_reconstruct, schedules)

def make_dad_schedule(beta_start=DAD_BETA_START, beta_end=DAD_BETA_END, num_steps=DAD_K_STEPS, device=DEVICE):
    betas = torch.linspace(beta_start, beta_end, num_steps, dtype=torch.float32, device=device)
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    return betas, alphas, alphas_cumprod

# Make schedules globally accessible after defining DAD_K_STEPS
# These will be on CPU by default if DEVICE is not yet cuda during this cell's run,
# so ensure they are moved to device when used or defined after DEVICE is set.
# Or, pass device to functions that use them.
# For simplicity, define them here assuming DEVICE is set.
if DEVICE.type == 'cuda' and not torch.cuda.is_available(): # Fallback if cuda specified but not avail
    print("Warning: DEVICE is cuda but not available. DAD schedules on CPU.")
    _sched_device = torch.device('cpu')
else:
    _sched_device = DEVICE

DAD_BETAS, DAD_ALPHAS, DAD_ALPHAS_CUMPROD = make_dad_schedule(device=_sched_device)

def q_sample_dad(x_start, k_indices, noise=None): # k_indices are 0-indexed timesteps (B,)
    if noise is None:
        noise = torch.randn_like(x_start)
    
    # Ensure k_indices is long and on the same device as alphas_cumprod
    k_indices = k_indices.long().to(DAD_ALPHAS_CUMPROD.device)
    
    sqrt_alphas_cumprod_t = torch.sqrt(DAD_ALPHAS_CUMPROD[k_indices])
    sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1.0 - DAD_ALPHAS_CUMPROD[k_indices])

    # Reshape for broadcasting: (B,) -> (B, 1) if x_start is (B, D)
    if x_start.ndim > 1 and sqrt_alphas_cumprod_t.ndim == 1:
        sqrt_alphas_cumprod_t = sqrt_alphas_cumprod_t.view(-1, *([1]*(x_start.ndim-1)))
        sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod_t.view(-1, *([1]*(x_start.ndim-1)))
        
    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

def q_reconstruct_dad(x_k, epsilon_theta_hat, k_indices): # k_indices are 0-indexed timesteps (B,)
    k_indices = k_indices.long().to(DAD_ALPHAS_CUMPROD.device)

    sqrt_alphas_cumprod_t = torch.sqrt(DAD_ALPHAS_CUMPROD[k_indices])
    sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1.0 - DAD_ALPHAS_CUMPROD[k_indices])

    if x_k.ndim > 1 and sqrt_alphas_cumprod_t.ndim == 1:
        sqrt_alphas_cumprod_t = sqrt_alphas_cumprod_t.view(-1, *([1]*(x_k.ndim-1)))
        sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod_t.view(-1, *([1]*(x_k.ndim-1)))

    x0_hat = (x_k - sqrt_one_minus_alphas_cumprod_t * epsilon_theta_hat) / sqrt_alphas_cumprod_t
    return x0_hat

print("DAD diffusion utilities (schedules, q_sample, q_reconstruct) defined.")
print(f"DAD schedules (betas, alphas, alphas_cumprod) created on device: {DAD_BETAS.device}")

In [None]:
# Cell 12: Utility Functions (EMA Update, Replay Buffer (optional))

def update_ema_teacher_components(student_vit, student_head, teacher_vit, teacher_head, decay):
    with torch.no_grad():
        # Update ViT parameters (LoRA, LayerNorm affine)
        for stud_param, teach_param in zip(student_vit.parameters(), teacher_vit.parameters()):
            if stud_param.requires_grad: # Only update trainable student params into teacher
                if teach_param.requires_grad:
                    print("Warning: Teacher ViT parameter has requires_grad=True during EMA update.")
                teach_param.data.mul_(decay).add_(stud_param.data, alpha=1 - decay)
        
        # Update Head parameters
        for stud_param, teach_param in zip(student_head.parameters(), teacher_head.parameters()):
            # Head params are usually all trainable in student, and all frozen in teacher
            if teach_param.requires_grad:
                print("Warning: Teacher Head parameter has requires_grad=True during EMA update.")
            teach_param.data.mul_(decay).add_(stud_param.data, alpha=1 - decay)

print("Utility functions (EMA update, Replay Buffer) defined.")

# Optional: Replay Buffer Class
class ExperienceReplayBuffer:
    def __init__(self, buffer_size, device=DEVICE):
        self.buffer_size = buffer_size
        self.device = device
        self.buffer_images = []
        self.buffer_labels = []
        self.position = 0

    def add(self, images_tensor, labels_tensor): # Expects tensors
        batch_size = images_tensor.size(0)
        for i in range(batch_size):
            img = images_tensor[i].cpu() # Store on CPU to save GPU VRAM
            lbl = labels_tensor[i].cpu()
            if len(self.buffer_images) < self.buffer_size:
                self.buffer_images.append(img)
                self.buffer_labels.append(lbl)
            else:
                self.buffer_images[self.position] = img
                self.buffer_labels[self.position] = lbl
            self.position = (self.position + 1) % self.buffer_size
    
    def sample(self, batch_size):
        if len(self.buffer_images) < batch_size:
            # Not enough samples, return what's available or None
            if not self.buffer_images: return None, None
            indices = np.random.choice(len(self.buffer_images), len(self.buffer_images), replace=False)
        else:
            indices = np.random.choice(len(self.buffer_images), batch_size, replace=False)
        
        sampled_images = torch.stack([self.buffer_images[i] for i in indices]).to(self.device)
        sampled_labels = torch.stack([self.buffer_labels[i] for i in indices]).to(self.device)
        return sampled_images, sampled_labels.long()

    def __len__(self):
        return len(self.buffer_images)

print("Utility functions (EMA update, Replay Buffer) defined.")

# Part 3 - Source Domain (Art) Supervised Training

In [None]:
# Cell 13: Source Domain - Dataset & DataLoader (Art)

print(f"\n--- Preparing Source Domain: {SOURCE_DOMAIN_NAME} ---")

source_domain_model_save_dir = os.path.join(MODEL_SAVE_DIR_BASE, SOURCE_DOMAIN_NAME)
os.makedirs(source_domain_model_save_dir, exist_ok=True)

if GLOBAL_CLASS_TO_IDX is None:
    raise RuntimeError("GLOBAL_CLASS_TO_IDX is not defined. Cannot proceed with source domain training.")

source_train_dataset = OfficeHomeDomainDataset(
    DATA_DIR, SOURCE_DOMAIN_NAME,
    transform=train_transform_strong, # Use strong aug for source training too
    split_type='train',
    class_to_idx_mapping=GLOBAL_CLASS_TO_IDX
)
source_val_dataset = OfficeHomeDomainDataset(
    DATA_DIR, SOURCE_DOMAIN_NAME,
    transform=val_test_transform_weak,
    split_type='val',
    class_to_idx_mapping=GLOBAL_CLASS_TO_IDX
)

source_train_loader = DataLoader(
    source_train_dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=True, drop_last=True
)
source_val_loader = DataLoader(
    source_val_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True
)

print(f"Source domain '{SOURCE_DOMAIN_NAME}': Train size {len(source_train_dataset)}, Val size {len(source_val_dataset)}")
print(f"Train loader batches: {len(source_train_loader)}, Val loader batches: {len(source_val_loader)}")

In [None]:
# Cell 14: Source Domain - Model Instantiation & Optimizer

# 1. Load base frozen ViT
base_vit_source_train = load_frozen_vit_backbone(device=DEVICE) # Fresh frozen backbone

# 2. Create a new LoRA-adapted ViT for the source domain
source_vit_lora = copy.deepcopy(base_vit_source_train) # Start with frozen base
source_vit_lora = inject_lora_to_vit_attention(source_vit_lora, rank=LORA_RANK, alpha=LORA_ALPHA, lora_dropout_p=LORA_DROPOUT)
source_vit_lora = set_layernorm_affine_trainable(source_vit_lora, trainable=True) # Make LN affine trainable
source_vit_lora = source_vit_lora.to(DEVICE)

# 3. Instantiate Source-Specific Head
source_head = DomainSpecificHead(in_features=VIT_EMBED_DIM, num_classes=NUM_CLASSES).to(DEVICE)

# 4. Optimizer for Source LoRA parameters, Source Head, and LayerNorm affine parameters
params_to_train_source = []
for name, param in source_vit_lora.named_parameters():
    if param.requires_grad: # LoRA A/B and LayerNorm affine params
        params_to_train_source.append(param)
params_to_train_source.extend(list(source_head.parameters())) # All head params

optimizer_source = optim.AdamW(params_to_train_source, lr=ART_LR_LORA_HEAD_LN, weight_decay=0.05)
criterion_source = nn.CrossEntropyLoss(label_smoothing=0.1)

# LR Scheduler (Poly Decay)
num_total_steps_source = ART_EPOCHS * len(source_train_loader)
scheduler_source = optim.lr_scheduler.LambdaLR(
    optimizer_source,
    lr_lambda=lambda step: (1 - step / num_total_steps_source) ** 0.9 # Poly power 0.9
)

grad_scaler_source = GradScaler(enabled=(DEVICE.type == 'cuda'))

print(f"Source expert models (LoRA ViT, Head) instantiated for '{SOURCE_DOMAIN_NAME}'.")
trainable_params_count_source = sum(p.numel() for p in params_to_train_source)
print(f"Total trainable parameters for source expert: {trainable_params_count_source:,}")

In [None]:
# Cell 15: Source Domain - Training Loop

print(f"\n--- Training Source Expert ({SOURCE_DOMAIN_NAME}) ---")
best_source_val_acc = 0.0

for epoch in range(ART_EPOCHS):
    source_vit_lora.train()
    source_head.train()
    
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    
    progress_bar = tqdm(source_train_loader, desc=f"Epoch {epoch+1}/{ART_EPOCHS} [{SOURCE_DOMAIN_NAME} Train]")
    for images, labels in progress_bar:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        
        optimizer_source.zero_grad()
        
        with autocast_ctx():
            # Features from LoRA-ViT
                all_features = source_vit_lora.forward_features(images) # Shape: (B, Num_Tokens, Embed_Dim)
                cls_features = all_features[:, 0] # Shape: (B, Embed_Dim)

                # Apply ViT's own final norm before the head
                normed_cls_features = source_vit_lora.norm(cls_features) # Shape: (B, Embed_Dim)
                logits = source_head(normed_cls_features) # Shape: (B, NUM_CLASSES)
                
                # # Debug shapes:
                # print(f"images shape: {images.shape}")
                # print(f"all_features shape: {all_features.shape}")
                # print(f"cls_features shape: {cls_features.shape}")
                # print(f"normed_cls_features shape: {normed_cls_features.shape}")
                # print(f"logits shape: {logits.shape}")
                # print(f"labels shape: {labels.shape}, dtype: {labels.dtype}")

                loss = criterion_source(logits, labels)
        
        grad_scaler_source.scale(loss).backward()
        grad_scaler_source.step(optimizer_source)
        grad_scaler_source.update()
        scheduler_source.step() # Step LR scheduler
        
        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(logits, 1)
        correct_predictions += torch.sum(preds == labels.data)
        total_samples += images.size(0)
        
        progress_bar.set_postfix(loss=loss.item(), acc=correct_predictions.double().item()/total_samples if total_samples > 0 else 0.0, lr=optimizer_source.param_groups[0]['lr'])

    epoch_loss = running_loss / total_samples
    epoch_acc = correct_predictions.double() / total_samples
    print(f"Epoch {epoch+1} Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}, LR: {optimizer_source.param_groups[0]['lr']:.6f}")

    # Validation
    source_vit_lora.eval()
    source_head.eval()
    val_loss = 0.0
    val_correct_predictions = 0
    val_total_samples = 0
    with torch.no_grad():
        for images, labels in tqdm(source_val_loader, desc=f"Epoch {epoch+1}/{ART_EPOCHS} [{SOURCE_DOMAIN_NAME} Val]", leave=False):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            with autocast_ctx(): # Apply autocast fix here too
                all_features_val = source_vit_lora.forward_features(images) # Shape: (B, Num_Tokens, Embed_Dim)
                cls_features_val = all_features_val[:, 0] # Shape: (B, Embed_Dim)
                normed_cls_features_val = source_vit_lora.norm(cls_features_val) # Shape: (B, Embed_Dim)
                logits = source_head(normed_cls_features_val) # Shape: (B, NUM_CLASSES)
                loss = criterion_source(logits, labels)
            
            val_loss += loss.item() * images.size(0)
            _, preds = torch.max(logits, 1)
            val_correct_predictions += torch.sum(preds == labels.data)
            val_total_samples += images.size(0)
            
    epoch_val_loss = val_loss / val_total_samples
    epoch_val_acc = val_correct_predictions.double() / val_total_samples
    print(f"Epoch {epoch+1} Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")

    if epoch_val_acc > best_source_val_acc:
        best_source_val_acc = epoch_val_acc
        torch.save(source_vit_lora.state_dict(), os.path.join(source_domain_model_save_dir, f"{SOURCE_DOMAIN_NAME.lower()}_vit_lora_best.pth"))
        torch.save(source_head.state_dict(), os.path.join(source_domain_model_save_dir, f"{SOURCE_DOMAIN_NAME.lower()}_head_best.pth"))
        print(f"    -> New best Val Acc: {best_source_val_acc:.4f}. Models saved.")
        
print(f"\nSource domain '{SOURCE_DOMAIN_NAME}' training finished. Best Val Acc: {best_source_val_acc:.4f}")
# Load best models for subsequent use
source_vit_lora.load_state_dict(torch.load(os.path.join(source_domain_model_save_dir, f"{SOURCE_DOMAIN_NAME.lower()}_vit_lora_best.pth")))
source_head.load_state_dict(torch.load(os.path.join(source_domain_model_save_dir, f"{SOURCE_DOMAIN_NAME.lower()}_head_best.pth")))
source_vit_lora.eval()
source_head.eval()
print("Best source expert models loaded and set to eval mode.")

In [None]:
# Cell 15b: Baseline Validation of Source Expert on All Domains

print(f"\n\n--- Baseline Validation of Source Expert ('{SOURCE_DOMAIN_NAME}') on All Domains ---")

if 'source_vit_lora' not in globals() or source_vit_lora is None or \
   'source_head' not in globals() or source_head is None:
    print(f"Warning: Source expert ({SOURCE_DOMAIN_NAME}) models not available. Skipping baseline validation.")
else:
    source_vit_lora.eval() # Ensure in eval mode
    source_head.eval()   # Ensure in eval mode

    baseline_accuracies_source_expert = {}
    criterion_val_baseline = nn.CrossEntropyLoss(label_smoothing=0.1) # Use a fresh criterion for validation

    for domain_name_eval in ALL_TRAINABLE_DOMAIN_NAMES:
        print(f"\n  Validating Source Expert ('{SOURCE_DOMAIN_NAME}') on '{domain_name_eval}' domain...")

        try:
            val_dataset_current_domain = OfficeHomeDomainDataset(
                DATA_DIR, domain_name_eval,
                transform=val_test_transform_weak, # Standard validation transform
                split_type='val', # Use the 10% validation split
                class_to_idx_mapping=GLOBAL_CLASS_TO_IDX
            )
            if len(val_dataset_current_domain) == 0:
                print(f"    Warning: Validation dataset for '{domain_name_eval}' is empty. Skipping.")
                baseline_accuracies_source_expert[domain_name_eval] = 0.0
                continue

            val_loader_current_domain = DataLoader(
                val_dataset_current_domain,
                batch_size=BATCH_SIZE,
                shuffle=False,
                num_workers=NUM_WORKERS,
                pin_memory=True
            )
        except Exception as e:
            print(f"    ERROR: Could not load validation dataset for '{domain_name_eval}': {e}. Skipping.")
            baseline_accuracies_source_expert[domain_name_eval] = -1.0 # Indicate error
            continue
        
        val_loss_domain = 0.0
        val_correct_predictions_domain = 0
        val_total_samples_domain = 0

        with torch.no_grad():
            for images, labels in tqdm(val_loader_current_domain, desc=f"Val on {domain_name_eval}", leave=False):
                images, labels = images.to(DEVICE), labels.to(DEVICE)

                with autocast_ctx():
                    # Pass images through the Source LoRA-ViT
                    all_features = source_vit_lora.forward_features(images)
                    cls_features = all_features[:, 0]
                    normed_cls_features = source_vit_lora.norm(cls_features)
                    # Then through the Source Head
                    logits = source_head(normed_cls_features)
                    loss = criterion_val_baseline(logits, labels)

                val_loss_domain += loss.item() * images.size(0)
                _, preds = torch.max(logits, 1)
                val_correct_predictions_domain += torch.sum(preds == labels.data)
                val_total_samples_domain += images.size(0)

        if val_total_samples_domain > 0:
            epoch_val_loss_domain = val_loss_domain / val_total_samples_domain
            epoch_val_acc_domain = val_correct_predictions_domain.double() / val_total_samples_domain
            baseline_accuracies_source_expert[domain_name_eval] = epoch_val_acc_domain.item() * 100 # Store as percentage
            print(f"    '{domain_name_eval}' Val Loss: {epoch_val_loss_domain:.4f}, Val Acc: {epoch_val_acc_domain.item()*100:.2f}%")
        else:
            print(f"    No samples processed for '{domain_name_eval}' validation.")
            baseline_accuracies_source_expert[domain_name_eval] = 0.0

    # --- Print Summary of Baseline Accuracies ---
    print("\n--- Source Expert Baseline Performance Summary ---")
    avg_baseline_acc = 0
    count_valid_domains = 0
    for domain, acc in baseline_accuracies_source_expert.items():
        if acc != -1.0: # Check for loading errors
            print(f"  Accuracy of '{SOURCE_DOMAIN_NAME}' expert on '{domain}': {acc:.2f}%")
            if domain != SOURCE_DOMAIN_NAME: # Calculate average target accuracy (excluding source-on-source)
                avg_baseline_acc += acc
                count_valid_domains +=1
        else:
            print(f"  Accuracy of '{SOURCE_DOMAIN_NAME}' expert on '{domain}': ERROR (Dataset issue)")
    
    if count_valid_domains > 0:
        avg_target_acc = avg_baseline_acc / count_valid_domains
        print(f"  Average Target Domain Accuracy (Source Expert): {avg_target_acc:.2f}%")
    else:
        print("  Could not calculate average target domain accuracy.")

    # Sanity check (optional, based on expected performance)
    # if SOURCE_DOMAIN_NAME in baseline_accuracies_source_expert and baseline_accuracies_source_expert[SOURCE_DOMAIN_NAME] < 10.0: # Very low
    #     print(f"WARNING: Source expert validation accuracy on its own domain ('{SOURCE_DOMAIN_NAME}') is very low ({baseline_accuracies_source_expert[SOURCE_DOMAIN_NAME]:.2f}%). Training might have issues.")

In [None]:
# Cell 16: Prepare for Continual Learning - Store Experts

# This dictionary will store the best trained expert (LoRA ViT + Head) for each domain
# It will be populated as we adapt to each new domain.
# For inference pipeline, this will be part of `all_task_experts`.
adapted_experts = {} 

# Store the source expert (which is already trained and loaded as best)
# This is the starting point for the first adaptation.
# For the robust inference pipeline, this source expert will also be needed.
# We can add it to `all_task_experts` later in the inference setup cell.

# Keep track of the "current best expert" to initialize the next adaptation stage
# Initially, this is the source expert.
current_expert_vit = copy.deepcopy(source_vit_lora).cpu() # Move to CPU to save GPU VRAM if not immediately needed
current_expert_head = copy.deepcopy(source_head).cpu()
current_expert_domain_name = SOURCE_DOMAIN_NAME

# Global frozen ViT backbone (without LoRA) for DC head and as base for new LoRA experts
base_vit_frozen_global = load_frozen_vit_backbone(device=DEVICE)

# Optional: Initialize Replay Buffer if used
# global_replay_buffer = ExperienceReplayBuffer(REPLAY_BUFFER_SIZE, device=DEVICE) if REPLAY_BUFFER_SIZE > 0 else None
global_replay_buffer = None # Disable by default for simplicity first

print(f"Preparation for continual learning complete. Current expert: '{current_expert_domain_name}'.")
print(f"Base frozen ViT ('base_vit_frozen_global') loaded for future use.")

In [None]:
# Cell 17: Continual Adaptation Loop for Target Domains

# Ensure current_expert_vit, current_expert_head, and current_expert_domain_name 
# are initialized from the source domain training (after Cell 16)

for target_domain_idx, target_domain_name_current_loop in enumerate(TARGET_DOMAIN_NAMES_ORDERED):
    
    TARGET_DOMAIN_NAME_CURRENT = target_domain_name_current_loop
    if target_domain_idx == 0 and current_expert_domain_name == SOURCE_DOMAIN_NAME:
        PREVIOUS_DOMAIN_NAME = SOURCE_DOMAIN_NAME
    elif target_domain_idx > 0 and current_expert_domain_name == TARGET_DOMAIN_NAMES_ORDERED[target_domain_idx-1]:
         PREVIOUS_DOMAIN_NAME = current_expert_domain_name
    else:
        # This case should ideally not happen if current_expert_domain_name is updated correctly
        print(f"Warning: Mismatch or unexpected previous domain. current_expert_domain_name='{current_expert_domain_name}', expected previous for '{TARGET_DOMAIN_NAME_CURRENT}'")
        PREVIOUS_DOMAIN_NAME = current_expert_domain_name # Fallback

    print(f"\n\n=== Starting Adaptation: {PREVIOUS_DOMAIN_NAME} -> {TARGET_DOMAIN_NAME_CURRENT} ===")

    target_model_save_dir = os.path.join(MODEL_SAVE_DIR_BASE, TARGET_DOMAIN_NAME_CURRENT)
    os.makedirs(target_model_save_dir, exist_ok=True)

    # --- 1. Datasets & DataLoaders ---
    print(f"  Loading datasets for Previous ('{PREVIOUS_DOMAIN_NAME}') and Target ('{TARGET_DOMAIN_NAME_CURRENT}')...")
    previous_domain_train_dataset = OfficeHomeDomainDataset(
        DATA_DIR, PREVIOUS_DOMAIN_NAME, transform=train_transform_strong,
        split_type='train', class_to_idx_mapping=GLOBAL_CLASS_TO_IDX
    )
    previous_domain_train_loader = DataLoader(
        previous_domain_train_dataset, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=NUM_WORKERS, pin_memory=True, drop_last=True
    )

    target_unlabeled_loader_for_ema_shot = DataLoader( 
        OfficeHomeDomainDataset(DATA_DIR, TARGET_DOMAIN_NAME_CURRENT, transform=val_test_transform_weak, 
                                split_type='train', class_to_idx_mapping=GLOBAL_CLASS_TO_IDX),
        batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True
    )
    fixmatch_target_dataset = FixMatchOfficeHomeDataset(
        DATA_DIR, TARGET_DOMAIN_NAME_CURRENT,
        transform_weak=val_test_transform_weak, 
        transform_strong=train_transform_strong, 
        split_type='train', 
        class_to_idx_mapping=GLOBAL_CLASS_TO_IDX
    )
    fixmatch_target_loader = DataLoader(
        fixmatch_target_dataset, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=NUM_WORKERS, pin_memory=True, drop_last=True
    )
    target_val_dataset = OfficeHomeDomainDataset(
        DATA_DIR, TARGET_DOMAIN_NAME_CURRENT, transform=val_test_transform_weak,
        split_type='val', class_to_idx_mapping=GLOBAL_CLASS_TO_IDX
    )
    target_val_loader = DataLoader(target_val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

    print(f"  Previous domain ('{PREVIOUS_DOMAIN_NAME}') train loader: {len(previous_domain_train_loader)} batches.")
    print(f"  Target domain ('{TARGET_DOMAIN_NAME_CURRENT}'):")
    print(f"    EMA/SHOT loader dataset size: {len(target_unlabeled_loader_for_ema_shot.dataset)}")
    print(f"    FixMatch loader dataset size: {len(fixmatch_target_loader.dataset)}")
    print(f"    Val dataset size: {len(target_val_dataset)}")

    # --- 2. Model Initialization ---
    print(f"  Initializing Student and Teacher models for '{TARGET_DOMAIN_NAME_CURRENT}'...")
    student_vit = copy.deepcopy(base_vit_frozen_global) 
    student_vit = inject_lora_to_vit_attention(student_vit, rank=LORA_RANK, alpha=LORA_ALPHA, lora_dropout_p=LORA_DROPOUT)
    
    prev_expert_vit_state_dict = current_expert_vit.cpu().state_dict()
    student_vit_state_dict_new = student_vit.state_dict()
    load_dict_student_vit_warm_start = {}
    for k, v in prev_expert_vit_state_dict.items():
        if ('lora_' in k or ('.norm' in k and (k.endswith('.weight') or k.endswith('.bias')))) and \
           k in student_vit_state_dict_new and student_vit_state_dict_new[k].shape == v.shape:
            load_dict_student_vit_warm_start[k] = v
    if load_dict_student_vit_warm_start:
        missing_keys_vit, unexpected_keys_vit = student_vit.load_state_dict(load_dict_student_vit_warm_start, strict=False)
        print(f"  Loaded LoRA/LN from previous expert '{current_expert_domain_name}'. Missing: {len(missing_keys_vit)}, Unexpected: {len(unexpected_keys_vit)}")
    else:
        print(f"  Warning: No LoRA/LN weights loaded from previous expert '{current_expert_domain_name}'. Starting fresh LoRA/LN for student ViT.")

    student_vit = set_layernorm_affine_trainable(student_vit, trainable=True) 
    student_vit = student_vit.to(DEVICE)

    student_head = DomainSpecificHead(in_features=VIT_EMBED_DIM, num_classes=NUM_CLASSES)
    student_head.load_state_dict(current_expert_head.cpu().state_dict()) 
    student_head = student_head.to(DEVICE)

    teacher_vit = copy.deepcopy(student_vit).to(DEVICE)
    teacher_head = copy.deepcopy(student_head).to(DEVICE)
    for param in teacher_vit.parameters(): param.requires_grad = False
    for param in teacher_head.parameters(): param.requires_grad = False
    teacher_vit.eval(); teacher_head.eval()
    update_ema_teacher_components(student_vit, student_head, teacher_vit, teacher_head, 0.0) 

    p_theta_dad = DAD_P_Theta_MLP().to(DEVICE)
    print(f"  Student, Teacher, and DAD p_theta models ready for '{TARGET_DOMAIN_NAME_CURRENT}'.")

    # --- 3. Optimizers ---
    params_to_train_student = []
    for name, param in student_vit.named_parameters():
        if param.requires_grad: params_to_train_student.append(param)
    params_to_train_student.extend(list(student_head.parameters()))

    optimizer_student = optim.AdamW(params_to_train_student, lr=ADAPT_LR_LORA_HEAD_LN, weight_decay=0.05)
    optimizer_p_theta = optim.SGD(p_theta_dad.parameters(), lr=ADAPT_LR_P_THETA, momentum=0.9, weight_decay=4.5e-3)
    
    grad_scaler_adapt = GradScaler(enabled=(DEVICE.type == 'cuda')) 
    criterion_adapt_ce = nn.CrossEntropyLoss(label_smoothing=0.1) # Added label smoothing
    criterion_adapt_mse = nn.MSELoss() 
    print("  Optimizers created.")

    # --- 4. DAD LTR Pre-training ---
    print(f"  Starting DAD LTR Pre-training for p_theta ({PREVIOUS_DOMAIN_NAME} -> {TARGET_DOMAIN_NAME_CURRENT})...")
    p_theta_dad.train()
    student_vit.eval(); student_head.eval() 
    for ltr_epoch in range(ADAPT_LTR_EPOCHS):
        ltr_running_loss = 0.0
        progress_bar_ltr = tqdm(target_unlabeled_loader_for_ema_shot, desc=f"LTR Epoch {ltr_epoch+1}/{ADAPT_LTR_EPOCHS} for {TARGET_DOMAIN_NAME_CURRENT}", leave=False)
        for target_images_weak, _ in progress_bar_ltr: 
            target_images_weak = target_images_weak.to(DEVICE)
            optimizer_p_theta.zero_grad()
            with torch.no_grad(): 
                all_target_features = student_vit.forward_features(target_images_weak) 
                F_T_cls = all_target_features[:, 0].to(torch.float32) 
            t_indices = torch.randint(0, DAD_K_STEPS, (F_T_cls.size(0),), device=DEVICE, dtype=torch.long)
            noise = torch.randn_like(F_T_cls)
            F_T_noisy = q_sample_dad(F_T_cls, t_indices, noise)
            predicted_noise = p_theta_dad(F_T_noisy, t_indices) 
            loss_ltr = criterion_adapt_mse(predicted_noise, noise)
            loss_ltr.backward()
            optimizer_p_theta.step() 
            ltr_running_loss += loss_ltr.item() * target_images_weak.size(0)
            progress_bar_ltr.set_postfix(ltr_loss=loss_ltr.item())
        avg_ltr_loss = ltr_running_loss / len(target_unlabeled_loader_for_ema_shot.dataset) if len(target_unlabeled_loader_for_ema_shot.dataset) > 0 else 0
        print(f"  LTR Epoch {ltr_epoch+1} ({TARGET_DOMAIN_NAME_CURRENT}) Avg Loss: {avg_ltr_loss:.4f}")
    print(f"  DAD LTR Pre-training for '{TARGET_DOMAIN_NAME_CURRENT}' finished.")

    # --- 5. Main Adaptation Loop (formerly Cell 18) ---
    print(f"  Starting Main Adaptation Loop for '{TARGET_DOMAIN_NAME_CURRENT}' ({DAD_K_STEPS} DAD steps, {ADAPT_MLS_R_ITER} MLS iters/step)...")
    
    # best_target_val_loss = float('inf') # Moved to top of this cell block
    # best_target_val_acc_at_best_loss = 0.0 # Moved to top
    best_target_val_acc = 0.0 # << CHANGED: Save based on accuracy
    val_loss_at_best_acc = float('inf') # Track loss when best acc is found
    patience_counter_adapt = 0


    num_total_adapt_optimizer_steps_current = DAD_K_STEPS * (ADAPT_MLS_R_ITER + 1) 
    scheduler_student_adapt = optim.lr_scheduler.CosineAnnealingLR(
        optimizer_student, T_max=num_total_adapt_optimizer_steps_current, eta_min=1e-6 # Added eta_min
    )
    print(f"  Student optimizer LR scheduler for '{TARGET_DOMAIN_NAME_CURRENT}': CosineAnnealingLR with T_max={num_total_adapt_optimizer_steps_current}")

    source_iter = iter(previous_domain_train_loader)
    target_iter_ema_shot = iter(target_unlabeled_loader_for_ema_shot) 
    fixmatch_iter = iter(fixmatch_target_loader) 
    outer_progress_bar = tqdm(range(DAD_K_STEPS), desc=f"DAD Steps for {TARGET_DOMAIN_NAME_CURRENT}")

    for k_dad_step_idx in outer_progress_bar: 
        for mls_iter_idx in range(ADAPT_MLS_R_ITER):
            try:
                source_images, source_labels = next(source_iter)
            except StopIteration:
                source_iter = iter(previous_domain_train_loader)
                source_images, source_labels = next(source_iter)
            source_images, source_labels = source_images.to(DEVICE), source_labels.to(DEVICE)

            p_theta_dad.train(); student_vit.eval(); student_head.eval()
            with torch.no_grad():
                all_source_features = student_vit.forward_features(source_images) 
                F_S_cls = all_source_features[:, 0].to(torch.float32) 
            t_k_current_long = torch.full((F_S_cls.size(0),), k_dad_step_idx, device=DEVICE, dtype=torch.long)
            noise_cd = torch.randn_like(F_S_cls)
            F_S_cls_noisy_cd = q_sample_dad(F_S_cls, t_k_current_long, noise_cd)
            predicted_noise_cd = p_theta_dad(F_S_cls_noisy_cd, t_k_current_long)
            with torch.no_grad():
                 x0_hat_cd = q_reconstruct_dad(F_S_cls_noisy_cd, predicted_noise_cd, t_k_current_long)
            with autocast_ctx():
                logits_for_ptheta_update = student_head(student_vit.norm(x0_hat_cd)) 
            loss_cd = criterion_adapt_ce(logits_for_ptheta_update, source_labels)
            optimizer_p_theta.zero_grad(); loss_cd.backward(); optimizer_p_theta.step()

            p_theta_dad.eval(); student_vit.train(); student_head.train()
            with torch.no_grad():
                predicted_noise_dc = p_theta_dad(F_S_cls_noisy_cd, t_k_current_long)
                x0_hat_dc = q_reconstruct_dad(F_S_cls_noisy_cd, predicted_noise_dc, t_k_current_long)
            with autocast_ctx():
                logits_for_student_update = student_head(student_vit.norm(x0_hat_dc.detach())) 
                loss_dc_val = criterion_adapt_ce(logits_for_student_update, source_labels)
                if global_replay_buffer and len(global_replay_buffer) >= (BATCH_SIZE * REPLAY_BATCH_SIZE_RATIO):
                    replayed_images, replayed_labels = global_replay_buffer.sample(int(BATCH_SIZE * REPLAY_BATCH_SIZE_RATIO))
                    all_replayed_F_S_cls = student_vit.forward_features(replayed_images)
                    cls_replayed_F_S_cls = all_replayed_F_S_cls[:,0] 
                    replayed_logits = student_head(student_vit.norm(cls_replayed_F_S_cls))
                    loss_replay = criterion_adapt_ce(replayed_logits, replayed_labels)
                    loss_dc_val += REPLAY_LAMBDA * loss_replay
            optimizer_student.zero_grad(); grad_scaler_adapt.scale(loss_dc_val).backward(); grad_scaler_adapt.step(optimizer_student); grad_scaler_adapt.update()
            scheduler_student_adapt.step() 
            if global_replay_buffer: global_replay_buffer.add(source_images, source_labels)
        
        current_mls_postfix = {"MLS_D->C": f"{loss_dc_val.item():.3f}", "MLS_C->D": f"{loss_cd.item():.3f}"} # Corrected key
        outer_progress_bar.set_postfix(current_mls_postfix)

        student_vit.train(); student_head.train()
        try:
            target_images_for_ema_shot, _ = next(target_iter_ema_shot)
        except StopIteration:
            target_iter_ema_shot = iter(target_unlabeled_loader_for_ema_shot)
            target_images_for_ema_shot, _ = next(target_iter_ema_shot)
        target_images_for_ema_shot = target_images_for_ema_shot.to(DEVICE)
        try:
            target_images_weak_fixmatch, target_images_strong_fixmatch = next(fixmatch_iter)
        except StopIteration:
            fixmatch_iter = iter(fixmatch_target_loader)
            target_images_weak_fixmatch, target_images_strong_fixmatch = next(fixmatch_iter)
        target_images_weak_fixmatch = target_images_weak_fixmatch.to(DEVICE)
        target_images_strong_fixmatch = target_images_strong_fixmatch.to(DEVICE)

        optimizer_student.zero_grad() 
        loss_pl_val_ema = torch.tensor(0.0, device=DEVICE) 
        loss_fixmatch_val_calc = torch.tensor(0.0, device=DEVICE) 
        current_progress_ratio = k_dad_step_idx / max(1, DAD_K_STEPS - 1)
        current_pseudo_label_thresh = PSEUDO_LABEL_THRESHOLD_START + \
                                     (PSEUDO_LABEL_THRESHOLD_END - PSEUDO_LABEL_THRESHOLD_START) * current_progress_ratio
        num_pseudo_labels_ema = 0
        with torch.no_grad(), autocast_ctx():
            teacher_vit.eval(); teacher_head.eval()
            all_F_T_teacher = teacher_vit.forward_features(target_images_for_ema_shot) 
            cls_F_T_teacher = all_F_T_teacher[:, 0] 
            norm_cls_F_T_teacher = teacher_vit.norm(cls_F_T_teacher)
            logits_teacher = teacher_head(norm_cls_F_T_teacher) 
            probs_teacher = F.softmax(logits_teacher, dim=1) 
            max_confidence_values, predicted_class_indices = torch.max(probs_teacher, dim=1) 
            confidence_mask_ema = max_confidence_values >= current_pseudo_label_thresh
            pseudo_labels_for_loss = predicted_class_indices 
        if confidence_mask_ema.any():
            num_pseudo_labels_ema = confidence_mask_ema.sum().item()
            with autocast_ctx():
                selected_target_images_ema = target_images_for_ema_shot[confidence_mask_ema]
                all_F_T_student_ema = student_vit.forward_features(selected_target_images_ema)
                cls_F_T_student_ema = all_F_T_student_ema[:, 0] 
                logits_student_on_pseudo = student_head(student_vit.norm(cls_F_T_student_ema))
                loss_pl_val_ema = criterion_adapt_ce(logits_student_on_pseudo, pseudo_labels_for_loss[confidence_mask_ema])
        num_pseudo_labels_fixmatch = 0
        with torch.no_grad(), autocast_ctx(): 
            all_F_T_student_weak_fixmatch = student_vit.forward_features(target_images_weak_fixmatch)
            cls_F_T_student_weak_fixmatch = all_F_T_student_weak_fixmatch[:, 0] 
            logits_weak_student_fixmatch = student_head(student_vit.norm(cls_F_T_student_weak_fixmatch))
            probs_weak_student_fixmatch = F.softmax(logits_weak_student_fixmatch, dim=1)
            max_probs_weak_fixmatch, pseudo_labels_weak_fixmatch = torch.max(probs_weak_student_fixmatch, dim=1)
            confidence_mask_fixmatch = max_probs_weak_fixmatch >= FIXMATCH_CONF_THRESHOLD
        if confidence_mask_fixmatch.any():
            num_pseudo_labels_fixmatch = confidence_mask_fixmatch.sum().item()
            with autocast_ctx(): 
                selected_target_images_strong_fixmatch = target_images_strong_fixmatch[confidence_mask_fixmatch]
                all_F_T_student_strong_fixmatch = student_vit.forward_features(selected_target_images_strong_fixmatch)
                cls_F_T_student_strong_fixmatch = all_F_T_student_strong_fixmatch[:, 0] 
                logits_student_on_strong_fixmatch = student_head(student_vit.norm(cls_F_T_student_strong_fixmatch))
                loss_fixmatch_val_calc = criterion_adapt_ce(logits_student_on_strong_fixmatch, pseudo_labels_weak_fixmatch[confidence_mask_fixmatch])
                loss_fixmatch_val_calc = FIXMATCH_LAMBDA * loss_fixmatch_val_calc
        loss_for_head_update = torch.tensor(0.0, device=DEVICE)
        if loss_pl_val_ema.item() > 0: loss_for_head_update += loss_pl_val_ema
        if loss_fixmatch_val_calc.item() > 0: loss_for_head_update += loss_fixmatch_val_calc
        if loss_for_head_update.item() > 0:
            grad_scaler_adapt.scale(loss_for_head_update).backward(retain_graph=True) 
        _temp_head_training_mode_shot = student_head.training
        _temp_head_params_req_grad_shot = [p.requires_grad for p in student_head.parameters()]
        student_head.eval() 
        for p_head in student_head.parameters(): p_head.requires_grad = False
        with autocast_ctx():
            all_F_T_student_shot = student_vit.forward_features(target_images_for_ema_shot) 
            cls_F_T_student_shot = all_F_T_student_shot[:, 0] 
            logits_student_shot = student_head(student_vit.norm(cls_F_T_student_shot)) 
            probs_student_shot = F.softmax(logits_student_shot, dim=1)
            loss_cond_ent = - (probs_student_shot * torch.log(probs_student_shot + 1e-9)).sum(1).mean()
            mean_probs_shot = probs_student_shot.mean(0)
            loss_ent_max = - (mean_probs_shot * torch.log(mean_probs_shot + 1e-9)).sum() 
            loss_shot_val_calculated = SHOT_LAMBDA_COND_ENT * loss_cond_ent + SHOT_LAMBDA_ENT_MAX * loss_ent_max
        if loss_shot_val_calculated.item() != 0: 
            grad_scaler_adapt.scale(loss_shot_val_calculated).backward() 
        student_head.train(_temp_head_training_mode_shot)
        for i_param, p_rg_status_shot in enumerate(_temp_head_params_req_grad_shot):
            list(student_head.parameters())[i_param].requires_grad = p_rg_status_shot
        if loss_for_head_update.item() > 0 or loss_shot_val_calculated.item() > 0:
            grad_scaler_adapt.step(optimizer_student) 
            grad_scaler_adapt.update()
        scheduler_student_adapt.step() 
        update_ema_teacher_components(student_vit, student_head, teacher_vit, teacher_head, EMA_DECAY)
        current_target_postfix = {
            "PL_EMA": f"{loss_pl_val_ema.item():.3f}({num_pseudo_labels_ema})",
            "PL_FixM": f"{loss_fixmatch_val_calc.item() / FIXMATCH_LAMBDA if FIXMATCH_LAMBDA > 0 and loss_fixmatch_val_calc.item() > 0 else 0.0:.3f}({num_pseudo_labels_fixmatch})",
            "SHOT": f"{loss_shot_val_calculated.item():.3f}",
            "LR_stud": f"{optimizer_student.param_groups[0]['lr']:.2e}"}
        combined_postfix = {**(current_mls_postfix if isinstance(current_mls_postfix, dict) else {}), **current_target_postfix}
        outer_progress_bar.set_postfix(combined_postfix)

        if (k_dad_step_idx + 1) % (DAD_K_STEPS // 10) == 0 or k_dad_step_idx == DAD_K_STEPS - 1: 
            student_vit.eval(); student_head.eval()
            val_loss_target_epoch = 0.0; val_correct_target_epoch = 0; val_total_target_epoch = 0
            val_progress_bar = tqdm(target_val_loader, desc=f"Val {TARGET_DOMAIN_NAME_CURRENT} (DAD {k_dad_step_idx+1})", leave=False)
            with torch.no_grad():
                for images_val, labels_val in val_progress_bar:
                    images_val, labels_val = images_val.to(DEVICE), labels_val.to(DEVICE)
                    with autocast_ctx():
                        all_features_val = student_vit.forward_features(images_val)
                        cls_features_val = all_features_val[:,0] 
                        logits_val = student_head(student_vit.norm(cls_features_val))
                        loss_val = criterion_adapt_ce(logits_val, labels_val)
                    val_loss_target_epoch += loss_val.item() * images_val.size(0)
                    _, preds_val = torch.max(logits_val, 1)
                    val_correct_target_epoch += torch.sum(preds_val == labels_val.data)
                    val_total_target_epoch += images_val.size(0)
            epoch_val_loss_target = val_loss_target_epoch / val_total_target_epoch if val_total_target_epoch > 0 else 0
            epoch_val_acc_target = val_correct_target_epoch.double() / val_total_target_epoch if val_total_target_epoch > 0 else 0.0
            current_val_postfix_info = {
                "ValAcc": f"{epoch_val_acc_target.item():.4f}", 
                "ValLoss": f"{epoch_val_loss_target:.4f}"}
            existing_postfix_before_val = outer_progress_bar.postfix if isinstance(outer_progress_bar.postfix, dict) else {}
            outer_progress_bar.set_postfix({**existing_postfix_before_val, **current_val_postfix_info})
            print(f"  DAD Step {k_dad_step_idx+1}/{DAD_K_STEPS} - Target '{TARGET_DOMAIN_NAME_CURRENT}' Val Loss: {epoch_val_loss_target:.4f}, Val Acc: {epoch_val_acc_target.item():.4f}")

            # << CHANGED: Save based on best validation accuracy >>
            if epoch_val_acc_target.item() > best_target_val_acc and val_total_target_epoch > 0 : 
                best_target_val_acc = epoch_val_acc_target.item()
                val_loss_at_best_acc = epoch_val_loss_target 
                torch.save(student_vit.state_dict(), os.path.join(target_model_save_dir, f"{TARGET_DOMAIN_NAME_CURRENT.lower()}_vit_lora_best.pth"))
                torch.save(student_head.state_dict(), os.path.join(target_model_save_dir, f"{TARGET_DOMAIN_NAME_CURRENT.lower()}_head_best.pth"))
                print(f"    -> New best Val Acc for '{TARGET_DOMAIN_NAME_CURRENT}': {best_target_val_acc:.4f} (Loss: {val_loss_at_best_acc:.4f}). Models saved.")
                patience_counter_adapt = 0 # Reset patience
            elif val_total_target_epoch > 0: # Accuracy did not improve
                patience_counter_adapt += 1
                print(f"    Val acc did not improve. Patience: {patience_counter_adapt}/{EARLY_STOPPING_PATIENCE_ADAPT}")
        
        if patience_counter_adapt >= EARLY_STOPPING_PATIENCE_ADAPT:
            print(f"  Early stopping triggered for '{TARGET_DOMAIN_NAME_CURRENT}' after {k_dad_step_idx+1} DAD steps.")
            break # Break the outer DAD step loop for the current target domain

    print(f"\nAdaptation to '{TARGET_DOMAIN_NAME_CURRENT}' finished. Best Val Acc: {best_target_val_acc:.4f} (at Val Loss: {val_loss_at_best_acc:.4f})")

    if os.path.exists(os.path.join(target_model_save_dir, f"{TARGET_DOMAIN_NAME_CURRENT.lower()}_vit_lora_best.pth")):
        student_vit.load_state_dict(torch.load(os.path.join(target_model_save_dir, f"{TARGET_DOMAIN_NAME_CURRENT.lower()}_vit_lora_best.pth")))
        student_head.load_state_dict(torch.load(os.path.join(target_model_save_dir, f"{TARGET_DOMAIN_NAME_CURRENT.lower()}_head_best.pth")))
        print(f"Loaded best saved models (by accuracy) for {TARGET_DOMAIN_NAME_CURRENT}.")
    else:
        print(f"Warning: No best saved models found for {TARGET_DOMAIN_NAME_CURRENT}. Using last state of student model.")
    student_vit.eval(); student_head.eval()

    adapted_experts[TARGET_DOMAIN_NAME_CURRENT] = {
        'vit': copy.deepcopy(student_vit).cpu(), 
        'head': copy.deepcopy(student_head).cpu()
    }
    current_expert_vit = copy.deepcopy(student_vit).cpu()
    current_expert_head = copy.deepcopy(student_head).cpu()
    current_expert_domain_name = TARGET_DOMAIN_NAME_CURRENT 
    print(f"Best models for '{TARGET_DOMAIN_NAME_CURRENT}' loaded and stored. Ready for next adaptation or evaluation.")

# End of the loop over TARGET_DOMAIN_NAMES_ORDERED
print("\n\n=== All Target Domain Adaptations Complete ===")

# Part 5

In [None]:
# Cell X1: Domain Classifier - Prepare Multi-Domain Dataset & DataLoader

print("\n\n=== Part 5: Domain Classifier Training ===")

# ALL_TRAINABLE_DOMAIN_NAMES was defined in Cell 3: [SOURCE_DOMAIN_NAME] + TARGET_DOMAIN_NAMES_ORDERED
# e.g., ['Art', 'Clipart', 'Product', 'RealWorld']
num_total_domains_for_dc = len(ALL_TRAINABLE_DOMAIN_NAMES)

# Create a mapping from domain name to an integer index for the DC
domain_name_to_idx_dc = {name: i for i, name in enumerate(ALL_TRAINABLE_DOMAIN_NAMES)}
domain_idx_to_name_dc = {i: name for i, name in enumerate(ALL_TRAINABLE_DOMAIN_NAMES)}
print(f"Domain Classifier - Domain to Index Mapping: {domain_name_to_idx_dc}")

class MultiDomainDatasetForDC(Dataset):
    def __init__(self, root_dir, all_domain_names, domain_to_idx_map, 
                 class_to_idx_overall_map, transform, split_type='train'):
        self.images_paths = []
        self.domain_labels = [] # Integer domain labels for DC

        for domain_name in all_domain_names:
            domain_idx = domain_to_idx_map.get(domain_name)
            if domain_idx is None:
                print(f"Warning: Domain '{domain_name}' not in domain_to_idx_map. Skipping for DC dataset.")
                continue
            
            # Use OfficeHomeDomainDataset to get image paths for the specified split
            # We use the 'train' split of each domain to train the DC
            # We could also use 'val' or a combined split if desired
            temp_domain_dataset = OfficeHomeDomainDataset(
                root_dir=root_dir, domain_name=domain_name,
                transform=None, # We'll apply the DC's transform later
                split_type=split_type, # e.g., 'train' to use 80% of each domain
                class_to_idx_mapping=class_to_idx_overall_map,
                load_pil=False # We just need paths
            )
            self.images_paths.extend(temp_domain_dataset.images_paths)
            self.domain_labels.extend([domain_idx] * len(temp_domain_dataset.images_paths))
            print(f"  Added {len(temp_domain_dataset.images_paths)} images from '{domain_name}' ({split_type} split) for DC training.")
        
        self.transform = transform

    def __len__(self):
        return len(self.images_paths)

    def __getitem__(self, idx):
        img_path = self.images_paths[idx]
        image_pil = Image.open(img_path).convert('RGB')
        domain_label = self.domain_labels[idx]
        
        if self.transform:
            image_tensor = self.transform(image_pil)
        else:
            image_tensor = transforms.ToTensor()(image_pil)
        
        return image_tensor, torch.tensor(domain_label).long()

# Use val_test_transform_weak for DC training, as it's about general domain features
dc_train_dataset = MultiDomainDatasetForDC(
    DATA_DIR, ALL_TRAINABLE_DOMAIN_NAMES, domain_name_to_idx_dc,
    GLOBAL_CLASS_TO_IDX, transform=val_test_transform_weak, split_type='train'
)
# Optional: Create a validation set for the DC from the 'val' splits
dc_val_dataset = MultiDomainDatasetForDC(
    DATA_DIR, ALL_TRAINABLE_DOMAIN_NAMES, domain_name_to_idx_dc,
    GLOBAL_CLASS_TO_IDX, transform=val_test_transform_weak, split_type='val'
)

dc_train_loader = DataLoader(dc_train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
dc_val_loader = DataLoader(dc_val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print(f"Domain Classifier: Train size {len(dc_train_dataset)}, Val size {len(dc_val_dataset)}")
print(f"DC Train loader batches: {len(dc_train_loader)}, DC Val loader batches: {len(dc_val_loader)}")

In [None]:
# Cell X2: Domain Classifier - Model, Optimizer, Training Loop

# base_vit_frozen_global was loaded in Cell 16 (original frozen ViT without LoRA)
if 'base_vit_frozen_global' not in globals() or base_vit_frozen_global is None:
    print("ERROR: base_vit_frozen_global not found. Re-loading.")
    base_vit_frozen_global = load_frozen_vit_backbone(device=DEVICE)

domain_classifier_head = DomainSpecificHead(
    in_features=VIT_EMBED_DIM, 
    num_classes=num_total_domains_for_dc # Output is number of domains
).to(DEVICE)

optimizer_dc = optim.AdamW(domain_classifier_head.parameters(), lr=DC_HEAD_LR, weight_decay=0.01)
criterion_dc = nn.CrossEntropyLoss(label_smoothing=0.1)
grad_scaler_dc = GradScaler(enabled=(DEVICE.type == 'cuda'))

# LR Scheduler for DC (optional, e.g., CosineAnnealing)
scheduler_dc = optim.lr_scheduler.CosineAnnealingLR(optimizer_dc, T_max=DC_HEAD_EPOCHS * len(dc_train_loader))


print(f"Training Domain Classifier Head for {DC_HEAD_EPOCHS} epochs...")
best_dc_val_acc = 0.0
dc_model_save_path = os.path.join(MODEL_SAVE_DIR_BASE, "domain_classifier_head_best.pth")

for epoch in range(DC_HEAD_EPOCHS):
    domain_classifier_head.train()
    base_vit_frozen_global.eval() # Ensure backbone is frozen and in eval

    running_loss_dc = 0.0
    correct_preds_dc = 0
    total_samples_dc = 0

    progress_bar_dc = tqdm(dc_train_loader, desc=f"DC Epoch {epoch+1}/{DC_HEAD_EPOCHS} [Train]")
    for images, domain_labels_true in progress_bar_dc:
        images, domain_labels_true = images.to(DEVICE), domain_labels_true.to(DEVICE)

        optimizer_dc.zero_grad()
        with autocast_ctx():
            with torch.no_grad(): # Backbone features are fixed
                all_features_dc = base_vit_frozen_global.forward_features(images)
                cls_features_dc = all_features_dc[:, 0] # Select CLS token
            
            # Norm from the backbone before passing to DC head
            normed_cls_features_dc = base_vit_frozen_global.norm(cls_features_dc)
            domain_logits = domain_classifier_head(normed_cls_features_dc)
            loss_dc = criterion_dc(domain_logits, domain_labels_true)
        
        grad_scaler_dc.scale(loss_dc).backward()
        grad_scaler_dc.step(optimizer_dc)
        grad_scaler_dc.update()
        scheduler_dc.step()

        running_loss_dc += loss_dc.item() * images.size(0)
        _, preds_dc = torch.max(domain_logits, 1)
        correct_preds_dc += torch.sum(preds_dc == domain_labels_true.data)
        total_samples_dc += images.size(0)
        progress_bar_dc.set_postfix(loss=loss_dc.item(), acc=correct_preds_dc.double().item()/total_samples_dc if total_samples_dc > 0 else 0.0)

    epoch_loss_dc = running_loss_dc / total_samples_dc
    epoch_acc_dc = correct_preds_dc.double() / total_samples_dc
    print(f"DC Epoch {epoch+1} Train Loss: {epoch_loss_dc:.4f}, Train Acc: {epoch_acc_dc:.4f}")

    # Validation for DC
    domain_classifier_head.eval()
    val_loss_dc_epoch = 0.0
    val_correct_dc = 0
    val_total_dc = 0
    with torch.no_grad():
        for images_val, domain_labels_val_true in tqdm(dc_val_loader, desc=f"DC Epoch {epoch+1}/{DC_HEAD_EPOCHS} [Val]", leave=False):
            images_val, domain_labels_val_true = images_val.to(DEVICE), domain_labels_val_true.to(DEVICE)
            with autocast_ctx():
                all_features_val_dc = base_vit_frozen_global.forward_features(images_val)
                cls_features_val_dc = all_features_val_dc[:, 0]
                normed_cls_features_val_dc = base_vit_frozen_global.norm(cls_features_val_dc)
                domain_logits_val = domain_classifier_head(normed_cls_features_val_dc)
                loss_val_dc = criterion_dc(domain_logits_val, domain_labels_val_true)

            val_loss_dc_epoch += loss_val_dc.item() * images_val.size(0)
            _, preds_val_dc = torch.max(domain_logits_val, 1)
            val_correct_dc += torch.sum(preds_val_dc == domain_labels_val_true.data)
            val_total_dc += images_val.size(0)
    
    epoch_val_loss_dc = val_loss_dc_epoch / val_total_dc if val_total_dc > 0 else 0
    epoch_val_acc_dc = val_correct_dc.double() / val_total_dc if val_total_dc > 0 else 0.0
    print(f"DC Epoch {epoch+1} Val Loss: {epoch_val_loss_dc:.4f}, Val Acc: {epoch_val_acc_dc:.4f}")

    if epoch_val_acc_dc > best_dc_val_acc:
        best_dc_val_acc = epoch_val_acc_dc
        torch.save(domain_classifier_head.state_dict(), dc_model_save_path)
        print(f"    -> New best DC Val Acc: {best_dc_val_acc:.4f}. DC Head saved to {dc_model_save_path}")

print(f"\nDomain Classifier training finished. Best Val Acc: {best_dc_val_acc:.4f}")
# Load the best DC head for inference
if os.path.exists(dc_model_save_path):
    domain_classifier_head.load_state_dict(torch.load(dc_model_save_path))
    domain_classifier_head.eval()
    print("Best Domain Classifier Head loaded and set to eval mode.")
else:
    print(f"Warning: Best DC Head model not found at {dc_model_save_path}. Using last state.")
    domain_classifier_head.eval()

In [None]:
# Cell X3: Robust Inference - TTA Definitions & Expert Population

print("\n\n=== Part 6: Robust Inference Pipeline Setup & Evaluation ===")
print("\n--- Defining Test-Time Augmentations and Populating Task Experts ---")

# --- 1. Test-Time Augmentation (TTA) Definitions ---
# (Copy TTA definitions for tta_lite_transforms and tta_full_transforms_manual from your Cell 17 Updated)
# Example:
tta_lite_transforms = [
    val_test_transform_weak, # Original (resized, normalized, maybe HFlip)
    transforms.Compose([
        transforms.RandomHorizontalFlip(p=1.0), 
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
]
# ... (define tta_full_transforms_manual similarly) ...
tta_full_transforms_manual = [
    val_test_transform_weak, 
    transforms.Compose([transforms.RandomHorizontalFlip(p=1.0), transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]),
    transforms.Compose([transforms.CenterCrop(int(IMAGE_SIZE * 0.875)), transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]),
    transforms.Compose([transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]),
    transforms.Compose([transforms.RandomRotation(degrees=(-15, 15)), transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]),
]

NUM_TTA_LITE_AUGMENTATIONS = len(tta_lite_transforms)
NUM_TTA_FULL_AUGMENTATIONS = len(tta_full_transforms_manual)
print(f"TTA Lite augmentations: {NUM_TTA_LITE_AUGMENTATIONS}, Full TTA augmentations: {NUM_TTA_FULL_AUGMENTATIONS}")

# --- 2. Populate all_task_experts dictionary ---
all_task_experts = {}
print("\nAttempting to populate all_task_experts...")

# 2.1. Add Source Expert (Art)
if 'source_vit_lora' in globals() and source_vit_lora is not None and \
   'source_head' in globals() and source_head is not None:
    try:
        # Ensure they are on CPU if stored there, then move to DEVICE for inference
        all_task_experts[SOURCE_DOMAIN_NAME] = {
            'vit': copy.deepcopy(source_vit_lora).to(DEVICE).eval(),
            'head': copy.deepcopy(source_head).to(DEVICE).eval()
        }
        print(f"  Successfully added Source Expert: '{SOURCE_DOMAIN_NAME}'")
    except Exception as e: print(f"  Error adding Source Expert '{SOURCE_DOMAIN_NAME}': {e}")
else: print(f"  Warning: Source expert for '{SOURCE_DOMAIN_NAME}' not fully available. Skipping.")

# 2.2. Add Adapted Target Experts (from `adapted_experts` dictionary filled in Part 4)
if 'adapted_experts' in globals() and isinstance(adapted_experts, dict) and adapted_experts:
    print(f"  Found adapted_experts with keys: {list(adapted_experts.keys())}")
    for domain_name_adapted, expert_models_adapted in adapted_experts.items():
        if isinstance(expert_models_adapted, dict) and \
           'vit' in expert_models_adapted and expert_models_adapted['vit'] is not None and \
           'head' in expert_models_adapted and expert_models_adapted['head'] is not None:
            try:
                all_task_experts[domain_name_adapted] = {
                    'vit': copy.deepcopy(expert_models_adapted['vit']).to(DEVICE).eval(), # Models in adapted_experts were on CPU
                    'head': copy.deepcopy(expert_models_adapted['head']).to(DEVICE).eval()
                }
                print(f"    Successfully added Adapted Expert: '{domain_name_adapted}'")
            except Exception as e: print(f"    Error adding Adapted Expert '{domain_name_adapted}': {e}")
        else: print(f"    Warning: Incomplete models for adapted expert '{domain_name_adapted}'. Skipping.")
else: print("  Warning: `adapted_experts` dictionary not found or empty. No adapted experts added.")

if not all_task_experts:
    print("\nCRITICAL WARNING: all_task_experts is EMPTY. Inference will fail.")
else:
    print(f"\nSuccessfully populated all_task_experts with {len(all_task_experts)} expert(s): {list(all_task_experts.keys())}")

# Ensure base_vit_frozen_global and domain_classifier_head are ready
if 'base_vit_frozen_global' not in globals() or base_vit_frozen_global is None:
    print("CRITICAL WARNING: base_vit_frozen_global not found for inference pipeline.")
else: base_vit_frozen_global.to(DEVICE).eval()

if 'domain_classifier_head' not in globals() or domain_classifier_head is None:
    print("CRITICAL WARNING: domain_classifier_head not found for inference pipeline.")
else: domain_classifier_head.to(DEVICE).eval()

print("Expert population and TTA definitions complete.")

In [None]:
# Cell X4: Robust Inference Pipeline Function Definition

def robust_inference_pipeline(image_pil, base_vit, dc_head, task_experts_dict,
                              domain_map_idx_to_name_dc, # Renamed for clarity
                              num_total_classes=NUM_CLASSES, # Pass NUM_CLASSES
                              tta_lite_transforms_list=tta_lite_transforms, # Use global TTA lists
                              tta_full_transforms_list=tta_full_transforms_manual,
                              domain_confidence_thresh=INFER_DOMAIN_CONF_THRESH,
                              expert_confidence_thresh=INFER_EXPERT_CONF_THRESH,
                              stage2_expert_confidence_thresh=INFER_STAGE2_EXPERT_CONF_THRESH,
                              k_experts_for_avg=INFER_K_EXPERTS_FOR_AVG,
                              device=DEVICE):
    # Ensure models are in eval mode and on the correct device (mostly handled at population)
    if base_vit is not None: base_vit.to(device).eval()
    if dc_head is not None: dc_head.to(device).eval()

    if not base_vit or not dc_head:
        print("Error: Base ViT or Domain Classifier Head not provided to pipeline.")
        return None, -1.0
    if not task_experts_dict:
        print("Error: task_experts_dict is empty in pipeline.")
        return None, -1.0
        
    # Initial transform for Stage 1 (domain classification)
    # Use the standard val_test_transform_weak for this initial, non-TTA step
    initial_transformed_image = val_test_transform_weak(image_pil).unsqueeze(0).to(device)

    with torch.no_grad(), autocast_ctx():
        # Use g_function for base_vit to get CLS, then norm
        base_all_features = base_vit.forward_features(initial_transformed_image)
        base_cls_features = base_all_features[:, 0]
        base_normed_cls_features = base_vit.norm(base_cls_features)
        domain_logits = dc_head(base_normed_cls_features) # DC head expects normed features
        domain_probs = F.softmax(domain_logits, dim=-1).squeeze(0) # Squeeze batch dim if B=1
    
    top_domain_prob, predicted_domain_idx_tensor = torch.max(domain_probs, dim=-1)
    predicted_domain_idx = predicted_domain_idx_tensor.item()
    predicted_domain_name = domain_map_idx_to_name_dc.get(predicted_domain_idx, f"UnknownDomainIdx{predicted_domain_idx}")

    # Stage 1 Logic
    selected_domain_names_for_next_stage = []
    expert_weights_for_next_stage = torch.tensor([], device=device)

    if top_domain_prob >= domain_confidence_thresh and predicted_domain_name in task_experts_dict:
        expert_vit = task_experts_dict[predicted_domain_name]['vit']
        expert_head = task_experts_dict[predicted_domain_name]['head']
        with torch.no_grad(), autocast_ctx():
            # Use g_function and f_function for expert
            expert_all_features = expert_vit.forward_features(initial_transformed_image)
            expert_cls_features = expert_all_features[:,0]
            expert_normed_cls_features = expert_vit.norm(expert_cls_features) # Expert ViT has its own norm
            task_logits = expert_head(expert_normed_cls_features)
            task_probs = F.softmax(task_logits, dim=-1).squeeze(0)
        top_task_prob, final_label_idx_tensor = torch.max(task_probs, dim=-1)
        if top_task_prob >= expert_confidence_thresh:
            return final_label_idx_tensor.item(), top_task_prob.item()
        else: # Fall through to Stage 2 with this single expert
            selected_domain_names_for_next_stage = [predicted_domain_name]
            expert_weights_for_next_stage = torch.tensor([1.0], device=device)
    else: # Low domain confidence or no expert for top predicted domain
        num_available_experts_in_map = len([name for name in domain_map_idx_to_name_dc.values() if name in task_experts_dict])
        actual_k = min(k_experts_for_avg, num_available_experts_in_map)
        if actual_k == 0: return None, -1.0 

        top_k_domain_probs, top_k_domain_indices = torch.topk(domain_probs, actual_k, dim=-1)
        
        selected_domain_names_for_stage2_raw = [domain_map_idx_to_name_dc.get(idx.item(), f"ErrDomain{idx.item()}") for idx in top_k_domain_indices]
        
        valid_indices_for_stage2 = [i for i, name in enumerate(selected_domain_names_for_stage2_raw) if name in task_experts_dict]
        if not valid_indices_for_stage2: return None, -1.0

        selected_domain_names_for_next_stage = [selected_domain_names_for_stage2_raw[i] for i in valid_indices_for_stage2]
        expert_weights_for_next_stage = top_k_domain_probs[valid_indices_for_stage2]
        if expert_weights_for_next_stage.sum() > 1e-6 : 
            expert_weights_for_next_stage = expert_weights_for_next_stage / expert_weights_for_next_stage.sum()
        else: 
             expert_weights_for_next_stage = torch.ones(len(selected_domain_names_for_next_stage), device=device) / max(1, len(selected_domain_names_for_next_stage))

    # Stage 2 Logic (TTA-Lite)
    aggregated_task_probs_stage2 = torch.zeros(num_total_classes, device=device)
    num_tta_lite = len(tta_lite_transforms_list)
    for tta_transform in tta_lite_transforms_list:
        aug_image_tensor = tta_transform(image_pil).unsqueeze(0).to(device)
        current_aug_weighted_probs = torch.zeros(num_total_classes, device=device)
        with torch.no_grad(), autocast_ctx():
            for i, domain_name in enumerate(selected_domain_names_for_next_stage):
                weight = expert_weights_for_next_stage[i]
                expert_vit = task_experts_dict[domain_name]['vit']
                expert_head = task_experts_dict[domain_name]['head']
                
                exp_all_feat_s2 = expert_vit.forward_features(aug_image_tensor)
                exp_cls_feat_s2 = exp_all_feat_s2[:,0]
                exp_norm_cls_feat_s2 = expert_vit.norm(exp_cls_feat_s2)
                task_logits_expert_aug = expert_head(exp_norm_cls_feat_s2)
                current_aug_weighted_probs += weight * F.softmax(task_logits_expert_aug.squeeze(0), dim=-1)
        aggregated_task_probs_stage2 += current_aug_weighted_probs
    
    final_averaged_probs_stage2 = aggregated_task_probs_stage2 / max(1, num_tta_lite)
    top_task_prob_stage2, final_label_idx_stage2_tensor = torch.max(final_averaged_probs_stage2, dim=-1)

    if top_task_prob_stage2 >= stage2_expert_confidence_thresh:
        return final_label_idx_stage2_tensor.item(), top_task_prob_stage2.item()
    # else: Fall through to Stage 3 with the same selected experts and weights

    # Stage 3 Logic (TTA-Full)
    aggregated_task_probs_stage3 = torch.zeros(num_total_classes, device=device)
    num_tta_full = len(tta_full_transforms_list)
    for tta_transform_full in tta_full_transforms_list:
        aug_image_tensor_full = tta_transform_full(image_pil).unsqueeze(0).to(device)
        current_aug_weighted_probs_full = torch.zeros(num_total_classes, device=device)
        with torch.no_grad(), autocast_ctx():
            for i, domain_name in enumerate(selected_domain_names_for_next_stage): # Same experts as stage 2
                weight = expert_weights_for_next_stage[i] # Same weights as stage 2
                expert_vit = task_experts_dict[domain_name]['vit']
                expert_head = task_experts_dict[domain_name]['head']

                exp_all_feat_s3 = expert_vit.forward_features(aug_image_tensor_full)
                exp_cls_feat_s3 = exp_all_feat_s3[:,0]
                exp_norm_cls_feat_s3 = expert_vit.norm(exp_cls_feat_s3)
                task_logits_expert_full_aug = expert_head(exp_norm_cls_feat_s3)
                current_aug_weighted_probs_full += weight * F.softmax(task_logits_expert_full_aug.squeeze(0), dim=-1)
        aggregated_task_probs_stage3 += current_aug_weighted_probs_full
    
    ultimate_final_probs = aggregated_task_probs_stage3 / max(1, num_tta_full)
    top_task_prob_stage3, ultimate_final_label_idx_tensor = torch.max(ultimate_final_probs, dim=-1)
    
    return ultimate_final_label_idx_tensor.item(), top_task_prob_stage3.item()

print("Robust inference pipeline function defined.")

In [None]:
# Cell X5: Evaluate Robust Inference Pipeline on Combined Validation Set

print("\n--- Evaluating Robust Inference Pipeline on Combined Validation Set ---")

# Check prerequisites
prereq_missing = False
if 'domain_classifier_head' not in globals() or domain_classifier_head is None:
    print("ERROR: Domain Classifier Head not trained/loaded.")
    prereq_missing = True
if 'all_task_experts' not in globals() or not all_task_experts:
    print("ERROR: No task experts available.")
    prereq_missing = True
if 'base_vit_frozen_global' not in globals() or base_vit_frozen_global is None:
    print("ERROR: Base ViT backbone not loaded.")
    prereq_missing = True
if 'GLOBAL_CLASS_TO_IDX' not in globals() or GLOBAL_CLASS_TO_IDX is None:
    print("ERROR: GLOBAL_CLASS_TO_IDX not defined.")
    prereq_missing = True
if 'domain_idx_to_name_dc' not in globals() or domain_idx_to_name_dc is None: # Check for DC specific map
    print("ERROR: domain_idx_to_name_dc mapping not defined (from DC training).")
    prereq_missing = True

if prereq_missing:
    print("Skipping robust inference pipeline evaluation due to missing prerequisites.")
else:
    # --- 1. Prepare Combined Validation Dataset ---
    class CombinedValDatasetForRobustEval(Dataset): # Renamed for clarity
        def __init__(self, root_dir, all_domain_names_list, 
                     class_to_idx_map_overall, split_type='val'):
            self.images_pil = [] 
            self.true_class_labels = []
            self.original_domain_names = [] 

            for domain_name_iter in all_domain_names_list:
                try:
                    # Use OfficeHomeDomainDataset to get image paths and class labels
                    domain_val_dataset_temp = OfficeHomeDomainDataset(
                        root_dir=root_dir, domain_name=domain_name_iter,
                        transform=None, # Load PIL
                        split_type=split_type, 
                        class_to_idx_mapping=class_to_idx_map_overall,
                        load_pil=True # Get PIL images
                    )
                    
                    for i in range(len(domain_val_dataset_temp)):
                        pil_img, class_lbl = domain_val_dataset_temp[i]
                        self.images_pil.append(pil_img)
                        self.true_class_labels.append(class_lbl) # Already an int
                        self.original_domain_names.append(domain_name_iter)
                        
                    print(f"  Added {len(domain_val_dataset_temp)} validation images from '{domain_name_iter}' for robust eval.")
                except Exception as e:
                    print(f"  Warning: Error loading val data for '{domain_name_iter}' (robust eval): {e}. Skipping.")
            
            if not self.images_pil:
                raise RuntimeError("No validation images found for combined robust evaluation.")

        def __len__(self):
            return len(self.images_pil)

        def __getitem__(self, idx):
            # Domain label for DC is not needed here as pipeline handles it internally
            return (self.images_pil[idx], 
                    torch.tensor(self.true_class_labels[idx]).long(),
                    self.original_domain_names[idx])
    
    try:
        combined_val_dataset_robust = CombinedValDatasetForRobustEval(
            root_dir=DATA_DIR,
            all_domain_names_list=ALL_TRAINABLE_DOMAIN_NAMES,
            class_to_idx_map_overall=GLOBAL_CLASS_TO_IDX,
            split_type='val'
        )
        print(f"Total combined validation images for robust eval: {len(combined_val_dataset_robust)}")
    except RuntimeError as e:
        print(f"Error creating combined validation dataset for robust eval: {e}. Aborting.")
        combined_val_dataset_robust = None
    
    # --- 2. Run Inference ---
    if combined_val_dataset_robust and len(combined_val_dataset_robust) > 0:
        all_predictions_robust = []
        all_true_labels_robust = []
        all_original_domains_robust = []
        
        for i in tqdm(range(len(combined_val_dataset_robust)), desc="Robust Pipeline Eval"):
            pil_image, true_class_label, original_domain = combined_val_dataset_robust[i]
            
            predicted_label_idx, confidence = robust_inference_pipeline(
                image_pil=pil_image,
                base_vit=base_vit_frozen_global, # Original frozen ViT
                dc_head=domain_classifier_head,
                task_experts_dict=all_task_experts,
                domain_map_idx_to_name_dc=domain_idx_to_name_dc, # From DC training
                device=DEVICE
            )
            all_predictions_robust.append(predicted_label_idx if predicted_label_idx is not None else -1)
            all_true_labels_robust.append(true_class_label.item())
            all_original_domains_robust.append(original_domain)

        # --- 3. Calculate and Report Accuracies ---
        print("\n--- Robust Inference Pipeline Evaluation Results ---")
        overall_correct_robust = 0
        overall_total_robust = 0
        domain_wise_stats_robust = {name: {'correct': 0, 'total': 0} for name in ALL_TRAINABLE_DOMAIN_NAMES}

        for i in range(len(all_predictions_robust)):
            pred_idx, true_idx, domain_name = all_predictions_robust[i], all_true_labels_robust[i], all_original_domains_robust[i]
            if pred_idx != -1:
                overall_total_robust += 1
                domain_wise_stats_robust[domain_name]['total'] += 1
                if pred_idx == true_idx:
                    overall_correct_robust += 1
                    domain_wise_stats_robust[domain_name]['correct'] += 1
        
        overall_accuracy_robust = (overall_correct_robust / overall_total_robust * 100) if overall_total_robust > 0 else 0.0
        print(f"Overall Accuracy (Robust Pipeline): {overall_accuracy_robust:.2f}% ({overall_correct_robust}/{overall_total_robust})")
        print("\nDomain-wise Accuracies (Robust Pipeline):")
        for domain_name, stats in domain_wise_stats_robust.items():
            acc = (stats['correct'] / stats['total'] * 100) if stats['total'] > 0 else 0.0
            print(f"  Domain '{domain_name}': {acc:.2f}% ({stats['correct']}/{stats['total']})")