In [None]:
# Cell 1: Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Dataset
import os
import numpy as np
import math
from PIL import Image
from tqdm.notebook import tqdm
import copy
from functools import partial

# For AMP
from torch.cuda.amp import GradScaler, autocast

print("Imports complete.")


In [None]:
# Cell 2: Global Hyperparameters & Configuration

# --- System ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATA_DIR = "/kaggle/input/officehome/OfficeHome"
MODEL_SAVE_DIR_BASE = "/kaggle/working/models_resnet50/"
os.makedirs(MODEL_SAVE_DIR_BASE, exist_ok=True)
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

# --- Dataset & DataLoader ---
IMAGE_SIZE = 224
BATCH_SIZE = 32
NUM_WORKERS = 2
NUM_CLASSES = 65

# --- ResNet-50 Backbone (Changed from ViT) ---
RESNET_MODEL_NAME = 'resnet50'
RESNET_EMBED_DIM = 2048  # ResNet-50 outputs 2048-dim features (vs 768 for ViT)

# --- LoRA ---
LORA_RANK = 16
LORA_ALPHA = 32.0
LORA_DROPOUT = 0.05

# --- DAD Module ---
DAD_K_STEPS = 200
DAD_P_THETA_TIMESTEP_EMBED_DIM = 128
DAD_P_THETA_HIDDEN_DIM_MULT = 4
DAD_P_THETA_MLP_HIDDEN_DIM = 1024
DAD_BETA_START = 1e-4
DAD_BETA_END = 2e-2

# --- Training (General) ---
SOURCE_DOMAIN_NAME = 'Art'
TARGET_DOMAIN_NAMES_ORDERED = ['Clipart', 'Product', 'RealWorld']
ALL_TRAINABLE_DOMAIN_NAMES = [SOURCE_DOMAIN_NAME] + TARGET_DOMAIN_NAMES_ORDERED

# --- Source Domain Training (Art) ---
ART_EPOCHS = 10
ART_LR_LORA_HEAD_BN = 5e-4

# --- Continual Adaptation (Per Target Domain) ---
ADAPT_MLS_R_ITER = 10
ADAPT_LR_LORA_HEAD_BN = 1e-4
ADAPT_LR_P_THETA = 1e-4
ADAPT_LTR_EPOCHS = 5
EARLY_STOPPING_PATIENCE_ADAPT = 3

# --- EMA Teacher & Pseudo-Labeling ---
EMA_DECAY = 0.999
PSEUDO_LABEL_THRESHOLD_START = 0.7
PSEUDO_LABEL_THRESHOLD_END = 0.9

# --- FixMatch ---
FIXMATCH_CONF_THRESHOLD = 0.95
FIXMATCH_LAMBDA = 1.0

# --- SHOT ---
SHOT_LAMBDA_COND_ENT = 0.05
SHOT_LAMBDA_ENT_MAX = 0.05

# --- Experience Replay (Optional) ---
REPLAY_BUFFER_SIZE = 2000
REPLAY_BATCH_SIZE_RATIO = 0.25
REPLAY_LAMBDA = 0.1

# --- Domain Classifier ---
DC_HEAD_LR = 1e-3
DC_HEAD_EPOCHS = 10

# --- Robust Inference Pipeline ---
INFER_DOMAIN_CONF_THRESH = 0.7
INFER_EXPERT_CONF_THRESH = 0.6
INFER_STAGE2_EXPERT_CONF_THRESH = 0.5
INFER_K_EXPERTS_FOR_AVG = 3

# --- Autocast Context for AMP ---
if DEVICE.type == 'cuda':
    autocast_ctx = partial(torch.amp.autocast, device_type='cuda', dtype=torch.float16, enabled=True)
else:
    import contextlib
    autocast_ctx = contextlib.nullcontext

print(f"Using device: {DEVICE}")
if torch.cuda.is_available():
    print(f"CUDA available. GPU: {torch.cuda.get_device_name(0)}")
print(f"All configurations set. Base model save dir: {MODEL_SAVE_DIR_BASE}")


In [None]:
# Cell 3: Transforms Definition

train_transform_strong = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomAffine(degrees=10, translate=(0.05, 0.05), scale=(0.95, 1.05)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_test_transform_weak = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

pil_load_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE))
])

print("Transforms defined: train_transform_strong, val_test_transform_weak, pil_load_transform.")


In [None]:
# Cell 4: Dataset Classes

class FixMatchOfficeHomeDataset(Dataset):
    def __init__(self, root_dir, domain_name, transform_weak, transform_strong,
                 split_ratios=(0.8, 0.1, 0.1), split_type='train',
                 random_seed=RANDOM_SEED, class_to_idx_mapping=None):
        self.base_dataset = OfficeHomeDomainDataset(
            root_dir, domain_name, transform=None,
            split_ratios=split_ratios, split_type=split_type,
            random_seed=random_seed, class_to_idx_mapping=class_to_idx_mapping,
            load_pil=True
        )
        self.transform_weak = transform_weak
        self.transform_strong = transform_strong

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        pil_image, label = self.base_dataset[idx]
        img_weak = self.transform_weak(pil_image)
        img_strong = self.transform_strong(pil_image)
        return img_weak, img_strong


class OfficeHomeDomainDataset(Dataset):
    def __init__(self, root_dir, domain_name, transform=None,
                 split_ratios=(0.8, 0.1, 0.1), split_type='train',
                 random_seed=RANDOM_SEED, class_to_idx_mapping=None, load_pil=False):
        self.domain_path = os.path.join(root_dir, domain_name)
        self.transform = transform
        self.load_pil = load_pil
        self.images_paths = []
        self.labels = []

        if class_to_idx_mapping is None:
            self.class_to_idx = {}
            self.idx_to_class = {}
            current_idx = 0
            for class_name_iter in sorted(os.listdir(self.domain_path)):
                if os.path.isdir(os.path.join(self.domain_path, class_name_iter)):
                    if class_name_iter not in self.class_to_idx:
                        self.class_to_idx[class_name_iter] = current_idx
                        self.idx_to_class[current_idx] = class_name_iter
                        current_idx += 1
        else:
            self.class_to_idx = class_to_idx_mapping
            self.idx_to_class = {v: k for k, v in class_to_idx_mapping.items()}

        for class_name in sorted(os.listdir(self.domain_path)):
            class_label_idx = self.class_to_idx.get(class_name)
            if class_label_idx is None:
                continue
            class_path = os.path.join(self.domain_path, class_name)
            if not os.path.isdir(class_path):
                continue

            domain_class_images_paths = []
            for img_name in sorted(os.listdir(class_path)):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    domain_class_images_paths.append(os.path.join(class_path, img_name))

            np.random.seed(random_seed)
            np.random.shuffle(domain_class_images_paths)

            n_total = len(domain_class_images_paths)
            n_train = int(n_total * split_ratios[0])
            n_val = int(n_total * split_ratios[1])

            if split_type == 'train':
                selected_paths = domain_class_images_paths[:n_train]
            elif split_type == 'val':
                selected_paths = domain_class_images_paths[n_train:n_train + n_val]
            elif split_type == 'test':
                selected_paths = domain_class_images_paths[n_train + n_val:]
            elif split_type == 'all':
                selected_paths = domain_class_images_paths
            else:
                raise ValueError("split_type must be 'train', 'val', 'test', or 'all'")

            self.images_paths.extend(selected_paths)
            self.labels.extend([class_label_idx] * len(selected_paths))

    def __len__(self):
        return len(self.images_paths)

    def __getitem__(self, idx):
        img_path = self.images_paths[idx]
        image_pil = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.load_pil:
            return image_pil, label

        if self.transform:
            image_tensor = self.transform(image_pil)
        else:
            image_tensor = transforms.ToTensor()(image_pil)
        return image_tensor, label

print("Dataset classes defined.")


In [None]:
# Cell 5: Global Class Mapping & Initial Data Check

print("Attempting to create GLOBAL_CLASS_TO_IDX from Art domain...")
try:
    temp_art_dataset_for_map = OfficeHomeDomainDataset(DATA_DIR, 'Art', split_type='all')
    GLOBAL_CLASS_TO_IDX = temp_art_dataset_for_map.class_to_idx
    GLOBAL_IDX_TO_CLASS = temp_art_dataset_for_map.idx_to_class
    assert len(GLOBAL_CLASS_TO_IDX) == NUM_CLASSES, \
        f"Mismatch: Expected {NUM_CLASSES} classes, found {len(GLOBAL_CLASS_TO_IDX)} in Art domain."
    print(f"GLOBAL_CLASS_TO_IDX created successfully with {len(GLOBAL_CLASS_TO_IDX)} classes.")
    del temp_art_dataset_for_map
except Exception as e:
    print(f"ERROR: Could not create GLOBAL_CLASS_TO_IDX from Art domain: {e}")
    GLOBAL_CLASS_TO_IDX = None

if GLOBAL_CLASS_TO_IDX:
    try:
        art_train_dataset_check = OfficeHomeDomainDataset(
            DATA_DIR, SOURCE_DOMAIN_NAME,
            transform=train_transform_strong,
            split_type='train',
            class_to_idx_mapping=GLOBAL_CLASS_TO_IDX
        )
        art_val_dataset_check = OfficeHomeDomainDataset(
            DATA_DIR, SOURCE_DOMAIN_NAME,
            transform=val_test_transform_weak,
            split_type='val',
            class_to_idx_mapping=GLOBAL_CLASS_TO_IDX
        )
        if len(art_train_dataset_check) > 0 and len(art_val_dataset_check) > 0:
            print(f"Successfully loaded check datasets for '{SOURCE_DOMAIN_NAME}': "
                  f"Train size {len(art_train_dataset_check)}, Val size {len(art_val_dataset_check)}")
        else:
            print(f"Warning: Check datasets for '{SOURCE_DOMAIN_NAME}' are empty.")
    except Exception as e:
        print(f"Error loading check datasets for '{SOURCE_DOMAIN_NAME}': {e}")


In [None]:
# Cell 6: ResNet-50 Backbone Loading Function

