In [None]:
!pip install open_clip_torch transformers

In [None]:
import open_clip
import torch

model, _, transform = open_clip.create_model_and_transforms(
  model_name="coca_ViT-L-14",
  pretrained="mscoco_finetuned_laion2B-s13B-b90k"
)

In [None]:
!wget https://i.imgur.com/8H7XCH0.jpg -O cat.jpg

In [None]:
from IPython.display import Image
Image('cat.jpg')

In [None]:
from PIL import Image
im = Image.open("cat.jpg").convert("RGB")
im = transform(im).unsqueeze(0)

with torch.no_grad(), torch.cuda.amp.autocast():
  generated = model.generate(im)

print(open_clip.decode(generated[0]).split("<end_of_text>")[0].replace("<start_of_text>", ""))

# mini image net sym links


In [2]:
import os
from collections import defaultdict
from datasets import load_dataset, get_dataset_config_info
from PIL import Image

def create_fewshot_splits(dataset_name="timm/mini-imagenet", output_dir="mini_imagenet_fewshot", num_images_per_class=20):
    """
    Loads the mini-imagenet dataset and creates train, validation, and test splits
    with a specified number of images per class, saved into folders.

    This version is optimized for low RAM usage by streaming the dataset.

    Args:
        dataset_name (str): The name of the dataset on the Hugging Face Hub.
        output_dir (str): The root directory to save the new dataset splits.
        num_images_per_class (int): The number of images to save for each class in each split.
    """
    print(f"Loading '{dataset_name}' dataset info...")
    # The original dataset only has 'train' and 'validation' splits.
    # We will create our train set from the original 'train' split,
    # and our validation and test sets from the original 'validation' split.
    
    # For streaming, we can't access .features directly on the dataset object.
    # We get the dataset info first to extract class names.
    try:
        info = get_dataset_config_info(dataset_name, trust_remote_code=True)
        class_names = info.features['label'].names
    except Exception as e:
        print(f"Could not automatically get class names via info: {e}")
        print("Attempting to load a small part of the dataset to infer class names (this may use more RAM).")
        # Fallback for datasets where info might not be straightforward
        temp_ds = load_dataset(dataset_name, split='train', trust_remote_code=True)
        class_names = temp_ds.features['label'].names
        del temp_ds # free up memory

    print(f"Found {len(class_names)} classes.")
    print("Starting to stream and process datasets to save RAM...")

    # Use streaming=True to avoid loading the entire dataset into memory.
    original_train_ds = load_dataset(dataset_name, split='train', trust_remote_code=True, streaming=True)
    original_val_ds = load_dataset(dataset_name, split='validation', trust_remote_code=True, streaming=True)
    
    # --- 1. Process the Training Split ---
    print("\nProcessing 'train' split...")
    process_split(
        ds=original_train_ds,
        split_name='train',
        output_dir=output_dir,
        class_names=class_names,
        num_images=num_images_per_class
    )

    # --- 2. Process the Validation and Test Splits from the original validation set ---
    print("\nProcessing 'validation' and 'test' splits...")
    # This part is rewritten to handle streaming data and avoid loading everything into memory.
    val_counts = defaultdict(int)
    test_counts = defaultdict(int)
    total_classes = len(class_names)
    # Use a set to track completed classes for faster lookups
    completed_classes = set()

    for item in original_val_ds:
        label_idx = item['label']
        
        # If we have already processed this class, skip to the next item
        if label_idx in completed_classes:
            continue

        class_name = class_names[label_idx]

        # Try to add to validation set if not full
        if val_counts[label_idx] < num_images_per_class:
            val_path = os.path.join(output_dir, 'validation', class_name)
            os.makedirs(val_path, exist_ok=True)
            img_path = os.path.join(val_path, f"{val_counts[label_idx]}.jpg")
            item['image'].convert("RGB").save(img_path)
            val_counts[label_idx] += 1
        # If validation is full, try to add to test set
        elif test_counts[label_idx] < num_images_per_class:
            test_path = os.path.join(output_dir, 'test', class_name)
            os.makedirs(test_path, exist_ok=True)
            img_path = os.path.join(test_path, f"{test_counts[label_idx]}.jpg")
            item['image'].convert("RGB").save(img_path)
            test_counts[label_idx] += 1
        
        # Check if the class is now complete for both splits after the potential save
        if val_counts[label_idx] >= num_images_per_class and test_counts[label_idx] >= num_images_per_class:
             if label_idx not in completed_classes:
                completed_classes.add(label_idx)
                print(f"  Finished val/test for class: {class_name} ({len(completed_classes)}/{total_classes})")
        
        # Early exit if all classes are done
        if len(completed_classes) == total_classes:
            print("Successfully created 'validation' and 'test' splits.")
            break
            
    print("\nDataset creation complete.")
    print(f"Your few-shot dataset is ready in the '{output_dir}' directory.")


def process_split(ds, split_name, output_dir, class_names, num_images):
    """
    Helper function to iterate through a dataset split and save a few images per class.
    """
    class_counts = defaultdict(int)
    total_classes = len(class_names)
    classes_done = 0

    for item in ds:
        label_idx = item['label']

        if class_counts[label_idx] < num_images:
            class_name = class_names[label_idx]
            
            # Create the directory path for the class
            class_dir = os.path.join(output_dir, split_name, class_name)
            os.makedirs(class_dir, exist_ok=True)
            
            # Save the image
            image = item['image']
            image_path = os.path.join(class_dir, f"{class_counts[label_idx]}.jpg")
            # Convert to RGB to handle potential RGBA or grayscale images
            image.convert("RGB").save(image_path)
            
            # Increment the counter for this class
            class_counts[label_idx] += 1

            # Check if this class is now complete
            if class_counts[label_idx] == num_images:
                classes_done += 1
                print(f"  Finished class: {class_name} ({classes_done}/{total_classes})")
        
        # Early exit if we have collected enough images for all classes
        if classes_done == total_classes:
            print(f"Successfully created '{split_name}' split.")
            break

if __name__ == '__main__':
    # You might need to install these libraries if you haven't already:
    # pip install datasets Pillow
    create_fewshot_splits()



Loading 'timm/mini-imagenet' dataset info...


README.md: 0.00B [00:00, ?B/s]

Found 100 classes.
Starting to stream and process datasets to save RAM...

Processing 'train' split...
  Finished class: n01532829 (1/100)
  Finished class: n01558993 (2/100)


KeyboardInterrupt: 

In [None]:
import os
import zipfile
import sys

def zip_folder(folder_path, output_path=None):
    """Create a zip file from a folder."""
    if not os.path.isdir(folder_path):
        raise ValueError(f"{folder_path} is not a valid directory")

    if output_path is None:
        output_path = folder_path.rstrip(os.sep) + ".zip"

    with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zipf:
        for root, _, files in os.walk(folder_path):
            for file in files:
                abs_path = os.path.join(root, file)
                rel_path = os.path.relpath(abs_path, os.path.dirname(folder_path))
                zipf.write(abs_path, rel_path)

    print(f"Created zip: {output_path}")


zip_folder("/kaggle/working/mini_imagenet_fewshot", "/kaggle/working/asd2.zip")


In [None]:
import os
from datasets import get_dataset_infos

def view_and_save_dataset_readme(dataset_name="timm/mini-imagenet", output_dir="mini_imagenet_fewshot"):
    """
    Loads the README (dataset card) for a specified Hugging Face dataset,
    prints it to the console, and saves it to a file.
    This is memory-efficient as it only fetches metadata, not the image data.

    Args:
        dataset_name (str): The name of the dataset on the Hugging Face Hub.
        output_dir (str): The directory where the README.md file will be saved.
    """
    print(f"Fetching README for '{dataset_name}'...")
    try:
        # get_dataset_infos returns a dictionary of configurations for the dataset.
        # We'll use the information from the primary (first) configuration.
        all_infos = get_dataset_infos(dataset_name, trust_remote_code=True)
        
        if not all_infos:
            print(f"Could not find any configuration info for '{dataset_name}'.")
            return

        # Get the info object from the first available configuration
        primary_config_key = next(iter(all_infos))
        info = all_infos[primary_config_key]
        
        readme_content = info.description
        if readme_content:
            # --- Print to Console ---
            print("\n" + "="*25 + " README / Dataset Card " + "="*25 + "\n")
            print(readme_content)
            print("\n" + "="*25 + "  End of README  " + "="*25)

            # --- Save to File ---
            # Ensure the output directory exists
            os.makedirs(output_dir, exist_ok=True)
            output_path = os.path.join(output_dir, "README.md")
            
            try:
                with open(output_path, "w", encoding="utf-8") as f:
                    f.write(readme_content)
                print(f"\nSuccessfully saved README to: '{output_path}'")
            except IOError as e:
                print(f"\nError: Could not write README to file '{output_path}'. Reason: {e}")

        else:
            print("No README or description was found for this dataset.")

    except Exception as e:
        print(f"\nAn error occurred while trying to fetch the dataset info: {e}")
        print("Please check if the dataset name is correct and you have an internet connection.")

if __name__ == '__main__':
    # You might need to install this library if you haven't already:
    # pip install datasets
    view_and_save_dataset_readme()

In [None]:
import os

