In [None]:
# Cell 1
!pip install timm scikit-learn tqdm matplotlib Pillow

In [None]:
# Cell 2
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import timm
import os
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
import copy

# --- Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATA_DIR = "/kaggle/input/officehome/OfficeHome"
VIT_MODEL_NAME = 'vit_base_patch16_224'
IMAGE_SIZE = 224
BATCH_SIZE = 32
LORA_RANK = 16
NUM_CLASSES = 65

print(f"Using device: {DEVICE}")
print(f"PyTorch version: {torch.__version__}")
if torch.cuda.is_available():
    print(f"CUDA available. GPU: {torch.cuda.get_device_name(0)}")
else:
    print("CUDA not available, running on CPU. This will be very slow for later phases.")

In [None]:
# Cell 3
# Define transformations
train_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomAffine(degrees=10, translate=(0.05, 0.05), scale=(0.95, 1.05)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_test_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# Cell 4
class OfficeHomeDomainDataset(Dataset):
    def __init__(self, root_dir, domain_name, transform=None, split_ratios=(0.8, 0.1, 0.1), split_type='train', random_seed=42):
        self.domain_path = os.path.join(root_dir, domain_name)
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_to_idx = {}
        self.idx_to_class = {}

        idx = 0
        for class_name in sorted(os.listdir(self.domain_path)):
            if class_name not in self.class_to_idx:
                self.class_to_idx[class_name] = idx
                self.idx_to_class[idx] = class_name
                idx += 1
            
            class_path = os.path.join(self.domain_path, class_name)
            if not os.path.isdir(class_path):
                continue
            
            domain_class_images = []
            for img_name in sorted(os.listdir(class_path)):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    domain_class_images.append(os.path.join(class_path, img_name))
            
            # Split data for the current class
            np.random.seed(random_seed) # for reproducible splits
            np.random.shuffle(domain_class_images)
            
            n_total = len(domain_class_images)
            n_train = int(n_total * split_ratios[0])
            n_val = int(n_total * split_ratios[1])
            # n_test is the rest

            if split_type == 'train':
                selected_images = domain_class_images[:n_train]
            elif split_type == 'val':
                selected_images = domain_class_images[n_train : n_train + n_val]
            elif split_type == 'test':
                selected_images = domain_class_images[n_train + n_val:]
            else:
                raise ValueError("split_type must be 'train', 'val', or 'test'")

            self.images.extend(selected_images)
            self.labels.extend([self.class_to_idx[class_name]] * len(selected_images))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [None]:
# Cell 5
try:
    art_train_dataset = OfficeHomeDomainDataset(DATA_DIR, 'Art', transform=train_transform, split_type='train')
    art_val_dataset = OfficeHomeDomainDataset(DATA_DIR, 'Art', transform=val_test_transform, split_type='val')

    if len(art_train_dataset) > 0 and len(art_val_dataset) > 0:
        art_train_loader = DataLoader(art_train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
        art_val_loader = DataLoader(art_val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

        print(f"Art train dataset size: {len(art_train_dataset)}")
        print(f"Art val dataset size: {len(art_val_dataset)}")
        print(f"Number of classes in Art domain: {len(art_train_dataset.class_to_idx)}")

        # Check a sample batch
        sample_images, sample_labels = next(iter(art_train_loader))
        print(f"Sample batch - images shape: {sample_images.shape}, labels shape: {sample_labels.shape}")
        assert sample_images.shape == (min(BATCH_SIZE, len(art_train_dataset)), 3, IMAGE_SIZE, IMAGE_SIZE)
        assert sample_labels.shape == (min(BATCH_SIZE, len(art_train_dataset)),)
        print("Dataset and DataLoader for 'Art' domain seem OK.")
    else:
        print("Warning: 'Art' dataset is empty. Check DATA_DIR and domain name.")
except FileNotFoundError:
    print(f"ERROR: Dataset not found at {DATA_DIR}. Please ensure OfficeHome is downloaded and extracted there.")
    print("You might need to create dummy folders if you want to proceed without data for now.")
except Exception as e:
    print(f"An error occurred during dataset loading: {e}")

In [None]:
# Cell 6
# Load ViT-Base/16
vit_backbone = timm.create_model(VIT_MODEL_NAME, pretrained=True, num_classes=0) # num_classes=0 removes the original head
vit_backbone = vit_backbone.to(DEVICE)

# Freeze all parameters of the base ViT model
for param in vit_backbone.parameters():
    param.requires_grad = False

print(f"Loaded ViT backbone: {VIT_MODEL_NAME}")
total_params_vit = sum(p.numel() for p in vit_backbone.parameters())
trainable_params_vit = sum(p.numel() for p in vit_backbone.parameters() if p.requires_grad)
print(f"Total ViT params: {total_params_vit:,}")
print(f"Trainable ViT params: {trainable_params_vit:,}") # Should be 0

In [None]:
# Cell 7
class LoRALinear(nn.Module):
    def __init__(self, linear_layer, rank):
        super().__init__()
        self.in_features = linear_layer.in_features
        self.out_features = linear_layer.out_features
        self.rank = rank

        # Original weight and bias (frozen)
        self.weight = nn.Parameter(linear_layer.weight.detach().clone(), requires_grad=False)
        if linear_layer.bias is not None:
            self.bias = nn.Parameter(linear_layer.bias.detach().clone(), requires_grad=False)
        else:
            self.register_parameter('bias', None)

        # LoRA matrices A and B
        self.lora_A = nn.Parameter(torch.zeros(self.rank, self.in_features)) # Shape: (rank, in_features)
        self.lora_B = nn.Parameter(torch.zeros(self.out_features, self.rank)) # Shape: (out_features, rank)
        
        # Initialization
        nn.init.kaiming_uniform_(self.lora_A, a=np.sqrt(5)) 
        nn.init.zeros_(self.lora_B) # B is initialized to zero
        
        self.scaling = 1.0 / self.rank # Alpha is implicitly 1 here, scaling is 1/rank

    def forward(self, x):
        out_original = nn.functional.linear(x, self.weight, self.bias)
        lora_adaptation = (x @ self.lora_A.t()) @ self.lora_B.t()
        
        return out_original + lora_adaptation * self.scaling


    def extra_repr(self):
        return f'in_features={self.in_features}, out_features={self.out_features}, rank={self.rank}'

# Helper function to inject LoRA
def inject_lora_to_vit_attention(vit_model, rank):
    injected_count = 0
    # First, freeze all parameters of the incoming vit_model
    for param in vit_model.parameters():
        param.requires_grad = False
        
    for block_idx, block in enumerate(vit_model.blocks):
        qkv_layer = block.attn.qkv
        if isinstance(qkv_layer, nn.Linear):

            block.attn.qkv = LoRALinear(qkv_layer, rank)
            injected_count += 1
            
    if injected_count == 0:
        print("WARNING: No QKV layers found or replaced with LoRA. Check ViT model structure.")
    else:
        print(f"Injected LoRA (rank={rank}) into {injected_count} QKV layers in ViT attention blocks.")
    return vit_model

In [None]:
# Cell 8
class DomainSpecificHead(nn.Module):
    def __init__(self, in_features=768, num_classes=NUM_CLASSES): # NUM_CLASSES should be defined
        super().__init__()
        self.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.fc(x)

In [None]:
# Cell 9
# --- Phase 1 Specific Configurations ---
ART_EPOCHS = 10 # As per plan
ART_LR = 1e-4
ART_EMBED_DIM = 768 # ViT-Base feature dimension

# For saving models
MODEL_SAVE_DIR = "/kaggle/working/"
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

# Initialize GradScaler for AMP
scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE.type == 'cuda'))

print("Phase 1: Configurations set.")
print(f"Art LoRA+Head training epochs: {ART_EPOCHS}, LR: {ART_LR}")

In [None]:
# Cell 10
# 1. Make a trainable copy of the backbone and inject Art-LoRA
# The main `vit_backbone` from Phase 0 is frozen. We need a copy to inject LoRA into for this domain.
art_vit_lora = copy.deepcopy(vit_backbone) # Deepcopy to not affect the original frozen backbone
art_vit_lora = inject_lora_to_vit_attention(art_vit_lora, rank=LORA_RANK) # inject_lora also freezes base model
art_vit_lora = art_vit_lora.to(DEVICE)

# 2. Instantiate Art-Head
art_head = DomainSpecificHead(in_features=ART_EMBED_DIM, num_classes=NUM_CLASSES).to(DEVICE)

# 3. Define Optimizer for Art-LoRA parameters and Art-Head parameters
# Collect only trainable parameters (LoRA A/B from art_vit_lora, and all from art_head)
params_to_train = []
for param in art_vit_lora.parameters():
    if param.requires_grad:
        params_to_train.append(param)
for param in art_head.parameters():
    if param.requires_grad: # Should be all head params
        params_to_train.append(param)

optimizer_art = optim.AdamW(params_to_train, lr=ART_LR)
criterion_art = nn.CrossEntropyLoss()

print(f"Number of trainable parameters for Art expert: {sum(p.numel() for p in params_to_train):,}")

In [None]:
# Cell 11
# --- Training Loop ---
print("\n--- Training Art Expert (Art-LoRA + Art-Head) ---")
for epoch in range(ART_EPOCHS):
    art_vit_lora.train()
    art_head.train()
    
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    
    progress_bar = tqdm(art_train_loader, desc=f"Epoch {epoch+1}/{ART_EPOCHS} [Art Train]")
    for images, labels in progress_bar:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        
        optimizer_art.zero_grad()
        
        with torch.cuda.amp.autocast(enabled=(DEVICE.type == 'cuda')):
            features = art_vit_lora(images)
            # Assuming features[:, 0] is the CLS token if features.ndim == 3
            cls_features = features[:, 0] if features.ndim == 3 else features
            logits = art_head(cls_features)
            loss = criterion_art(logits, labels)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer_art)
        scaler.update()
        
        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(logits, 1)
        correct_predictions += torch.sum(preds == labels.data)
        total_samples += images.size(0)
        
        progress_bar.set_postfix(loss=loss.item(), acc=correct_predictions.double().item()/total_samples if total_samples > 0 else 0.0)

    epoch_loss = running_loss / total_samples
    epoch_acc = correct_predictions.double() / total_samples
    print(f"Epoch {epoch+1} Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}")

    # Validation
    art_vit_lora.eval()
    art_head.eval()
    val_loss = 0.0
    val_correct_predictions = 0
    val_total_samples = 0
    with torch.no_grad():
        for images, labels in tqdm(art_val_loader, desc=f"Epoch {epoch+1}/{ART_EPOCHS} [Art Val]"):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            with torch.cuda.amp.autocast(enabled=(DEVICE.type == 'cuda')):
                features = art_vit_lora(images)
                cls_features = features[:, 0] if features.ndim == 3 else features
                logits = art_head(cls_features)
                loss = criterion_art(logits, labels)
            
            val_loss += loss.item() * images.size(0)
            _, preds = torch.max(logits, 1)
            val_correct_predictions += torch.sum(preds == labels.data)
            val_total_samples += images.size(0)
            
    epoch_val_loss = val_loss / val_total_samples
    epoch_val_acc = val_correct_predictions.double() / val_total_samples
    print(f"Epoch {epoch+1} Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")