def load_frozen_resnet_backbone(device=DEVICE):
    """Load ResNet-50 backbone with pretrained weights, frozen parameters.
    Returns feature extractor without the final FC layer.
    """
    resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    
    # Remove the final FC layer - we'll use our own head
    resnet.fc = nn.Identity()
    
    # Freeze all parameters
    for param in resnet.parameters():
        param.requires_grad = False
    
    resnet = resnet.to(device)
    resnet.eval()
    print(f"Loaded and froze ResNet-50 backbone")
    return resnet


def get_resnet_features(resnet_model, images):
    """Extract features from ResNet backbone.
    Returns pooled features of shape (B, 2048).
    """
    x = resnet_model.conv1(images)
    x = resnet_model.bn1(x)
    x = resnet_model.relu(x)
    x = resnet_model.maxpool(x)
    x = resnet_model.layer1(x)
    x = resnet_model.layer2(x)
    x = resnet_model.layer3(x)
    x = resnet_model.layer4(x)
    x = resnet_model.avgpool(x)
    x = torch.flatten(x, 1)  # (B, 2048)
    return x


print("ResNet-50 backbone loading and feature extraction functions defined.")


In [None]:
# Cell 7: LoRA for Convolutional Layers (ResNet adaptation)

class LoRAConv2d(nn.Module):
    """LoRA adapter for Conv2d layers.
    Uses 1x1 convolutions to create low-rank factorization.
    """
    def __init__(self, conv_layer, rank, alpha, lora_dropout_p=0.0):
        super().__init__()
        self.in_channels = conv_layer.in_channels
        self.out_channels = conv_layer.out_channels
        self.kernel_size = conv_layer.kernel_size
        self.stride = conv_layer.stride
        self.padding = conv_layer.padding
        self.dilation = conv_layer.dilation
        self.groups = conv_layer.groups
        self.rank = rank
        self.alpha = alpha
        self.lora_dropout = nn.Dropout2d(p=lora_dropout_p)

        # Original weight and bias (frozen)
        self.weight = nn.Parameter(conv_layer.weight.detach().clone(), requires_grad=False)
        if conv_layer.bias is not None:
            self.bias = nn.Parameter(conv_layer.bias.detach().clone(), requires_grad=False)
        else:
            self.register_parameter('bias', None)

        # LoRA matrices using 1x1 convolutions
        # Down-projection: in_channels -> rank
        self.lora_A = nn.Conv2d(self.in_channels, rank, kernel_size=1, bias=False)
        # Up-projection: rank -> out_channels
        self.lora_B = nn.Conv2d(rank, self.out_channels, kernel_size=1, bias=False)

        # Initialization
        nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B.weight)

        self.scaling = self.alpha / self.rank

    def forward(self, x):
        # Original convolution
        out_original = F.conv2d(
            x, self.weight, self.bias,
            stride=self.stride, padding=self.padding,
            dilation=self.dilation, groups=self.groups
        )

        # LoRA path: Apply 1x1 convs as low-rank adaptation
        lora_out = self.lora_A(x)
        lora_out = self.lora_dropout(lora_out)
        lora_out = self.lora_B(lora_out)

        # Match spatial dimensions if needed
        if lora_out.shape[-2:] != out_original.shape[-2:]:
            lora_out = F.adaptive_avg_pool2d(lora_out, out_original.shape[-2:])

        return out_original + lora_out * self.scaling

    def extra_repr(self):
        return f'in_ch={self.in_channels}, out_ch={self.out_channels}, rank={self.rank}, alpha={self.alpha}'


def inject_lora_to_resnet(resnet_model, rank, alpha, lora_dropout_p=0.0,
                          target_layers=['layer3', 'layer4']):
    """Inject LoRA into specified ResNet layers.
    By default, targets the last two stages (layer3, layer4) for efficiency.
    """
    injected_count = 0

    for layer_name in target_layers:
        if not hasattr(resnet_model, layer_name):
            continue

        layer = getattr(resnet_model, layer_name)

        for block_idx, block in enumerate(layer):
            # Inject into conv1, conv2, conv3 of each Bottleneck block
            for conv_name in ['conv1', 'conv2', 'conv3']:
                if hasattr(block, conv_name):
                    original_conv = getattr(block, conv_name)
                    if isinstance(original_conv, nn.Conv2d):
                        lora_conv = LoRAConv2d(original_conv, rank, alpha, lora_dropout_p)
                        setattr(block, conv_name, lora_conv)
                        injected_count += 1

            # Also inject into downsample if present
            if block.downsample is not None:
                for i, module in enumerate(block.downsample):
                    if isinstance(module, nn.Conv2d):
                        lora_conv = LoRAConv2d(module, rank, alpha, lora_dropout_p)
                        block.downsample[i] = lora_conv
                        injected_count += 1

    if injected_count == 0:
        print("WARNING: No Conv2d layers found or replaced with LoRA.")
    else:
        print(f"Injected LoRA (rank={rank}, alpha={alpha}) into {injected_count} Conv2d layers in {target_layers}.")

    return resnet_model


def set_batchnorm_affine_trainable(model, trainable=True):
    """Set BatchNorm affine parameters (weight, bias) to trainable/frozen."""
    bn_param_count = 0
    for name, mod in model.named_modules():
        if isinstance(mod, (nn.BatchNorm2d, nn.BatchNorm1d)):
            if hasattr(mod, 'weight') and mod.weight is not None:
                mod.weight.requires_grad = trainable
                if trainable:
                    bn_param_count += mod.weight.numel()
            if hasattr(mod, 'bias') and mod.bias is not None:
                mod.bias.requires_grad = trainable
                if trainable:
                    bn_param_count += mod.bias.numel()
    status = "trainable" if trainable else "frozen"
    print(f"Set BatchNorm affine parameters {status}. Total BN params affected: {bn_param_count:,}")
    return model


print("LoRAConv2d class and injection function defined.")

# Test LoRA for ResNet
x_test = torch.randn(2, 64, 56, 56)
conv_test = nn.Conv2d(64, 128, kernel_size=3, padding=1)
lora_conv_test = LoRAConv2d(conv_test, rank=LORA_RANK, alpha=LORA_ALPHA, lora_dropout_p=LORA_DROPOUT)
out_test = lora_conv_test(x_test)
print(f"Test LoRA Conv output shape: {out_test.shape}")


In [None]:
# Cell 8: DomainSpecificHead for ResNet

class DomainSpecificHead(nn.Module):
    def __init__(self, in_features=RESNET_EMBED_DIM, num_classes=NUM_CLASSES):
        super().__init__()
        self.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.fc(x)


print("DomainSpecificHead class defined.")


In [None]:
# Cell 9: DAD Module - Timestep Embedding & p_theta Denoiser MLP

class SinusoidalTimestepEmbedding(nn.Module):
    def __init__(self, dim, max_period=10000):
        super().__init__()
        self.dim = dim
        self.max_period = max_period

    def forward(self, timesteps):
        if timesteps.ndim == 0:
            timesteps = timesteps.unsqueeze(0)
        device = timesteps.device
        half_dim = self.dim // 2
        freqs = torch.exp(
            -math.log(self.max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32) / half_dim
        ).to(device)
        args = timesteps.float().unsqueeze(1) * freqs.unsqueeze(0)
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if self.dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding


class DAD_P_Theta_MLP(nn.Module):
    def __init__(self, feature_dim=RESNET_EMBED_DIM,
                 timestep_embed_dim=DAD_P_THETA_TIMESTEP_EMBED_DIM,
                 ts_embed_hidden_mult=DAD_P_THETA_HIDDEN_DIM_MULT,
                 mlp_hidden_dim=DAD_P_THETA_MLP_HIDDEN_DIM):
        super().__init__()
        self.timestep_embed_dim_actual = timestep_embed_dim

        self.timestep_encoder = nn.Sequential(
            SinusoidalTimestepEmbedding(timestep_embed_dim),
            nn.Linear(timestep_embed_dim, timestep_embed_dim * ts_embed_hidden_mult),
            nn.GELU(),
            nn.Linear(timestep_embed_dim * ts_embed_hidden_mult, timestep_embed_dim * ts_embed_hidden_mult)
        )

        combined_dim = feature_dim + (timestep_embed_dim * ts_embed_hidden_mult)

        # Larger MLP for ResNet's 2048-dim features
        self.mlp = nn.Sequential(
            nn.Linear(combined_dim, mlp_hidden_dim * 2),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim * 2, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, feature_dim)
        )

    def forward(self, noisy_features, timesteps_long):
        timesteps_long = timesteps_long.to(noisy_features.device)
        ts_embedding = self.timestep_encoder(timesteps_long)

        if ts_embedding.ndim == 1 and noisy_features.ndim > 1:
            ts_embedding = ts_embedding.unsqueeze(0).expand(noisy_features.shape[0], -1)

        combined_input = torch.cat([noisy_features, ts_embedding], dim=-1)
        predicted_noise = self.mlp(combined_input)
        return predicted_noise


print("DAD p_theta MLP denoiser and Timestep Embedding defined.")

# Test p_theta
p_theta_test = DAD_P_Theta_MLP().to(DEVICE)
test_noisy_feat = torch.randn(BATCH_SIZE, RESNET_EMBED_DIM).to(DEVICE)
test_timesteps = torch.randint(0, DAD_K_STEPS, (BATCH_SIZE,), dtype=torch.long).to(DEVICE)
pred_noise_test = p_theta_test(test_noisy_feat, test_timesteps)
print(f"p_theta test output shape: {pred_noise_test.shape}")


In [None]:
# Cell 10: DAD Diffusion Utilities

def make_dad_schedule(beta_start=DAD_BETA_START, beta_end=DAD_BETA_END, num_steps=DAD_K_STEPS, device=DEVICE):
    betas = torch.linspace(beta_start, beta_end, num_steps, dtype=torch.float32, device=device)
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    return betas, alphas, alphas_cumprod


if DEVICE.type == 'cuda' and not torch.cuda.is_available():
    print("Warning: DEVICE is cuda but not available. DAD schedules on CPU.")
    _sched_device = torch.device('cpu')
else:
    _sched_device = DEVICE

DAD_BETAS, DAD_ALPHAS, DAD_ALPHAS_CUMPROD = make_dad_schedule(device=_sched_device)


