In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
from xgboost import XGBClassifier
import os
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from torchvision import datasets

from sklearn.model_selection import GridSearchCV


In [None]:
dir_ = "/kaggle/input/big-cats-image-classification-dataset/animals"

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

In [None]:
data = datasets.ImageFolder(dir_, transform=transform)

In [None]:
n_samples = len(data)

indices = list(range(n_samples))

train_idx, test_idx = train_test_split(
    indices, 
    test_size=0.2,       
    random_state=42,      
    shuffle=True
)

train_dataset = Subset(data, train_idx)
test_dataset = Subset(data, test_idx)

In [None]:
class_names = train_data.classes
print(class_names)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
resnet = nn.Sequential(*list(resnet.children())[:-1])
resnet.to(device)
resnet.eval()

In [None]:
mobilenet = models.mobilenet_v3_large(weights="IMAGENET1K_V2")
mobilenet.classifier = nn.Identity()
mobilenet.to(device)
mobilenet.eval()

In [None]:
def extract_features(dataloader, model, device):
    features = []
    labels = []
    
    with torch.no_grad():
        for imgs, _labels in tqdm(dataloader):
            imgs = imgs.to(device)
            outputs = model(imgs)
            outputs = outputs.view(outputs.size(0), -1)
            features.append(outputs.cpu().numpy().flatten())
            labels.append(_labels.numpy())
    
    features = np.concatenate(features)
    labels = np.concatenate(labels)
    return features, labels

In [None]:
train_features, y_train = extract_features(train_loader, mobilenet, device)
test_features, y_test = extract_features(test_loader, mobilenet, device)

In [None]:
from sklearn.model_selection import GridSearchCV

def train_and_evaluate(
    X_train_feat, X_test_feat, y_train, y_test, 
    class_names, name="", use_gridsearch=False, X_test_raw=None
):

    if use_gridsearch:
        param_grid = {
            "n_estimators": [100, 150, 200, 250, 300],
            "max_depth": [10, 15, 20, 25, 30, 35, 40],
            "min_samples_split": [2, 3, 4, 5]
        }

        base_clf = RandomForestClassifier(random_state=168)

        print(f"\nRunning GridSearchCV for {name}...")
        grid = GridSearchCV(
            estimator=base_clf,
            param_grid=param_grid,
            scoring="accuracy",
            cv=3,
            n_jobs=-1,
            verbose=1
        )

        grid.fit(X_train_feat, y_train)
        clf = grid.best_estimator_

        print("\nBest Params:", grid.best_params_)
        print("Best CV Score:", grid.best_score_)

    else:
        clf = RandomForestClassifier(
            n_estimators=300,
            max_depth=30,
            random_state=168
        )
        clf.fit(X_train_feat, y_train)

    pred = clf.predict(X_test_feat)

    mis_idx = np.where(pred != y_test)[0]
    correct_idx = np.where(pred == y_test)[0]

    print(f"\nTotal misclassified samples: {len(mis_idx)}")

    print("\n=== CORRECTLY predicted samples ===")
    print(correct_idx[:5])
    for i in correct_idx[:5]:
        print(f"Index: {i}, Label: {class_names[y_test[i]]}")

    print("\n=== MISCLASSIFIED samples ===")
    print(mis_idx[:5])
    for i in mis_idx[:5]:
        print(f"Index: {i}, True: {class_names[y_test[i]]}, Pred: {class_names[pred[i]]}")

    model_path = f"{name}_rf_model.pkl"
    joblib.dump(clf, model_path)
    print(f"\nSaved model to {model_path}")

    acc = accuracy_score(y_test, pred)
    print(f"\nAccuracy ({name}): {acc:.4f}")
    print(f"\nClassification Report ({name}):\n", classification_report(y_test, pred))

    if X_test_raw is not None:
        print("\nShowing correct samples:")
        for i in correct_idx[:5]:
            plt.figure(figsize=(3, 3))
            plt.imshow(X_test_raw[i])
            plt.title(f"CORRECT — Label: {class_names[y_test[i]]}")
            plt.axis("off")
            plt.show()

        print("\nShowing misclassified samples:")
        for i in mis_idx[:5]:
            plt.figure(figsize=(3, 3))
            plt.imshow(X_test_raw[i])
            plt.title(f"WRONG — True: {class_names[y_test[i]]}, Pred: {class_names[pred[i]]}")
            plt.axis("off")
            plt.show()

    plot_confusion_matrix(y_test, pred, class_names, f"Confusion Matrix – {name}")
    return acc

In [None]:
class_names = ['cheetah', 'leopard', 'lion', 'tiger']

In [None]:
train_and_evaluate(
    train_features, test_features, y_train, y_test, 
    class_names, name="", use_gridsearch=True, X_test_raw=None
)