In [1]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

In [6]:
 import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image

import clip 
import numpy as np

In [7]:
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score, classification_report

In [8]:
class AnimalsDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        self.image_paths = []
        self.labels = []
        self.class_to_idx = {}

        # sorted for consistency
        classes = sorted(os.listdir(root_dir))
        for idx, cls in enumerate(classes):
            cls_path = os.path.join(root_dir, cls)
            if not os.path.isdir(cls_path):
                continue

            self.class_to_idx[cls] = idx

            for file in os.listdir(cls_path):
                if file.lower().endswith((".jpg", ".jpeg", ".png")):
                    self.image_paths.append(os.path.join(cls_path, file))
                    self.labels.append(idx)

        print(f"Found {len(self.image_paths)} images across {len(self.class_to_idx)} classes.")
        print("Class mapping:", self.class_to_idx)

In [9]:

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]

        img = Image.open(img_path).convert("RGB")

        if self.transform:
            img = self.transform(img)

        return img, label


In [12]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

model, preprocess = clip.load("ViT-B/32", device=device)  # downloads if needed
model.eval()

Using device: cuda


100%|███████████████████████████████████████| 338M/338M [00:11<00:00, 31.4MiB/s]


CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [13]:
root_dir = "/home/akash/Downloads/Animals-10"
dataset = AnimalsDataset(root_dir=root_dir, transform=preprocess)

loader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=False,      # important: we want features aligned with labels
    num_workers=4
)

Found 498 images across 6 classes.
Class mapping: {'butterfly': 0, 'cat': 1, 'dog': 2, 'elephant': 3, 'horse': 4, 'squirrel': 5}


In [17]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
import clip

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

model, preprocess = clip.load("ViT-B/32", device=device)
model.eval()

class AnimalImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        self.image_paths = []
        self.labels = []
        self.class_to_idx = {}

        classes = sorted(os.listdir(root_dir))
        for idx, cls in enumerate(classes):
            cls_path = os.path.join(root_dir, cls)
            if not os.path.isdir(cls_path):
                continue

            self.class_to_idx[cls] = idx

            for file in os.listdir(cls_path):
                if file.lower().endswith((".jpg", ".jpeg", ".png")):
                    self.image_paths.append(os.path.join(cls_path, file))
                    self.labels.append(idx)

        print(f"Found {len(self.image_paths)} images across {len(self.class_to_idx)} classes.")
        print("Class mapping:", self.class_to_idx)

    def __len__(self):
        # <- THIS is critical, must be at same indent level as __init__
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            img = self.transform(img)

        return img, label


Using device: cuda


In [19]:
root_dir = "/home/akash/Downloads/Animals-10"

dataset = AnimalImageDataset(
    root_dir=root_dir,
    transform=preprocess
)

print("len(dataset) =", len(dataset))  # should print 498 or similar

loader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=False,
    num_workers=0
)

# quick test: just grab one batch
for images, labels in loader:
    print("Batch images shape:", images.shape)
    print("Batch labels shape:", labels.shape)
    break


Found 498 images across 6 classes.
Class mapping: {'butterfly': 0, 'cat': 1, 'dog': 2, 'elephant': 3, 'horse': 4, 'squirrel': 5}
len(dataset) = 498
Batch images shape: torch.Size([64, 3, 224, 224])
Batch labels shape: torch.Size([64])


In [20]:
import numpy as np
import torch

all_features = []
all_labels = []

model.eval()
with torch.no_grad():
    for images, labels in loader:
        images = images.to(device)

        # Encode using CLIP
        feats = model.encode_image(images)          # [batch, 512]
        feats = feats / feats.norm(dim=-1, keepdim=True)  # L2 normalize

        all_features.append(feats.cpu())
        all_labels.append(labels)

all_features = torch.cat(all_features, dim=0).numpy()   # (N, 512)
all_labels = torch.cat(all_labels, dim=0).numpy()       # (N,)

print("Embeddings shape:", all_features.shape)
print("Labels shape:", all_labels.shape)


Embeddings shape: (498, 512)
Labels shape: (498,)


