In [None]:
# !pip install git+https://github.com/openai/CLIP.git
# !pip install google-generativeai
# !pip install nltk

In [None]:
# !wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
# !unzip tiny-imagenet-200.zip

In [8]:
import os
import torch
import clip
import time
import random
import google.generativeai as genai
import nltk
from nltk.corpus import wordnet as wn
from PIL import Image, ImageDraw
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.utils.data import random_split
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset, Dataset
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pickle


# nltk.download('wordnet')


In [9]:
def generate_concept_list(class_names):
    
    concept_set = set()
    model = genai.GenerativeModel('gemini-2.0-flash')  
    count = 0
    
    for cname in tqdm(class_names, desc="Generating concepts"):
        prompts = [
            f"List exactly 5-10 of the most important features for recognizing something as a {cname}. Format as a simple list with one feature per line, no bullets, numbering, or explanations.",
            f"List exactly 5-10 things most commonly seen around a {cname}. Format as a simple list with one item per line, no bullets, numbering, or explanations.",
            f"List exactly 3-5 superclasses or categories for the word {cname}. Format as a simple list with one category per line, no bullets, numbering, or explanations."
        ]
        
        for prompt_id, prompt in enumerate(prompts):
            max_retries = 5
            retry_count = 0
            retry_delay = 2  
            
            print("prompt id: ", prompt_id)
            while retry_count <= max_retries:
                try:
                    response = model.generate_content(
                        prompt,
                        generation_config={
                            'temperature': 0.2,  # Lower temperature for more consistent formatting
                            'top_p': 0.95,
                            'top_k': 40,
                            'max_output_tokens': 300,
                        }
                    )
                    content = response.text
                    lines = content.split('\n')
                    
                    for line in lines:
                        # Skip empty lines and header/instructional text
                        if not line.strip() or "list" in line.lower() or "feature" in line.lower() or "category" in line.lower():
                            continue
                            
                        cleaned = line.strip(" .•-*0123456789:()[]{}\"\',").lower()
                        print("---")
                        print(cleaned)
                        print("-.-")

                        if cleaned and len(cleaned) > 2 and cname.lower() not in cleaned:
                            concept_set.add(cleaned)
                    
                    break
                    
                except Exception as e:
                    error_msg = str(e)
                    print(f"Gemini API error: {error_msg}")
                    
                    # Check if it's a rate limit error (429)
                    if "429" in error_msg:
                        # Extract retry delay from error if available
                        import re
                        delay_match = re.search(r'retry_delay \{\s*seconds: (\d+)', error_msg)
                        
                        if delay_match:
                            # Use the suggested delay from the API
                            retry_seconds = int(delay_match.group(1))
                            retry_seconds += random.uniform(0, 2)  # Add small random jitter
                        else:
                            # Exponential backoff with jitter
                            retry_seconds = retry_delay + random.uniform(0, retry_delay * 0.1)
                            retry_delay *= 2  # Double the delay for next retry
                            
                        print(f"Rate limited. Retrying in {retry_seconds:.1f} seconds...")
                        time.sleep(retry_seconds)
                        retry_count += 1
                    else:
                        # For non-rate-limit errors, just print and continue
                        print(f"Error (not retrying): {error_msg}")
                        break
            
            # Add a small delay between successful requests to avoid hitting rate limits
            time.sleep(1)
    
    # Return concepts that are between 2 and 5 words long
    return sorted([c for c in concept_set if 2 <= len(c.split()) <= 5])

# Decode WNIDs to human-readable class names
def decode_wnid(wnid):
    synset = wn.synset_from_pos_and_offset(wnid[0], int(wnid[1:]))
    return synset.name().split('.')[0].replace('_', ' ')

def draw_red_circle(image, center, radius):
    img = image.copy()
    draw = ImageDraw.Draw(img)
    x, y = center
    draw.ellipse((x - radius, y - radius, x + radius, y + radius), outline="red", width=2)
    return img

def compute_spatial_similarity_matrix(images, concept_list, model, preprocess, device,
                                      grid_size=(7, 7), radius=32):
    model.eval()
    H̃, W̃ = grid_size
    P = torch.zeros((len(images), len(concept_list), H̃, W̃))

    with torch.no_grad():
        # Precompute concept embeddings
        text_tokens = clip.tokenize(concept_list).to(device)
        concept_embeddings = model.encode_text(text_tokens)
        concept_embeddings = concept_embeddings / concept_embeddings.norm(dim=1, keepdim=True)

        for n, image in enumerate(images):
            print("--", n)
            width, height = image.size
            dH = height // (H̃ + 1)
            dW = width // (W̃ + 1)

            for h in range(H̃):
                for w in range(W̃):
                    cx = (w + 1) * dW
                    cy = (h + 1) * dH
                    prompted_img = draw_red_circle(image, (cx, cy), radius)
                    input_tensor = preprocess(prompted_img).unsqueeze(0).to(device)

                    image_embedding = model.encode_image(input_tensor)
                    image_embedding = image_embedding / image_embedding.norm(dim=1, keepdim=True)

                    sim = (image_embedding @ concept_embeddings.T).squeeze(0)  # (M,)
                    P[n, :, h, w] = sim

    return P  # Shape: [N, M, H̃, W̃]