In [None]:
# Cell 12
# Save the trained Art-LoRA (only LoRA params) and Art-Head
art_lora_state_dict = {name: param for name, param in art_vit_lora.named_parameters() if 'lora_' in name and param.requires_grad}
torch.save(art_lora_state_dict, os.path.join(MODEL_SAVE_DIR, "art_lora.pth"))
torch.save(art_head.state_dict(), os.path.join(MODEL_SAVE_DIR, "art_head.pth"))
print("Art-LoRA and Art-Head models saved.")

# --- Sanity Check: Art Expert Training ---
assert epoch_val_acc > 0.1, "Validation accuracy for Art is too low. Training might have failed." # Basic check
print("Art expert training sanity check passed (accuracy > 0.1).")

In [None]:
# Cell 13

print("\n--- Baseline Validation of Art Expert on All Domains ---")

# --- 1. Load the Trained Art Expert ---

# Make sure models are in evaluation mode
art_vit_lora.eval()
art_head.eval()

# --- 2. Define Domain Names and Prepare DataLoaders ---
domain_names_all = ['Art', 'Clipart', 'Product', 'RealWorld']
baseline_accuracies = {}

# For amp autocasting, ensure autocast_ctx is defined, or define it:
if 'autocast_ctx' not in globals():
    autocast_ctx = lambda: torch.amp.autocast(device_type=DEVICE.type, enabled=(DEVICE.type == 'cuda'))

