In [25]:
# libraries

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from glob import glob

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from transformers import CLIPProcessor, CLIPModel, AutoTokenizer

from sklearn.metrics import confusion_matrix, accuracy_score, classification_report

import csv
import seaborn as sns

from tqdm import tqdm
import random

import torch
import torch.nn as nn
import clip

import config


## 1 - Dataset and Selection of a Small Subset for Few-Shot Learning

### Selecting a Small Subset
For fine-tuning, we will use a small, balanced subset of the EuroSAT dataset. Instead of using the full dataset, we will randomly sample 10 images per class to simulate a **few-shot learning** scenario. The goal is to test whether training on a small number of images can improve classification performance.



In [66]:
classes = config.CLASSES
print("Classes found: ", classes)

Classes found:  ['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial', 'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake']


In [4]:
def get_class_indices(class_list, class_names):
    """ Convert class names to numerical labels. """
    return torch.tensor([class_names.index(cls) for cls in class_list], dtype=torch.long)

In [18]:
device = (
    torch.device("mps") if torch.backends.mps.is_available() else
    torch.device("cuda") if torch.cuda.is_available() else
    torch.device("cpu")
)
print("Device: ", device)

Device:  mps


In [44]:
class FewShotDataset(Dataset):
    def __init__(self, dataframe, image_dir):
        self.dataframe = dataframe
        self.image_dir = image_dir

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_path = self.dataframe.iloc[idx]["Image Path"]
        label = self.dataframe.iloc[idx]["Class"]
        label_tensor = torch.tensor([label], dtype=torch.long).to(device)
        
        return img_path, label_tensor


In [19]:
def sample_few_shot_data(dataset_path, classes, num_samples):
    sampled_data = []
    for cls in classes:
        image_paths = glob(os.path.join(dataset_path, cls, "*.jpg"))
        sampled_images = random.sample(image_paths, num_samples)
        sampled_data.extend([(img, cls) for img in sampled_images])
    return sampled_data

In [46]:
# Initialize the dataset and dataloader
test_df = pd.read_csv("test_data/test_set.csv")

test_dataset = FewShotDataset(dataframe=test_df, image_dir="2750")
test_dataloader = DataLoader(test_dataset, batch_size=60)  # Adjust batch_size as needed


In [6]:
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

clip_model.to(device)

for param in clip_model.parameters():
    param.requires_grad = False

print("CLIP model loaded and backbone frozen.")

CLIP model loaded and backbone frozen.


In [7]:
class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.text_model
        self.encoder = clip_model.text_model.encoder
        self.positional_embedding = clip_model.text_model.embeddings.position_embedding
        self.ln_final = clip_model.text_model.final_layer_norm
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts):
        seq_length = prompts.size(1)  # Sequence length (number of tokens)
        positions = torch.arange(seq_length, device=prompts.device).unsqueeze(0)  # Shape [1, seq_length]
        
        # Get the positional embeddings for each position
        pos_embeddings = self.positional_embedding(positions)  # Shape [1, seq_length, embedding_dim]
        
        x = prompts + pos_embeddings
        x = x.reshape(x.size(0), x.size(1), -1)  # NLD -> LND
        encoder_outputs = self.encoder(x)

        last_hidden_state = encoder_outputs[0]
        pooled_output = last_hidden_state[:, 0, :]
        pooled_output = self.ln_final(pooled_output)

        text_features = self.text_projection(pooled_output)

        return text_features