def q_sample_dad(x_start, k_indices, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    k_indices = k_indices.long().to(DAD_ALPHAS_CUMPROD.device)

    sqrt_alphas_cumprod_t = torch.sqrt(DAD_ALPHAS_CUMPROD[k_indices])
    sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1.0 - DAD_ALPHAS_CUMPROD[k_indices])

    if x_start.ndim > 1 and sqrt_alphas_cumprod_t.ndim == 1:
        sqrt_alphas_cumprod_t = sqrt_alphas_cumprod_t.view(-1, *([1] * (x_start.ndim - 1)))
        sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod_t.view(-1, *([1] * (x_start.ndim - 1)))

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise


def q_reconstruct_dad(x_k, epsilon_theta_hat, k_indices):
    k_indices = k_indices.long().to(DAD_ALPHAS_CUMPROD.device)

    sqrt_alphas_cumprod_t = torch.sqrt(DAD_ALPHAS_CUMPROD[k_indices])
    sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1.0 - DAD_ALPHAS_CUMPROD[k_indices])

    if x_k.ndim > 1 and sqrt_alphas_cumprod_t.ndim == 1:
        sqrt_alphas_cumprod_t = sqrt_alphas_cumprod_t.view(-1, *([1] * (x_k.ndim - 1)))
        sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod_t.view(-1, *([1] * (x_k.ndim - 1)))

    x0_hat = (x_k - sqrt_one_minus_alphas_cumprod_t * epsilon_theta_hat) / sqrt_alphas_cumprod_t
    return x0_hat


print("DAD diffusion utilities defined.")
print(f"DAD schedules created on device: {DAD_BETAS.device}")


In [None]:
# Cell 11: Utility Functions (EMA Update, Replay Buffer)

def update_ema_teacher_components(student_resnet, student_head, teacher_resnet, teacher_head, decay):
    with torch.no_grad():
        for stud_param, teach_param in zip(student_resnet.parameters(), teacher_resnet.parameters()):
            if stud_param.requires_grad:
                teach_param.data.mul_(decay).add_(stud_param.data, alpha=1 - decay)

        for stud_param, teach_param in zip(student_head.parameters(), teacher_head.parameters()):
            teach_param.data.mul_(decay).add_(stud_param.data, alpha=1 - decay)


class ExperienceReplayBuffer:
    def __init__(self, buffer_size, device=DEVICE):
        self.buffer_size = buffer_size
        self.device = device
        self.buffer_images = []
        self.buffer_labels = []
        self.position = 0

    def add(self, images_tensor, labels_tensor):
        batch_size = images_tensor.size(0)
        for i in range(batch_size):
            img = images_tensor[i].cpu()
            lbl = labels_tensor[i].cpu()
            if len(self.buffer_images) < self.buffer_size:
                self.buffer_images.append(img)
                self.buffer_labels.append(lbl)
            else:
                self.buffer_images[self.position] = img
                self.buffer_labels[self.position] = lbl
            self.position = (self.position + 1) % self.buffer_size

    def sample(self, batch_size):
        if len(self.buffer_images) < batch_size:
            if not self.buffer_images:
                return None, None
            indices = np.random.choice(len(self.buffer_images), len(self.buffer_images), replace=False)
        else:
            indices = np.random.choice(len(self.buffer_images), batch_size, replace=False)

        sampled_images = torch.stack([self.buffer_images[i] for i in indices]).to(self.device)
        sampled_labels = torch.stack([self.buffer_labels[i] for i in indices]).to(self.device)
        return sampled_images, sampled_labels.long()

    def __len__(self):
        return len(self.buffer_images)


print("Utility functions defined.")


# Part 3 - Source Domain (Art) Supervised Training


In [None]:
# Cell 12: Source Domain - Dataset & DataLoader

print(f"\n--- Preparing Source Domain: {SOURCE_DOMAIN_NAME} ---")

source_domain_model_save_dir = os.path.join(MODEL_SAVE_DIR_BASE, SOURCE_DOMAIN_NAME)
os.makedirs(source_domain_model_save_dir, exist_ok=True)

if GLOBAL_CLASS_TO_IDX is None:
    raise RuntimeError("GLOBAL_CLASS_TO_IDX is not defined. Cannot proceed.")

source_train_dataset = OfficeHomeDomainDataset(
    DATA_DIR, SOURCE_DOMAIN_NAME,
    transform=train_transform_strong,
    split_type='train',
    class_to_idx_mapping=GLOBAL_CLASS_TO_IDX
)
source_val_dataset = OfficeHomeDomainDataset(
    DATA_DIR, SOURCE_DOMAIN_NAME,
    transform=val_test_transform_weak,
    split_type='val',
    class_to_idx_mapping=GLOBAL_CLASS_TO_IDX
)

source_train_loader = DataLoader(
    source_train_dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=True, drop_last=True
)
source_val_loader = DataLoader(
    source_val_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True
)

print(f"Source domain '{SOURCE_DOMAIN_NAME}': Train size {len(source_train_dataset)}, Val size {len(source_val_dataset)}")
print(f"Train loader batches: {len(source_train_loader)}, Val loader batches: {len(source_val_loader)}")


In [None]:
# Cell 13: Source Domain - Model Instantiation & Optimizer

# 1. Load base frozen ResNet-50
base_resnet_source_train = load_frozen_resnet_backbone(device=DEVICE)

# 2. Create a new LoRA-adapted ResNet for the source domain
source_resnet_lora = copy.deepcopy(base_resnet_source_train)
source_resnet_lora = inject_lora_to_resnet(
    source_resnet_lora, rank=LORA_RANK, alpha=LORA_ALPHA, lora_dropout_p=LORA_DROPOUT
)
source_resnet_lora = set_batchnorm_affine_trainable(source_resnet_lora, trainable=True)
source_resnet_lora = source_resnet_lora.to(DEVICE)

# 3. Instantiate Source-Specific Head
source_head = DomainSpecificHead(in_features=RESNET_EMBED_DIM, num_classes=NUM_CLASSES).to(DEVICE)

# 4. Optimizer
params_to_train_source = []
for name, param in source_resnet_lora.named_parameters():
    if param.requires_grad:
        params_to_train_source.append(param)
params_to_train_source.extend(list(source_head.parameters()))

optimizer_source = optim.AdamW(params_to_train_source, lr=ART_LR_LORA_HEAD_BN, weight_decay=0.05)
criterion_source = nn.CrossEntropyLoss(label_smoothing=0.1)

num_total_steps_source = ART_EPOCHS * len(source_train_loader)
scheduler_source = optim.lr_scheduler.LambdaLR(
    optimizer_source,
    lr_lambda=lambda step: (1 - step / num_total_steps_source) ** 0.9
)

grad_scaler_source = GradScaler(enabled=(DEVICE.type == 'cuda'))

print(f"Source expert models (LoRA ResNet, Head) instantiated for '{SOURCE_DOMAIN_NAME}'.")
trainable_params_count_source = sum(p.numel() for p in params_to_train_source)
print(f"Total trainable parameters for source expert: {trainable_params_count_source:,}")


In [None]:
# Cell 14: Source Domain - Training Loop

print(f"\n--- Training Source Expert ({SOURCE_DOMAIN_NAME}) ---")
best_source_val_acc = 0.0

for epoch in range(ART_EPOCHS):
    source_resnet_lora.train()
    source_head.train()

    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    progress_bar = tqdm(source_train_loader, desc=f"Epoch {epoch+1}/{ART_EPOCHS} [{SOURCE_DOMAIN_NAME} Train]")
    for images, labels in progress_bar:
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        optimizer_source.zero_grad()

        with autocast_ctx():
            features = get_resnet_features(source_resnet_lora, images)
            logits = source_head(features)
            loss = criterion_source(logits, labels)

        grad_scaler_source.scale(loss).backward()
        grad_scaler_source.step(optimizer_source)
        grad_scaler_source.update()
        scheduler_source.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(logits, 1)
        correct_predictions += torch.sum(preds == labels.data)
        total_samples += images.size(0)

        progress_bar.set_postfix(
            loss=loss.item(),
            acc=correct_predictions.double().item() / total_samples if total_samples > 0 else 0.0,
            lr=optimizer_source.param_groups[0]['lr']
        )

    epoch_loss = running_loss / total_samples
    epoch_acc = correct_predictions.double() / total_samples
    print(f"Epoch {epoch+1} Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}")

    # Validation
    source_resnet_lora.eval()
    source_head.eval()
    val_loss = 0.0
    val_correct_predictions = 0
    val_total_samples = 0

    with torch.no_grad():
        for images, labels in tqdm(source_val_loader, desc=f"Epoch {epoch+1}/{ART_EPOCHS} [{SOURCE_DOMAIN_NAME} Val]", leave=False):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            with autocast_ctx():
                features = get_resnet_features(source_resnet_lora, images)
                logits = source_head(features)
                loss = criterion_source(logits, labels)

            val_loss += loss.item() * images.size(0)
            _, preds = torch.max(logits, 1)
            val_correct_predictions += torch.sum(preds == labels.data)
            val_total_samples += images.size(0)

    epoch_val_loss = val_loss / val_total_samples
    epoch_val_acc = val_correct_predictions.double() / val_total_samples
    print(f"Epoch {epoch+1} Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")

    if epoch_val_acc > best_source_val_acc:
        best_source_val_acc = epoch_val_acc
        torch.save(source_resnet_lora.state_dict(), os.path.join(source_domain_model_save_dir, f"{SOURCE_DOMAIN_NAME.lower()}_resnet_lora_best.pth"))
        torch.save(source_head.state_dict(), os.path.join(source_domain_model_save_dir, f"{SOURCE_DOMAIN_NAME.lower()}_head_best.pth"))
        print(f"    -> New best Val Acc: {best_source_val_acc:.4f}. Models saved.")

print(f"\nSource domain '{SOURCE_DOMAIN_NAME}' training finished. Best Val Acc: {best_source_val_acc:.4f}")

# Load best models
source_resnet_lora.load_state_dict(torch.load(os.path.join(source_domain_model_save_dir, f"{SOURCE_DOMAIN_NAME.lower()}_resnet_lora_best.pth")))
source_head.load_state_dict(torch.load(os.path.join(source_domain_model_save_dir, f"{SOURCE_DOMAIN_NAME.lower()}_head_best.pth")))
source_resnet_lora.eval()
source_head.eval()
print("Best source expert models loaded and set to eval mode.")