root = "/kaggle/input/mini-image-net-fewshot-dataset/mini_imagenet_fewshot/train"
folders = os.listdir(root)
print(folders)  # copy these and fill them in the dict above



# start here

In [1]:
synset_to_caption = {
    "n02457408": "three-toed sloth",
    "n02101006": "Gordon setter dog",
    "n02950826": "cannon",
    "n03854065": "organ keyboard instrument",
    "n02219486": "ant",
    "n03888605": "parallel bars",
    "n03017168": "chime or bell",
    "n01558993": "robin bird",
    "n02108551": "Tibetan mastiff dog",
    "n03676483": "lipstick",
    "n03400231": "frying pan",
    "n03838899": "oboe",
    "n03347037": "fire screen",
    "n04509417": "unicycle",
    "n03062245": "cocktail shaker",
    "n04149813": "scoreboard",
    "n03337140": "file cabinet",
    "n04522168": "vase",
    "n02966193": "car radiator grille",
    "n13133613": "ear of corn",
    "n03207743": "dishwasher",
    "n02091831": "Saluki dog",
    "n04612504": "yawl sailboat",
    "n03770439": "minivan",
    "n02091244": "Ibizan hound",
    "n03924679": "photocopier",
    "n02111277": "Newfoundland dog",
    "n02981792": "chain saw",
    "n07747607": "orange fruit",
    "n04418357": "theater curtain",
    "n01981276": "king crab",
    "n02113712": "miniature poodle",
    "n01855672": "goose",
    "n04604644": "worm fence",
    "n04251144": "snorkel",
    "n02108915": "French bulldog",
    "n04275548": "spatula",
    "n03272010": "electric guitar",
    "n02795169": "barber chair",
    "n13054560": "bolete mushroom",
    "n02823428": "beer bottle",
    "n04596742": "wok",
    "n03775546": "mixing bowl",
    "n09256479": "coral reef",
    "n03476684": "hair slide",
    "n02099601": "golden retriever",
    "n02110063": "malamute dog",
    "n04146614": "school bus",
    "n01930112": "nematode worm",
    "n01532829": "house sparrow",
    "n02165456": "ladybug",
    "n07697537": "hotdog sandwich",
    "n04515003": "upright piano",
    "n03544143": "hourglass",
    "n03527444": "holster",
    "n06794110": "street sign",
    "n02129165": "lion",
    "n04067472": "reflex camera",
    "n01704323": "triceratops",
    "n03047690": "clog shoe",
    "n03417042": "garbage truck",
    "n02089867": "Walker hound",
    "n02074367": "dugong",
    "n02120079": "Arctic fox",
    "n03220513": "dome",
    "n02114548": "white wolf",
    "n03075370": "combination lock",
    "n02443484": "black-footed ferret",
    "n04243546": "slot machine",
    "n01749939": "green mamba",
    "n01770081": "harvestman spider",
    "n03584254": "iPod",
    "n02687172": "aircraft carrier",
    "n02971356": "cardigan sweater",
    "n02606052": "rock beauty fish",
    "n02174001": "rhinoceros beetle",
    "n01910747": "jellyfish",
    "n09246464": "cliff",
    "n03127925": "crash helmet",
    "n07613480": "trifle dessert",
    "n02110341": "dalmatian dog",
    "n04296562": "stage",
    "n04443257": "tobacco shop",
    "n02105505": "komondor dog",
    "n03535780": "horizontal bar",
    "n04389033": "tank",
    "n02116738": "African hunting dog",
    "n02747177": "ashcan trash bin",
    "n03773504": "missile",
    "n02108089": "boxer dog",
    "n03146219": "cuirass armor",
    "n03980874": "poncho",
    "n03908618": "pencil box",
    "n01843383": "toucan",
    "n04258138": "solar dish",
    "n03998194": "prayer rug",
    "n07584110": "consomme soup",
    "n04435653": "tile roof",
    "n02138441": "meerkat",
    "n02871525": "book jacket"
}
import os

src_root = "/kaggle/input/mini-image-net-fewshot-dataset/asd2/mini_imagenet_fewshot"
dst_root = "/kaggle/working/mini_imagenet_fewshot_renamed"

splits = ["train", "validation", "test"]

# create output dirs
for split in splits:
    os.makedirs(os.path.join(dst_root, split), exist_ok=True)

# symlink each class folder with a new readable name
for split in splits:
    split_src = os.path.join(src_root, split)
    split_dst = os.path.join(dst_root, split)

    for synset in os.listdir(split_src):
        readable = synset_to_caption.get(synset, synset).replace(" ", "_")
        src_path = os.path.join(split_src, synset)
        dst_path = os.path.join(split_dst, readable)

        # create symlink
        if not os.path.exists(dst_path):
            os.symlink(src_path, dst_path)
            print(f"🔗 Linked {synset} → {readable}")


# libraries install

In [2]:
!pip install open_clip_torch transformers



In [3]:
!pip install peft



# prototype baseline

In [None]:
import open_clip
import torch
import torch.nn.functional as F
from PIL import Image
from pathlib import Path
from typing import List, Dict, Tuple
import numpy as np
from tqdm import tqdm
import json

