In [None]:
import os
import requests
import torch
from PIL import Image
import matplotlib.pyplot as plt
from io import BytesIO
import numpy as np
from tqdm import tqdm
import pandas as pd
import os
import sys
print(sys.executable)  # just to see which Python the notebook is using

# Install into THIS Python, not some other one
!{sys.executable} -m pip install --upgrade diffusers[torch] transformers
from diffusers import StableUnCLIPImg2ImgPipeline
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# ──────────────────────────────────────────────────────────────
# 1.  Load unCLIP – vision side only (projection_dim = 1024)   ─
# ──────────────────────────────────────────────────────────────
pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
    "sd2-community/stable-diffusion-2-1-unclip",
    torch_dtype=torch.float16,
    variant="fp16"
).to(device)

vision_encoder = pipe.image_encoder         

In [4]:
openclip_repo = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"     # projection_dim = 1024 :contentReference[oaicite:0]{index=0}
tokenizer = CLIPTokenizer.from_pretrained(openclip_repo)
text_encoder = CLIPTextModelWithProjection.from_pretrained(
    openclip_repo,
    torch_dtype=torch.float16
).to(device)

# optional: stuff them into the pipe so `pipe.tokenizer` etc. work
pipe.tokenizer, pipe.text_encoder = tokenizer, text_encoder

In [5]:
def embed_images(paths, batch_size=8):
    """Return (N,1024) image embeddings"""
    out, fe, enc = [], pipe.feature_extractor, pipe.image_encoder
    for i in range(0, len(paths), batch_size):
        print(f"Batch {i}/{len(paths)}")
        imgs = [Image.open(p).convert("RGB") for p in paths[i:i + batch_size]]
        px   = fe(imgs, return_tensors="pt").pixel_values.to(enc.device, enc.dtype)
        with torch.no_grad():
            v = enc(px)[0]                              # (B,1024)
        out.append(v)
    return torch.cat(out)  # (N,1024)

In [None]:
# ──────────────────────────────────────────────────────────────
#  Deal with file structure and get working paths
# ──────────────────────────────────────────────────────────────
root = "THINGS_animalgroups"

# collect all jpgs and keep their group (top-level folder)
groupedimages = {}

for group in os.listdir(root):
    group_dir = os.path.join(root, group)
    groupedimages[group] = []

    #go into animal name from category
    for animal in os.listdir(group_dir):
        animal_dir = os.path.join(group_dir, animal)
        if not os.path.isdir(animal_dir):
            continue

        # animal images inside animal files
        for fname in os.listdir(animal_dir):
            if fname.lower().endswith(".jpg"):
                full_path = os.path.join(animal_dir, fname)
                groupedimages[group].append(full_path)
# test
for g, imgs in groupedimages.items():
    print(g, "->", len(imgs), "images")

In [None]:
#map each group
groups = sorted(groupedimages.keys())   # deterministic order
group_to_idx = {g: i for i, g in enumerate(groups)}
print("Label mapping:", group_to_idx)

all_paths = []
all_labels = []

for group, paths in groupedimages.items():
    for p in paths:
        all_paths.append(p)
        all_labels.append(group_to_idx[group])

all_labels = torch.tensor(all_labels, dtype=torch.long)
print("Total training images:", len(all_paths))

In [None]:
#embedding using given image embedding
with torch.no_grad():
    img_feats = embed_images(all_paths) # (N, 1024)
    img_feats = img_feats.to(torch.float32)
    img_feats = F.normalize(img_feats, dim=-1)  # same normalization as notebook


print("Embedding tensor:", img_feats.shape)

In [None]:
num_classes = len(groups)
embed_dim = img_feats.shape[1]

print("num classes:", num_classes)
print("embedding dim:", embed_dim)

train_ds = TensorDataset(img_feats, all_labels.to(device))
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)

# basic vlassifier
classifier = nn.Linear(embed_dim, num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3, weight_decay=1e-4)
epochs = 100  