In [10]:
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
])




train_dataset = datasets.ImageFolder('tiny-imagenet-200/train', transform=transform)
# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)


val_dataset = datasets.ImageFolder('tiny-imagenet-200/val', transform=transform)
# val_loader = DataLoader(val_dataset, batch_size=64)

idx_to_wnid = {v: k for k, v in train_dataset.class_to_idx.items()}


In [11]:
device = "cuda" if torch.cuda.is_available() else "cpu"
grid_size = (7, 7)
circle_radius = 32
my_key = "AIzaSyBiUwxyp8ASs_UgameBEwv5NgWUlTXLMWA"
genai.configure(api_key=my_key)

model, preprocess = clip.load("ViT-B/16", device=device)


In [12]:
TINY_IMAGENET_ROOT = "tiny-imagenet-200"

# Step 1: Read wnids.txt
with open(os.path.join(TINY_IMAGENET_ROOT, 'wnids.txt'), 'r') as f:
    wnids = [line.strip() for line in f.readlines()]

class_start = 22
class_end = 37

print("Total classes:", len(wnids))
class_names = [decode_wnid(wnid) for wnid in wnids]
print("Exact class names:", class_names[class_start:class_end])




Total classes: 200
Exact class names: ['goldfish', 'potpie', 'hourglass', 'seashore', 'computer keyboard', 'arabian camel', 'ice cream', 'nail', 'space heater', 'cardigan', 'baboon', 'snail', 'coral reef', 'albatross', 'spider web']


In [13]:
# concepts = generate_concept_list(class_names[class_start:class_end])  
# print("Generated Concepts:", concepts)

#loading saved_concepts
with open("concepts_22_37.pkl", "rb") as f:
    concepts = pickle.load(f)

In [27]:
N = 1000
step2_images = []
bottleneck_images = []
val_image_labels = []
counts = {wnid: 0 for wnid in wnids[class_start:class_end]}

for path, label in train_dataset.samples:
    # print(label)
    wnid = idx_to_wnid[label]
    if wnid in wnids[class_start:class_end] and counts[wnid] < N:
        img = Image.open(path)
        rgb_img = Image.open(path).convert("RGB")
        # img = pil_transform(img)
        step2_images.append(rgb_img)
        bottleneck_images.append((img, label))

        counts[wnid] += 1
    if all(c >= N for c in counts.values()):
        break

counts = {wnid: 0 for wnid in wnids[class_start:class_end]}

for path, label in val_dataset.samples:
    # print(label)
    wnid = idx_to_wnid[label]
    if wnid in wnids[class_start:class_end] and counts[wnid] < N:
        img = Image.open(path)
        # img = pil_transform(img)
        val_image_labels.append((img, label))

        counts[wnid] += 1
    if all(c >= N for c in counts.values()):
        break


# selected_images is now a list of PIL images for the first 10 classes
print(f"Collected {len(step2_images)} images from {len(wnids[class_start:class_end])} classes.")
print(f"Collected Val {len(val_image_labels)} images from {len(wnids[class_start:class_end])} classes.")

Collected 7500 images from 15 classes.
Collected Val 1000 images from 15 classes.


In [15]:
# P = compute_spatial_similarity_matrix(
#     images=step2_images,
#     concept_list=concepts,  
#     model=model,
#     preprocess=preprocess,
#     device=device,
#     grid_size=(7, 7),
#     radius=32
# )

P = torch.load("P_22_37.pt")


