In [1]:
# libraries

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from glob import glob

import torch
import torch.nn as nn
import torch.optim as optim
from transformers import CLIPProcessor, CLIPModel, AutoTokenizer

from sklearn.metrics import confusion_matrix, accuracy_score, classification_report

import seaborn as sns

from tqdm import tqdm
import random
from sklearn.linear_model import LogisticRegression

import torch
import torch.nn as nn
import clip


## 1 - Dataset and Selection of a Small Subset for Few-Shot Learning

### Selecting a Small Subset
For fine-tuning, we will use a small, balanced subset of the EuroSAT dataset. Instead of using the full dataset, we will randomly sample 10 images per class to simulate a **few-shot learning** scenario. The goal is to test whether training on a small number of images can improve classification performance.



In [2]:
DATASET_PATH = "2750/"

classes = sorted([cls for cls in os.listdir(DATASET_PATH) if not cls.startswith('.')])

print(f"Classes found: {classes}")

Classes found: ['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial', 'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake']


In [3]:
CLASSNAME_DICT = {'AnnualCrop': 'Annual Crop', 
                    'Forest': 'Forest', 
                    'HerbaceousVegetation': 'Herbaceous Vegetation', 
                    'Highway': 'Highway', 
                    'Industrial': 'Industrial', 
                    'Pasture': 'Pasture', 
                    'PermanentCrop': 'Permanent Crop', 
                    'Residential': 'Residential', 'River': 'River', 
                    'SeaLake': 'Sea or Lake'}
modified_classes = [CLASSNAME_DICT[c] for c in classes]
modified_classes


['Annual Crop',
 'Forest',
 'Herbaceous Vegetation',
 'Highway',
 'Industrial',
 'Pasture',
 'Permanent Crop',
 'Residential',
 'River',
 'Sea or Lake']

In [4]:
DATASET_PATH = "2750/"

classes = sorted([cls for cls in os.listdir(DATASET_PATH) if not cls.startswith('.')])

print(f"Classes found: {classes}")

Classes found: ['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial', 'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake']


In [5]:
def sample_few_shot_data(dataset_path, classes, num_samples=10):
    """
    Randomly sample a small, balanced dataset for few-shot learning.
    
    Parameters:
        dataset_path (str): Path to the dataset directory.
        classes (list): List of class names.
        num_samples (int): Number of images per class to sample.
    
    Returns:
        sampled_data (list): List of (image_path, class_name) tuples.
    """
    sampled_data = []
    
    for cls in classes:
        image_paths = glob(os.path.join(dataset_path, cls, "*.jpg"))
        sampled_images = random.sample(image_paths, num_samples)
        sampled_data.extend([(img, cls) for img in sampled_images])
    
    return sampled_data

In [6]:
few_shot_data = sample_few_shot_data(DATASET_PATH, classes, num_samples=10)

few_shot_df = pd.DataFrame(few_shot_data, columns=["Image Path", "Class"])
few_shot_df

Unnamed: 0,Image Path,Class
0,2750/AnnualCrop\AnnualCrop_292.jpg,AnnualCrop
1,2750/AnnualCrop\AnnualCrop_120.jpg,AnnualCrop
2,2750/AnnualCrop\AnnualCrop_2503.jpg,AnnualCrop
3,2750/AnnualCrop\AnnualCrop_2383.jpg,AnnualCrop
4,2750/AnnualCrop\AnnualCrop_487.jpg,AnnualCrop
...,...,...
95,2750/SeaLake\SeaLake_1594.jpg,SeaLake
96,2750/SeaLake\SeaLake_2535.jpg,SeaLake
97,2750/SeaLake\SeaLake_5.jpg,SeaLake
98,2750/SeaLake\SeaLake_2793.jpg,SeaLake


In [7]:
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model.to(device)

for param in clip_model.parameters():
    param.requires_grad = False

print("CLIP model loaded and backbone frozen.")

CLIP model loaded and backbone frozen.


In [8]:
clip_model

CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPSdpaAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e