# Criterion for validation (should match what was used in training)
if 'criterion_art' not in globals(): # If not available from previous cell
    criterion_val = nn.CrossEntropyLoss()
else:
    criterion_val = criterion_art


# --- 3. Iterate Through Domains and Evaluate ---
for domain_name in domain_names_all:
    print(f"\nValidating Art Expert on {domain_name} domain...")

    # Load validation dataset for the current domain
    # Ensure val_test_transform is defined (from Phase 0)
    try:
        val_dataset_current_domain = OfficeHomeDomainDataset(
            DATA_DIR,
            domain_name,
            transform=val_test_transform, # Use the standard validation/test transform
            split_type='val' # Use the 10% validation split
        )
        if len(val_dataset_current_domain) == 0:
            print(f"Warning: Validation dataset for {domain_name} is empty. Skipping.")
            baseline_accuracies[domain_name] = 0.0
            continue

        val_loader_current_domain = DataLoader(
            val_dataset_current_domain,
            batch_size=BATCH_SIZE, # Use the global BATCH_SIZE
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )
    except FileNotFoundError:
        print(f"ERROR: Dataset for domain '{domain_name}' not found at {DATA_DIR}. Skipping.")
        baseline_accuracies[domain_name] = -1.0 # Indicate error
        continue
    except Exception as e:
        print(f"An error occurred loading dataset for {domain_name}: {e}. Skipping.")
        baseline_accuracies[domain_name] = -1.0 # Indicate error
        continue


    val_loss_domain = 0.0
    val_correct_predictions_domain = 0
    val_total_samples_domain = 0

    with torch.no_grad():
        for images, labels in tqdm(val_loader_current_domain, desc=f"Val on {domain_name}", leave=False):
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            with autocast_ctx():
                # Pass images through the Art-LoRA ViT
                features = art_vit_lora(images)
                cls_features = features[:, 0] if features.ndim == 3 else features
                # Then through the Art-Head
                logits = art_head(cls_features)
                loss = criterion_val(logits, labels)

            val_loss_domain += loss.item() * images.size(0)
            _, preds = torch.max(logits, 1)
            val_correct_predictions_domain += torch.sum(preds == labels.data)
            val_total_samples_domain += images.size(0)

    if val_total_samples_domain > 0:
        epoch_val_loss_domain = val_loss_domain / val_total_samples_domain
        epoch_val_acc_domain = val_correct_predictions_domain.double() / val_total_samples_domain
        baseline_accuracies[domain_name] = epoch_val_acc_domain.item()
        print(f"{domain_name} Val Loss: {epoch_val_loss_domain:.4f}, Val Acc: {epoch_val_acc_domain:.4f}")
    else:
        print(f"No samples processed for {domain_name} validation.")
        baseline_accuracies[domain_name] = 0.0


# --- 4. Print Summary of Baseline Accuracies ---
print("\n--- Baseline Art Expert Performance Summary ---")
for domain, acc in baseline_accuracies.items():
    if acc != -1.0: # Check for loading errors
        print(f"Accuracy on {domain}: {acc:.4f}")
    else:
        print(f"Accuracy on {domain}: ERROR (Dataset not found or loading issue)")


# Adapt to new domain

In [None]:
# Cell 14
# --- Phase 2 Specific Configurations for 'CLIPART' ---
TARGET_DOMAIN_NAME = 'Clipart'
PREV_DOMAIN_NAMES = ['Art'] # Keep track of all previous domains