class CoCaFewShotBenchmark:
    def __init__(
        self,
        model_name: str = "coca_ViT-L-14",
        pretrained: str = "mscoco_finetuned_laion2B-s13B-b90k",
        device: str = "cuda" if torch.cuda.is_available() else "cpu"
    ):
        """
        Initialize CoCa model for few-shot learning.
        
        Args:
            model_name: CoCa model architecture
            pretrained: Pretrained weights to use
            device: Device to run inference on
        """
        self.device = device
        print(f"Loading CoCa model on {device}...")
        
        self.model, _, self.transform = open_clip.create_model_and_transforms(
            model_name=model_name,
            pretrained=pretrained
        )
        self.model = self.model.to(device)
        self.model.eval()
        
        self.tokenizer = open_clip.get_tokenizer(model_name)
        
    def create_text_prompts(self, class_name: str) -> List[str]:
        """
        Create multiple prompt templates for a class name.
        
        Args:
            class_name: Name of the class
            
        Returns:
            List of text prompts
        """
        # Clean up class name (replace underscores with spaces)
        clean_name = class_name.replace("_", " ")
        
        templates = [
            f"a photo of a {clean_name}",
            f"an image of a {clean_name}",
            f"a picture of a {clean_name}",
            f"a {clean_name}",
            f"the {clean_name}",
        ]
        return templates
    
    def encode_text_prompts(self, prompts: List[str]) -> torch.Tensor:
        """
        Encode text prompts and average them.
        
        Args:
            prompts: List of text prompts
            
        Returns:
            Averaged text embedding
        """
        with torch.no_grad(), torch.cuda.amp.autocast():
            text_tokens = self.tokenizer(prompts).to(self.device)
            text_features = self.model.encode_text(text_tokens)
            text_features = F.normalize(text_features, dim=-1)
            # Average across prompts
            text_prototype = text_features.mean(dim=0)
            text_prototype = F.normalize(text_prototype, dim=-1)
        return text_prototype
    
    def encode_image(self, image_path: Path) -> torch.Tensor:
        """
        Encode a single image.
        
        Args:
            image_path: Path to image file
            
        Returns:
            Image embedding
        """
        with torch.no_grad(), torch.cuda.amp.autocast():
            image = Image.open(image_path).convert("RGB")
            image_tensor = self.transform(image).unsqueeze(0).to(self.device)
            image_features = self.model.encode_image(image_tensor)
            image_features = F.normalize(image_features, dim=-1)
        return image_features.squeeze(0)
    
    def build_class_prototypes(
        self,
        train_dir: Path,
        n_shot: int = 5,
        text_weight: float = 0.5
    ) -> Tuple[Dict[str, torch.Tensor], List[str]]:
        """
        Build class prototypes from few-shot training data.
        
        Args:
            train_dir: Path to training directory
            n_shot: Number of examples per class
            text_weight: Weight for text embeddings (1 - text_weight for images)
            
        Returns:
            Dictionary mapping class names to prototypes, and list of class names
        """
        class_prototypes = {}
        class_names = []
        
        # Get all class directories
        class_dirs = sorted([d for d in train_dir.iterdir() if d.is_dir()])
        
        print(f"\nBuilding {n_shot}-shot prototypes for {len(class_dirs)} classes...")
        
        for class_dir in tqdm(class_dirs, desc="Processing classes"):
            class_name = class_dir.name
            class_names.append(class_name)
            
            # Get image paths (limit to n_shot)
            image_paths = sorted(class_dir.glob("*.jpg"))[:n_shot]
            
            if len(image_paths) < n_shot:
                print(f"Warning: {class_name} has only {len(image_paths)} images (expected {n_shot})")
            
            # Encode images
            image_embeddings = []
            for img_path in image_paths:
                img_emb = self.encode_image(img_path)
                image_embeddings.append(img_emb)
            
            # Average image embeddings
            image_prototype = torch.stack(image_embeddings).mean(dim=0)
            image_prototype = F.normalize(image_prototype, dim=-1)
            
            # Create and encode text prompts
            text_prompts = self.create_text_prompts(class_name)
            text_prototype = self.encode_text_prompts(text_prompts)
            
            # Combine text and image prototypes
            class_prototype = (
                text_weight * text_prototype + 
                (1 - text_weight) * image_prototype
            )
            class_prototype = F.normalize(class_prototype, dim=-1)
            
            class_prototypes[class_name] = class_prototype
        
        return class_prototypes, class_names
    
    def evaluate(
        self,
        test_dir: Path,
        class_prototypes: Dict[str, torch.Tensor],
        class_names: List[str]
    ) -> Dict[str, float]:
        """
        Evaluate on test set.
        
        Args:
            test_dir: Path to test directory
            class_prototypes: Dictionary of class prototypes
            class_names: List of class names (in order)
            
        Returns:
            Dictionary of evaluation metrics
        """
        correct = 0
        total = 0
        per_class_correct = {name: 0 for name in class_names}
        per_class_total = {name: 0 for name in class_names}
        
        # Stack all prototypes for efficient similarity computation
        prototype_matrix = torch.stack([class_prototypes[name] for name in class_names])
        
        print(f"\nEvaluating on test set...")
        
        # Get all test class directories
        test_class_dirs = sorted([d for d in test_dir.iterdir() if d.is_dir()])
        
        for test_class_dir in tqdm(test_class_dirs, desc="Testing classes"):
            true_class = test_class_dir.name
            
            # Get all test images
            test_images = sorted(test_class_dir.glob("*.jpg"))
            
            for img_path in test_images:
                # Encode test image
                img_embedding = self.encode_image(img_path)
                
                # Compute similarities with all class prototypes
                similarities = img_embedding @ prototype_matrix.T
                
                # Get predicted class
                pred_idx = similarities.argmax().item()
                pred_class = class_names[pred_idx]
                
                # Update metrics
                total += 1
                per_class_total[true_class] += 1
                
                if pred_class == true_class:
                    correct += 1
                    per_class_correct[true_class] += 1
        
        # Calculate metrics
        overall_accuracy = (correct / total) * 100 if total > 0 else 0
        
        per_class_accuracy = {
            name: (per_class_correct[name] / per_class_total[name] * 100) 
            if per_class_total[name] > 0 else 0
            for name in class_names
        }
        
        mean_per_class_accuracy = np.mean(list(per_class_accuracy.values()))
        
        results = {
            "overall_accuracy": overall_accuracy,
            "mean_per_class_accuracy": mean_per_class_accuracy,
            "total_samples": total,
            "correct_predictions": correct,
            "per_class_accuracy": per_class_accuracy
        }
        
        return results
    
    def run_benchmark(
        self,
        dataset_dir: Path,
        n_shot: int = 5,
        text_weight: float = 0.5,
        save_results: bool = True
    ) -> Dict:
        """
        Run complete few-shot benchmark.
        
        Args:
            dataset_dir: Path to dataset root
            n_shot: Number of examples per class
            text_weight: Weight for text vs image prototypes
            save_results: Whether to save results to JSON
            
        Returns:
            Evaluation results dictionary
        """
        train_dir = dataset_dir / "train"
        test_dir = dataset_dir / "test"
        
        # Build prototypes
        class_prototypes, class_names = self.build_class_prototypes(
            train_dir, n_shot, text_weight
        )
        
        # Evaluate
        results = self.evaluate(test_dir, class_prototypes, class_names)
        
        # Add experiment metadata
        results["experiment_config"] = {
            "n_shot": n_shot,
            "text_weight": text_weight,
            "num_classes": len(class_names),
            "model": "coca_ViT-L-14"
        }
        
        # Print results
        print("\n" + "="*60)
        print(f"Few-Shot Learning Results ({n_shot}-shot)")
        print("="*60)
        print(f"Overall Accuracy: {results['overall_accuracy']:.2f}%")
        print(f"Mean Per-Class Accuracy: {results['mean_per_class_accuracy']:.2f}%")
        print(f"Total Samples: {results['total_samples']}")
        print(f"Correct Predictions: {results['correct_predictions']}")
        print("="*60)
        
        # Save results
        if save_results:
            output_file = f"coca_fewshot_results_{n_shot}shot.json"
            with open(output_file, 'w') as f:
                # Convert per-class accuracies to serializable format
                serializable_results = results.copy()
                serializable_results["per_class_accuracy"] = {
                    k: float(v) for k, v in results["per_class_accuracy"].items()
                }
                json.dump(serializable_results, f, indent=2)
            print(f"Results saved to {output_file}")
        
        return results


def main():
    """Example usage of the benchmark."""
    
    # Configuration
    DATASET_DIR = Path("/kaggle/working/mini_imagenet_fewshot_renamed")
    N_SHOTS = [10,20]  # Different shot configurations to test
    TEXT_WEIGHT = 0.7  # Balance between text and image prototypes
    
    # Initialize benchmark
    benchmark = CoCaFewShotBenchmark()
    
    # Run experiments for different shot settings
    all_results = {}
    
    for n_shot in N_SHOTS:
        print(f"\n{'='*60}")
        print(f"Running {n_shot}-shot experiment")
        print(f"{'='*60}")
        
        results = benchmark.run_benchmark(
            dataset_dir=DATASET_DIR,
            n_shot=n_shot,
            text_weight=TEXT_WEIGHT,
            save_results=True
        )
        
        all_results[f"{n_shot}_shot"] = results
    
    # Print summary comparison
    print("\n" + "="*60)
    print("SUMMARY COMPARISON")
    print("="*60)
    print(f"{'N-Shot':<10} {'Overall Acc':<15} {'Mean Per-Class Acc':<20}")
    print("-"*60)
    for n_shot in N_SHOTS:
        res = all_results[f"{n_shot}_shot"]
        print(f"{n_shot:<10} {res['overall_accuracy']:<15.2f} {res['mean_per_class_accuracy']:<20.2f}")
    print("="*60)


if __name__ == "__main__":
    main()

# finetuning 1 classification head

In [4]:
import open_clip
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from pathlib import Path
from typing import List, Dict, Tuple, Optional
import numpy as np
from tqdm import tqdm
import json
from torchvision import transforms
import random


class FewShotDataset(Dataset):
    """Dataset for few-shot learning."""
    
    def __init__(self, root_dir: Path, transform=None, augment: bool = False):
        self.root_dir = root_dir
        self.transform = transform
        self.augment = augment
        
        # Collect all images and labels
        self.samples = []
        self.class_to_idx = {}
        
        class_dirs = sorted([d for d in root_dir.iterdir() if d.is_dir()])
        for idx, class_dir in enumerate(class_dirs):
            self.class_to_idx[class_dir.name] = idx
            for img_path in sorted(class_dir.glob("*.jpg")):
                self.samples.append((img_path, idx))
        
        # Augmentation pipeline
        if self.augment:
            self.augment_transform = transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            ])
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        
        if self.augment:
            image = self.augment_transform(image)
        
        if self.transform:
            image = self.transform(image)
        
        return image, label


class LinearClassifier(nn.Module):
    """Linear classification head."""
    
    def __init__(self, input_dim: int, num_classes: int):
        super().__init__()
        self.classifier = nn.Linear(input_dim, num_classes)
        
        # Initialize with small values
        nn.init.normal_(self.classifier.weight, std=0.01)
        nn.init.zeros_(self.classifier.bias)
    
    def forward(self, x):
        return self.classifier(x)


