In [None]:
# !pip install git+https://github.com/openai/CLIP.git
# !pip install google-generativeai
# !pip install nltk

In [None]:
# !wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
# !unzip tiny-imagenet-200.zip

In [39]:
import os
import torch
import clip
import time
import random
import google.generativeai as genai
import nltk
from nltk.corpus import wordnet as wn
from PIL import Image, ImageDraw
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from tqdm import tqdm
# nltk.download('wordnet')


In [63]:
def generate_concept_list(class_names):
    
    concept_set = set()
    model = genai.GenerativeModel('gemini-2.0-flash')  
    count = 0
    
    for cname in tqdm(class_names, desc="Generating concepts"):
        prompts = [
            f"List exactly 5-10 of the most important features for recognizing something as a {cname}. Format as a simple list with one feature per line, no bullets, numbering, or explanations.",
            f"List exactly 5-10 things most commonly seen around a {cname}. Format as a simple list with one item per line, no bullets, numbering, or explanations.",
            f"List exactly 3-5 superclasses or categories for the word {cname}. Format as a simple list with one category per line, no bullets, numbering, or explanations."
        ]
        
        for prompt_id, prompt in enumerate(prompts):
            max_retries = 5
            retry_count = 0
            retry_delay = 2  
            
            print("prompt id: ", prompt_id)
            while retry_count <= max_retries:
                try:
                    response = model.generate_content(
                        prompt,
                        generation_config={
                            'temperature': 0.2,  # Lower temperature for more consistent formatting
                            'top_p': 0.95,
                            'top_k': 40,
                            'max_output_tokens': 300,
                        }
                    )
                    content = response.text
                    lines = content.split('\n')
                    
                    for line in lines:
                        # Skip empty lines and header/instructional text
                        if not line.strip() or "list" in line.lower() or "feature" in line.lower() or "category" in line.lower():
                            continue
                            
                        cleaned = line.strip(" .•-*0123456789:()[]{}\"\',").lower()
                        print("---")
                        print(cleaned)
                        print("-.-")

                        if cleaned and len(cleaned) > 2 and cname.lower() not in cleaned:
                            concept_set.add(cleaned)
                    
                    break
                    
                except Exception as e:
                    error_msg = str(e)
                    print(f"Gemini API error: {error_msg}")
                    
                    # Check if it's a rate limit error (429)
                    if "429" in error_msg:
                        # Extract retry delay from error if available
                        import re
                        delay_match = re.search(r'retry_delay \{\s*seconds: (\d+)', error_msg)
                        
                        if delay_match:
                            # Use the suggested delay from the API
                            retry_seconds = int(delay_match.group(1))
                            retry_seconds += random.uniform(0, 2)  # Add small random jitter
                        else:
                            # Exponential backoff with jitter
                            retry_seconds = retry_delay + random.uniform(0, retry_delay * 0.1)
                            retry_delay *= 2  # Double the delay for next retry
                            
                        print(f"Rate limited. Retrying in {retry_seconds:.1f} seconds...")
                        time.sleep(retry_seconds)
                        retry_count += 1
                    else:
                        # For non-rate-limit errors, just print and continue
                        print(f"Error (not retrying): {error_msg}")
                        break
            
            # Add a small delay between successful requests to avoid hitting rate limits
            time.sleep(1)
    
    # Return concepts that are between 2 and 5 words long
    return sorted([c for c in concept_set if 2 <= len(c.split()) <= 5])

# Decode WNIDs to human-readable class names
def decode_wnid(wnid):
    synset = wn.synset_from_pos_and_offset(wnid[0], int(wnid[1:]))
    return synset.name().split('.')[0].replace('_', ' ')

def draw_red_circle(image, center, radius):
    img = image.copy()
    draw = ImageDraw.Draw(img)
    x, y = center
    draw.ellipse((x - radius, y - radius, x + radius, y + radius), outline="red", width=2)
    return img

def compute_spatial_similarity_matrix(images, concept_list, model, preprocess, device,
                                      grid_size=(7, 7), radius=32):
    model.eval()
    H̃, W̃ = grid_size
    P = torch.zeros((len(images), len(concept_list), H̃, W̃))

    with torch.no_grad():
        # Precompute concept embeddings
        text_tokens = clip.tokenize(concept_list).to(device)
        concept_embeddings = model.encode_text(text_tokens)
        concept_embeddings = concept_embeddings / concept_embeddings.norm(dim=1, keepdim=True)

        for n, image in enumerate(images):
            print("--", n)
            width, height = image.size
            dH = height // (H̃ + 1)
            dW = width // (W̃ + 1)

            for h in range(H̃):
                for w in range(W̃):
                    cx = (w + 1) * dW
                    cy = (h + 1) * dH
                    prompted_img = draw_red_circle(image, (cx, cy), radius)
                    input_tensor = preprocess(prompted_img).unsqueeze(0).to(device)

                    image_embedding = model.encode_image(input_tensor)
                    image_embedding = image_embedding / image_embedding.norm(dim=1, keepdim=True)

                    sim = (image_embedding @ concept_embeddings.T).squeeze(0)  # (M,)
                    P[n, :, h, w] = sim

    return P  # Shape: [N, M, H̃, W̃]