ADAPT_EPOCHS = 10
ADAPT_LR_LORA = 1e-4
ADAPT_LR_HEAD = 1e-4
ADAPT_LR_LN = 1e-5 # Learning rate for LayerNorm affine parameters

EMA_DECAY = 0.999
PSEUDO_LABEL_START_THRESHOLD = 0.6
PSEUDO_LABEL_END_THRESHOLD = 0.8 # Reaches this at the last epoch

# Make a subdirectory for CLIPART models within the main model save dir
CLIPART_MODEL_SAVE_DIR = os.path.join(MODEL_SAVE_DIR, TARGET_DOMAIN_NAME)
os.makedirs(CLIPART_MODEL_SAVE_DIR, exist_ok=True)

# For AMP
# scaler should be globally defined from Phase 0/1, if not:
if 'scaler' not in globals():
    # scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE.type == 'cuda')) # Old
    scaler = torch.amp.GradScaler(device_type=DEVICE.type, enabled=(DEVICE.type == 'cuda'))


In [None]:
# Cell 15
# --- 1. Load Frozen ViT Backbone ---
# vit_backbone should be available from Phase 0 (the completely frozen one)
base_vit_frozen = copy.deepcopy(vit_backbone)
base_vit_frozen.eval() # Ensure it's in eval mode

# --- Helper function to make LayerNorm affine parameters trainable ---
# (Keep the function definition as it was)
def set_layernorm_affine_trainable(model):
    for name, mod in model.named_modules():
        if isinstance(mod, nn.LayerNorm):
            if hasattr(mod, 'weight') and mod.weight is not None:
                mod.weight.requires_grad = True
            if hasattr(mod, 'bias') and mod.bias is not None:
                mod.bias.requires_grad = True
    print("Made LayerNorm affine parameters trainable.") # Added print statement
    return model

In [None]:
# Cell 16
# --- 2. Load Art Expert (Frozen) ---
# # THIS CELL IS NOW NEEDED TO GET INITIAL WEIGHTS FOR THE STUDENT/TEACHER

art_lora_state_dict_path = os.path.join(MODEL_SAVE_DIR, "art_lora.pth")
art_head_state_dict_path = os.path.join(MODEL_SAVE_DIR, "art_head.pth")

# Load the Art LoRA parameters
try:
    # Use weights_only=True for safety if the source is trusted, otherwise handle potential risks
    art_lora_trained_weights = torch.load(art_lora_state_dict_path, map_location=DEVICE, weights_only=True)
    print(f"Loaded Art LoRA weights dictionary with keys: {list(art_lora_trained_weights.keys())[:5]}...") # Print first few keys
except FileNotFoundError:
    print(f"ERROR: Art LoRA weights not found at {art_lora_state_dict_path}. Cannot initialize.")
    art_lora_trained_weights = None
except Exception as e:
    print(f"Error loading Art LoRA weights: {e}")
    art_lora_trained_weights = None


# Load the Art Head parameters
try:
    # Use weights_only=True for safety
    art_head_trained_weights = torch.load(art_head_state_dict_path, map_location=DEVICE, weights_only=True)
    print(f"Loaded Art Head weights dictionary.")
except FileNotFoundError:
    print(f"ERROR: Art Head weights not found at {art_head_state_dict_path}. Cannot initialize.")
    art_head_trained_weights = None
except Exception as e:
    print(f"Error loading Art Head weights: {e}")
    art_head_trained_weights = None

# We don't need to create separate expert models here, just hold the state dicts.

In [None]:
# Cell 17
# --- 3. Initialize CLIPART Student Model (Trainable LoRA and LayerNorm affines) ---
# This is the model whose parameters we will optimize.
student_vit_CLIPART = copy.deepcopy(base_vit_frozen)
student_vit_CLIPART = inject_lora_to_vit_attention(student_vit_CLIPART, rank=LORA_RANK) # Injects *new* CLIPART LoRA
# Don't make LN trainable yet, happens after loading weights if needed.
student_vit_CLIPART = student_vit_CLIPART.to(DEVICE)

student_head_CLIPART = DomainSpecificHead(in_features=ART_EMBED_DIM, num_classes=NUM_CLASSES).to(DEVICE)
# student_head_CLIPART parameters are trainable by default

print(f"Initialized STUDENT model structure for {TARGET_DOMAIN_NAME}.")

# --- Load Art weights into STUDENT ---
weights_loaded = False
if art_lora_trained_weights is not None and art_head_trained_weights is not None:
    try:
        # Load LoRA weights into the student ViT. `strict=False` is important
        # as the student ViT has base weights too, which are not in the LoRA dict.
        missing_keys, unexpected_keys = student_vit_CLIPART.load_state_dict(art_lora_trained_weights, strict=False)
        print(f"Loaded Art LoRA into student ViT. Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}")
        if not any('lora' in k for k in student_vit_CLIPART.state_dict() if k not in missing_keys):
             print("Warning: It seems no LoRA weights were actually loaded.")

        # Load head weights
        student_head_CLIPART.load_state_dict(art_head_trained_weights, strict=True)
        print("Loaded Art Head weights into student Head.")
        weights_loaded = True
    except Exception as e:
        print(f"ERROR loading Art weights into student: {e}")
else:
    print("Skipping weight initialization from Art due to previous loading errors.")