In [16]:
class BackboneDataset(Dataset):
    def __init__(self, image_label_tuples, transform=None):
        """
        Dataset for image-label tuples.

        Args:
            image_label_tuples: List of tuples (PIL image, label)
            transform: Torchvision transforms
        """
        self.image_label_tuples = image_label_tuples
        self.transform = transform if transform is not None else transforms.Compose([
            transforms.Lambda(lambda image: image.convert("RGB")),  # Ensures 3 channels

            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])
    
    def __len__(self):
        return len(self.image_label_tuples)
    
    def __getitem__(self, idx):
        image, label = self.image_label_tuples[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, torch.tensor(label)
    
    
def load_backbone_model(device, num_classes=200):
    """
    Load a pre-trained ResNet model and modify it for TinyImageNet.
    
    Args:
        device: Device to load the model on
        num_classes: Number of classes in TinyImageNet (default: 200)
    
    Returns:
        model: Modified ResNet model
    """
    # Load pre-trained ResNet18
    backbone = models.resnet18(weights='IMAGENET1K_V1')
    backbone.fc = nn.Linear(backbone.fc.in_features, num_classes)
    
    # Move model to device
    backbone = backbone.to(device)
    
    return backbone

def extract_backbone_features(backbone, dataloader, device):
    features = []
    labels = []
    
    # Create a hook to capture the features
    activation = {}
    def get_activation(name):
        def hook(model, input, output):
            activation[name] = output.detach()
        return hook
    
    backbone.layer4.register_forward_hook(get_activation('layer4'))
    
    backbone.eval()
    with torch.no_grad():
        for images, image_labels in tqdm(dataloader, desc="Extracting features"):
            images = images.to(device)
            backbone(images)  # Forward pass to trigger the hook
            
            # Store the features and labels
            features.append(activation['layer4'].cpu())
            labels.append(image_labels.cpu())
    
    return torch.cat(features), torch.cat(labels)

In [18]:

class BottleNeckDataset(Dataset):
    def __init__(self, features, target_conept_maps):

        self.features = features
        self.target_conept_maps = target_conept_maps
    
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):

        return self.features[idx], self.target_conept_maps[idx] 

# Define the Spatial Concept Bottleneck Layer
class SpatialConceptBottleneckLayer(nn.Module):
    def __init__(self, in_channels, num_concepts, grid_size=(7, 7)):
        super(SpatialConceptBottleneckLayer, self).__init__()
        self.grid_size = grid_size
        
        # 1x1 convolution to project features to concept maps
        self.bottleneck = nn.Conv2d(in_channels, num_concepts, kernel_size=1)
    
    def forward(self, features):
        # Resize features to match grid size
        resized_features = F.interpolate(features, size=self.grid_size, mode='bilinear', align_corners=False)
        
        # Apply 1x1 convolution
        concept_maps = self.bottleneck(resized_features)  # Shape: [N, M, H̃, W̃]
        
        return concept_maps


In [19]:
class SparseLinearClassifierDataset(Dataset):
    def __init__(self, concept_maps, target_classes):

        self.concept_maps = concept_maps
        self.target_classes = target_classes
    
    def __len__(self):
        return len(self.concept_maps)
    
    def __getitem__(self, idx):

        return self.concept_maps[idx], self.target_classes[idx] 

class SparseLinearClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.linear = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.linear(x)

def elastic_net_regularization(W, alpha=0.5):
    frob_norm = torch.norm(W, p='fro')**2
    l1_norm = torch.norm(W, p=1)
    return (1 - alpha) / 2 * frob_norm + alpha * l1_norm




In [35]:
# Define the Cubic Cosine Similarity Loss
class CubicCosineSimilarityLoss(nn.Module):
    def __init__(self):
        super(CubicCosineSimilarityLoss, self).__init__()
    
    def forward(self, concept_maps, target_similarities):
        N, M, H, W = concept_maps.shape
        loss = 0.0
        
        for m in range(M):
            for h in range(H):
                for w in range(W):
                    # Extract q[m,h,w] and p[m,h,w] as described in the paper
                    q = concept_maps[:, m, h, w]  # C[:, m, h, w] shape: [N]
                    p = target_similarities[:, m, h, w]  # P[:, m, h, w] shape: [N]
                    
                    # Zero-mean normalization
                    q_norm = q - q.mean()
                    p_norm = p - p.mean()
                    
                    # Cubic transformation (raise to power of 3)
                    q_cubic = q_norm ** 3
                    p_cubic = p_norm ** 3
                    
                    # L2 normalization
                    q_cubic_norm = q_cubic / (torch.norm(q_cubic) + 1e-8)
                    p_cubic_norm = p_cubic / (torch.norm(p_cubic) + 1e-8)
                    
                    # Cosine similarity
                    sim = torch.dot(q_cubic_norm, p_cubic_norm)
                    
                    # Negative similarity for loss minimization
                    loss -= sim
        
        return loss / (M * H * W)


