In [1]:
print("File started")

File started


In [2]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_api_key")

import wandb
wandb.login(key=secret_value_0)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mpratham3992[0m ([33mpratham3992-plaksha[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
print("Wandb login done")

Wandb login done


In [4]:
import os
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
import heapq
import wandb  # Import wandb

device = torch.device('cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
print(f"Using device: {device}")

# Global variables for image size - will be set in main()
img_size = 224  # default for standard resolution
resize_size = 256  # default for standard resolution

class ProductDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.df = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        
        if 'group' not in self.df.columns:
            self.df['group'] = -1
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_name = os.path.join(self.img_dir, self.df.iloc[idx, 0])
        image = Image.open(img_name).convert('RGB')
        class_id = self.df.iloc[idx, 1]
        group_id = self.df.iloc[idx, 2] if 'group' in self.df.columns else -1
        
        sample = {'image': image, 'class': class_id, 'group': group_id, 'filename': self.df.iloc[idx, 0]}
        
        if self.transform:
            sample['image'] = self.transform(sample['image'])
        
        return sample

def get_data_transforms(high_res=False):
    resize_size = 512 if high_res else 256
    crop_size = 448 if high_res else 224
    
    transform = transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return transform

class FeatureExtractor(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super(FeatureExtractor, self).__init__()
        
        model = models.efficientnet_b1(weights='DEFAULT' if pretrained else None)
        
        self.features = nn.Sequential(*list(model.children())[:-1])
        
        in_features = model.classifier[1].in_features
        self.fc = nn.Linear(in_features, num_classes)
        
    def forward(self, x):
        features = self.features(x)
        features = features.flatten(start_dim=1)
        
        classifier = self.fc(features)
        
        return features, classifier
    
    def extract_features(self, x):
        features = self.features(x)
        features = features.flatten(start_dim=1)
        return features

def compute_similarity(query_features, gallery_features, method='cosine'):
    # Convert to PyTorch tensors if they're numpy arrays
    if isinstance(query_features, np.ndarray):
        query_features = torch.from_numpy(query_features).to(device)
    if isinstance(gallery_features, np.ndarray):
        gallery_features = torch.from_numpy(gallery_features).to(device)
    
    if method == 'cosine':
        query_norm = query_features / torch.norm(query_features, dim=1, keepdim=True)
        gallery_norm = gallery_features / torch.norm(gallery_features, dim=1, keepdim=True)
        similarity = torch.mm(query_norm, gallery_norm.T)
    elif method == 'euclidean':
        # Compute pairwise euclidean distance using PyTorch
        similarity = -torch.cdist(query_features, gallery_features, p=2.0)
    elif method == 'dot':
        similarity = torch.mm(query_features, gallery_features.T)
    
    # Return as numpy for compatibility with the rest of the code
    return similarity.cpu().numpy()

def extract_all_features(model, dataloader):
    model.eval()
    all_features = []
    all_classes = []
    all_groups = []
    all_filenames = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Extracting features"):
            images = batch['image'].to(device)
            classes = batch['class']
            groups = batch['group']
            filenames = batch['filename']
            
            features = model.extract_features(images)
            
            # Keep the features on CPU as NumPy arrays for consistency with the rest of the code
            all_features.append(features.cpu().numpy())
            all_classes.extend(classes.numpy())
            all_groups.extend(groups.numpy())
            all_filenames.extend(filenames)
    
    all_features = np.concatenate(all_features, axis=0)
    all_classes = np.array(all_classes)
    all_groups = np.array(all_groups)
    
    return all_features, all_classes, all_groups, all_filenames

def visualize_retrieval_results(query_img_dir, gallery_img_dir, query_filename, retrieval_filenames, 
                                query_class, retrieval_classes, query_group, retrieval_groups, 
                                similarity_scores, k=5):
    """Visualize image retrieval results"""
    os.makedirs('retrieval_results', exist_ok=True)
    
    query_img_path = os.path.join(query_img_dir, query_filename)
    query_img = Image.open(query_img_path).convert('RGB')
    
    fig, axes = plt.subplots(1, k + 1, figsize=(15, 3))
    
    axes[0].imshow(query_img)
    axes[0].set_title(f'Query Image\nClass: {query_class}, Group: {query_group}')
    axes[0].axis('off')
    
    for i in range(k):
        if i < len(retrieval_filenames):
            gallery_img_path = os.path.join(gallery_img_dir, retrieval_filenames[i])
            gallery_img = Image.open(gallery_img_path).convert('RGB')
            
            axes[i+1].imshow(gallery_img)
            match_text = "Same Class" if retrieval_classes[i] == query_class else "Different Class"
            group_text = f"Group: {retrieval_groups[i]}"
            axes[i+1].set_title(f'Rank {i+1}: {match_text}\n{group_text}\nSimilarity: {similarity_scores[i]:.4f}')
            axes[i+1].axis('off')
    
    plt.tight_layout()
    result_filepath = f'retrieval_results/{os.path.splitext(query_filename)[0]}_retrieval.png'
    plt.savefig(result_filepath)
    plt.close()
    
    # Log the visualization to wandb
    wandb.log({
        f"retrieval_{os.path.splitext(query_filename)[0]}": wandb.Image(
            result_filepath,
            caption=f"Query: {query_filename}, Class: {query_class}, Group: {query_group}"
        )
    })

def perform_retrieval_visualization(query_features, query_classes, query_groups, query_filenames,
                               gallery_features, gallery_classes, gallery_groups, gallery_filenames,
                               distance_method, query_img_dir, gallery_img_dir, top_k=5, num_queries=10):
    """Visualizes retrieval results for selected queries"""
    # Ensure we don't try to retrieve more queries than exist
    num_queries = min(num_queries, len(query_features))
    
    # Convert features to tensors if they're not already
    if isinstance(query_features, np.ndarray):
        query_features = torch.from_numpy(query_features).to(device)
    if isinstance(gallery_features, np.ndarray):
        gallery_features = torch.from_numpy(gallery_features).to(device)
    
    # Randomly select subset of queries for visualization
    query_indices = np.random.choice(len(query_features), num_queries, replace=False)
    
    for idx in query_indices:
        query_feature = query_features[idx:idx+1]
        query_class = query_classes[idx]
        query_group = query_groups[idx]
        query_filename = query_filenames[idx]
        
        # Compute similarity between query and all gallery images
        similarity = compute_similarity(query_feature, gallery_features, method=distance_method)
        similarity = similarity[0]  # Take the first row as we only have one query
        
        # Get top k indices
        if distance_method in ['cosine', 'dot']:
            top_indices = np.argsort(similarity)[::-1][:top_k]
        else:  # euclidean - smaller distance is better
            top_indices = np.argsort(similarity)[:top_k]
            
        retrieval_filenames = [gallery_filenames[i] for i in top_indices]
        retrieval_classes = [gallery_classes[i] for i in top_indices]
        retrieval_groups = [gallery_groups[i] for i in top_indices]
        retrieval_scores = [similarity[i] for i in top_indices]
        
        # Visualize the retrieval results
        visualize_retrieval_results(
            query_img_dir, gallery_img_dir, query_filename, retrieval_filenames,
            query_class, retrieval_classes, query_group, retrieval_groups,
            retrieval_scores, k=top_k
        )

def calculate_precision(query_features, query_classes, gallery_features, gallery_classes, method='cosine', top_k=5):
    """Calculate precision@k for a given distance method"""
    # Convert inputs to tensors on device if they're not already
    if isinstance(query_features, np.ndarray):
        query_features = torch.from_numpy(query_features).to(device)
    if isinstance(gallery_features, np.ndarray):
        gallery_features = torch.from_numpy(gallery_features).to(device)
    
    query_classes_np = query_classes  # Keep a numpy version for indexing
    if isinstance(gallery_classes, np.ndarray):
        gallery_classes_tensor = torch.from_numpy(gallery_classes).to(device)
    else:
        gallery_classes_tensor = gallery_classes
    
    batch_size = 100  # Process queries in batches to avoid memory issues
    num_queries = len(query_features)
    all_precision = []
    
    for batch_start in range(0, num_queries, batch_size):
        batch_end = min(batch_start + batch_size, num_queries)
        batch_query_features = query_features[batch_start:batch_end]
        
        # Compute similarity for the entire batch
        similarity_matrix = compute_similarity(batch_query_features, gallery_features, method=method)
        
        # Process each query in the batch
        for i in range(batch_end - batch_start):
            idx = batch_start + i
            query_class = query_classes_np[idx]
            similarity = similarity_matrix[i]
            
            if method in ['cosine', 'dot']:
                top_indices = np.argsort(similarity)[::-1][:top_k]
            else:
                top_indices = np.argsort(similarity)[:top_k]
            
            retrieval_classes = [gallery_classes[j] for j in top_indices]
            
            # Calculate precision
            correct = sum([1 for c in retrieval_classes if c == query_class])
            precision = correct / top_k
            all_precision.append(precision)
    
    avg_precision = np.mean(all_precision)
    print(f"Average Precision@{top_k} ({method} distance): {avg_precision:.4f}")
    
    # Log precision metric to wandb
    wandb.log({f"precision@{top_k}_{method}": avg_precision})
    
    return avg_precision

def calculate_ap_at_k(relevant_scores, k=5):
    """
    Calculate Average Precision @ K.
    
    Args:
        relevant_scores: Binary list indicating if each retrieval is relevant (same class)
        k: Cut-off rank (consider only top k retrievals)
    
    Returns:
        AP@k score
    """
    relevant_scores = relevant_scores[:k]  # Consider only top k
    
    if not any(relevant_scores):  # No relevant retrievals
        return 0.0
    
    # Calculate precision at each relevant position
    precisions = []
    num_relevant = 0
    
    for i, is_relevant in enumerate(relevant_scores):
        if is_relevant:
            num_relevant += 1
            # Precision at position i+1 (0-indexed to 1-indexed)
            precision_at_i = num_relevant / (i + 1)
            precisions.append(precision_at_i)
    
    # AP is the average of precisions at relevant positions
    ap = sum(precisions) / min(sum(relevant_scores), k)
    return ap

def calculate_map(query_features, query_classes, gallery_features, gallery_classes, method='cosine', top_k=5):
    """Calculate mAP@k for a given distance method"""
    # Convert inputs to tensors on device if they're not already
    if isinstance(query_features, np.ndarray):
        query_features = torch.from_numpy(query_features).to(device)
    if isinstance(gallery_features, np.ndarray):
        gallery_features = torch.from_numpy(gallery_features).to(device)
    
    query_classes_np = query_classes if isinstance(query_classes, np.ndarray) else query_classes.cpu().numpy()
    gallery_classes_np = gallery_classes if isinstance(gallery_classes, np.ndarray) else gallery_classes.cpu().numpy()
    
    batch_size = 100  # Process queries in batches to avoid memory issues
    num_queries = len(query_features)
    all_ap = []
    
    for batch_start in range(0, num_queries, batch_size):
        batch_end = min(batch_start + batch_size, num_queries)
        batch_query_features = query_features[batch_start:batch_end]
        
        # Compute similarity for the entire batch
        similarity_matrix = compute_similarity(batch_query_features, gallery_features, method=method)
        
        # Process each query in the batch
        for i in range(batch_end - batch_start):
            idx = batch_start + i
            query_class = query_classes_np[idx]
            similarity = similarity_matrix[i]
            
            # Get ranking indices
            if method in ['cosine', 'dot']:
                top_indices = np.argsort(similarity)[::-1][:top_k]
            else:
                top_indices = np.argsort(similarity)[:top_k]
                
            # Get relevance scores (1 for same class, 0 for different)
            relevant_scores = [1 if gallery_classes_np[j] == query_class else 0 for j in top_indices]
            
            # Calculate AP@k for this query
            ap = calculate_ap_at_k(relevant_scores, top_k)
            all_ap.append(ap)
    
    # Calculate mAP@k (mean of APs)
    map_k = np.mean(all_ap)
    print(f"mAP@{top_k} ({method} distance): {map_k:.4f}")
    
    # Log mAP metric to wandb
    wandb.log({f"map@{top_k}_{method}": map_k})
    
    return map_k

def perform_retrieval(model, query_dataloader, gallery_dataloader, distance_method, query_img_dir, gallery_img_dir, top_k=5, num_queries=10, 
                  pre_extracted_features=None):
    # Use pre-extracted features if provided, otherwise extract them
    if pre_extracted_features:
        query_features, query_classes, query_groups, query_filenames, gallery_features, gallery_classes, gallery_groups, gallery_filenames = pre_extracted_features
    else:
        # Extract features for gallery and query images
        gallery_features, gallery_classes, gallery_groups, gallery_filenames = extract_all_features(model, gallery_dataloader)
        query_features, query_classes, query_groups, query_filenames = extract_all_features(model, query_dataloader)
    
    # Do visualization with the primary distance method
    perform_retrieval_visualization(
        query_features, query_classes, query_groups, query_filenames,
        gallery_features, gallery_classes, gallery_groups, gallery_filenames,
        distance_method=distance_method, 
        top_k=top_k,
        query_img_dir=query_img_dir,
        gallery_img_dir=gallery_img_dir,
        num_queries=num_queries
    )
    
    # Calculate metrics for all distance methods
    all_distances = ['cosine', 'euclidean', 'dot']
    results = {}
    
    for dist_method in all_distances:
        avg_precision = calculate_precision(
            query_features, query_classes,
            gallery_features, gallery_classes,
            method=dist_method,
            top_k=top_k
        )
        results[dist_method] = avg_precision
    
    # Calculate mAP for all distance methods
    map_results = {}
    for dist_method in all_distances:
        map_k = calculate_map(
            query_features, query_classes,
            gallery_features, gallery_classes,
            method=dist_method,
            top_k=top_k
        )
        map_results[dist_method] = map_k
    
    # Print comparison table
    print("\nComparison of distance metrics (Precision@{} and mAP@{}):".format(top_k, top_k))
    print("=" * 50)
    print("{:<15} {:<10} {:<10}".format("Distance Metric", f"Precision@{top_k}", f"mAP@{top_k}"))
    print("-" * 50)
    for dist_method in all_distances:
        print("{:<15} {:<10.4f} {:<10.4f}".format(dist_method, results[dist_method], map_results[dist_method]))
    print("=" * 50)
    
    # Log comparison table to wandb
    wandb.log({f"precision_map_comparison_at_{top_k}": wandb.Table(
        columns=["Distance Metric", f"Precision@{top_k}", f"mAP@{top_k}"],
        data=[[method, results[method], map_results[method]] for method in all_distances]
    )})
    
    return results

def main():
    # Parameters that were previously handled by argparse
    train_dir = '/kaggle/input/visual-product-recognition/train/train'
    test_dir = '/kaggle/input/visual-product-recognition/test/test'
    train_csv = '/kaggle/input/visual-product-recognition/train.csv'
    test_csv = '/kaggle/input/visual-product-recognition/test.csv'
    batch_size = 16
    high_res = False  # Set to True to use high resolution images
    top_k = 5
    distance = 'cosine'  # Primary distance for visualization: 'cosine', 'euclidean', 'dot'
    
    # Initialize wandb
    wandb.init(
        project="visual-product-recognition",
        name="image-retrieval",
        config={
            "model": "efficientnet_b1",
            "batch_size": batch_size,
            "high_res": high_res,
            "top_k": top_k,
            "primary_distance": distance
        }
    )
    
    # Update global variables based on high_res setting
    global img_size, resize_size
    img_size = 448 if high_res else 224
    resize_size = 512 if high_res else 256
    
    # Get data transformations
    transform = get_data_transforms(high_res=high_res)
    
    # Load datasets
    train_dataset = ProductDataset(train_csv, train_dir, transform=transform)
    test_dataset = ProductDataset(test_csv, test_dir, transform=transform)
    
    # Calculate number of classes
    num_classes = len(pd.read_csv(train_csv)['class'].unique())
    print(f"Number of classes: {num_classes}")
    
    # Create data loaders
    gallery_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    query_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    
    # Initialize the model
    model = FeatureExtractor(num_classes=num_classes)
    
    # Load pre-trained weights if available, otherwise just use ImageNet weights
    try:
        checkpoint = torch.load('/kaggle/input/metricloss-false/pytorch/default/1/product_model.pth', map_location=device)
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint, strict=False)
        print("Loaded pre-trained product model weights")
    except Exception as e:
        print(f"Could not load pre-trained weights: {e}")
        print("Using ImageNet pre-trained weights")
    
    model = model.to(device)
    
    # Extract features once (to avoid re-computing them for each distance metric)
    print("Extracting gallery features...")
    gallery_features, gallery_classes, gallery_groups, gallery_filenames = extract_all_features(model, gallery_loader)
    print("Extracting query features...")
    query_features, query_classes, query_groups, query_filenames = extract_all_features(model, query_loader)
    print("Feature extraction complete.")

    # Package extracted features
    pre_extracted_features = (
        query_features, query_classes, query_groups, query_filenames,
        gallery_features, gallery_classes, gallery_groups, gallery_filenames
    )

    # Perform retrieval and visualization with pre-extracted features
    perform_retrieval(
        model,
        query_loader,
        gallery_loader,
        distance_method=distance,
        query_img_dir=test_dir,
        gallery_img_dir=train_dir,
        top_k=top_k,
        num_queries=10,
        pre_extracted_features=pre_extracted_features
    )
    
    # Finish wandb run
    wandb.finish()

if __name__ == "__main__":
    main()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Using device: cuda


[34m[1mwandb[0m: Tracking run with wandb version 0.19.6
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20250510_213616-exsmq3nl[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mimage-retrieval[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/pratham3992-plaksha/visual-product-recognition[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/pratham3992-plaksha/visual-product-recognition/runs/exsmq3nl[0m


Number of classes: 9691


Downloading: "https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b1-c27df63c.pth
100%|██████████| 30.1M/30.1M [00:00<00:00, 132MB/s]
  checkpoint = torch.load('/kaggle/input/metricloss-false/pytorch/default/1/product_model.pth', map_location=device)


Loaded pre-trained product model weights
Extracting gallery features...


Extracting features: 100%|██████████| 8871/8871 [17:13<00:00,  8.59it/s]


Extracting query features...


Extracting features: 100%|██████████| 3461/3461 [05:32<00:00, 10.42it/s]


Feature extraction complete.
Average Precision@5 (cosine distance): 0.4849
Average Precision@5 (euclidean distance): 0.0000
Average Precision@5 (dot distance): 0.4008
mAP@5 (cosine distance): 0.6363
mAP@5 (euclidean distance): 0.0000
mAP@5 (dot distance): 0.5486

Comparison of distance metrics (Precision@5 and mAP@5):
Distance Metric Precision@5 mAP@5     
--------------------------------------------------
cosine          0.4849     0.6363    
euclidean       0.0000     0.0000    
dot             0.4008     0.5486    


[34m[1mwandb[0m: uploading artifact run-exsmq3nl-precision_map_comparison_at_5
[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:          map@5_cosine ▁
[34m[1mwandb[0m:             map@5_dot ▁
[34m[1mwandb[0m:       map@5_euclidean ▁
[34m[1mwandb[0m:    precision@5_cosine ▁
[34m[1mwandb[0m:       precision@5_dot ▁
[34m[1mwandb[0m: precision@5_euclidean ▁
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m:          map@5_cosine 0.63631
[34m[1mwandb[0m:             map@5_dot 0.54863
[34m[1mwandb[0m:       map@5_euclidean 0
[34m[1mwandb[0m:    precision@5_cosine 0.48493
[34m[1mwandb[0m:       precision@5_dot 0.40081
[34m[1mwandb[0m: precision@5_euclidean 0
[34m[1mwandb[0m: 
[34m[1mwandb[0m: 🚀 View run [33mimage-retrieval[0m at: [34m[4mhttps://wandb.ai/pratham3992-plaksha/visual-product-recognition/runs