**Assignment 2 - Garbage Classifier**

In [None]:
import os
import random
import numpy as np
import torch
import toch.nn as nn
import pandas as pd
from torch.utils.data import DataLoader

from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from sklearn.metric import confusion_matrix, classification_report

from config import OUT_DIR, CLASS_NAMES, TRAIN_DIR, VAL_DIR, TEST_DIR, set_seed
from preprocessor import transform, build_vocab_from_dirs, ImageTextGarbageDataset
from model import EfficientNetV2MMultimodalClassifier

In [None]:
#for reproducibility
set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
print("Classes", CLASS_NAMES)

In [None]:
#Build the vocabulary from training and validation data
VOCAB = build_vocab_from_dirs([TRAIN_DIR, VAL_DIR], CLASS_NAMES)
VOCAB_SIZE = len(VOCAB)
print("Vocabulary size:", VOCAB_SIZE) 

In [None]:
test_dataset = ImageTextGarbageDataset(TEST_DIR, transform["test"], VOCAB, CLASS_NAMES)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0,
                         pin_memory=(device.type == "cuda"))
print("Test dataset size:", len(test_dataset))

In [None]:
#Load the trained model
MODEL_PATH = os.path.join(OUT_DIR, "best_model.pth")
print("Checkpoint exists:", os.path.exists(MODEL_PATH))

model = EfficientNetV2MMultimodalClassifier(
    vocab_size=VOCAB_SIZE,
    num_classes=len(CLASS_NAMES)
).to(device)

model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()

In [None]:
#Run predictions on the test set
all_preds, all_labels, all_paths, all_texts = [], [], [], []

with torch.no_grad():
    for batch in tqdm(test_loader):
        images = batch["image"].to(device)
        text_vec = batch["text_vec"].to(device)
        labels = batch["label"].to(device)

        outputs = model(images, text_vec)
        predicted = outputs.argmax(dim=1)

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_paths.extend(batch["path"])
        all_texts.extend(batch["text"])

In [None]:
#Print overall accuracy and calissification report
accuracy = 100 * (np.array(all_preds) == np.array(all_labels)).mean()
print(f"\nAccuracy on test set: {accuracy:.2f}%\n")

print("Classification Report:")
print(classification_report(all_labels, all_preds, target_names=CLASS_NAMES))

In [None]:
cm = confusion_matrix(all_labels, all_preds)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
plt.title("Confusion Matrix (Test)")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.tight_layout()

os.makedirs(OUT_DIR, exist_ok=True)
plt.savefig(os.path.join(OUT_DIR, "confusion_matrix.png"), dpi=200, bbox_inches="tight")
plt.show()

In [None]:
#Misclassified examples figure
misclassified = {name: [] for name in CLASS_NAMES}

mean = np.array([0.485, 0.456, 0.406])
std  = np.array([0.229, 0.224, 0.225])

for i, (y, p) in enumerate(zip(all_labels, all_preds)):
    if y != p:
        true_name = CLASS_NAMES[y]
        pred_name = CLASS_NAMES[p]

        img = Image.open(all_paths[i]).convert("RGB")
        img = transform["test"](img).cpu().numpy().transpose(1,2,0)
        img = (img * std) + mean
        img = np.clip(img, 0, 1)

        misclassified[true_name].append({
            "image": img,
            "true": true_name,
            "pred": pred_name,
            "text": all_texts[i]
        })

plt.figure(figsize=(15, 12))
rows = len(CLASS_NAMES)
for row, cname in enumerate(CLASS_NAMES):
    examples = misclassified[cname]
    if len(examples) == 0:
        continue
    selected = random.sample(examples, min(3, len(examples)))
    for col, ex in enumerate(selected):
        plt.subplot(rows, 3, row*3 + col + 1)
        plt.imshow(ex["image"])
        plt.title(f"True: {ex['true']}\nPred: {ex['pred']}\n{ex['text'][:20]}")
        plt.axis("off")

plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "misclassified_examples.png"), dpi=200, bbox_inches="tight")
plt.show()

In [None]:
#Save predictions to CSV
df = pd.DataFrame({
    "path": all_paths,
    "text": all_texts,
    "true": [CLASS_NAMES[i] for i in all_labels],
    "pred": [CLASS_NAMES[i] for i in all_preds],
})
csv_path = os.path.join(OUT_DIR, "test_predictions.csv")
df.to_csv(csv_path, index=False)
print("Saved predictions CSV:", csv_path)