In [21]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    all_features,
    all_labels,
    test_size=0.2,        # 20% test
    stratify=all_labels,  # keep class balance
    random_state=42
)

print("Train:", X_train.shape, "Test:", X_test.shape)


Train: (398, 512) Test: (100, 512)


In [22]:
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score, classification_report

# --- Logistic Regression ---
log_reg = LogisticRegression(
    max_iter=1000,
    n_jobs=-1,
    multi_class="multinomial"
)
log_reg.fit(X_train, y_train)
y_pred_lr = log_reg.predict(X_test)

print("=== Logistic Regression ===")
print("Accuracy:", accuracy_score(y_test, y_pred_lr))
print(classification_report(y_test, y_pred_lr))


# --- KNN (k=5, cosine distance) ---
knn = KNeighborsClassifier(
    n_neighbors=5,
    metric="cosine"
)
knn.fit(X_train, y_train)
y_pred_knn = knn.predict(X_test)

print("\n=== KNN (k=5, cosine) ===")
print("Accuracy:", accuracy_score(y_test, y_pred_knn))
print(classification_report(y_test, y_pred_knn))


# --- Gaussian Naive Bayes ---
nb = GaussianNB()
nb.fit(X_train, y_train)
y_pred_nb = nb.predict(X_test)

print("\n=== Gaussian Naive Bayes ===")
print("Accuracy:", accuracy_score(y_test, y_pred_nb))
print(classification_report(y_test, y_pred_nb))


Exception in thread QueueFeederThread:
Traceback (most recent call last):
  File "/home/akash/miniconda3/envs/ml/lib/python3.10/multiprocessing/queues.py", line 239, in _feed
Exception in thread QueueFeederThread:
Traceback (most recent call last):
  File "/home/akash/miniconda3/envs/ml/lib/python3.10/multiprocessing/queues.py", line 239, in _feed
    reader_close()
  File "/home/akash/miniconda3/envs/ml/lib/python3.10/multiprocessing/connection.py", line 177, in close
Exception in thread QueueFeederThread:
Traceback (most recent call last):
  File "/home/akash/miniconda3/envs/ml/lib/python3.10/multiprocessing/queues.py", line 239, in _feed
    reader_close()
  File "/home/akash/miniconda3/envs/ml/lib/python3.10/multiprocessing/connection.py", line 177, in close
    self._close()
  File "/home/akash/miniconda3/envs/ml/lib/python3.10/multiprocessing/connection.py", line 361, in _close
    reader_close()
  File "/home/akash/miniconda3/envs/ml/lib/python3.10/multiprocessing/connection.py"

=== Logistic Regression ===
Accuracy: 1.0
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        28
           1       1.00      1.00      1.00        20
           2       1.00      1.00      1.00        10
           3       1.00      1.00      1.00        22
           4       1.00      1.00      1.00        10
           5       1.00      1.00      1.00        10

    accuracy                           1.00       100
   macro avg       1.00      1.00      1.00       100
weighted avg       1.00      1.00      1.00       100


=== KNN (k=5, cosine) ===
Accuracy: 1.0
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        28
           1       1.00      1.00      1.00        20
           2       1.00      1.00      1.00        10
           3       1.00      1.00      1.00        22
           4       1.00      1.00      1.00        10
           5       1.00      1.00      1.00      

In [23]:
import joblib
import os

# Make an output folder (optional, but tidy)
os.makedirs("saved_models", exist_ok=True)

# 1. Save the three classifiers
joblib.dump(log_reg, "saved_models/logreg_clip_animals.pkl")
joblib.dump(knn, "saved_models/knn_clip_animals.pkl")
joblib.dump(nb, "saved_models/nb_clip_animals.pkl")

print("Models saved!")

# 2. Save class names in index order: index -> class name
class_names = [None] * len(dataset.class_to_idx)
for cls_name, idx in dataset.class_to_idx.items():
    class_names[idx] = cls_name

joblib.dump(class_names, "saved_models/class_names.pkl")
print("Class names saved:", class_names)


Models saved!
Class names saved: ['butterfly', 'cat', 'dog', 'elephant', 'horse', 'squirrel']