# --- Ensure required parameters are trainable AFTER loading ---
student_vit_CLIPART = set_layernorm_affine_trainable(student_vit_CLIPART) # Make LN affine trainable *now*
# Double-check LoRA parameters are trainable (should be by default from inject_lora)
lora_trainable_count = 0
for name, param in student_vit_CLIPART.named_parameters():
    if 'lora_' in name:
        param.requires_grad = True
        lora_trainable_count += 1
print(f"Ensured {lora_trainable_count} LoRA parameters in student ViT are trainable.")

if weights_loaded:
    print(f"Initialized STUDENT model weights for {TARGET_DOMAIN_NAME} from Art expert.")
else:
    print(f"Proceeding with STUDENT model for {TARGET_DOMAIN_NAME} without Art initialization.")


# --- Report Trainable Params ---
trainable_params_student_vit = sum(p.numel() for p in student_vit_CLIPART.parameters() if p.requires_grad)
trainable_params_student_head = sum(p.numel() for p in student_head_CLIPART.parameters() if p.requires_grad)
print(f"\nTrainable params in student_vit_{TARGET_DOMAIN_NAME}: {trainable_params_student_vit:,}")
print(f"Trainable params in student_head_{TARGET_DOMAIN_NAME}: {trainable_params_student_head:,}")
print(f"Total trainable STUDENT params: {trainable_params_student_vit + trainable_params_student_head:,}")

In [None]:
# Cell 18
# --- 4. Initialize EMA Teacher Model for CLIPART ---
# This model provides pseudo-labels and is updated via EMA.
# It STARTS as a copy of the (potentially Art-initialized) student.
teacher_vit_CLIPART = copy.deepcopy(student_vit_CLIPART)
teacher_head_CLIPART = copy.deepcopy(student_head_CLIPART)

# Freeze teacher parameters
for param in teacher_vit_CLIPART.parameters(): param.requires_grad = False
for param in teacher_head_CLIPART.parameters(): param.requires_grad = False

teacher_vit_CLIPART = teacher_vit_CLIPART.to(DEVICE).eval() # Set to eval mode
teacher_head_CLIPART = teacher_head_CLIPART.to(DEVICE).eval()
print(f"Initialized EMA TEACHER for {TARGET_DOMAIN_NAME} (from student's initial state).")

# --- EMA Update Function (Keep as is) ---
def update_ema_teacher(student_vit, student_head, teacher_vit, teacher_head, decay):
    # ... (function definition remains the same) ...
    with torch.no_grad():
        # Update ViT (LoRA and LayerNorm affine parameters)
        student_params = {name: param for name, param in student_vit.named_parameters() if param.requires_grad}
        teacher_params = dict(teacher_vit.named_parameters())

        for name, stud_param in student_params.items():
            if name in teacher_params:
                teach_param = teacher_params[name]
                # EMA update: teacher = decay * teacher + (1 - decay) * student
                teach_param.data.mul_(decay).add_(stud_param.data, alpha=1 - decay)
            # else:
                # print(f"Warning: Param {name} from student_vit not found in teacher_vit for EMA update.")

        # Update Head (all parameters)
        for stud_param, teach_param in zip(student_head.parameters(), teacher_head.parameters()):
            teach_param.data.mul_(decay).add_(stud_param.data, alpha=1 - decay)


# Perform an initial update (makes teacher parameters exactly match student initially)
update_ema_teacher(student_vit_CLIPART, student_head_CLIPART, teacher_vit_CLIPART, teacher_head_CLIPART, 0.0) # Decay=0 means teacher = student
print("Initial EMA update complete (Teacher = Student).")

In [None]:
# Cell 19
# --- 5. Initialize EMA Teacher Model for CLIPART ---
teacher_vit_CLIPART = copy.deepcopy(student_vit_CLIPART) # Has CLIPART LoRA and LN structure
teacher_head_CLIPART = copy.deepcopy(student_head_CLIPART)

for param in teacher_vit_CLIPART.parameters(): param.requires_grad = False
for param in teacher_head_CLIPART.parameters(): param.requires_grad = False
teacher_vit_CLIPART.eval()
teacher_head_CLIPART.eval()
print("Initialized EMA teacher for CLIPART.")

# --- EMA Update Function (Corrected) ---
def update_ema_teacher(student_vit, student_head, teacher_vit, teacher_head, decay):
    with torch.no_grad():
        # Update ViT (LoRA and LayerNorm affine parameters)
        student_trainable_vit_params = {name: param for name, param in student_vit.named_parameters() if param.requires_grad}
        teacher_vit_params_dict = dict(teacher_vit.named_parameters())

        for name, stud_param in student_trainable_vit_params.items():
            if name in teacher_vit_params_dict:
                teach_param = teacher_vit_params_dict[name]
                teach_param.data.mul_(decay).add_(stud_param.data, alpha=1 - decay)
            else:
                print(f"Warning: Param {name} from student_vit not found in teacher_vit for EMA update.")

        # Update Head
        for stud_param, teach_param in zip(student_head.parameters(), teacher_head.parameters()):
            teach_param.data.mul_(decay).add_(stud_param.data, alpha=1 - decay)

# Initial EMA update (student and teacher are identical at this point, but good practice)
update_ema_teacher(student_vit_CLIPART, student_head_CLIPART, teacher_vit_CLIPART, teacher_head_CLIPART, EMA_DECAY)
print("Initial EMA update complete.")