def train_concept_bottleneck_layer(features, labels, P, device, grid_size=(7, 7), num_epochs=10):
    # step 3

    in_channels = features.shape[1]  # Channels in backbone features
    bottleneck = SpatialConceptBottleneckLayer(in_channels, len(concepts), grid_size).to(device)


    criterion = CubicCosineSimilarityLoss()
    optimizer = optim.Adam(bottleneck.parameters(), lr=0.001)

    # print(features.shape)
    # features = features.to(device)
    # P = P.to(device)

    bottleneck_dataset = BottleNeckDataset(features=features, target_conept_maps=P)
    bottleneck_dataloader = DataLoader(bottleneck_dataset, batch_size=128, shuffle=True)

    bottleneck.train()
    print("training bottleneck")
    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch_features, target_concept_maps in  tqdm(bottleneck_dataloader):
            batch_features = batch_features.to(device)
            target_concept_maps = target_concept_maps.to(device)

            optimizer.zero_grad()
            output_concept_maps = bottleneck(batch_features)
            loss = criterion(output_concept_maps, target_concept_maps)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * batch_features.size(0)

        epoch_loss = running_loss / len(bottleneck_dataloader.dataset)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}")

    # step 4
    all_concept_maps_list = []
    with torch.no_grad():
        for batch_features, target_concept_maps in  bottleneck_dataloader:
            batch_features = batch_features.to(device)
            concept_maps = bottleneck(batch_features)  
            all_concept_maps_list.append(concept_maps.cpu())  

    all_concept_maps = torch.cat(all_concept_maps_list, dim=0)  # 

    print(f"Generated concept maps shape: {all_concept_maps.shape}")
    
    concept_activations = F.adaptive_avg_pool2d(concept_maps, 1).squeeze(-1).squeeze(-1)  
    mapped_labels = labels - 22

    print(mapped_labels.shape)

    classifier_head_dataset = SparseLinearClassifierDataset(concept_activations, mapped_labels)
    classifier_head_dataloader = DataLoader(classifier_head_dataset, batch_size=128, shuffle=True)

    classifier_head = SparseLinearClassifier(input_dim=all_concept_maps.shape[1], num_classes=15).to(device)
    classifier_head_optimizer = torch.optim.SGD(classifier_head.parameters(), lr=0.01)  # placeholder for GLM-SAGA
    lambda_reg = 1e-4
    alpha = 0.5
    classifier_head_criterion = nn.CrossEntropyLoss()

    classifier_head.train()
    for epoch in range(num_epochs):
        total_loss = 0.0
        correct = 0
        total = 0

        for x_concepts, x_concepts_label in tqdm(classifier_head_dataloader):
            x_concepts, x_concepts_label = x_concepts.to(device), x_concepts_label.to(device)
            classifier_head_optimizer.zero_grad()

            logits = classifier_head(x_concepts)  # shape: [batch_size, num_classes]
            loss_ce = classifier_head_criterion(logits, x_concepts_label)

            W = classifier_head.linear.weight
            loss_reg = lambda_reg * elastic_net_regularization(W, alpha)
            loss = loss_ce + loss_reg

            loss.backward()
            classifier_head_optimizer.step()

            total_loss += loss.item() * x_concepts.size(0)

            # Accuracy computation
            predictions = torch.argmax(logits, dim=1)
            correct += (predictions == x_concepts_label).sum().item()
            total += x_concepts_label.size(0)

        avg_loss = total_loss / len(classifier_head_dataloader.dataset)
        accuracy = 100.0 * correct / total
        print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")
        return bottleneck, classifier_head



In [None]:
backbone = load_backbone_model(device)
backbone.eval() 
backbone_data = BackboneDataset(bottleneck_images)
backbone_dataloader = DataLoader(backbone_data, batch_size=128, shuffle=False)
features, labels = extract_backbone_features(backbone, backbone_dataloader, device)


Extracting features:  93%|█████████▎| 55/59 [02:28<00:10,  2.52s/it]

In [None]:
grid_size = (7, 7) 
num_concepts = len(concepts)
bottleneck, classifier_head = train_concept_bottleneck_layer(
    features=features,
    labels = labels,
    dataloader=backbone_dataloader, 
    P=P,  
    device=device,
    grid_size=(7, 7), 
    num_epochs=1
)


100%|██████████| 59/59 [02:32<00:00,  2.58s/it]
100%|██████████| 1/1 [02:32<00:00, 152.08s/it]


Epoch [1/1], Loss: -0.2338
Generated concept maps shape: torch.Size([7500, 134, 7, 7])
torch.Size([7500])


TypeError: empty(): argument 'size' failed to unpack the object at pos 2 with error "type must be tuple of ints,but got Tensor"

In [None]:
#this is incomplete
def predict_class(backbone, bottleneck, classifier_head, image, preprocess, device):
    backbone.eval()
    bottleneck.eval()
    classifier_head.eval()
    
    # Create a hook to capture the features
    activation = {}
    def get_activation(name):
        def hook(model, input, output):
            activation[name] = output.detach()
        return hook
    
    # Register the hook
    backbone.layer4.register_forward_hook(get_activation('layer4'))
    
    with torch.no_grad():
        # Preprocess and forward through backbone
        input_tensor = preprocess(image).unsqueeze(0).to(device)
        backbone(input_tensor)
        
        # Extract features and predict concept maps
        features = activation['layer4']
        concept_maps = bottleneck(features)
        concept_activation = F.adaptive_avg_pool2d(concept_maps, 1).squeeze(-1).squeeze(-1)  
        logits = classifier_head(concept_activation) 
        prediction = torch.argmax(logits, dim=1)
    return prediction 