In [None]:
def train_vit_classifier(
    model,
    train_dataset,
    val_dataset,
    device,
    num_epochs=10,
    batch_size=32,
    lr=1e-4
):
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(1, num_epochs + 1):
        # --- Training ---
        model.train()
        train_loss = 0.0
        correct_train = 0
        total_train = 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch} [Train]"):
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            logits = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            preds = logits.argmax(dim=1)
            correct_train += (preds == labels).sum().item()
            total_train += labels.size(0)

        train_loss /= total_train
        train_acc = correct_train / total_train

        # --- Validation ---
        model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0

        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc=f"Epoch {epoch} [Val]"):
                images = images.to(device)
                labels = labels.to(device)

                logits = model(images)
                loss = criterion(logits, labels)

                val_loss += loss.item() * images.size(0)
                preds = logits.argmax(dim=1)
                correct_val += (preds == labels).sum().item()
                total_val += labels.size(0)

        val_loss /= total_val
        val_acc = correct_val / total_val

        print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.4f}, "
              f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}")



class ViTClassifier(nn.Module):
    def __init__(self, vit_model, num_classes, use_patch=False, svd=False, svd_dim=256):
        super(ViTClassifier, self).__init__()
        self.vit = vit_model
        self.use_patch = use_patch
        self.svd = svd
        self.svd_dim = svd_dim
        base_feature_dim = vit_model.embeddings.patch_embedding.out_channels
        if use_patch:
            num_patches = vit_model.embeddings.position_embedding.num_embeddings - 1
            self.feature_dim = num_patches * base_feature_dim
        else:
            self.feature_dim = base_feature_dim
        classifier_input_dim = svd_dim if svd else self.feature_dim
        if svd:
            self.svd_projection = None
        self.classifier = nn.Linear(classifier_input_dim, num_classes)

    def fit_svd(self, train_features):
        if not self.svd:
            raise ValueError("SVD is not enabled for this model")
        U, S, Vt = torch.svd(train_features.cpu())
        self.svd_projection = Vt[:self.svd_dim, :].T
        self.svd_projection = nn.Parameter(
            self.svd_projection.to(train_features.device),
            requires_grad=False
        )
        print(f"SVD fitted: {train_features.shape[1]} -> {self.svd_dim} dimensions")

    def forward(self, x):
        outputs = self.vit(x)
        hidden_states = outputs.last_hidden_state
        if self.use_patch:
            features = hidden_states[:, 1:, :]
            B, N, D = features.shape
            features = features.reshape(B, N * D)
        else:
            features = hidden_states[:, 0, :]
        if self.svd:
            if self.svd_projection is None:
                raise RuntimeError("SVD is enabled but not fitted. Call fit_svd() first!")
            features = features @ self.svd_projection
        logits = self.classifier(features)
        return logits


class ImageDatasetVit(Dataset):
    def __init__(self, image_base_path, nums=None, categories=None):
        self.image_paths = []
        self.labels = []

        all_img_files = os.listdir(image_base_path)

        for img_file in all_img_files:
            true_num = self._get_true_num(img_file)
            cat_name = self._get_category(img_file)

            if (nums is None or true_num in nums) and (categories is None or cat_name in categories):
                self.image_paths.append(os.path.join(image_base_path, img_file))
                self.labels.append(true_num)

    def _get_true_num(self, img_name):
        parts = img_name.split('_')
        for part in parts:
            if part.isdigit() and int(part) < 100:
                return int(part)
        return None

    def _get_category(self, img_name):
        return img_name.split('_')[0]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGB")
        label = self.labels[idx]
        return image, label

def partially_unfreeze_vit(model, unfreeze_last_n_layers=4):

    for param in model.parameters():
        param.requires_grad = False


    for layer in model.encoder.layers[-unfreeze_last_n_layers:]:
        for param in layer.parameters():
            if param.is_floating_point():  # <-- only set requires_grad if float
                param.requires_grad = True

    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    unfrozen_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f"Total parameters: {total_params}")
    print(f"Frozen parameters: {frozen_params}")
    print(f"Unfrozen parameters: {unfrozen_params}")

# Example usage
partially_unfreeze_vit(baseline_vit, unfreeze_last_n_layers=4)


from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model

def setup_peft_vit_mlp(model, r=8, lora_alpha=16, lora_dropout=0.1,
                       target_modules=None, use_gradient_checkpointing=True):

    # Step 1: Prepare the quantized model for training
    model = prepare_model_for_kbit_training(
        model,
        use_gradient_checkpointing=use_gradient_checkpointing
    )

    # Step 2: Define target modules for CLIP ViT MLP layers
    if target_modules is None:
        target_modules = [
            "mlp.fc1",  # First linear layer in MLP (1024 -> 4096)
            "mlp.fc2"   # Second linear layer in MLP (4096 -> 1024)
        ]

    # Step 3: Configure LoRA
    peft_config = LoraConfig(
        r=r,
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        target_modules=target_modules,
        bias="none",
        task_type="FEATURE_EXTRACTION",
        inference_mode=False
    )

    # Step 4: Apply PEFT
    lora_model = get_peft_model(model, peft_config)

    # Print trainable parameters info
    lora_model.print_trainable_parameters()

    return lora_model