In [None]:
# Cell 20
# Load CLIPART dataset (target domain, unlabeled for pseudo-labeling)
CLIPART_train_dataset_unlabeled = OfficeHomeDomainDataset(DATA_DIR, TARGET_DOMAIN_NAME, transform=train_transform, split_type='train')
CLIPART_val_dataset = OfficeHomeDomainDataset(DATA_DIR, TARGET_DOMAIN_NAME, transform=val_test_transform, split_type='val')

# Loader for generating pseudo-labels from the *entire* training set each epoch
CLIPART_train_loader_unlabeled_images = DataLoader(CLIPART_train_dataset_unlabeled, batch_size=BATCH_SIZE * 2, shuffle=False, num_workers=2, pin_memory=True) # Use larger batch for generation, no shuffle needed
# Loader for validation
CLIPART_val_loader = DataLoader(CLIPART_val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f"{TARGET_DOMAIN_NAME} (target) unlabeled train dataset size: {len(CLIPART_train_dataset_unlabeled)}")
print(f"{TARGET_DOMAIN_NAME} (target) val dataset size: {len(CLIPART_val_dataset)}")

In [None]:
# Cell 21
# --- Optimizer for CLIPART Adaptation (Student Only) ---
# Group parameters for different LRs (LoRA, LayerNorm, Head)
params_to_train_student = []
# Group 1: LoRA parameters
params_to_train_student.append({
    'params': [p for name, p in student_vit_CLIPART.named_parameters() if 'lora_' in name and p.requires_grad],
    'lr': ADAPT_LR_LORA
})
# Group 2: LayerNorm affine parameters
params_to_train_student.append({
    'params': [p for name, p in student_vit_CLIPART.named_parameters() if ('norm' in name.lower() or 'layernorm' in name.lower()) and p.requires_grad],
    'lr': ADAPT_LR_LN
})
# Group 3: Head parameters
params_to_train_student.append({
    'params': student_head_CLIPART.parameters(), # Already requires_grad=True
    'lr': ADAPT_LR_HEAD
})

optimizer_CLIPART_student = optim.AdamW(params_to_train_student) # Only optimizes student

# Criterion for pseudo-labeled target data (Using hard labels here)
# Use KLDivLoss if you collect soft probabilities from teacher
criterion_pseudo = nn.CrossEntropyLoss(reduction='mean')
# Criterion for validation (standard classification)
criterion_val = nn.CrossEntropyLoss()

print(f"Optimizer groups for STUDENT ({TARGET_DOMAIN_NAME}):")
total_params_optimized = 0
for i, group in enumerate(optimizer_CLIPART_student.param_groups):
    group_params = sum(p.numel() for p in group['params'])
    print(f"  Group {i}: LR={group['lr']}, Params={group_params:,}")
    total_params_optimized += group_params
print(f"Total trainable parameters optimized for {TARGET_DOMAIN_NAME} STUDENT: {total_params_optimized:,}")

print(f"\n--- Adapting to {TARGET_DOMAIN_NAME} (Optimizing STUDENT: LoRA, LN, Head) ---")

In [None]:
# Cell 22
# For amp autocasting
autocast_ctx = lambda: torch.amp.autocast(device_type=DEVICE.type, enabled=(DEVICE.type == 'cuda'))

# For early stopping
best_val_acc = 0.0
patience_counter = 0
EARLY_STOPPING_PATIENCE = 5 # Example patience value

print(f"\n--- Training Adaptation Loop for {TARGET_DOMAIN_NAME} ---")
print(f"Using Mean Teacher (EMA Decay: {EMA_DECAY}) for pseudo-labeling.")