In [None]:
# Cell 15: Baseline Validation of Source Expert on All Domains

print(f"\n\n--- Baseline Validation of Source Expert ('{SOURCE_DOMAIN_NAME}') on All Domains ---")

if 'source_resnet_lora' not in globals() or source_resnet_lora is None or \
   'source_head' not in globals() or source_head is None:
    print(f"Warning: Source expert ({SOURCE_DOMAIN_NAME}) models not available. Skipping baseline validation.")
else:
    source_resnet_lora.eval()
    source_head.eval()

    baseline_accuracies_source_expert = {}
    criterion_val_baseline = nn.CrossEntropyLoss(label_smoothing=0.1)

    for domain_name_eval in ALL_TRAINABLE_DOMAIN_NAMES:
        print(f"\n  Validating Source Expert ('{SOURCE_DOMAIN_NAME}') on '{domain_name_eval}' domain...")

        try:
            val_dataset_current_domain = OfficeHomeDomainDataset(
                DATA_DIR, domain_name_eval,
                transform=val_test_transform_weak,
                split_type='val',
                class_to_idx_mapping=GLOBAL_CLASS_TO_IDX
            )
            if len(val_dataset_current_domain) == 0:
                print(f"    Warning: Validation dataset for '{domain_name_eval}' is empty. Skipping.")
                baseline_accuracies_source_expert[domain_name_eval] = 0.0
                continue

            val_loader_current_domain = DataLoader(
                val_dataset_current_domain,
                batch_size=BATCH_SIZE,
                shuffle=False,
                num_workers=NUM_WORKERS,
                pin_memory=True
            )
        except Exception as e:
            print(f"    ERROR: Could not load validation dataset for '{domain_name_eval}': {e}. Skipping.")
            baseline_accuracies_source_expert[domain_name_eval] = -1.0
            continue

        val_loss_domain = 0.0
        val_correct_predictions_domain = 0
        val_total_samples_domain = 0

        with torch.no_grad():
            for images, labels in tqdm(val_loader_current_domain, desc=f"Val on {domain_name_eval}", leave=False):
                images, labels = images.to(DEVICE), labels.to(DEVICE)

                with autocast_ctx():
                    features = get_resnet_features(source_resnet_lora, images)
                    logits = source_head(features)
                    loss = criterion_val_baseline(logits, labels)

                val_loss_domain += loss.item() * images.size(0)
                _, preds = torch.max(logits, 1)
                val_correct_predictions_domain += torch.sum(preds == labels.data)
                val_total_samples_domain += images.size(0)

        if val_total_samples_domain > 0:
            epoch_val_loss_domain = val_loss_domain / val_total_samples_domain
            epoch_val_acc_domain = val_correct_predictions_domain.double() / val_total_samples_domain
            baseline_accuracies_source_expert[domain_name_eval] = epoch_val_acc_domain.item() * 100
            print(f"    '{domain_name_eval}' Val Loss: {epoch_val_loss_domain:.4f}, Val Acc: {epoch_val_acc_domain.item()*100:.2f}%")
        else:
            print(f"    No samples processed for '{domain_name_eval}' validation.")
            baseline_accuracies_source_expert[domain_name_eval] = 0.0

    print("\n--- Source Expert Baseline Performance Summary ---")
    avg_baseline_acc = 0
    count_valid_domains = 0
    for domain, acc in baseline_accuracies_source_expert.items():
        if acc != -1.0:
            print(f"  Accuracy of '{SOURCE_DOMAIN_NAME}' expert on '{domain}': {acc:.2f}%")
            if domain != SOURCE_DOMAIN_NAME:
                avg_baseline_acc += acc
                count_valid_domains += 1
        else:
            print(f"  Accuracy of '{SOURCE_DOMAIN_NAME}' expert on '{domain}': ERROR")

    if count_valid_domains > 0:
        avg_target_acc = avg_baseline_acc / count_valid_domains
        print(f"  Average Target Domain Accuracy (Source Expert): {avg_target_acc:.2f}%")


In [None]:
# Cell 16: Prepare for Continual Learning - Store Experts

adapted_experts = {}

current_expert_resnet = copy.deepcopy(source_resnet_lora).cpu()
current_expert_head = copy.deepcopy(source_head).cpu()
current_expert_domain_name = SOURCE_DOMAIN_NAME

# Global frozen ResNet backbone for DC head
base_resnet_frozen_global = load_frozen_resnet_backbone(device=DEVICE)

global_replay_buffer = None

print(f"Preparation for continual learning complete. Current expert: '{current_expert_domain_name}'.")
print(f"Base frozen ResNet ('base_resnet_frozen_global') loaded for future use.")


In [None]:
# Cell 17: Continual Adaptation Loop for Target Domains