In [26]:
class SimplePromptLearner(nn.Module):
    def __init__(self, clip_model, classnames, n_ctx=8, ctx_init=None):
        super().__init__()

        self.n_cls = len(classnames)
        self.n_ctx = n_ctx
        self.ctx_dim = clip_model.text_model.final_layer_norm.weight.shape[0]
        dtype = clip_model.dtype
        device = clip_model.device  # Ensure we use the same device as the model

        # Initialize context vectors
        if ctx_init:
            # Initialize with provided context string
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init).to(device)  # Ensure tokenization is on the same device
            with torch.no_grad():
                embedding = clip_model.text_model.embeddings.token_embedding(prompt).type(dtype).to(device)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            print("Initializing class-specific contexts")
            ctx_vectors = torch.empty(self.n_cls, n_ctx, self.ctx_dim, dtype=dtype, device=device)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")

        self.ctx = nn.Parameter(ctx_vectors)  # To be optimized

        # Prepare the prompts
        classnames = [name.replace("_", " ") for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        print(prompts)
        # Tokenize and get embeddings for class names
        tokenized_prompts = tokenizer(prompts, padding=True, return_tensors="pt").to(device)  # Move to the same device
        
        with torch.no_grad():
            input_ids = tokenized_prompts["input_ids"].to(device)  # Ensure this is on the same device
            embedding = clip_model.text_model.embeddings.token_embedding(input_ids).type(dtype).to(device)

        # Store prefix and suffix
        self.register_buffer("token_prefix", embedding[:, :1, :])  # Start of sequence token
        self.register_buffer("token_suffix", embedding[:, 1:, :])  # Class names and rest

    def forward(self):
        # The context vectors
        ctx = self.ctx

        # Expand context to be per-class
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        # Concatenate the prefix, context, and suffix
        prompts = torch.cat([self.token_prefix, ctx, self.token_suffix], dim=1)
        return prompts


In [27]:
def get_text_features(prompt_learner, clip_model):
    prompts = prompt_learner()  # Forward pass through prompt learner
    text_features = clip_model.text_model(prompts).last_hidden_state  # Extract features
    return text_features


In [28]:
def get_class_indices(class_list, class_names):
    """ Convert class names to numerical labels. """
    return torch.tensor([class_names.index(cls) for cls in class_list], dtype=torch.long)

In [29]:
labels = get_class_indices(few_shot_df["Class"].tolist(), classes)
labels = labels.to(device)

In [30]:
images = [Image.open(img_path).convert("RGB") for img_path in few_shot_df["Image Path"].tolist()]
inputs = clip_processor(images=images, return_tensors="pt", padding=True).to(device)

prompt_learner = SimplePromptLearner(clip_model, modified_classes, n_ctx=8, ctx_init=None)

optimizer = optim.Adam(prompt_learner.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

losses = []

num_epochs=100
for epoch in range(num_epochs):
    optimizer.zero_grad()
    
    # Get updated text features using learned prompts
    text_features = get_text_features(prompt_learner, clip_model)
    
    # Get image features
    image_features = clip_model.encode_image(images)

    # Normalize features
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)

    # Compute cosine similarity
    logits = image_features @ text_features.T
    loss = criterion(logits, labels)  # Some classification loss like CrossEntropyLoss

    loss.backward()
    optimizer.step()

    losses.append(loss.item())

    if epoch % 100 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {loss.item():.4f}")



Initializing class-specific contexts
Initial context: "X X X X X X X X"
Number of context words (tokens): 8
['X X X X X X X X Annual Crop.', 'X X X X X X X X Forest.', 'X X X X X X X X Herbaceous Vegetation.', 'X X X X X X X X Highway.', 'X X X X X X X X Industrial.', 'X X X X X X X X Pasture.', 'X X X X X X X X Permanent Crop.', 'X X X X X X X X Residential.', 'X X X X X X X X River.', 'X X X X X X X X Sea or Lake.']


ValueError: Sequence length must be less than max_position_embeddings (got `sequence length`: 512 and max_position_embeddings: 77