for epoch in range(ADAPT_EPOCHS):
    # Calculate current pseudo-label threshold (linear schedule)
    if ADAPT_EPOCHS == 1:
        current_pseudo_threshold = PSEUDO_LABEL_END_THRESHOLD
    else:
        current_pseudo_threshold = PSEUDO_LABEL_START_THRESHOLD + \
                                   (PSEUDO_LABEL_END_THRESHOLD - PSEUDO_LABEL_START_THRESHOLD) * (epoch / (ADAPT_EPOCHS - 1))
    print(f"\nEpoch {epoch+1}/{ADAPT_EPOCHS}, Pseudo-label threshold: {current_pseudo_threshold:.3f}")

    # --- 1. Pseudo-Label Generation using EMA TEACHER ---
    pseudo_labeled_images_collected = []
    pseudo_labels_collected = [] # Store hard labels (indices)

    teacher_vit_CLIPART.eval() # Ensure teacher is in eval mode
    teacher_head_CLIPART.eval()

    print(f"Collecting pseudo-labels for {TARGET_DOMAIN_NAME} using EMA TEACHER...")
    with torch.no_grad():
        # Iterate over the *entire* unlabeled target training set
        for images, _ in tqdm(CLIPART_train_loader_unlabeled_images, desc=f"Pseudo-Labeling {TARGET_DOMAIN_NAME}", leave=False):
            images = images.to(DEVICE)
            with autocast_ctx():
                # Get features and logits from the TEACHER
                features_teacher = teacher_vit_CLIPART(images)
                cls_features_teacher = features_teacher[:, 0] if features_teacher.ndim == 3 and features_teacher.shape[1] > 0 else features_teacher
                logits_teacher = teacher_head_CLIPART(cls_features_teacher)

            probs_teacher = torch.softmax(logits_teacher, dim=1)
            max_probs, pred_labels = torch.max(probs_teacher, dim=1)

            # Filter based on threshold
            mask = max_probs >= current_pseudo_threshold
            if mask.any():
                pseudo_labeled_images_collected.append(images[mask].cpu())
                pseudo_labels_collected.append(pred_labels[mask].cpu()) # Store hard labels

    if not pseudo_labeled_images_collected:
        print("Warning: No pseudo-labels collected in this epoch with current threshold. Skipping training steps.")
        # Optionally: Reduce threshold slightly for next epoch, or just continue
        continue # Go to validation/next epoch

    pseudo_labeled_images_cat = torch.cat(pseudo_labeled_images_collected, dim=0)
    pseudo_labels_cat = torch.cat(pseudo_labels_collected, dim=0)
    print(f"Collected {len(pseudo_labeled_images_cat)} pseudo-labeled samples for {TARGET_DOMAIN_NAME} using EMA Teacher.")

    # Create DataLoader for the pseudo-labeled subset
    pseudo_batch_size_actual = BATCH_SIZE
    pseudo_dataset = torch.utils.data.TensorDataset(pseudo_labeled_images_cat, pseudo_labels_cat)
    # num_workers=0 is often recommended for TensorDataset
    pseudo_loader = DataLoader(pseudo_dataset, batch_size=pseudo_batch_size_actual, shuffle=True, num_workers=0, pin_memory=True)

    # --- 2. STUDENT Training on Pseudo-Labels ---
    student_vit_CLIPART.train() # Set student model to train mode
    student_head_CLIPART.train()

    running_loss_pseudo_epoch = 0.0
    total_pseudo_samples_processed = 0

    progress_bar = tqdm(pseudo_loader, desc=f"Epoch {epoch+1} [Adapt Train {TARGET_DOMAIN_NAME}]", leave=False)
    for pseudo_batch_images, pseudo_batch_labels in progress_bar:
        pseudo_batch_images = pseudo_batch_images.to(DEVICE)
        pseudo_batch_labels = pseudo_batch_labels.to(DEVICE) # Hard labels from teacher
        current_pseudo_batch_size = pseudo_batch_images.size(0)

        with autocast_ctx():
            # Features from STUDENT ViT
            features_student = student_vit_CLIPART(pseudo_batch_images)
            cls_features_student = features_student[:, 0] if features_student.ndim == 3 and features_student.shape[1] > 0 else features_student
            # Logits from STUDENT Head
            logits_student = student_head_CLIPART(cls_features_student)

            # --- Loss on Pseudo-Labeled Data (Using CrossEntropy with hard labels) ---
            loss_pseudo = criterion_pseudo(logits_student, pseudo_batch_labels)

            # --- (Optional: Add Entropy Minimization here if needed) ---
            # loss_ent = ... calculated on student outputs for unlabeled data

        # --- Backpropagation (Optimizing STUDENT) ---
        total_loss_step = loss_pseudo # + ENT_WEIGHT * loss_ent
        if torch.isnan(total_loss_step) or torch.isinf(total_loss_step):
             print(f"Warning: NaN or Inf loss detected (pseudo loss: {loss_pseudo.item()}). Skipping step.")
             continue

        optimizer_CLIPART_student.zero_grad()
        scaler.scale(total_loss_step).backward()
        # Optional: Gradient clipping
        # scaler.unscale_(optimizer_CLIPART_student)
        # torch.nn.utils.clip_grad_norm_(params_to_train_student_flat, max_norm=1.0) # Need to flatten param groups first
        scaler.step(optimizer_CLIPART_student)
        scaler.update()

        # --- EMA Update of TEACHER (based on updated STUDENT) ---
        update_ema_teacher(student_vit_CLIPART, student_head_CLIPART, teacher_vit_CLIPART, teacher_head_CLIPART, EMA_DECAY)

        running_loss_pseudo_epoch += loss_pseudo.item() * current_pseudo_batch_size
        total_pseudo_samples_processed += current_pseudo_batch_size

        progress_bar.set_postfix(L_pseudo=loss_pseudo.item())

    if total_pseudo_samples_processed > 0:
        avg_loss_pseudo = running_loss_pseudo_epoch / total_pseudo_samples_processed
        print(f"Epoch {epoch+1} Adapt Train Avg Pseudo Loss: {avg_loss_pseudo:.4f}")
    else:
        # This case handled by the check after pseudo-label generation
        pass

    # --- 3. Validation on Target Domain (using STUDENT model) ---
    student_vit_CLIPART.eval()
    student_head_CLIPART.eval()
    val_loss = 0.0
    val_correct_predictions = 0
    val_total_samples = 0
    with torch.no_grad():
        for images, labels in tqdm(CLIPART_val_loader, desc=f"Epoch {epoch+1} [{TARGET_DOMAIN_NAME} Val]", leave=False):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            with autocast_ctx():
                # Use STUDENT model for validation
                features = student_vit_CLIPART(images)
                cls_features = features[:, 0] if features.ndim == 3 and features.shape[1] > 0 else features
                logits = student_head_CLIPART(cls_features)
                loss = criterion_val(logits, labels) # Use standard CE loss for validation

            val_loss += loss.item() * images.size(0)
            _, preds = torch.max(logits, 1)
            val_correct_predictions += torch.sum(preds == labels.data)
            val_total_samples += images.size(0)

    if val_total_samples > 0:
        epoch_val_loss = val_loss / val_total_samples
        epoch_val_acc = val_correct_predictions.double() / val_total_samples
        print(f"Epoch {epoch+1} {TARGET_DOMAIN_NAME} Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")

        # --- Early Stopping & Saving Best STUDENT Model ---
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            patience_counter = 0
            # Save the state dicts of the best student model
            CLIPART_best_lora_ln_state_dict = {name: param for name, param in student_vit_CLIPART.named_parameters() if param.requires_grad}
            torch.save(CLIPART_best_lora_ln_state_dict, os.path.join(CLIPART_MODEL_SAVE_DIR, f"{TARGET_DOMAIN_NAME.lower()}_best_lora_ln_student.pth"))
            torch.save(student_head_CLIPART.state_dict(), os.path.join(CLIPART_MODEL_SAVE_DIR, f"{TARGET_DOMAIN_NAME.lower()}_best_head_student.pth"))
            print(f"    -> New best validation accuracy: {best_val_acc:.4f}. Saved best student model.")
        else:
            patience_counter += 1
            print(f"    -> Validation accuracy did not improve. Patience: {patience_counter}/{EARLY_STOPPING_PATIENCE}")
            if patience_counter >= EARLY_STOPPING_PATIENCE:
                print(f"--- Early stopping triggered after epoch {epoch+1} ---")
                break # Exit training loop
    else:
        print(f"Epoch {epoch+1} {TARGET_DOMAIN_NAME} Val: No validation samples processed.")