In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
])




train_dataset = datasets.ImageFolder('tiny-imagenet-200/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64)


val_dataset = datasets.ImageFolder('tiny-imagenet-200/val', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=64)

idx_to_wnid = {v: k for k, v in train_dataset.class_to_idx.items()}


In [53]:
idx_to_wnid

{0: 'n01443537',
 1: 'n01629819',
 2: 'n01641577',
 3: 'n01644900',
 4: 'n01698640',
 5: 'n01742172',
 6: 'n01768244',
 7: 'n01770393',
 8: 'n01774384',
 9: 'n01774750',
 10: 'n01784675',
 11: 'n01855672',
 12: 'n01882714',
 13: 'n01910747',
 14: 'n01917289',
 15: 'n01944390',
 16: 'n01945685',
 17: 'n01950731',
 18: 'n01983481',
 19: 'n01984695',
 20: 'n02002724',
 21: 'n02056570',
 22: 'n02058221',
 23: 'n02074367',
 24: 'n02085620',
 25: 'n02094433',
 26: 'n02099601',
 27: 'n02099712',
 28: 'n02106662',
 29: 'n02113799',
 30: 'n02123045',
 31: 'n02123394',
 32: 'n02124075',
 33: 'n02125311',
 34: 'n02129165',
 35: 'n02132136',
 36: 'n02165456',
 37: 'n02190166',
 38: 'n02206856',
 39: 'n02226429',
 40: 'n02231487',
 41: 'n02233338',
 42: 'n02236044',
 43: 'n02268443',
 44: 'n02279972',
 45: 'n02281406',
 46: 'n02321529',
 47: 'n02364673',
 48: 'n02395406',
 49: 'n02403003',
 50: 'n02410509',
 51: 'n02415577',
 52: 'n02423022',
 53: 'n02437312',
 54: 'n02480495',
 55: 'n02481823',
 5

In [49]:
device = "cuda" if torch.cuda.is_available() else "cpu"
grid_size = (7, 7)
circle_radius = 32
my_key = "AIzaSyBiUwxyp8ASs_UgameBEwv5NgWUlTXLMWA"
genai.configure(api_key=my_key)

model, preprocess = clip.load("ViT-B/16", device=device)


100%|███████████████████████████████████████| 335M/335M [01:10<00:00, 4.95MiB/s]


In [22]:
TINY_IMAGENET_ROOT = "tiny-imagenet-200"

# Step 1: Read wnids.txt
with open(os.path.join(TINY_IMAGENET_ROOT, 'wnids.txt'), 'r') as f:
    wnids = [line.strip() for line in f.readlines()]

print("Total classes:", len(wnids))
class_names = [decode_wnid(wnid) for wnid in wnids]
print("Exact class names:", class_names[:10])


Total classes: 200
Exact class names: ['egyptian cat', 'reel', 'volleyball', 'rocking chair', 'lemon', 'bullfrog', 'basketball', 'cliff', 'espresso', 'plunger']


In [55]:
N = 10
selected_images = []
counts = {wnid: 0 for wnid in wnids[:10]}

for path, label in train_dataset.samples:
    # print(label)
    wnid = idx_to_wnid[label]
    if wnid in wnids[:10] and counts[wnid] < N:
        img = Image.open(path).convert("RGB")
        # img = pil_transform(img)
        selected_images.append(img)
        counts[wnid] += 1
    if all(c >= N for c in counts.values()):
        break

# selected_images is now a list of PIL images for the first 10 classes
print(f"Collected {len(selected_images)} images from {len(wnids[:10])} classes.")

Collected 100 images from 10 classes.


In [None]:
concepts = generate_concept_list(class_names[:10])  
print("Generated Concepts:", concepts)

In [None]:
P = compute_spatial_similarity_matrix(
    images=selected_images,
    concept_list=concepts,  
    model=model,
    preprocess=preprocess,
    device=device,
    grid_size=(7, 7),
    radius=32
)

In [66]:
P.shape

torch.Size([100, 87, 7, 7])