class CoCaFewShotFinetune:
    def __init__(
        self,
        model_name: str = "coca_ViT-L-14",
        pretrained: str = "mscoco_finetuned_laion2B-s13B-b90k",
        device: str = "cuda" if torch.cuda.is_available() else "cpu"
    ):
        """
        Initialize CoCa model for few-shot fine-tuning.
        
        Args:
            model_name: CoCa model architecture
            pretrained: Pretrained weights to use
            device: Device to run on
        """
        self.device = device
        print(f"Loading CoCa model on {device}...")
        
        self.model, _, self.transform = open_clip.create_model_and_transforms(
            model_name=model_name,
            pretrained=pretrained
        )
        self.model = self.model.to(device)
        self.model.eval()  # Keep in eval mode (frozen)
        
        # Freeze all parameters
        for param in self.model.parameters():
            param.requires_grad = False
        
        self.classifier = None
        self.class_names = None
    
    @torch.no_grad()
    def extract_features(self, dataloader: DataLoader) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Extract image features from frozen encoder.
        
        Args:
            dataloader: DataLoader for images
            
        Returns:
            Features and labels tensors
        """
        all_features = []
        all_labels = []
        
        for images, labels in tqdm(dataloader, desc="Extracting features"):
            images = images.to(self.device)
            
            with torch.cuda.amp.autocast():
                features = self.model.encode_image(images)
                features = F.normalize(features, dim=-1)
            
            all_features.append(features.cpu())
            all_labels.append(labels)
        
        return torch.cat(all_features), torch.cat(all_labels)
    
    def train_classifier(
        self,
        train_dir: Path,
        n_shot: int = 5,
        learning_rate: float = 1e-4,
        weight_decay: float = 0.01,
        num_epochs: int = 100,
        batch_size: int = 32,
        label_smoothing: float = 0.1,
        use_augmentation: bool = True,
        warmup_epochs: int = 10,
        patience: int = 20
    ) -> Dict:
        """
        Train linear classifier on few-shot data.
        
        Args:
            train_dir: Path to training directory
            n_shot: Number of examples per class
            learning_rate: Learning rate for optimizer
            weight_decay: Weight decay coefficient
            num_epochs: Maximum number of epochs
            batch_size: Batch size for training
            label_smoothing: Label smoothing factor
            use_augmentation: Whether to use data augmentation
            warmup_epochs: Number of warmup epochs
            patience: Early stopping patience
            
        Returns:
            Training history dictionary
        """
        # Create limited dataset (n-shot per class)
        train_samples = []
        class_dirs = sorted([d for d in train_dir.iterdir() if d.is_dir()])
        self.class_names = [d.name for d in class_dirs]
        num_classes = len(self.class_names)
        
        print(f"\nPreparing {n_shot}-shot training data...")
        class_to_idx = {name: idx for idx, name in enumerate(self.class_names)}
        
        for class_dir in class_dirs:
            class_name = class_dir.name
            class_idx = class_to_idx[class_name]
            image_paths = sorted(class_dir.glob("*.jpg"))[:n_shot]
            
            for img_path in image_paths:
                train_samples.append((img_path, class_idx))
        
        print(f"Total training samples: {len(train_samples)}")
        
        # Create dataset with or without augmentation
        class TempDataset(Dataset):
            def __init__(self, samples, transform, augment):
                self.samples = samples
                self.transform = transform
                self.augment = augment
                
                if self.augment:
                    self.augment_transform = transforms.Compose([
                        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
                        transforms.RandomHorizontalFlip(),                   
                        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), 
                    ])
            
            def __len__(self):
                return len(self.samples)
            
            def __getitem__(self, idx):
                img_path, label = self.samples[idx]
                image = Image.open(img_path).convert("RGB")
                
                if self.augment:
                    image = self.augment_transform(image)
                
                if self.transform:
                    image = self.transform(image)
                
                return image, label
        
        train_dataset = TempDataset(train_samples, self.transform, use_augmentation)
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )
        
        # Extract features once (for validation without augmentation)
        print("\nExtracting validation features (no augmentation)...")
        val_dataset = TempDataset(train_samples, self.transform, augment=False)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        val_features, val_labels = self.extract_features(val_loader)
        val_features = val_features.to(self.device)
        val_labels = val_labels.to(self.device)
        
        # Get feature dimension
        with torch.no_grad(), torch.cuda.amp.autocast():
            sample_img = train_dataset[0][0].unsqueeze(0).to(self.device)
            sample_feat = self.model.encode_image(sample_img)
            feature_dim = sample_feat.shape[1]
        
        # Initialize classifier
        self.classifier = LinearClassifier(feature_dim, num_classes).to(self.device)
        
        # Setup optimizer and loss
        optimizer = torch.optim.AdamW(
            self.classifier.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay
        )
        
        criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
        
        # Learning rate scheduler with warmup
        def lr_lambda(epoch):
            if epoch < warmup_epochs:
                return (epoch + 1) / warmup_epochs
            else:
                # Cosine decay after warmup
                progress = (epoch - warmup_epochs) / (num_epochs - warmup_epochs)
                return 0.5 * (1.0 + np.cos(np.pi * progress))
        
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
        
        # Training loop
        history = {
            "train_loss": [],
            "val_accuracy": [],
            "learning_rates": []
        }
        
        best_val_acc = 0.0
        patience_counter = 0
        
        print(f"\nTraining linear classifier...")
        print(f"Config: LR={learning_rate}, WD={weight_decay}, Label Smoothing={label_smoothing}")
        print(f"Augmentation: {use_augmentation}, Warmup Epochs: {warmup_epochs}\n")
        
        for epoch in range(num_epochs):
            # Training phase
            self.classifier.train()
            train_loss = 0.0
            
            for images, labels in train_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                
                # Extract features (frozen encoder)
                with torch.no_grad(), torch.cuda.amp.autocast():
                    features = self.model.encode_image(images)
                    features = F.normalize(features, dim=-1)
                
                # Forward pass through classifier
                logits = self.classifier(features)
                loss = criterion(logits, labels)
                
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
            
            train_loss /= len(train_loader)
            
            # Validation phase (on training data without augmentation)
            self.classifier.eval()
            with torch.no_grad():
                val_logits = self.classifier(val_features)
                val_preds = val_logits.argmax(dim=1)
                val_accuracy = (val_preds == val_labels).float().mean().item() * 100
            
            # Update scheduler
            current_lr = optimizer.param_groups[0]['lr']
            scheduler.step()
            
            # Record history
            history["train_loss"].append(train_loss)
            history["val_accuracy"].append(val_accuracy)
            history["learning_rates"].append(current_lr)
            
            # Early stopping check
            if val_accuracy > best_val_acc:
                best_val_acc = val_accuracy
                patience_counter = 0
                # Save best model
                best_state = self.classifier.state_dict()
            else:
                patience_counter += 1
            
            # Print progress
            if (epoch + 1) % 10 == 0 or epoch == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}] "
                      f"Loss: {train_loss:.4f} | "
                      f"Val Acc: {val_accuracy:.2f}% | "
                      f"LR: {current_lr:.6f}")
            
            # Early stopping
            if patience_counter >= patience:
                print(f"\nEarly stopping at epoch {epoch+1}")
                break
        
        # Load best model
        self.classifier.load_state_dict(best_state)
        print(f"\nTraining complete! Best validation accuracy: {best_val_acc:.2f}%")
        
        return history
    
    def evaluate(self, test_dir: Path, batch_size: int = 32) -> Dict[str, float]:
        """
        Evaluate classifier on test set.
        
        Args:
            test_dir: Path to test directory
            batch_size: Batch size for evaluation
            
        Returns:
            Dictionary of evaluation metrics
        """
        if self.classifier is None:
            raise ValueError("Classifier not trained. Call train_classifier first.")
        
        self.classifier.eval()
        
        # Create test dataset
        test_dataset = FewShotDataset(test_dir, transform=self.transform, augment=False)
        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=4
        )
        
        # Extract features
        print("\nExtracting test features...")
        test_features, test_labels = self.extract_features(test_loader)
        test_features = test_features.to(self.device)
        test_labels = test_labels.to(self.device)
        
        # Predict
        with torch.no_grad():
            logits = self.classifier(test_features)
            preds = logits.argmax(dim=1)
        
        # Calculate metrics
        correct = (preds == test_labels).sum().item()
        total = len(test_labels)
        overall_accuracy = (correct / total) * 100
        
        # Per-class accuracy
        per_class_correct = {}
        per_class_total = {}
        
        for pred, label in zip(preds.cpu().numpy(), test_labels.cpu().numpy()):
            class_name = self.class_names[label]
            per_class_total[class_name] = per_class_total.get(class_name, 0) + 1
            if pred == label:
                per_class_correct[class_name] = per_class_correct.get(class_name, 0) + 1
        
        per_class_accuracy = {
            name: (per_class_correct.get(name, 0) / per_class_total[name] * 100)
            for name in per_class_total
        }
        
        mean_per_class_accuracy = np.mean(list(per_class_accuracy.values()))
        
        results = {
            "overall_accuracy": overall_accuracy,
            "mean_per_class_accuracy": mean_per_class_accuracy,
            "total_samples": total,
            "correct_predictions": correct,
            "per_class_accuracy": per_class_accuracy
        }
        
        return results
    
    def run_finetuning_experiment(
        self,
        dataset_dir: Path,
        n_shot: int = 5,
        learning_rate: float = 1e-4,
        weight_decay: float = 0.01,
        num_epochs: int = 100,
        label_smoothing: float = 0.1,
        use_augmentation: bool = True,
        save_results: bool = True
    ) -> Dict:
        """
        Run complete fine-tuning experiment.
        
        Args:
            dataset_dir: Path to dataset root
            n_shot: Number of examples per class
            learning_rate: Learning rate
            weight_decay: Weight decay
            num_epochs: Number of training epochs
            label_smoothing: Label smoothing factor
            use_augmentation: Whether to use data augmentation
            save_results: Whether to save results
            
        Returns:
            Complete results dictionary
        """
        train_dir = dataset_dir / "train"
        test_dir = dataset_dir / "test"
        
        # Train classifier
        history = self.train_classifier(
            train_dir=train_dir,
            n_shot=n_shot,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
            num_epochs=num_epochs,
            label_smoothing=label_smoothing,
            use_augmentation=use_augmentation
        )
        
        # Evaluate on test set
        results = self.evaluate(test_dir)
        
        # Combine results
        results["training_history"] = history
        results["experiment_config"] = {
            "n_shot": n_shot,
            "learning_rate": learning_rate,
            "weight_decay": weight_decay,
            "num_epochs": num_epochs,
            "label_smoothing": label_smoothing,
            "use_augmentation": use_augmentation,
            "model": "coca_ViT-L-14"
        }
        
        # Print results
        print("\n" + "="*60)
        print(f"Fine-tuning Results ({n_shot}-shot)")
        print("="*60)
        print(f"Overall Accuracy: {results['overall_accuracy']:.2f}%")
        print(f"Mean Per-Class Accuracy: {results['mean_per_class_accuracy']:.2f}%")
        print(f"Total Samples: {results['total_samples']}")
        print(f"Correct Predictions: {results['correct_predictions']}")
        print("="*60)
        
        # Save results
        if save_results:
            output_file = f"coca_finetune_{n_shot}shot_aug{use_augmentation}.json"
            with open(output_file, 'w') as f:
                serializable_results = results.copy()
                serializable_results["per_class_accuracy"] = {
                    k: float(v) for k, v in results["per_class_accuracy"].items()
                }
                json.dump(serializable_results, f, indent=2)
            print(f"Results saved to {output_file}")
        
        return results


def main():
    """Example usage demonstrating different configurations."""
    
    DATASET_DIR = Path("/kaggle/working/mini_imagenet_fewshot_renamed")
    
    # Configuration options
    configs = [
        # Compare different shot numbers with augmentation
        {"n_shot": 1, "use_augmentation": False, "num_epochs": 200},
        {"n_shot": 3, "use_augmentation": False, "num_epochs": 150},
        {"n_shot": 1, "use_augmentation": True, "num_epochs": 200},
        {"n_shot": 3, "use_augmentation": True, "num_epochs": 150},
        #{"n_shot": 5, "use_augmentation": True, "num_epochs": 200},
        
        # Compare with/without augmentation for 5-shot
        #{"n_shot": 10, "use_augmentation": True, "num_epochs": 200},
        # {"n_shot": 10, "use_augmentation": False, "num_epochs": 200},
        #{"n_shot": 20, "use_augmentation": True, "num_epochs": 200},
        # {"n_shot": 20, "use_augmentation": False, "num_epochs": 200}
    ]
    
    all_results = {}
    
    for config in configs:
        n_shot = config["n_shot"]
        use_aug = config["use_augmentation"]
        
        print(f"\n{'='*70}")
        print(f"Running {n_shot}-shot experiment (Augmentation: {use_aug})")
        print(f"{'='*70}")
        
        # Initialize new instance for each experiment
        finetuner = CoCaFewShotFinetune()
        
        results = finetuner.run_finetuning_experiment(
            dataset_dir=DATASET_DIR,
            n_shot=n_shot,
            learning_rate=5e-5,
            weight_decay=0.05,
            num_epochs=config["num_epochs"],
            label_smoothing=0.1,
            use_augmentation=use_aug,
            save_results=True
        )
        
        key = f"{n_shot}_shot_aug{use_aug}"
        all_results[key] = results
    
    # Print summary
    print("\n" + "="*80)
    print("SUMMARY COMPARISON")
    print("="*80)
    print(f"{'Config':<25} {'Overall Acc':<15} {'Mean Per-Class Acc':<20}")
    print("-"*80)
    for key, res in all_results.items():
        print(f"{key:<25} {res['overall_accuracy']:<15.2f} {res['mean_per_class_accuracy']:<20.2f}")
    print("="*80)


if __name__ == "__main__":
    main()


Running 1-shot experiment (Augmentation: False)
Loading CoCa model on cuda...

Preparing 1-shot training data...
Total training samples: 100

Extracting validation features (no augmentation)...


  with torch.cuda.amp.autocast():
Extracting features: 100%|██████████| 4/4 [00:04<00:00,  1.04s/it]


Training linear classifier...
Config: LR=5e-05, WD=0.05, Label Smoothing=0.1
Augmentation: False, Warmup Epochs: 10




  with torch.no_grad(), torch.cuda.amp.autocast():
  with torch.no_grad(), torch.cuda.amp.autocast():


Epoch [1/200] Loss: 4.6048 | Val Acc: 2.00% | LR: 0.000005
Epoch [10/200] Loss: 4.5987 | Val Acc: 13.00% | LR: 0.000050
Epoch [20/200] Loss: 4.5835 | Val Acc: 42.00% | LR: 0.000050
Epoch [30/200] Loss: 4.5714 | Val Acc: 71.00% | LR: 0.000049
Epoch [40/200] Loss: 4.5646 | Val Acc: 81.00% | LR: 0.000047
Epoch [50/200] Loss: 4.5526 | Val Acc: 89.00% | LR: 0.000045
Epoch [60/200] Loss: 4.5413 | Val Acc: 93.00% | LR: 0.000042
Epoch [70/200] Loss: 4.5303 | Val Acc: 94.00% | LR: 0.000039
Epoch [80/200] Loss: 4.5235 | Val Acc: 95.00% | LR: 0.000035
Epoch [90/200] Loss: 4.5167 | Val Acc: 95.00% | LR: 0.000032

Early stopping at epoch 93

Training complete! Best validation accuracy: 95.00%

Extracting test features...


Extracting features: 100%|██████████| 63/63 [01:00<00:00,  1.05it/s]



Fine-tuning Results (1-shot)
Overall Accuracy: 55.05%
Mean Per-Class Accuracy: 55.05%
Total Samples: 2000
Correct Predictions: 1101
Results saved to coca_finetune_1shot_augFalse.json

Running 3-shot experiment (Augmentation: False)
Loading CoCa model on cuda...

Preparing 3-shot training data...
Total training samples: 300

Extracting validation features (no augmentation)...


Extracting features: 100%|██████████| 10/10 [00:11<00:00,  1.17s/it]


Training linear classifier...
Config: LR=5e-05, WD=0.05, Label Smoothing=0.1
Augmentation: False, Warmup Epochs: 10






Epoch [1/150] Loss: 4.6049 | Val Acc: 0.67% | LR: 0.000005
Epoch [10/150] Loss: 4.5928 | Val Acc: 17.67% | LR: 0.000050
Epoch [20/150] Loss: 4.5675 | Val Acc: 78.00% | LR: 0.000049
Epoch [30/150] Loss: 4.5430 | Val Acc: 93.33% | LR: 0.000048
Epoch [40/150] Loss: 4.5197 | Val Acc: 97.67% | LR: 0.000045
Epoch [50/150] Loss: 4.4974 | Val Acc: 98.67% | LR: 0.000041
Epoch [60/150] Loss: 4.4786 | Val Acc: 99.67% | LR: 0.000036
Epoch [70/150] Loss: 4.4609 | Val Acc: 99.67% | LR: 0.000031

Early stopping at epoch 79

Training complete! Best validation accuracy: 99.67%

Extracting test features...


Extracting features: 100%|██████████| 63/63 [01:00<00:00,  1.04it/s]



Fine-tuning Results (3-shot)
Overall Accuracy: 82.55%
Mean Per-Class Accuracy: 82.55%
Total Samples: 2000
Correct Predictions: 1651
Results saved to coca_finetune_3shot_augFalse.json

Running 1-shot experiment (Augmentation: True)
Loading CoCa model on cuda...

Preparing 1-shot training data...
Total training samples: 100

Extracting validation features (no augmentation)...


Extracting features: 100%|██████████| 4/4 [00:03<00:00,  1.07it/s]


Training linear classifier...
Config: LR=5e-05, WD=0.05, Label Smoothing=0.1
Augmentation: True, Warmup Epochs: 10






Epoch [1/200] Loss: 4.6051 | Val Acc: 0.00% | LR: 0.000005
Epoch [10/200] Loss: 4.5994 | Val Acc: 3.00% | LR: 0.000050
Epoch [20/200] Loss: 4.5873 | Val Acc: 38.00% | LR: 0.000050
Epoch [30/200] Loss: 4.5747 | Val Acc: 67.00% | LR: 0.000049
Epoch [40/200] Loss: 4.5659 | Val Acc: 82.00% | LR: 0.000047
Epoch [50/200] Loss: 4.5550 | Val Acc: 90.00% | LR: 0.000045
Epoch [60/200] Loss: 4.5480 | Val Acc: 96.00% | LR: 0.000042
Epoch [70/200] Loss: 4.5355 | Val Acc: 99.00% | LR: 0.000039
Epoch [80/200] Loss: 4.5319 | Val Acc: 99.00% | LR: 0.000035
Epoch [90/200] Loss: 4.5207 | Val Acc: 99.00% | LR: 0.000032

Early stopping at epoch 94

Training complete! Best validation accuracy: 100.00%

Extracting test features...


Extracting features: 100%|██████████| 63/63 [01:00<00:00,  1.04it/s]



Fine-tuning Results (1-shot)
Overall Accuracy: 58.30%
Mean Per-Class Accuracy: 58.30%
Total Samples: 2000
Correct Predictions: 1166
Results saved to coca_finetune_1shot_augTrue.json

Running 3-shot experiment (Augmentation: True)
Loading CoCa model on cuda...

Preparing 3-shot training data...
Total training samples: 300

Extracting validation features (no augmentation)...


Extracting features: 100%|██████████| 10/10 [00:11<00:00,  1.11s/it]


Training linear classifier...
Config: LR=5e-05, WD=0.05, Label Smoothing=0.1
Augmentation: True, Warmup Epochs: 10






Epoch [1/150] Loss: 4.6050 | Val Acc: 0.67% | LR: 0.000005
Epoch [10/150] Loss: 4.5932 | Val Acc: 17.33% | LR: 0.000050
Epoch [20/150] Loss: 4.5683 | Val Acc: 79.67% | LR: 0.000049
Epoch [30/150] Loss: 4.5443 | Val Acc: 90.67% | LR: 0.000048
Epoch [40/150] Loss: 4.5210 | Val Acc: 95.67% | LR: 0.000045
Epoch [50/150] Loss: 4.5002 | Val Acc: 97.00% | LR: 0.000041
Epoch [60/150] Loss: 4.4804 | Val Acc: 98.67% | LR: 0.000036
Epoch [70/150] Loss: 4.4646 | Val Acc: 99.00% | LR: 0.000031
Epoch [80/150] Loss: 4.4498 | Val Acc: 99.00% | LR: 0.000026

Early stopping at epoch 82

Training complete! Best validation accuracy: 99.00%

Extracting test features...


Extracting features: 100%|██████████| 63/63 [01:00<00:00,  1.04it/s]


Fine-tuning Results (3-shot)
Overall Accuracy: 82.10%
Mean Per-Class Accuracy: 82.10%
Total Samples: 2000
Correct Predictions: 1642
Results saved to coca_finetune_3shot_augTrue.json

SUMMARY COMPARISON
Config                    Overall Acc     Mean Per-Class Acc  
--------------------------------------------------------------------------------
1_shot_augFalse           55.05           55.05               
3_shot_augFalse           82.55           82.55               
1_shot_augTrue            58.30           58.30               
3_shot_augTrue            82.10           82.10               





# ----------------------------------------LoRa-----------------------------------------------------

In [7]:
import open_clip
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Union
import numpy as np
from tqdm import tqdm
import json
from torchvision import transforms
import random
from peft import LoraConfig, get_peft_model


class LinearClassifier(nn.Module):
    """Simple linear classifier head."""
    def __init__(self, input_dim: int, num_classes: int):
        super().__init__()
        self.fc = nn.Linear(input_dim, num_classes)
        
    def forward(self, x):
        return self.fc(x)


class FewShotDataset(Dataset):
    """Dataset for few-shot learning."""
    def __init__(self, data_dir: Path, transform=None, augment=False):
        self.samples = []
        self.transform = transform
        self.class_to_idx = {}
        
        class_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir()])
        
        for idx, class_dir in enumerate(class_dirs):
            self.class_to_idx[class_dir.name] = idx
            image_paths = list(class_dir.glob("*.jpg")) + list(class_dir.glob("*.png")) + list(class_dir.glob("*.JPEG"))
            for img_path in image_paths:
                self.samples.append((img_path, idx))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        return image, label


class LoRACoCaFinetune:
    def __init__(
        self,
        model_name: str = "coca_ViT-L-14",
        pretrained: str = "mscoco_finetuned_laion2B-s13B-b90k",
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
    ):
        """
        Initialize CoCa model for LoRA fine-tuning with configurable loss functions.
        """
        self.device = device
        print(f"Loading CoCa model on {device}...")
        
        self.model, _, self.transform = open_clip.create_model_and_transforms(
            model_name=model_name,
            pretrained=pretrained
        )
        self.model = self.model.to(device)
        
        # Store original model for reference
        self.original_model = self.model
        self.lora_model = None
        self.classifier = None
        self.class_names = None

    def run_lora_experiment(
        self,
        dataset_dir: Path,
        n_shot: int = 5,
        learning_rate: float = 1e-4,
        weight_decay: float = 0.01,
        num_epochs: int = 100,
        label_smoothing: float = 0.1,
        loss_type: str = "cross_entropy",
        use_augmentation: bool = True,
        temperature: float = 0.1,
        save_results: bool = True
    ) -> Dict:
        """Run complete LoRA fine-tuning experiment."""
        train_dir = dataset_dir / "train"
        test_dir = dataset_dir / "test"
        
        # Auto-select loss type
        if loss_type == "auto":
            if n_shot <= 2:
                loss_type = "prototypical"
            elif n_shot <= 5:
                loss_type = "contrastive"
            else:
                loss_type = "cross_entropy"
        
        print(f"Using {loss_type} loss for {n_shot}-shot learning")
        
        # Train
        history = self.train_with_lora(
            train_dir=train_dir,
            n_shot=n_shot,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
            num_epochs=num_epochs,
            label_smoothing=label_smoothing,
            loss_type=loss_type,
            use_augmentation=use_augmentation,
            temperature=temperature
        )
        
        # Evaluate
        results = self.evaluate(test_dir)
        
        # Combine results
        results["training_history"] = history
        results["experiment_config"] = {
            "n_shot": n_shot,
            "learning_rate": learning_rate,
            "weight_decay": weight_decay,
            "num_epochs": num_epochs,
            "label_smoothing": label_smoothing,
            "loss_type": loss_type,
            "use_augmentation": use_augmentation,
            "temperature": temperature,
            "model": "coca_ViT-L-14 + LoRA",
        }
        
        # Print results
        print("\n" + "="*60)
        print(f"LoRA Fine-tuning Results ({n_shot}-shot, {loss_type} loss)")
        print("="*60)
        print(f"Overall Accuracy: {results['overall_accuracy']:.2f}%")
        print(f"Mean Per-Class Accuracy: {results['mean_per_class_accuracy']:.2f}%")
        print(f"Total Samples: {results['total_samples']}")
        print(f"Correct Predictions: {results['correct_predictions']}")
        print("="*60)
        
        # Save results
        if save_results:
            output_file = f"lora_coca_{n_shot}shot_{loss_type}_aug{use_augmentation}.json"
            with open(output_file, 'w') as f:
                serializable_results = results.copy()
                serializable_results["per_class_accuracy"] = {
                    k: float(v) for k, v in results["per_class_accuracy"].items()
                }
                json.dump(serializable_results, f, indent=2)
            print(f"Results saved to {output_file}")
        
        return results
        
    
        
    def setup_lora_for_fewshot(self, n_shot: int) -> LoraConfig:
        """
        Configure LoRA based on number of shots.
        Target the out_proj and MLP layers since MultiheadAttention doesn't expose q/k/v directly.
        
        Args:
            n_shot: Number of examples per class
            
        Returns:
            LoRA configuration
        """
        if n_shot <= 2:
            # 1-2 shot: out_proj only, small rank
            config = LoraConfig(
                r=4,
                lora_alpha=16,
                target_modules=["attn.out_proj"],  # Attention output projection
                # REMOVED: layers_to_transform=list(range(20, 24)),
                lora_dropout=0.15,
                bias="none",
            )
            print(f"LoRA Config: attn.out_proj all layers, r={config.r}")
            
        elif n_shot <= 10:
            # 5-10 shot: out_proj in all layers, medium rank
            config = LoraConfig(
                r=4,
                lora_alpha=16,
                target_modules=["attn.out_proj"],  # Attention output in all layers
                lora_dropout=0.15,
                bias="none",
            )
            print(f"LoRA Config: attn.out_proj all layers, r={config.r}, droupout={config.lora_dropout}")
            
        else:  # 10-20+ shot
            # out_proj + MLP layers, larger rank
            config = LoraConfig(
                r=8,
                lora_alpha=32,
                target_modules=["attn.out_proj", "mlp.c_fc", "mlp.c_proj"],
                lora_dropout=0.15,
                bias="none",
            )
            print(f"LoRA Config: attn.out_proj all layers, r={config.r}")
            
        return config

    def apply_lora(self, n_shot: int):
        """Apply LoRA to the model based on few-shot configuration."""
        lora_config = self.setup_lora_for_fewshot(n_shot)
        
        # Apply LoRA to the visual encoder only
        self.lora_model = get_peft_model(self.model.visual, lora_config)
        self.lora_model.print_trainable_parameters()
    
    def prototypical_loss(self, features: torch.Tensor, labels: torch.Tensor, n_support: int = 2) -> torch.Tensor:
        """
        Prototypical loss for few-shot learning.
        """
        unique_labels = torch.unique(labels)
        n_classes = len(unique_labels)
        
        # Create label to index mapping
        label_to_idx = {label.item(): idx for idx, label in enumerate(unique_labels)}
        
        # Create prototypes for each class
        prototypes = []
        query_features_list = []
        query_labels_list = []
        
        for label in unique_labels:
            mask = labels == label
            class_features = features[mask]
            
            if len(class_features) <= n_support:
                # Not enough samples for this class, use all as support
                prototype = class_features.mean(dim=0)
                prototypes.append(prototype)
            else:
                # Split into support and query
                support_features = class_features[:n_support]
                query_features = class_features[n_support:]
                
                prototype = support_features.mean(dim=0)
                prototypes.append(prototype)
                
                query_features_list.append(query_features)
                query_labels_list.extend([label_to_idx[label.item()]] * len(query_features))
        
        if len(query_features_list) == 0:
            # No query samples, fallback to cross-entropy with prototypes
            distances = torch.cdist(features, torch.stack(prototypes), p=2)
            logits = -distances
            mapped_labels = torch.tensor([label_to_idx[l.item()] for l in labels], device=features.device)
            return F.cross_entropy(logits, mapped_labels)
        
        prototypes = torch.stack(prototypes)
        query_features = torch.cat(query_features_list)
        query_labels = torch.tensor(query_labels_list, device=features.device)
        
        # Calculate distances to prototypes
        distances = torch.cdist(query_features, prototypes, p=2)
        logits = -distances
        
        return F.cross_entropy(logits, query_labels)
    
    def contrastive_loss(self, features: torch.Tensor, labels: torch.Tensor, temperature: float = 0.1) -> torch.Tensor:
        """
        Supervised contrastive loss (SupCon).
        """
        batch_size = features.size(0)
        
        if batch_size < 2:
            return torch.tensor(0.0, device=features.device)
        
        # Normalize features
        features = F.normalize(features, dim=-1)
        
        # Compute similarity matrix
        similarity_matrix = torch.matmul(features, features.T) / temperature
        
        # Create mask for positive pairs (same class)
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(features.device)
        
        # Remove diagonal
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size).view(-1, 1).to(features.device),
            0
        )
        mask = mask * logits_mask
        
        # Compute log_prob
        exp_logits = torch.exp(similarity_matrix) * logits_mask
        log_prob = similarity_matrix - torch.log(exp_logits.sum(1, keepdim=True))
        
        # Compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-6)
        
        loss = -mean_log_prob_pos.mean()
        return loss
    
    def train_with_lora(
        self,
        train_dir: Path,
        n_shot: int = 5,
        learning_rate: float = 1e-4,
        weight_decay: float = 0.01,
        num_epochs: int = 100,
        batch_size: int = 32,
        label_smoothing: float = 0.1,
        loss_type: str = "cross_entropy",
        use_augmentation: bool = True,
        temperature: float = 0.1,
        patience: int = 20
    ) -> Dict:
        """
        Train with LoRA fine-tuning and configurable loss functions.
        """
        # Apply LoRA configuration based on few-shot setting
        self.apply_lora(n_shot)
        
        # Create dataset
        train_samples = []
        class_dirs = sorted([d for d in train_dir.iterdir() if d.is_dir()])
        self.class_names = [d.name for d in class_dirs]
        num_classes = len(self.class_names)
        class_to_idx = {name: idx for idx, name in enumerate(self.class_names)}
        
        print(f"\nPreparing {n_shot}-shot training data...")
        for class_dir in class_dirs:
            class_name = class_dir.name
            class_idx = class_to_idx[class_name]
            image_files = list(class_dir.glob("*.jpg")) + list(class_dir.glob("*.png")) + list(class_dir.glob("*.JPEG"))
            image_paths = sorted(image_files)[:n_shot]
            
            for img_path in image_paths:
                train_samples.append((img_path, class_idx))
        
        print(f"Total training samples: {len(train_samples)}")
        
        # Create dataset with augmentation
        class TempDataset(Dataset):
            def __init__(self, samples, transform, augment):
                self.samples = samples
                self.transform = transform
                self.augment = augment
                
                if self.augment:
                    self.augment_transform = transforms.Compose([
                        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
                        transforms.RandomHorizontalFlip(),
                        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
                        transforms.RandomGrayscale(p=0.1),
                        transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
                    ])
            
            def __len__(self):
                return len(self.samples)
            
            def __getitem__(self, idx):
                img_path, label = self.samples[idx]
                image = Image.open(img_path).convert("RGB")
                
                if self.augment:
                    image = self.augment_transform(image)
                
                if self.transform:
                    image = self.transform(image)
                
                return image, label
        
        train_dataset = TempDataset(train_samples, self.transform, use_augmentation)
        train_loader = DataLoader(
            train_dataset,
            batch_size=min(batch_size, len(train_dataset)),
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )
        
        # Get feature dimension and setup classifier
        with torch.no_grad(), torch.cuda.amp.autocast():
            sample_img = train_dataset[0][0].unsqueeze(0).to(self.device)
            sample_feat = self.lora_model(sample_img)
            feature_tensor = sample_feat[0]
            feature_dim = feature_tensor.shape[1]
            
        
        self.classifier = LinearClassifier(feature_dim, num_classes).to(self.device)
        
        # Setup optimizer
        lora_params = [p for p in self.lora_model.parameters() if p.requires_grad]
        classifier_params = [p for p in self.classifier.parameters() if p.requires_grad]
        
        optimizer = torch.optim.AdamW(
            [
                {'params': lora_params, 'lr': learning_rate},
                {'params': classifier_params, 'lr': learning_rate * 10}
            ],
            weight_decay=weight_decay
        )
        
        # Loss functions
        if loss_type == "cross_entropy":
            criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
        elif loss_type == "prototypical":
            criterion = lambda features, labels: self.prototypical_loss(features, labels, n_support=max(1, n_shot//2))
        elif loss_type == "contrastive":
            criterion = lambda features, labels: self.contrastive_loss(features, labels, temperature)
        else:
            raise ValueError(f"Unknown loss type: {loss_type}")
        
        # Learning rate scheduler
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
        
        # Training history
        history = {
            "train_loss": [],
            "val_accuracy": [],
            "learning_rates": []
        }
        
        best_val_acc = 0.0
        patience_counter = 0
        
        print(f"\nTraining with LoRA and {loss_type} loss...")
        print(f"Config: LR={learning_rate}, WD={weight_decay}, Label Smoothing={label_smoothing}")
        print(f"Augmentation: {use_augmentation}, Temperature: {temperature}")
        
        for epoch in range(num_epochs):
            # Training phase
            self.lora_model.train()
            self.classifier.train()
            train_loss = 0.0
            
            for images, labels in train_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                
                # Forward pass
                with torch.cuda.amp.autocast(): # Consider updating to torch.amp.autocast('cuda')
                    features = self.lora_model(images)[0]
                    # It's better to normalize inside loss functions if needed, 
                    # but keeping it here is fine.
                    # features = F.normalize(features, dim=-1)
                    
                    if loss_type == "cross_entropy":
                        logits = self.classifier(features)
                        loss = criterion(logits, labels)
                    else:
                        # For metric learning losses, use a hybrid approach
                        
                        # 1. Metric Loss on features to structure the embedding space
                        metric_loss = criterion(features, labels)
                        
                        # 2. CrossEntropy Loss on classifier to train the head
                        # THIS MUST BE OUTSIDE a no_grad() block
                        logits = self.classifier(features)
                        ce_loss = F.cross_entropy(logits, labels, label_smoothing=label_smoothing)
                        
                        # 3. Combine the losses
                        loss = metric_loss + ce_loss
                        
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                # Update the train_loss tracking
                train_loss += loss.item()
            
            train_loss /= len(train_loader)
            
            # Validation phase
            self.lora_model.eval()
            self.classifier.eval()
            val_accuracy = self._validate_on_train(train_samples, batch_size)
            
            # Update scheduler
            current_lr = optimizer.param_groups[0]['lr']
            scheduler.step()
            
            # Record history
            history["train_loss"].append(train_loss)
            history["val_accuracy"].append(val_accuracy)
            history["learning_rates"].append(current_lr)
            
            # Early stopping check
            if val_accuracy > best_val_acc:
                best_val_acc = val_accuracy
                patience_counter = 0
                best_lora_state = self.lora_model.state_dict()
                best_classifier_state = self.classifier.state_dict()
            else:
                patience_counter += 1
            
            # Print progress
            if (epoch + 1) % 10 == 0 or epoch == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}] "
                      f"Loss: {train_loss:.4f} | "
                      f"Val Acc: {val_accuracy:.2f}% | "
                      f"LR: {current_lr:.6f}")
            
            # Early stopping
            if patience_counter >= patience:
                print(f"\nEarly stopping at epoch {epoch+1}")
                break
        
        # Load best model
        self.lora_model.load_state_dict(best_lora_state)
        self.classifier.load_state_dict(best_classifier_state)
        print(f"\nTraining complete! Best validation accuracy: {best_val_acc:.2f}%")
        
        return history
    
    def _validate_on_train(self, train_samples: List, batch_size: int) -> float:
        """Validate on training data (without augmentation)."""
        class TempValDataset(Dataset):
            def __init__(self, samples, transform):
                self.samples = samples
                self.transform = transform
            
            def __len__(self):
                return len(self.samples)
            
            def __getitem__(self, idx):
                img_path, label = self.samples[idx]
                image = Image.open(img_path).convert("RGB")
                if self.transform:
                    image = self.transform(image)
                return image, label
        
        val_dataset = TempValDataset(train_samples, self.transform)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                
                with torch.cuda.amp.autocast():
                    features = self.lora_model(images)[0]
                    logits = self.classifier(features)
                    preds = logits.argmax(dim=1)
                
                correct += (preds == labels).sum().item()
                total += len(labels)
        
        return (correct / total) * 100
    
    def evaluate(self, test_dir: Path, batch_size: int = 32) -> Dict[str, float]:
        """Evaluate the fine-tuned model on test set."""
        if self.lora_model is None or self.classifier is None:
            raise ValueError("Model not trained. Call train_with_lora first.")
        
        self.lora_model.eval()
        self.classifier.eval()
        
        # Create test dataset
        test_dataset = FewShotDataset(test_dir, transform=self.transform, augment=False)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
        
        all_preds = []
        all_labels = []
        
        print("\nEvaluating on test set...")
        with torch.no_grad():
            for images, labels in tqdm(test_loader, desc="Testing"):
                images = images.to(self.device)
                
                with torch.cuda.amp.autocast():
                    features = self.lora_model(images)[0]
                    logits = self.classifier(features)
                    preds = logits.argmax(dim=1)
                
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.numpy())
        
        all_preds = np.array(all_preds)
        all_labels = np.array(all_labels)
        
        # Calculate metrics
        overall_accuracy = (all_preds == all_labels).mean() * 100
        
        # Per-class accuracy
        per_class_accuracy = {}
        for class_idx, class_name in enumerate(self.class_names):
            class_mask = all_labels == class_idx
            if class_mask.sum() > 0:
                class_acc = (all_preds[class_mask] == class_idx).mean() * 100
                per_class_accuracy[class_name] = class_acc
        
        mean_per_class_accuracy = np.mean(list(per_class_accuracy.values()))
        
        results = {
            # Convert NumPy floats to Python floats
            "overall_accuracy": float(overall_accuracy),
            "mean_per_class_accuracy": float(mean_per_class_accuracy),
            
            "total_samples": len(all_labels), # This is already a Python int
            
            # Convert the numpy.int64 to a Python int
            "correct_predictions": int((all_preds == all_labels).sum()),
            
            "per_class_accuracy": per_class_accuracy
        }
        
        return results


def compare_strategies():
    """Compare different fine-tuning strategies."""
    DATASET_DIR = Path("/kaggle/working/mini_imagenet_fewshot_renamed")
    
    strategies = [
        # Few-shot scenarios with different loss functions
        # {"n_shot": 1, "loss_type": "prototypical", "use_augmentation": False},
        # {"n_shot": 1, "loss_type": "contrastive", "use_augmentation": False},
        # {"n_shot": 1, "loss_type": "cross_entropy", "use_augmentation": False},
        
        # {"n_shot": 5, "loss_type": "contrastive", "use_augmentation": True},
        # {"n_shot": 5, "loss_type": "cross_entropy", "use_augmentation": True},
        
        {"n_shot": 10, "loss_type": "contrastive", "use_augmentation": False},
        # {"n_shot": 20, "loss_type": "cross_entropy", "use_augmentation": True},
    ]
    
    all_results = {}
    
    for strategy in strategies:
        n_shot = strategy["n_shot"]
        loss_type = strategy["loss_type"]
        use_aug = strategy["use_augmentation"]
        
        print(f"\n{'='*70}")
        print(f"Running {n_shot}-shot with {loss_type} loss (Augmentation: {use_aug})")
        print(f"{'='*70}")
        
        finetuner = LoRACoCaFinetune()
        
        results = finetuner.run_lora_experiment(
            dataset_dir=DATASET_DIR,
            n_shot=n_shot,
            learning_rate=5e-5,
            weight_decay=0.01,
            num_epochs=100,
            label_smoothing=0.1,
            loss_type=loss_type,
            use_augmentation=use_aug,
            temperature=0.1,
            save_results=True
        )
        
        key = f"{n_shot}shot_{loss_type}_aug{use_aug}"
        all_results[key] = results
    
    # Print summary
    print("\n" + "="*80)
    print("STRATEGY COMPARISON SUMMARY")
    print("="*80)
    print(f"{'Config':<35} {'Overall Acc':<12} {'Mean Per-Class':<15}")
    print("-"*80)
    for key, res in all_results.items():
        print(f"{key:<35} {res['overall_accuracy']:<12.2f} {res['mean_per_class_accuracy']:<15.2f}")
    print("="*80)
    
    return all_results

results = compare_strategies()


# def run_single_experiment():
#     """Run a single experiment with specific configuration."""
#     DATASET_DIR = Path("/kaggle/working/mini_imagenet_fewshot_renamed")
    
#     # Configuration
#     config = {
#         "n_shot": 5,
#         "learning_rate": 1e-4,
#         "weight_decay": 0.01,
#         "num_epochs": 100,
#         "batch_size": 32,
#         "label_smoothing": 0.1,
#         "loss_type": "cross_entropy",  # Options: "cross_entropy", "prototypical", "contrastive", "auto"
#         "use_augmentation": True,
#         "temperature": 0.1,  # For contrastive loss
#         "save_results": True
#     }
    
#     print("="*70)
#     print("Running LoRA Fine-tuning Experiment")
#     print("="*70)
#     print(f"Configuration:")
#     for key, value in config.items():
#         print(f"  {key}: {value}")
#     print("="*70)
    
#     # Initialize model
#     finetuner = LoRACoCaFinetune(
#         model_name="coca_ViT-L-14",
#         pretrained="mscoco_finetuned_laion2B-s13B-b90k",
#     )
    
#     # Run experiment
#     results = finetuner.run_lora_experiment(
#         dataset_dir=DATASET_DIR,
#         **config
#     )
    
#     return results


# def main():
#     """Main function with options."""
#     import argparse
    
#     parser = argparse.ArgumentParser(description='LoRA Fine-tuning for CoCa')
#     parser.add_argument('--mode', type=str, default='single', 
#                         choices=['single', 'compare'],
#                         help='Run single experiment or compare strategies')
#     parser.add_argument('--dataset_dir', type=str, 
#                         default='/kaggle/working/mini_imagenet_fewshot_renamed',
#                         help='Path to dataset directory')
#     parser.add_argument('--n_shot', type=int, default=5,
#                         help='Number of shots per class')
#     parser.add_argument('--loss_type', type=str, default='cross_entropy',
#                         choices=['cross_entropy', 'prototypical', 'contrastive', 'auto'],
#                         help='Loss function to use')
#     parser.add_argument('--lr', type=float, default=1e-4,
#                         help='Learning rate')
#     parser.add_argument('--epochs', type=int, default=100,
#                         help='Number of training epochs')
#     parser.add_argument('--batch_size', type=int, default=32,
#                         help='Batch size')
#     parser.add_argument('--no_augmentation', action='store_true',
#                         help='Disable data augmentation')
#     parser.add_argument('--temperature', type=float, default=0.1,
#                         help='Temperature for contrastive loss')
    
#     args = parser.parse_args()
    
#     if args.mode == 'compare':
#         # Run comparison of multiple strategies
#         results = compare_strategies()
#     else:
#         # Run single experiment
#         DATASET_DIR = Path(args.dataset_dir)
        
#         config = {
#             "n_shot": args.n_shot,
#             "learning_rate": args.lr,
#             "weight_decay": 0.01,
#             "num_epochs": args.epochs,
#             "batch_size": args.batch_size,
#             "label_smoothing": 0.1,
#             "loss_type": args.loss_type,
#             "use_augmentation": not args.no_augmentation,
#             "temperature": args.temperature,
#             "save_results": True
#         }
        
#         print("="*70)
#         print("Running LoRA Fine-tuning Experiment")
#         print("="*70)
#         print(f"Configuration:")
#         for key, value in config.items():
#             print(f"  {key}: {value}")
#         print("="*70)
        
#         finetuner = LoRACoCaFinetune()
#         results = finetuner.run_lora_experiment(dataset_dir=DATASET_DIR, **config)
    
#     print("\n" + "="*70)
#     print("Experiment Complete!")
#     print("="*70)


# if __name__ == "__main__":
#     # For Kaggle/Jupyter notebook, use these simple functions:
    
#     # Option 1: Run a single experiment with default settings
#     # results = run_single_experiment()
    
#     # Option 2: Run comparison of multiple strategies
#     results = compare_strategies()
    
#     # Option 3: For command line usage
#     # main()


Running 10-shot with contrastive loss (Augmentation: False)
Loading CoCa model on cuda...
Using contrastive loss for 10-shot learning
LoRA Config: attn.out_proj all layers, r=4, droupout=0.15
trainable params: 202,752 || all params: 306,927,616 || trainable%: 0.0661

Preparing 10-shot training data...
Total training samples: 1000

Training with LoRA and contrastive loss...
Config: LR=5e-05, WD=0.01, Label Smoothing=0.1
Augmentation: False, Temperature: 0.1


  with torch.no_grad(), torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast(): # Consider updating to torch.amp.autocast('cuda')
  with torch.cuda.amp.autocast():


Epoch [1/100] Loss: 4.6193 | Val Acc: 61.30% | LR: 0.000050
Epoch [10/100] Loss: 1.3777 | Val Acc: 99.40% | LR: 0.000049
Epoch [20/100] Loss: 1.2056 | Val Acc: 99.90% | LR: 0.000046
Epoch [30/100] Loss: 1.1277 | Val Acc: 100.00% | LR: 0.000040
Epoch [40/100] Loss: 1.1269 | Val Acc: 100.00% | LR: 0.000033

Early stopping at epoch 44

Training complete! Best validation accuracy: 100.00%

Evaluating on test set...


  with torch.cuda.amp.autocast():
Testing: 100%|██████████| 63/63 [01:00<00:00,  1.04it/s]


LoRA Fine-tuning Results (10-shot, contrastive loss)
Overall Accuracy: 91.95%
Mean Per-Class Accuracy: 91.95%
Total Samples: 2000
Correct Predictions: 1839
Results saved to lora_coca_10shot_contrastive_augFalse.json

STRATEGY COMPARISON SUMMARY
Config                              Overall Acc  Mean Per-Class 
--------------------------------------------------------------------------------
10shot_contrastive_augFalse         91.95        91.95          