# --- End of Training ---
print(f"\n--- Adaptation training for {TARGET_DOMAIN_NAME} finished ---")
# Load the best saved student model weights for final use/evaluation
print("Loading best student model weights...")
best_student_lora_ln_weights = torch.load(os.path.join(CLIPART_MODEL_SAVE_DIR, f"{TARGET_DOMAIN_NAME.lower()}_best_lora_ln_student.pth"), map_location=DEVICE)
best_student_head_weights = torch.load(os.path.join(CLIPART_MODEL_SAVE_DIR, f"{TARGET_DOMAIN_NAME.lower()}_best_head_student.pth"), map_location=DEVICE)

# Need to re-initialize a student model and load weights into it
final_student_vit = copy.deepcopy(base_vit_frozen)
final_student_vit = inject_lora_to_vit_attention(final_student_vit, rank=LORA_RANK)
final_student_vit = set_layernorm_affine_trainable(final_student_vit)
# Load the trainable parameters only
final_student_vit.load_state_dict(best_student_lora_ln_weights, strict=False)
final_student_vit = final_student_vit.to(DEVICE).eval()

final_student_head = DomainSpecificHead(in_features=ART_EMBED_DIM, num_classes=NUM_CLASSES).to(DEVICE)
final_student_head.load_state_dict(best_student_head_weights)
final_student_head = final_student_head.eval()

print("Best student model loaded.")

# Now `final_student_vit` and `final_student_head` hold the best adapted model for Clipart
# You would use these for final evaluation or as the starting point for the *next* domain adaptation.

In [None]:
# cell 23
# --- Saving the BEST STUDENT Model (Loaded after early stopping/loop completion) ---
# The final_student_vit and final_student_head hold the best weights now.

# Note: The saving already happened inside the loop when the best val acc was found.
# This cell now mainly serves as a confirmation and sanity check.

print(f"Best {TARGET_DOMAIN_NAME} STUDENT model weights were saved during training to {CLIPART_MODEL_SAVE_DIR}")
print(f"  ViT LoRA/LN file: {TARGET_DOMAIN_NAME.lower()}_best_lora_ln_student.pth")
print(f"  Head file: {TARGET_DOMAIN_NAME.lower()}_best_head_student.pth")

# --- Sanity Check: CLIPART Adaptation (using the final loaded best student) ---
# Re-run validation on the loaded best model to confirm accuracy
final_val_loss = 0.0
final_val_correct = 0
final_val_total = 0
final_student_vit.eval()
final_student_head.eval()
with torch.no_grad():
    for images, labels in tqdm(CLIPART_val_loader, desc=f"Final Validation [{TARGET_DOMAIN_NAME}]", leave=False):
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        with autocast_ctx():
            features = final_student_vit(images)
            cls_features = features[:, 0] if features.ndim == 3 and features.shape[1] > 0 else features
            logits = final_student_head(cls_features)
            loss = criterion_val(logits, labels)

        final_val_loss += loss.item() * images.size(0)
        _, preds = torch.max(logits, 1)
        final_val_correct += torch.sum(preds == labels.data)
        final_val_total += images.size(0)

if final_val_total > 0:
    final_epoch_val_loss = final_val_loss / final_val_total
    final_epoch_val_acc = final_val_correct.double() / final_val_total
    print(f"Confirmed Best {TARGET_DOMAIN_NAME} Student Val Loss: {final_epoch_val_loss:.4f}, Val Acc: {final_epoch_val_acc:.4f}")
    # Sanity check based on the best accuracy achieved during training
    assert final_epoch_val_acc > 0.02, f"{TARGET_DOMAIN_NAME} validation accuracy ({final_epoch_val_acc:.4f}) is too low. Adaptation might have failed."
    print(f"{TARGET_DOMAIN_NAME} adaptation sanity check passed (best accuracy > 0.02).")
else:
    print(f"Warning: {TARGET_DOMAIN_NAME} validation set was empty or not processed, cannot perform final sanity check.")

# IMPORTANT: For the *next* adaptation step (e.g., Clipart -> Product),
# you would initialize the Product student/teacher based on the *final_student_vit*
# and *final_student_head* from this Clipart adaptation phase.