In [37]:
class SimplePromptLearner(nn.Module):
    def __init__(self, clip_model, classnames, n_ctx=8, ctx_init=None):
        super().__init__()

        self.n_cls = len(classnames)
        self.n_ctx = n_ctx
        self.ctx_dim = clip_model.text_model.final_layer_norm.weight.shape[0]
        dtype = clip_model.dtype

        # Initialize context vectors
        if ctx_init:
            # Initialize with provided context string
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            # print("Initializing class-specific contexts")
            ctx_vectors = torch.empty(self.n_cls, n_ctx, self.ctx_dim, dtype=dtype).to(device)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        # print(f'Initial context: "{prompt_prefix}"')
        # print(f"Number of context words (tokens): {n_ctx}")

        self.ctx = nn.Parameter(ctx_vectors, requires_grad=True).to(device)  # to be optimized

        # Prepare the prompts
        classnames = [name.replace("_", " ") for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        # Tokenize and get embeddings for class names
        tokenized_prompts = tokenizer(prompts, padding=True, return_tensors="pt", truncation=True, max_length=77).to(device)
        
        with torch.no_grad():
            input_ids = tokenized_prompts["input_ids"]
            embedding = clip_model.text_model.embeddings.token_embedding(input_ids).type(dtype)

        # Store prefix and suffix
        self.register_buffer("token_prefix", embedding[:, :1, :])  # Start of sequence token
        self.register_buffer("token_suffix", embedding[:, 1+n_ctx:, :])  # Class names and rest

        self.tokenized_prompts = tokenized_prompts

    def forward(self):
        # The context vectors
        ctx = self.ctx

        # Concatenate the prefix, context, and suffix
        prompts = torch.cat([self.token_prefix, ctx, self.token_suffix], dim=1)

        return prompts



In [9]:
import gc
torch.cuda.empty_cache()
gc.collect()


0

In [10]:
validation_df = pd.read_csv("validation_data/validation_set.csv")

validation_labels = get_class_indices(validation_df["Class"].tolist(), classes).to(device)

validation_images = [Image.open(img_path).convert("RGB") for img_path in validation_df["Image Path"].tolist()]
validation_inputs = clip_processor(images=validation_images, return_tensors="pt", padding=True).to(device)

In [39]:
num_epochs=500

save_dir = "saved_prompt_learners"

losses = {}
validation_losses = {}
validation_accuracies = {}

for few_shot in [1,2,4,8,16]:
    few_shot_df = pd.read_csv(f"few_shot_data/few_shot_{few_shot}.csv")
    few_shot_images = [Image.open(img_path).convert("RGB") for img_path in few_shot_df["Image Path"].tolist()]
    inputs = clip_processor(images=few_shot_images, return_tensors="pt", padding=True).to(device)

    labels = get_class_indices(few_shot_df["Class"].tolist(), classes)
    labels = labels.to(device)

    prompt_learner = SimplePromptLearner(clip_model, modified_classes, n_ctx=16, ctx_init=None).to(device)
    text_encoder = TextEncoder(clip_model=clip_model).to(device)

    optimizer = optim.Adam(prompt_learner.parameters(), lr=0.01)

    criterion = nn.CrossEntropyLoss()

    losses[few_shot] = []
    validation_losses[few_shot] = []
    validation_accuracies[few_shot] = []


    for epoch in range(num_epochs+1):
        optimizer.zero_grad()
        
        prompts = prompt_learner()
        
        # Get updated text features using learned prompts, and image features
        text_features = text_encoder(prompts)
        image_features = clip_model.get_image_features(**inputs)

        # Normalize features
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        # Compute cosine similarity
        logit_scale = clip_model.logit_scale.exp()
        logits = image_features @ text_features.T
        logits *= logit_scale
        
        loss = criterion(logits, labels)

        loss.backward()
        optimizer.step()

        losses[few_shot].append(loss.item())

        if epoch % 50 == 0:
            with torch.no_grad():
                prompts = prompt_learner()

                text_features = text_encoder(prompts)
                image_features = clip_model.get_image_features(**validation_inputs)

                # Normalize features
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                # Compute cosine similarity
                logit_scale = clip_model.logit_scale.exp()
                val_logits = image_features @ text_features.T
                val_logits *= logit_scale

                val_loss = criterion(val_logits, validation_labels)
                
                predictions = torch.argmax(val_logits, dim=1)
            
            accuracy = (predictions == validation_labels).float().mean().item()
            
            validation_accuracies[few_shot].append(accuracy)
            validation_losses[few_shot].append((epoch, val_loss.item()))
            
            if epoch % 250 == 0:    
                print(f"{few_shot}-shot Epoch [{epoch}/{num_epochs}]:") 
                print(f"\tTraining Loss: {loss.item():.4f}")
                print(f"\tValidation Loss: {val_loss.item():.4f}")
                print(f"\tValidation Accuracy: {accuracy * 100:.2f}%\n")

    save_path = os.path.join(save_dir, f"prompt_learner_{few_shot}_shot.pth")
    torch.save(prompt_learner.state_dict(), save_path)
    print(f"Prompt learner for {few_shot}-shot saved to {save_path}")

1-shot Epoch [0/500]:
	Training Loss: 4.7444
	Validation Loss: 4.7915
	Validation Accuracy: 11.20%

1-shot Epoch [250/500]:
	Training Loss: 0.0007
	Validation Loss: 2.4287
	Validation Accuracy: 44.00%

1-shot Epoch [500/500]:
	Training Loss: 0.0003
	Validation Loss: 2.4757
	Validation Accuracy: 45.20%

Prompt learner for 1-shot saved to saved_prompt_learners/prompt_learner_1_shot.pth
2-shot Epoch [0/500]:
	Training Loss: 3.0558
	Validation Loss: 4.8715
	Validation Accuracy: 20.20%

2-shot Epoch [250/500]:
	Training Loss: 0.0004
	Validation Loss: 1.2689
	Validation Accuracy: 67.20%

2-shot Epoch [500/500]:
	Training Loss: 0.0002
	Validation Loss: 1.2730
	Validation Accuracy: 67.40%

Prompt learner for 2-shot saved to saved_prompt_learners/prompt_learner_2_shot.pth
4-shot Epoch [0/500]:
	Training Loss: 3.9506
	Validation Loss: 3.7038
	Validation Accuracy: 14.40%

4-shot Epoch [250/500]:
	Training Loss: 0.0012
	Validation Loss: 0.9381
	Validation Accuracy: 74.80%

4-shot Epoch [500/500]:


In [41]:
import gc
torch.cuda.empty_cache()
torch.mps.empty_cache()
gc.collect()


18

In [47]:
test_accuracies = []
for few_shot in [1,2,4,8,16]:
    load_path = os.path.join(save_dir, f"prompt_learner_{few_shot}_shot.pth")
    loaded_prompt_learner = SimplePromptLearner(clip_model, MO, n_ctx=16, ctx_init=None).to(device)
    loaded_prompt_learner.load_state_dict(torch.load(load_path, map_location=device))

    prompts = loaded_prompt_learner()
    text_encoder = TextEncoder(clip_model=clip_model).to(device)

    batch_accuracies = []
    for image_paths, labels in tqdm(test_dataloader, leave=False, desc=f"Testing {few_shot}-shot classifier"):
        with torch.no_grad():
            test_images = [Image.open(img_path).convert("RGB") for img_path in image_paths]
            test_inputs = clip_processor(images=test_images, return_tensors="pt", padding=True).to(device)


            text_features = text_encoder(prompts)
            image_features = clip_model.get_image_features(**test_inputs)
            
            logit_scale = clip_model.logit_scale.exp()
            test_logits = image_features @ text_features.T
            test_logits *= logit_scale
            
            predictions = torch.argmax(test_logits, dim=1)
            accuracy = (predictions == labels).float().mean().item()
        batch_accuracies.append(accuracy)
    print(f"Test Accuracy on {few_shot}-shot CoOP CLIP: {np.mean(batch_accuracies) * 100:.2f}%")



                                                                          

Test Accuracy on 1-shot CoOP CLIP: 51.97%


                                                                          

Test Accuracy on 2-shot CoOP CLIP: 61.67%


                                                                          

Test Accuracy on 4-shot CoOP CLIP: 71.97%


                                                                          

Test Accuracy on 8-shot CoOP CLIP: 81.03%


                                                                           

Test Accuracy on 16-shot CoOP CLIP: 83.30%




In [53]:
clip_model.text_model.embeddings.token_embedding.weight.shape

torch.Size([49408, 512])

In [65]:
import torch
import clip
from torch.nn.functional import cosine_similarity

def decode_prompt(prompt_learner, clip_model, tokenizer):
    # Get the optimized context embeddings
    ctx_vectors = prompt_learner.ctx  # Shape: (num_classes, n_ctx, dim)
    prompt = prompt_learner()
    print(prompt.shape)
    
    # Get all possible token embeddings from CLIP's vocabulary
    # vocab_size = clip_model.token_embedding.weight.shape[0]  # Total vocab size
    token_embeddings = clip_model.text_model.embeddings.token_embedding.weight  # Shape: (vocab_size, dim)

    # Find the closest tokens for each context vector
    decoded_prompts = []
    for class_idx in range(prompt.shape[0]):  # Iterate over classes
        class_prompt = []
        for ctx_token in prompt[class_idx]:  # Iterate over context words
            similarities = cosine_similarity(ctx_token.unsqueeze(0), token_embeddings)  # Cosine similarity with vocab
            closest_token_id = similarities.argmax().item()  # Get most similar token
            closest_token = tokenizer.convert_ids_to_tokens([closest_token_id])[0]  # Convert to text
            class_prompt.append(closest_token)

        decoded_prompts.append(" ".join(class_prompt))  # Reconstruct approximate text prompt

    return decoded_prompts

# Example Usage
load_path = os.path.join(save_dir, f"prompt_learner_16_shot.pth")
loaded_prompt_learner = SimplePromptLearner(clip_model, modified_classes, n_ctx=16, ctx_init=None).to(device)
loaded_prompt_learner.load_state_dict(torch.load(load_path, map_location=device))
decoded_prompts = decode_prompt(loaded_prompt_learner, clip_model, tokenizer)
for class_prompt in decoded_prompts:
    print(class_prompt)


torch.Size([10, 22, 512])
<|startoftext|> within</w> syndrome</w> cos fooled</w> hot</w> ludo fathers</w> graci vy</w> louder</w> minute azar</w> poche rau havent</w> sbball</w> annual</w> crop</w> .</w> <|endoftext|> <|endoftext|>
<|startoftext|> thenight</w> deals</w> famili le increased</w> bet</w> auditions</w> openings</w> monetary</w> blanco</w> damon uka</w> demon</w> fallen</w> yelled</w> iom</w> forest</w> .</w> <|endoftext|> <|endoftext|> <|endoftext|>
<|startoftext|> fields</w> prince ðŁĩ®ðŁĩ³ asen welcoming</w> opar pist tyran âĻ¡</w> ðŁĹ£</w> sexually</w> piles</w> output</w> atism</w> invite</w> shiva herb aceous</w> vegetation</w> .</w> <|endoftext|>
<|startoftext|> gate</w> blazing</w> ¨ babe</w> ancing</w> complex</w> âĿ¤ï¸ı enough</w> liberals</w> dra gosling</w> mean</w> low</w> langley</w> cept</w> clearly</w> highway</w> .</w> <|endoftext|> <|endoftext|> <|endoftext|>
<|startoftext|> strata</w> ag</w> progress</w> passes</w> exploit</w> atlantis</w> characteristics

In [64]:
load_path = os.path.join(save_dir, f"prompt_learner_{few_shot}_shot.pth")
loaded_prompt_learner = SimplePromptLearner(clip_model, modified_classes, n_ctx=16, ctx_init=None).to(device)
loaded_prompt_learner.load_state_dict(torch.load(load_path, map_location=device))



with torch.no_grad():
    prompts = loaded_prompt_learner()
    text_encoder = TextEncoder(clip_model=clip_model).to(device)
    test_images = [Image.open("2750/HerbaceousVegetation/HerbaceousVegetation_2.jpg").convert("RGB")]
    test_inputs = clip_processor(images=test_images, return_tensors="pt", padding=True).to(device)


    text_features = text_encoder(prompts)
    image_features = clip_model.get_image_features(**test_inputs)
    
    logit_scale = clip_model.logit_scale.exp()
    test_logits = image_features @ text_features.T
    # test_logits *= logit_scale
    print(test_logits)
    print(torch.softmax(test_logits,dim=-1))
    
    predictions = torch.argmax(test_logits, dim=1)
    print(predictions)


tensor([[ 9.5461,  4.6747, 13.0687,  7.8658,  9.7408,  8.1585, 10.2043,  5.1856,
          9.5310,  8.7160]], device='mps:0')
tensor([[2.5066e-02, 1.9206e-04, 8.4902e-01, 4.6703e-03, 3.0453e-02, 6.2581e-03,
         4.8407e-02, 3.2013e-04, 2.4689e-02, 1.0929e-02]], device='mps:0')
tensor([2], device='mps:0')