for target_domain_idx, target_domain_name_current_loop in enumerate(TARGET_DOMAIN_NAMES_ORDERED):
    
    TARGET_DOMAIN_NAME_CURRENT = target_domain_name_current_loop
    if target_domain_idx == 0 and current_expert_domain_name == SOURCE_DOMAIN_NAME:
        PREVIOUS_DOMAIN_NAME = SOURCE_DOMAIN_NAME
    elif target_domain_idx > 0 and current_expert_domain_name == TARGET_DOMAIN_NAMES_ORDERED[target_domain_idx-1]:
        PREVIOUS_DOMAIN_NAME = current_expert_domain_name
    else:
        PREVIOUS_DOMAIN_NAME = current_expert_domain_name

    print(f"\n\n=== Starting Adaptation: {PREVIOUS_DOMAIN_NAME} -> {TARGET_DOMAIN_NAME_CURRENT} ===")

    target_model_save_dir = os.path.join(MODEL_SAVE_DIR_BASE, TARGET_DOMAIN_NAME_CURRENT)
    os.makedirs(target_model_save_dir, exist_ok=True)

    # --- 1. Datasets & DataLoaders ---
    print(f"  Loading datasets for Previous ('{PREVIOUS_DOMAIN_NAME}') and Target ('{TARGET_DOMAIN_NAME_CURRENT}')...")
    previous_domain_train_dataset = OfficeHomeDomainDataset(
        DATA_DIR, PREVIOUS_DOMAIN_NAME, transform=train_transform_strong,
        split_type='train', class_to_idx_mapping=GLOBAL_CLASS_TO_IDX
    )
    previous_domain_train_loader = DataLoader(
        previous_domain_train_dataset, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=NUM_WORKERS, pin_memory=True, drop_last=True
    )

    target_unlabeled_loader_for_ema_shot = DataLoader(
        OfficeHomeDomainDataset(DATA_DIR, TARGET_DOMAIN_NAME_CURRENT, transform=val_test_transform_weak,
                                split_type='train', class_to_idx_mapping=GLOBAL_CLASS_TO_IDX),
        batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True
    )
    fixmatch_target_dataset = FixMatchOfficeHomeDataset(
        DATA_DIR, TARGET_DOMAIN_NAME_CURRENT,
        transform_weak=val_test_transform_weak,
        transform_strong=train_transform_strong,
        split_type='train',
        class_to_idx_mapping=GLOBAL_CLASS_TO_IDX
    )
    fixmatch_target_loader = DataLoader(
        fixmatch_target_dataset, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=NUM_WORKERS, pin_memory=True, drop_last=True
    )
    target_val_dataset = OfficeHomeDomainDataset(
        DATA_DIR, TARGET_DOMAIN_NAME_CURRENT, transform=val_test_transform_weak,
        split_type='val', class_to_idx_mapping=GLOBAL_CLASS_TO_IDX
    )
    target_val_loader = DataLoader(target_val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

    print(f"  Previous domain ('{PREVIOUS_DOMAIN_NAME}') train loader: {len(previous_domain_train_loader)} batches.")
    print(f"  Target domain ('{TARGET_DOMAIN_NAME_CURRENT}'):")
    print(f"    EMA/SHOT loader dataset size: {len(target_unlabeled_loader_for_ema_shot.dataset)}")
    print(f"    FixMatch loader dataset size: {len(fixmatch_target_loader.dataset)}")
    print(f"    Val dataset size: {len(target_val_dataset)}")

    # --- 2. Model Initialization ---
    print(f"  Initializing Student and Teacher models for '{TARGET_DOMAIN_NAME_CURRENT}'...")
    student_resnet = copy.deepcopy(base_resnet_frozen_global)
    student_resnet = inject_lora_to_resnet(student_resnet, rank=LORA_RANK, alpha=LORA_ALPHA, lora_dropout_p=LORA_DROPOUT)

    prev_expert_resnet_state_dict = current_expert_resnet.cpu().state_dict()
    student_resnet_state_dict_new = student_resnet.state_dict()
    load_dict_student_resnet_warm_start = {}
    for k, v in prev_expert_resnet_state_dict.items():
        if ('lora_' in k or ('bn' in k.lower() and (k.endswith('.weight') or k.endswith('.bias')))) and \
           k in student_resnet_state_dict_new and student_resnet_state_dict_new[k].shape == v.shape:
            load_dict_student_resnet_warm_start[k] = v
    if load_dict_student_resnet_warm_start:
        missing_keys, unexpected_keys = student_resnet.load_state_dict(load_dict_student_resnet_warm_start, strict=False)
        print(f"  Loaded LoRA/BN from previous expert '{current_expert_domain_name}'. Missing: {len(missing_keys)}, Unexpected: {len(unexpected_keys)}")
    else:
        print(f"  Warning: No LoRA/BN weights loaded from previous expert '{current_expert_domain_name}'.")

    student_resnet = set_batchnorm_affine_trainable(student_resnet, trainable=True)
    student_resnet = student_resnet.to(DEVICE)

    student_head = DomainSpecificHead(in_features=RESNET_EMBED_DIM, num_classes=NUM_CLASSES)
    student_head.load_state_dict(current_expert_head.cpu().state_dict())
    student_head = student_head.to(DEVICE)

    teacher_resnet = copy.deepcopy(student_resnet).to(DEVICE)
    teacher_head = copy.deepcopy(student_head).to(DEVICE)
    for param in teacher_resnet.parameters():
        param.requires_grad = False
    for param in teacher_head.parameters():
        param.requires_grad = False
    teacher_resnet.eval()
    teacher_head.eval()
    update_ema_teacher_components(student_resnet, student_head, teacher_resnet, teacher_head, 0.0)

    p_theta_dad = DAD_P_Theta_MLP().to(DEVICE)
    print(f"  Student, Teacher, and DAD p_theta models ready for '{TARGET_DOMAIN_NAME_CURRENT}'.")

    # --- 3. Optimizers ---
    params_to_train_student = []
    for name, param in student_resnet.named_parameters():
        if param.requires_grad:
            params_to_train_student.append(param)
    params_to_train_student.extend(list(student_head.parameters()))

    optimizer_student = optim.AdamW(params_to_train_student, lr=ADAPT_LR_LORA_HEAD_BN, weight_decay=0.05)
    optimizer_p_theta = optim.SGD(p_theta_dad.parameters(), lr=ADAPT_LR_P_THETA, momentum=0.9, weight_decay=4.5e-3)

    grad_scaler_adapt = GradScaler(enabled=(DEVICE.type == 'cuda'))
    criterion_adapt_ce = nn.CrossEntropyLoss(label_smoothing=0.1)
    criterion_adapt_mse = nn.MSELoss()
    print("  Optimizers created.")

    # --- 4. DAD LTR Pre-training ---
    print(f"  Starting DAD LTR Pre-training for p_theta ({PREVIOUS_DOMAIN_NAME} -> {TARGET_DOMAIN_NAME_CURRENT})...")
    p_theta_dad.train()
    student_resnet.eval()
    student_head.eval()
    for ltr_epoch in range(ADAPT_LTR_EPOCHS):
        ltr_running_loss = 0.0
        progress_bar_ltr = tqdm(target_unlabeled_loader_for_ema_shot, desc=f"LTR Epoch {ltr_epoch+1}/{ADAPT_LTR_EPOCHS} for {TARGET_DOMAIN_NAME_CURRENT}", leave=False)
        for target_images_weak, _ in progress_bar_ltr:
            target_images_weak = target_images_weak.to(DEVICE)
            optimizer_p_theta.zero_grad()
            with torch.no_grad():
                F_T_cls = get_resnet_features(student_resnet, target_images_weak).to(torch.float32)
            t_indices = torch.randint(0, DAD_K_STEPS, (F_T_cls.size(0),), device=DEVICE, dtype=torch.long)
            noise = torch.randn_like(F_T_cls)
            F_T_noisy = q_sample_dad(F_T_cls, t_indices, noise)
            predicted_noise = p_theta_dad(F_T_noisy, t_indices)
            loss_ltr = criterion_adapt_mse(predicted_noise, noise)
            loss_ltr.backward()
            optimizer_p_theta.step()
            ltr_running_loss += loss_ltr.item() * target_images_weak.size(0)
            progress_bar_ltr.set_postfix(ltr_loss=loss_ltr.item())
        avg_ltr_loss = ltr_running_loss / len(target_unlabeled_loader_for_ema_shot.dataset) if len(target_unlabeled_loader_for_ema_shot.dataset) > 0 else 0
        print(f"  LTR Epoch {ltr_epoch+1} ({TARGET_DOMAIN_NAME_CURRENT}) Avg Loss: {avg_ltr_loss:.4f}")
    print(f"  DAD LTR Pre-training for '{TARGET_DOMAIN_NAME_CURRENT}' finished.")

    # --- 5. Main Adaptation Loop ---
    print(f"  Starting Main Adaptation Loop for '{TARGET_DOMAIN_NAME_CURRENT}'...")
    
    best_target_val_acc = 0.0
    val_loss_at_best_acc = float('inf')
    patience_counter_adapt = 0

    num_total_adapt_optimizer_steps_current = DAD_K_STEPS * (ADAPT_MLS_R_ITER + 1)
    scheduler_student_adapt = optim.lr_scheduler.CosineAnnealingLR(
        optimizer_student, T_max=num_total_adapt_optimizer_steps_current, eta_min=1e-6
    )

    source_iter = iter(previous_domain_train_loader)
    target_iter_ema_shot = iter(target_unlabeled_loader_for_ema_shot)
    fixmatch_iter = iter(fixmatch_target_loader)
    outer_progress_bar = tqdm(range(DAD_K_STEPS), desc=f"DAD Steps for {TARGET_DOMAIN_NAME_CURRENT}")

    for k_dad_step_idx in outer_progress_bar:
        for mls_iter_idx in range(ADAPT_MLS_R_ITER):
            try:
                source_images, source_labels = next(source_iter)
            except StopIteration:
                source_iter = iter(previous_domain_train_loader)
                source_images, source_labels = next(source_iter)
            source_images, source_labels = source_images.to(DEVICE), source_labels.to(DEVICE)

            p_theta_dad.train()
            student_resnet.eval()
            student_head.eval()
            with torch.no_grad():
                F_S_cls = get_resnet_features(student_resnet, source_images).to(torch.float32)
            t_k_current_long = torch.full((F_S_cls.size(0),), k_dad_step_idx, device=DEVICE, dtype=torch.long)
            noise_cd = torch.randn_like(F_S_cls)
            F_S_cls_noisy_cd = q_sample_dad(F_S_cls, t_k_current_long, noise_cd)
            predicted_noise_cd = p_theta_dad(F_S_cls_noisy_cd, t_k_current_long)
            with torch.no_grad():
                x0_hat_cd = q_reconstruct_dad(F_S_cls_noisy_cd, predicted_noise_cd, t_k_current_long)
            with autocast_ctx():
                logits_for_ptheta_update = student_head(x0_hat_cd)
            loss_cd = criterion_adapt_ce(logits_for_ptheta_update, source_labels)
            optimizer_p_theta.zero_grad()
            loss_cd.backward()
            optimizer_p_theta.step()

            p_theta_dad.eval()
            student_resnet.train()
            student_head.train()
            with torch.no_grad():
                predicted_noise_dc = p_theta_dad(F_S_cls_noisy_cd, t_k_current_long)
                x0_hat_dc = q_reconstruct_dad(F_S_cls_noisy_cd, predicted_noise_dc, t_k_current_long)
            with autocast_ctx():
                logits_for_student_update = student_head(x0_hat_dc.detach())
                loss_dc_val = criterion_adapt_ce(logits_for_student_update, source_labels)
            optimizer_student.zero_grad()
            grad_scaler_adapt.scale(loss_dc_val).backward()
            grad_scaler_adapt.step(optimizer_student)
            grad_scaler_adapt.update()
            scheduler_student_adapt.step()

        current_mls_postfix = {"MLS_D->C": f"{loss_dc_val.item():.3f}", "MLS_C->D": f"{loss_cd.item():.3f}"}
        outer_progress_bar.set_postfix(current_mls_postfix)

        # EMA, FixMatch, SHOT updates
        student_resnet.train()
        student_head.train()
        try:
            target_images_for_ema_shot, _ = next(target_iter_ema_shot)
        except StopIteration:
            target_iter_ema_shot = iter(target_unlabeled_loader_for_ema_shot)
            target_images_for_ema_shot, _ = next(target_iter_ema_shot)
        target_images_for_ema_shot = target_images_for_ema_shot.to(DEVICE)
        try:
            target_images_weak_fixmatch, target_images_strong_fixmatch = next(fixmatch_iter)
        except StopIteration:
            fixmatch_iter = iter(fixmatch_target_loader)
            target_images_weak_fixmatch, target_images_strong_fixmatch = next(fixmatch_iter)
        target_images_weak_fixmatch = target_images_weak_fixmatch.to(DEVICE)
        target_images_strong_fixmatch = target_images_strong_fixmatch.to(DEVICE)

        optimizer_student.zero_grad()
        loss_pl_val_ema = torch.tensor(0.0, device=DEVICE)
        loss_fixmatch_val_calc = torch.tensor(0.0, device=DEVICE)
        current_progress_ratio = k_dad_step_idx / max(1, DAD_K_STEPS - 1)
        current_pseudo_label_thresh = PSEUDO_LABEL_THRESHOLD_START + \
                                     (PSEUDO_LABEL_THRESHOLD_END - PSEUDO_LABEL_THRESHOLD_START) * current_progress_ratio
        num_pseudo_labels_ema = 0
        with torch.no_grad(), autocast_ctx():
            teacher_resnet.eval()
            teacher_head.eval()
            F_T_teacher = get_resnet_features(teacher_resnet, target_images_for_ema_shot)
            logits_teacher = teacher_head(F_T_teacher)
            probs_teacher = F.softmax(logits_teacher, dim=1)
            max_confidence_values, predicted_class_indices = torch.max(probs_teacher, dim=1)
            confidence_mask_ema = max_confidence_values >= current_pseudo_label_thresh
            pseudo_labels_for_loss = predicted_class_indices
        if confidence_mask_ema.any():
            num_pseudo_labels_ema = confidence_mask_ema.sum().item()
            with autocast_ctx():
                selected_target_images_ema = target_images_for_ema_shot[confidence_mask_ema]
                F_T_student_ema = get_resnet_features(student_resnet, selected_target_images_ema)
                logits_student_on_pseudo = student_head(F_T_student_ema)
                loss_pl_val_ema = criterion_adapt_ce(logits_student_on_pseudo, pseudo_labels_for_loss[confidence_mask_ema])
        num_pseudo_labels_fixmatch = 0
        with torch.no_grad(), autocast_ctx():
            F_T_student_weak_fixmatch = get_resnet_features(student_resnet, target_images_weak_fixmatch)
            logits_weak_student_fixmatch = student_head(F_T_student_weak_fixmatch)
            probs_weak_student_fixmatch = F.softmax(logits_weak_student_fixmatch, dim=1)
            max_probs_weak_fixmatch, pseudo_labels_weak_fixmatch = torch.max(probs_weak_student_fixmatch, dim=1)
            confidence_mask_fixmatch = max_probs_weak_fixmatch >= FIXMATCH_CONF_THRESHOLD
        if confidence_mask_fixmatch.any():
            num_pseudo_labels_fixmatch = confidence_mask_fixmatch.sum().item()
            with autocast_ctx():
                selected_target_images_strong_fixmatch = target_images_strong_fixmatch[confidence_mask_fixmatch]
                F_T_student_strong_fixmatch = get_resnet_features(student_resnet, selected_target_images_strong_fixmatch)
                logits_student_on_strong_fixmatch = student_head(F_T_student_strong_fixmatch)
                loss_fixmatch_val_calc = criterion_adapt_ce(logits_student_on_strong_fixmatch, pseudo_labels_weak_fixmatch[confidence_mask_fixmatch])
                loss_fixmatch_val_calc = FIXMATCH_LAMBDA * loss_fixmatch_val_calc
        loss_for_head_update = torch.tensor(0.0, device=DEVICE)
        if loss_pl_val_ema.item() > 0:
            loss_for_head_update += loss_pl_val_ema
        if loss_fixmatch_val_calc.item() > 0:
            loss_for_head_update += loss_fixmatch_val_calc
        if loss_for_head_update.item() > 0:
            grad_scaler_adapt.scale(loss_for_head_update).backward(retain_graph=True)

        # SHOT loss
        _temp_head_training_mode_shot = student_head.training
        _temp_head_params_req_grad_shot = [p.requires_grad for p in student_head.parameters()]
        student_head.eval()
        for p_head in student_head.parameters():
            p_head.requires_grad = False
        with autocast_ctx():
            F_T_student_shot = get_resnet_features(student_resnet, target_images_for_ema_shot)
            logits_student_shot = student_head(F_T_student_shot)
            probs_student_shot = F.softmax(logits_student_shot, dim=1)
            loss_cond_ent = - (probs_student_shot * torch.log(probs_student_shot + 1e-9)).sum(1).mean()
            mean_probs_shot = probs_student_shot.mean(0)
            loss_ent_max = - (mean_probs_shot * torch.log(mean_probs_shot + 1e-9)).sum()
            loss_shot_val_calculated = SHOT_LAMBDA_COND_ENT * loss_cond_ent + SHOT_LAMBDA_ENT_MAX * loss_ent_max
        if loss_shot_val_calculated.item() != 0:
            grad_scaler_adapt.scale(loss_shot_val_calculated).backward()
        student_head.train(_temp_head_training_mode_shot)
        for i_param, p_rg_status_shot in enumerate(_temp_head_params_req_grad_shot):
            list(student_head.parameters())[i_param].requires_grad = p_rg_status_shot
        if loss_for_head_update.item() > 0 or loss_shot_val_calculated.item() > 0:
            grad_scaler_adapt.step(optimizer_student)
            grad_scaler_adapt.update()
        scheduler_student_adapt.step()
        update_ema_teacher_components(student_resnet, student_head, teacher_resnet, teacher_head, EMA_DECAY)

        # Validation at intervals
        if (k_dad_step_idx + 1) % (DAD_K_STEPS // 10) == 0 or k_dad_step_idx == DAD_K_STEPS - 1:
            student_resnet.eval()
            student_head.eval()
            val_loss_target_epoch = 0.0
            val_correct_target_epoch = 0
            val_total_target_epoch = 0
            with torch.no_grad():
                for images_val, labels_val in target_val_loader:
                    images_val, labels_val = images_val.to(DEVICE), labels_val.to(DEVICE)
                    with autocast_ctx():
                        features_val = get_resnet_features(student_resnet, images_val)
                        logits_val = student_head(features_val)
                        loss_val = criterion_adapt_ce(logits_val, labels_val)
                    val_loss_target_epoch += loss_val.item() * images_val.size(0)
                    _, preds_val = torch.max(logits_val, 1)
                    val_correct_target_epoch += torch.sum(preds_val == labels_val.data)
                    val_total_target_epoch += images_val.size(0)
            epoch_val_loss_target = val_loss_target_epoch / val_total_target_epoch if val_total_target_epoch > 0 else 0
            epoch_val_acc_target = val_correct_target_epoch.double() / val_total_target_epoch if val_total_target_epoch > 0 else 0.0
            print(f"  DAD Step {k_dad_step_idx+1}/{DAD_K_STEPS} - Target '{TARGET_DOMAIN_NAME_CURRENT}' Val Loss: {epoch_val_loss_target:.4f}, Val Acc: {epoch_val_acc_target.item():.4f}")

            if epoch_val_acc_target.item() > best_target_val_acc and val_total_target_epoch > 0:
                best_target_val_acc = epoch_val_acc_target.item()
                val_loss_at_best_acc = epoch_val_loss_target
                torch.save(student_resnet.state_dict(), os.path.join(target_model_save_dir, f"{TARGET_DOMAIN_NAME_CURRENT.lower()}_resnet_lora_best.pth"))
                torch.save(student_head.state_dict(), os.path.join(target_model_save_dir, f"{TARGET_DOMAIN_NAME_CURRENT.lower()}_head_best.pth"))
                print(f"    -> New best Val Acc for '{TARGET_DOMAIN_NAME_CURRENT}': {best_target_val_acc:.4f}. Models saved.")
                patience_counter_adapt = 0
            elif val_total_target_epoch > 0:
                patience_counter_adapt += 1
                print(f"    Val acc did not improve. Patience: {patience_counter_adapt}/{EARLY_STOPPING_PATIENCE_ADAPT}")

        if patience_counter_adapt >= EARLY_STOPPING_PATIENCE_ADAPT:
            print(f"  Early stopping triggered for '{TARGET_DOMAIN_NAME_CURRENT}' after {k_dad_step_idx+1} DAD steps.")
            break

    print(f"\nAdaptation to '{TARGET_DOMAIN_NAME_CURRENT}' finished. Best Val Acc: {best_target_val_acc:.4f}")

    if os.path.exists(os.path.join(target_model_save_dir, f"{TARGET_DOMAIN_NAME_CURRENT.lower()}_resnet_lora_best.pth")):
        student_resnet.load_state_dict(torch.load(os.path.join(target_model_save_dir, f"{TARGET_DOMAIN_NAME_CURRENT.lower()}_resnet_lora_best.pth")))
        student_head.load_state_dict(torch.load(os.path.join(target_model_save_dir, f"{TARGET_DOMAIN_NAME_CURRENT.lower()}_head_best.pth")))
        print(f"Loaded best saved models for {TARGET_DOMAIN_NAME_CURRENT}.")
    student_resnet.eval()
    student_head.eval()

    adapted_experts[TARGET_DOMAIN_NAME_CURRENT] = {
        'resnet': copy.deepcopy(student_resnet).cpu(),
        'head': copy.deepcopy(student_head).cpu()
    }
    current_expert_resnet = copy.deepcopy(student_resnet).cpu()
    current_expert_head = copy.deepcopy(student_head).cpu()
    current_expert_domain_name = TARGET_DOMAIN_NAME_CURRENT
    print(f"Best models for '{TARGET_DOMAIN_NAME_CURRENT}' stored. Ready for next adaptation.")

print("\n\n=== All Target Domain Adaptations Complete ===")


# Part 5 - Domain Classifier Training


In [None]:
# Cell 18: Domain Classifier - Prepare Multi-Domain Dataset & DataLoader

print("\n\n=== Part 5: Domain Classifier Training ===")

num_total_domains_for_dc = len(ALL_TRAINABLE_DOMAIN_NAMES)
domain_name_to_idx_dc = {name: i for i, name in enumerate(ALL_TRAINABLE_DOMAIN_NAMES)}
domain_idx_to_name_dc = {i: name for i, name in enumerate(ALL_TRAINABLE_DOMAIN_NAMES)}
print(f"Domain Classifier - Domain to Index Mapping: {domain_name_to_idx_dc}")


class MultiDomainDatasetForDC(Dataset):
    def __init__(self, root_dir, all_domain_names, domain_to_idx_map,
                 class_to_idx_overall_map, transform, split_type='train'):
        self.images_paths = []
        self.domain_labels = []

        for domain_name in all_domain_names:
            domain_idx = domain_to_idx_map.get(domain_name)
            if domain_idx is None:
                continue

            temp_domain_dataset = OfficeHomeDomainDataset(
                root_dir=root_dir, domain_name=domain_name,
                transform=None,
                split_type=split_type,
                class_to_idx_mapping=class_to_idx_overall_map,
                load_pil=False
            )
            self.images_paths.extend(temp_domain_dataset.images_paths)
            self.domain_labels.extend([domain_idx] * len(temp_domain_dataset.images_paths))
            print(f"  Added {len(temp_domain_dataset.images_paths)} images from '{domain_name}' for DC training.")

        self.transform = transform

    def __len__(self):
        return len(self.images_paths)

    def __getitem__(self, idx):
        img_path = self.images_paths[idx]
        image_pil = Image.open(img_path).convert('RGB')
        domain_label = self.domain_labels[idx]

        if self.transform:
            image_tensor = self.transform(image_pil)
        else:
            image_tensor = transforms.ToTensor()(image_pil)

        return image_tensor, torch.tensor(domain_label).long()


dc_train_dataset = MultiDomainDatasetForDC(
    DATA_DIR, ALL_TRAINABLE_DOMAIN_NAMES, domain_name_to_idx_dc,
    GLOBAL_CLASS_TO_IDX, transform=val_test_transform_weak, split_type='train'
)
dc_val_dataset = MultiDomainDatasetForDC(
    DATA_DIR, ALL_TRAINABLE_DOMAIN_NAMES, domain_name_to_idx_dc,
    GLOBAL_CLASS_TO_IDX, transform=val_test_transform_weak, split_type='val'
)

dc_train_loader = DataLoader(dc_train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
dc_val_loader = DataLoader(dc_val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print(f"Domain Classifier: Train size {len(dc_train_dataset)}, Val size {len(dc_val_dataset)}")


In [None]:
# Cell 19: Domain Classifier - Model, Optimizer, Training Loop

if 'base_resnet_frozen_global' not in globals() or base_resnet_frozen_global is None:
    print("Re-loading base_resnet_frozen_global.")
    base_resnet_frozen_global = load_frozen_resnet_backbone(device=DEVICE)

domain_classifier_head = DomainSpecificHead(
    in_features=RESNET_EMBED_DIM,
    num_classes=num_total_domains_for_dc
).to(DEVICE)

optimizer_dc = optim.AdamW(domain_classifier_head.parameters(), lr=DC_HEAD_LR, weight_decay=0.01)
criterion_dc = nn.CrossEntropyLoss(label_smoothing=0.1)
grad_scaler_dc = GradScaler(enabled=(DEVICE.type == 'cuda'))
scheduler_dc = optim.lr_scheduler.CosineAnnealingLR(optimizer_dc, T_max=DC_HEAD_EPOCHS * len(dc_train_loader))

print(f"Training Domain Classifier Head for {DC_HEAD_EPOCHS} epochs...")
best_dc_val_acc = 0.0
dc_model_save_path = os.path.join(MODEL_SAVE_DIR_BASE, "domain_classifier_head_best.pth")

for epoch in range(DC_HEAD_EPOCHS):
    domain_classifier_head.train()
    base_resnet_frozen_global.eval()

    running_loss_dc = 0.0
    correct_preds_dc = 0
    total_samples_dc = 0

    progress_bar_dc = tqdm(dc_train_loader, desc=f"DC Epoch {epoch+1}/{DC_HEAD_EPOCHS} [Train]")
    for images, domain_labels_true in progress_bar_dc:
        images, domain_labels_true = images.to(DEVICE), domain_labels_true.to(DEVICE)

        optimizer_dc.zero_grad()
        with autocast_ctx():
            with torch.no_grad():
                features_dc = get_resnet_features(base_resnet_frozen_global, images)
            domain_logits = domain_classifier_head(features_dc)
            loss_dc = criterion_dc(domain_logits, domain_labels_true)

        grad_scaler_dc.scale(loss_dc).backward()
        grad_scaler_dc.step(optimizer_dc)
        grad_scaler_dc.update()
        scheduler_dc.step()

        running_loss_dc += loss_dc.item() * images.size(0)
        _, preds_dc = torch.max(domain_logits, 1)
        correct_preds_dc += torch.sum(preds_dc == domain_labels_true.data)
        total_samples_dc += images.size(0)
        progress_bar_dc.set_postfix(loss=loss_dc.item(), acc=correct_preds_dc.double().item()/total_samples_dc)

    epoch_loss_dc = running_loss_dc / total_samples_dc
    epoch_acc_dc = correct_preds_dc.double() / total_samples_dc
    print(f"DC Epoch {epoch+1} Train Loss: {epoch_loss_dc:.4f}, Train Acc: {epoch_acc_dc:.4f}")

    # Validation
    domain_classifier_head.eval()
    val_loss_dc_epoch = 0.0
    val_correct_dc = 0
    val_total_dc = 0
    with torch.no_grad():
        for images_val, domain_labels_val_true in tqdm(dc_val_loader, desc=f"DC Epoch {epoch+1}/{DC_HEAD_EPOCHS} [Val]", leave=False):
            images_val, domain_labels_val_true = images_val.to(DEVICE), domain_labels_val_true.to(DEVICE)
            with autocast_ctx():
                features_val_dc = get_resnet_features(base_resnet_frozen_global, images_val)
                domain_logits_val = domain_classifier_head(features_val_dc)
                loss_val_dc = criterion_dc(domain_logits_val, domain_labels_val_true)

            val_loss_dc_epoch += loss_val_dc.item() * images_val.size(0)
            _, preds_val_dc = torch.max(domain_logits_val, 1)
            val_correct_dc += torch.sum(preds_val_dc == domain_labels_val_true.data)
            val_total_dc += images_val.size(0)

    epoch_val_loss_dc = val_loss_dc_epoch / val_total_dc if val_total_dc > 0 else 0
    epoch_val_acc_dc = val_correct_dc.double() / val_total_dc if val_total_dc > 0 else 0.0
    print(f"DC Epoch {epoch+1} Val Loss: {epoch_val_loss_dc:.4f}, Val Acc: {epoch_val_acc_dc:.4f}")

    if epoch_val_acc_dc > best_dc_val_acc:
        best_dc_val_acc = epoch_val_acc_dc
        torch.save(domain_classifier_head.state_dict(), dc_model_save_path)
        print(f"    -> New best DC Val Acc: {best_dc_val_acc:.4f}. DC Head saved.")

print(f"\nDomain Classifier training finished. Best Val Acc: {best_dc_val_acc:.4f}")
if os.path.exists(dc_model_save_path):
    domain_classifier_head.load_state_dict(torch.load(dc_model_save_path))
    domain_classifier_head.eval()
    print("Best Domain Classifier Head loaded.")


# Part 6 - Robust Inference Pipeline


In [None]:
# Cell 20: Robust Inference - TTA Definitions & Expert Population

print("\n\n=== Part 6: Robust Inference Pipeline Setup & Evaluation ===")
print("\n--- Defining Test-Time Augmentations and Populating Task Experts ---")

# TTA Definitions
tta_lite_transforms = [
    val_test_transform_weak,
    transforms.Compose([
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
]

tta_full_transforms_manual = [
    val_test_transform_weak,
    transforms.Compose([transforms.RandomHorizontalFlip(p=1.0), transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]),
    transforms.Compose([transforms.CenterCrop(int(IMAGE_SIZE * 0.875)), transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]),
    transforms.Compose([transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]),
    transforms.Compose([transforms.RandomRotation(degrees=(-15, 15)), transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]),
]

NUM_TTA_LITE_AUGMENTATIONS = len(tta_lite_transforms)
NUM_TTA_FULL_AUGMENTATIONS = len(tta_full_transforms_manual)
print(f"TTA Lite augmentations: {NUM_TTA_LITE_AUGMENTATIONS}, Full TTA augmentations: {NUM_TTA_FULL_AUGMENTATIONS}")

# Populate all_task_experts dictionary
all_task_experts = {}
print("\nAttempting to populate all_task_experts...")

# Add Source Expert
if 'source_resnet_lora' in globals() and source_resnet_lora is not None and \
   'source_head' in globals() and source_head is not None:
    try:
        all_task_experts[SOURCE_DOMAIN_NAME] = {
            'resnet': copy.deepcopy(source_resnet_lora).to(DEVICE).eval(),
            'head': copy.deepcopy(source_head).to(DEVICE).eval()
        }
        print(f"  Successfully added Source Expert: '{SOURCE_DOMAIN_NAME}'")
    except Exception as e:
        print(f"  Error adding Source Expert '{SOURCE_DOMAIN_NAME}': {e}")
else:
    print(f"  Warning: Source expert for '{SOURCE_DOMAIN_NAME}' not fully available. Skipping.")

# Add Adapted Target Experts
if 'adapted_experts' in globals() and isinstance(adapted_experts, dict) and adapted_experts:
    print(f"  Found adapted_experts with keys: {list(adapted_experts.keys())}")
    for domain_name_adapted, expert_models_adapted in adapted_experts.items():
        if isinstance(expert_models_adapted, dict) and \
           'resnet' in expert_models_adapted and expert_models_adapted['resnet'] is not None and \
           'head' in expert_models_adapted and expert_models_adapted['head'] is not None:
            try:
                all_task_experts[domain_name_adapted] = {
                    'resnet': copy.deepcopy(expert_models_adapted['resnet']).to(DEVICE).eval(),
                    'head': copy.deepcopy(expert_models_adapted['head']).to(DEVICE).eval()
                }
                print(f"    Successfully added Adapted Expert: '{domain_name_adapted}'")
            except Exception as e:
                print(f"    Error adding Adapted Expert '{domain_name_adapted}': {e}")
        else:
            print(f"    Warning: Incomplete models for adapted expert '{domain_name_adapted}'. Skipping.")
else:
    print("  Warning: `adapted_experts` dictionary not found or empty.")

if not all_task_experts:
    print("\nCRITICAL WARNING: all_task_experts is EMPTY. Inference will fail.")
else:
    print(f"\nSuccessfully populated all_task_experts with {len(all_task_experts)} expert(s): {list(all_task_experts.keys())}")

# Ensure base_resnet_frozen_global and domain_classifier_head are ready
if 'base_resnet_frozen_global' not in globals() or base_resnet_frozen_global is None:
    print("CRITICAL WARNING: base_resnet_frozen_global not found for inference pipeline.")
else:
    base_resnet_frozen_global.to(DEVICE).eval()

if 'domain_classifier_head' not in globals() or domain_classifier_head is None:
    print("CRITICAL WARNING: domain_classifier_head not found for inference pipeline.")
else:
    domain_classifier_head.to(DEVICE).eval()

print("Expert population and TTA definitions complete.")


In [None]:
# Cell 21: Robust Inference Pipeline Function Definition

def robust_inference_pipeline(image_pil, base_resnet, dc_head, task_experts_dict,
                              domain_map_idx_to_name_dc,
                              num_total_classes=NUM_CLASSES,
                              tta_lite_transforms_list=tta_lite_transforms,
                              tta_full_transforms_list=tta_full_transforms_manual,
                              domain_confidence_thresh=INFER_DOMAIN_CONF_THRESH,
                              expert_confidence_thresh=INFER_EXPERT_CONF_THRESH,
                              stage2_expert_confidence_thresh=INFER_STAGE2_EXPERT_CONF_THRESH,
                              k_experts_for_avg=INFER_K_EXPERTS_FOR_AVG,
                              device=DEVICE):
    if base_resnet is not None:
        base_resnet.to(device).eval()
    if dc_head is not None:
        dc_head.to(device).eval()

    if not base_resnet or not dc_head:
        print("Error: Base ResNet or Domain Classifier Head not provided.")
        return None, -1.0
    if not task_experts_dict:
        print("Error: task_experts_dict is empty.")
        return None, -1.0

    initial_transformed_image = val_test_transform_weak(image_pil).unsqueeze(0).to(device)

    with torch.no_grad(), autocast_ctx():
        base_features = get_resnet_features(base_resnet, initial_transformed_image)
        domain_logits = dc_head(base_features)
        domain_probs = F.softmax(domain_logits, dim=-1).squeeze(0)

    top_domain_prob, predicted_domain_idx_tensor = torch.max(domain_probs, dim=-1)
    predicted_domain_idx = predicted_domain_idx_tensor.item()
    predicted_domain_name = domain_map_idx_to_name_dc.get(predicted_domain_idx, f"UnknownDomain{predicted_domain_idx}")

    selected_domain_names_for_next_stage = []
    expert_weights_for_next_stage = torch.tensor([], device=device)

    if top_domain_prob >= domain_confidence_thresh and predicted_domain_name in task_experts_dict:
        expert_resnet = task_experts_dict[predicted_domain_name]['resnet']
        expert_head = task_experts_dict[predicted_domain_name]['head']
        with torch.no_grad(), autocast_ctx():
            expert_features = get_resnet_features(expert_resnet, initial_transformed_image)
            task_logits = expert_head(expert_features)
            task_probs = F.softmax(task_logits, dim=-1).squeeze(0)
        top_task_prob, final_label_idx_tensor = torch.max(task_probs, dim=-1)
        if top_task_prob >= expert_confidence_thresh:
            return final_label_idx_tensor.item(), top_task_prob.item()
        else:
            selected_domain_names_for_next_stage = [predicted_domain_name]
            expert_weights_for_next_stage = torch.tensor([1.0], device=device)
    else:
        num_available_experts = len([name for name in domain_map_idx_to_name_dc.values() if name in task_experts_dict])
        actual_k = min(k_experts_for_avg, num_available_experts)
        if actual_k == 0:
            return None, -1.0

        top_k_domain_probs, top_k_domain_indices = torch.topk(domain_probs, actual_k, dim=-1)
        selected_domain_names_raw = [domain_map_idx_to_name_dc.get(idx.item(), f"Err{idx.item()}") for idx in top_k_domain_indices]

        valid_indices = [i for i, name in enumerate(selected_domain_names_raw) if name in task_experts_dict]
        if not valid_indices:
            return None, -1.0

        selected_domain_names_for_next_stage = [selected_domain_names_raw[i] for i in valid_indices]
        expert_weights_for_next_stage = top_k_domain_probs[valid_indices]
        if expert_weights_for_next_stage.sum() > 1e-6:
            expert_weights_for_next_stage = expert_weights_for_next_stage / expert_weights_for_next_stage.sum()
        else:
            expert_weights_for_next_stage = torch.ones(len(selected_domain_names_for_next_stage), device=device) / max(1, len(selected_domain_names_for_next_stage))

    # Stage 2: TTA-Lite
    aggregated_task_probs_stage2 = torch.zeros(num_total_classes, device=device)
    num_tta_lite = len(tta_lite_transforms_list)
    for tta_transform in tta_lite_transforms_list:
        aug_image_tensor = tta_transform(image_pil).unsqueeze(0).to(device)
        current_aug_weighted_probs = torch.zeros(num_total_classes, device=device)
        with torch.no_grad(), autocast_ctx():
            for i, domain_name in enumerate(selected_domain_names_for_next_stage):
                weight = expert_weights_for_next_stage[i]
                expert_resnet = task_experts_dict[domain_name]['resnet']
                expert_head = task_experts_dict[domain_name]['head']
                exp_feat = get_resnet_features(expert_resnet, aug_image_tensor)
                task_logits_expert_aug = expert_head(exp_feat)
                current_aug_weighted_probs += weight * F.softmax(task_logits_expert_aug.squeeze(0), dim=-1)
        aggregated_task_probs_stage2 += current_aug_weighted_probs

    final_averaged_probs_stage2 = aggregated_task_probs_stage2 / max(1, num_tta_lite)
    top_task_prob_stage2, final_label_idx_stage2_tensor = torch.max(final_averaged_probs_stage2, dim=-1)

    if top_task_prob_stage2 >= stage2_expert_confidence_thresh:
        return final_label_idx_stage2_tensor.item(), top_task_prob_stage2.item()

    # Stage 3: TTA-Full
    aggregated_task_probs_stage3 = torch.zeros(num_total_classes, device=device)
    num_tta_full = len(tta_full_transforms_list)
    for tta_transform_full in tta_full_transforms_list:
        aug_image_tensor_full = tta_transform_full(image_pil).unsqueeze(0).to(device)
        current_aug_weighted_probs_full = torch.zeros(num_total_classes, device=device)
        with torch.no_grad(), autocast_ctx():
            for i, domain_name in enumerate(selected_domain_names_for_next_stage):
                weight = expert_weights_for_next_stage[i]
                expert_resnet = task_experts_dict[domain_name]['resnet']
                expert_head = task_experts_dict[domain_name]['head']
                exp_feat_full = get_resnet_features(expert_resnet, aug_image_tensor_full)
                task_logits_expert_full_aug = expert_head(exp_feat_full)
                current_aug_weighted_probs_full += weight * F.softmax(task_logits_expert_full_aug.squeeze(0), dim=-1)
        aggregated_task_probs_stage3 += current_aug_weighted_probs_full

    ultimate_final_probs = aggregated_task_probs_stage3 / max(1, num_tta_full)
    top_task_prob_stage3, ultimate_final_label_idx_tensor = torch.max(ultimate_final_probs, dim=-1)

    return ultimate_final_label_idx_tensor.item(), top_task_prob_stage3.item()


print("Robust inference pipeline function defined.")


In [None]:
# Cell 22: Evaluate Robust Inference Pipeline on Combined Validation Set

print("\n--- Evaluating Robust Inference Pipeline on Combined Validation Set ---")

prereq_missing = False
if 'domain_classifier_head' not in globals() or domain_classifier_head is None:
    print("ERROR: Domain Classifier Head not trained/loaded.")
    prereq_missing = True
if 'all_task_experts' not in globals() or not all_task_experts:
    print("ERROR: No task experts available.")
    prereq_missing = True
if 'base_resnet_frozen_global' not in globals() or base_resnet_frozen_global is None:
    print("ERROR: Base ResNet backbone not loaded.")
    prereq_missing = True
if 'GLOBAL_CLASS_TO_IDX' not in globals() or GLOBAL_CLASS_TO_IDX is None:
    print("ERROR: GLOBAL_CLASS_TO_IDX not defined.")
    prereq_missing = True
if 'domain_idx_to_name_dc' not in globals() or domain_idx_to_name_dc is None:
    print("ERROR: domain_idx_to_name_dc mapping not defined.")
    prereq_missing = True

if prereq_missing:
    print("Skipping robust inference pipeline evaluation due to missing prerequisites.")
else:
    class CombinedValDatasetForRobustEval(Dataset):
        def __init__(self, root_dir, all_domain_names_list, class_to_idx_map_overall, split_type='val'):
            self.images_pil = []
            self.true_class_labels = []
            self.original_domain_names = []

            for domain_name_iter in all_domain_names_list:
                try:
                    domain_val_dataset_temp = OfficeHomeDomainDataset(
                        root_dir=root_dir, domain_name=domain_name_iter,
                        transform=None,
                        split_type=split_type,
                        class_to_idx_mapping=class_to_idx_map_overall,
                        load_pil=True
                    )

                    for i in range(len(domain_val_dataset_temp)):
                        pil_img, class_lbl = domain_val_dataset_temp[i]
                        self.images_pil.append(pil_img)
                        self.true_class_labels.append(class_lbl)
                        self.original_domain_names.append(domain_name_iter)

                    print(f"  Added {len(domain_val_dataset_temp)} validation images from '{domain_name_iter}'.")
                except Exception as e:
                    print(f"  Warning: Error loading val data for '{domain_name_iter}': {e}. Skipping.")

            if not self.images_pil:
                raise RuntimeError("No validation images found.")

        def __len__(self):
            return len(self.images_pil)

        def __getitem__(self, idx):
            return (self.images_pil[idx],
                    torch.tensor(self.true_class_labels[idx]).long(),
                    self.original_domain_names[idx])

    try:
        combined_val_dataset_robust = CombinedValDatasetForRobustEval(
            root_dir=DATA_DIR,
            all_domain_names_list=ALL_TRAINABLE_DOMAIN_NAMES,
            class_to_idx_map_overall=GLOBAL_CLASS_TO_IDX,
            split_type='val'
        )
        print(f"Total combined validation images for robust eval: {len(combined_val_dataset_robust)}")
    except RuntimeError as e:
        print(f"Error creating combined validation dataset: {e}. Aborting.")
        combined_val_dataset_robust = None

    if combined_val_dataset_robust and len(combined_val_dataset_robust) > 0:
        all_predictions_robust = []
        all_true_labels_robust = []
        all_original_domains_robust = []

        for i in tqdm(range(len(combined_val_dataset_robust)), desc="Robust Pipeline Eval"):
            pil_image, true_class_label, original_domain = combined_val_dataset_robust[i]

            predicted_label_idx, confidence = robust_inference_pipeline(
                image_pil=pil_image,
                base_resnet=base_resnet_frozen_global,
                dc_head=domain_classifier_head,
                task_experts_dict=all_task_experts,
                domain_map_idx_to_name_dc=domain_idx_to_name_dc,
                device=DEVICE
            )
            all_predictions_robust.append(predicted_label_idx if predicted_label_idx is not None else -1)
            all_true_labels_robust.append(true_class_label.item())
            all_original_domains_robust.append(original_domain)

        print("\n--- Robust Inference Pipeline Evaluation Results ---")
        overall_correct_robust = 0
        overall_total_robust = 0
        domain_wise_stats_robust = {name: {'correct': 0, 'total': 0} for name in ALL_TRAINABLE_DOMAIN_NAMES}

        for i in range(len(all_predictions_robust)):
            pred_idx, true_idx, domain_name = all_predictions_robust[i], all_true_labels_robust[i], all_original_domains_robust[i]
            if pred_idx != -1:
                overall_total_robust += 1
                domain_wise_stats_robust[domain_name]['total'] += 1
                if pred_idx == true_idx:
                    overall_correct_robust += 1
                    domain_wise_stats_robust[domain_name]['correct'] += 1

        overall_accuracy_robust = (overall_correct_robust / overall_total_robust * 100) if overall_total_robust > 0 else 0.0
        print(f"Overall Accuracy (Robust Pipeline): {overall_accuracy_robust:.2f}% ({overall_correct_robust}/{overall_total_robust})")
        print("\nDomain-wise Accuracies (Robust Pipeline):")
        for domain_name, stats in domain_wise_stats_robust.items():
            acc = (stats['correct'] / stats['total'] * 100) if stats['total'] > 0 else 0.0
            print(f"  Domain '{domain_name}': {acc:.2f}% ({stats['correct']}/{stats['total']})")