for epoch in range(epochs):
    classifier.train()
    total_loss = 0.0

    for feats_batch, labels_batch in train_loader:
        feats_batch = feats_batch.to(device)
        labels_batch = labels_batch.to(device)

        logits = classifier(feats_batch)
        loss = criterion(logits, labels_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * labels_batch.size(0)

    avg_loss = total_loss / len(train_ds)

    #monitor training accuracy
    if (epoch + 1) % 50 == 0 or epoch == 1:
        classifier.eval()
        with torch.no_grad():
            logits_all = classifier(img_feats.to(device))
            preds = logits_all.argmax(dim=-1).cpu()
            acc = (preds == all_labels).float().mean().item()
        print(f"Epoch {epoch+1:3d}/{epochs} | loss = {avg_loss:.4f} | train acc = {acc:.3f}")

print("Training finished.")

In [10]:
# map class indices back to group names
idx_to_group = {i: g for g, i in group_to_idx.items()}
classificationResults = []

def classifyNewImage(img_path, topk=None):
    
    classifier.eval()
    with torch.no_grad():

        feat = embed_images([img_path])
        feat = feat.to(torch.float32)
        feat = F.normalize(feat, dim=-1)
        
        logits = classifier(feat.to(next(classifier.parameters()).device))
        probs = logits.softmax(dim=-1)[0].cpu().numpy()

    pred_idx = probs.argmax()
    pred_group = idx_to_group[pred_idx]
    pred_conf = float(probs[pred_idx])
    pred_conf_pct = pred_conf * 100.0
    
    #to create data table display
    row = {
        "image": os.path.basename(img_path),
        "path": img_path,
        "pred_group": pred_group,
        "confidence_pct": pred_conf_pct,
    }

    classificationResults.append(row)

    print(f"Image: {img_path}")
    print(f"Predicted group: {pred_group} ({pred_conf_pct:.1f}% confidence)")
    print("Class probabilities:")
    for i, g in enumerate(groups):
        print(f"  {g:>12}: {probs[i]*100:.1f}%")

    if topk is not None:
        # return top-k labels + probs as a small list for analysis
        topk_idx = probs.argsort()[::-1][:topk]
        topk_labels = [idx_to_group[i] for i in topk_idx]
        topk_probs = probs[topk_idx]
        return pred_group, probs, topk_labels, topk_probs

    return pred_group, probs

In [None]:
classifyNewImage('PCA_images/alligator2_original_.jpg')

In [None]:
classifyNewImage('PCA_images/alligator2_1_8_.jpg')

In [None]:
classifyNewImage('PCA_images/beetle_original.jpg')

In [None]:
classifyNewImage('PCA_images/beetle_0_10.jpg')

In [None]:
classifyNewImage('PCA_images/beetle_10_15.jpg')

In [None]:
classifyNewImage('PCA_images/elephant7_original_.jpg')

In [None]:
classifyNewImage('PCA_images/elephant_1_5_.jpg')

In [None]:
classifyNewImage('PCA_images/hawk5_original_.jpg')

In [None]:
classifyNewImage('PCA_images/hawk5_1_5_.jpg')

In [20]:
results= pd.DataFrame(classificationResults)
results

Unnamed: 0,image,path,pred_group,confidence_pct
0,alligator2_original_.jpg,PCA_images/alligator2_original_.jpg,reptiles,85.019684
1,alligator2_1_8_.jpg,PCA_images/alligator2_1_8_.jpg,reptiles,68.069148
2,beetle_original.jpg,PCA_images/beetle_original.jpg,bugs,95.490265
3,beetle_0_10.jpg,PCA_images/beetle_0_10.jpg,bugs,84.802097
4,beetle_10_15.jpg,PCA_images/beetle_10_15.jpg,bugs,61.293793
5,elephant7_original_.jpg,PCA_images/elephant7_original_.jpg,mammals,95.410377
6,elephant_1_5_.jpg,PCA_images/elephant_1_5_.jpg,mammals,93.283194
7,hawk5_original_.jpg,PCA_images/hawk5_original_.jpg,birds,95.783722
8,hawk5_1_5_.jpg,PCA_images/hawk5_1_5_.jpg,birds,90.796012
