# Imports

In [None]:
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm

import torch
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split, Subset, ConcatDataset
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary

!pip install optuna
import optuna

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Global Variables

In [None]:
# --- Dataset variables ---
BATCH_SIZE = 64
NUM_WORKERS = 4
#ROOT = '/content/drive/My Drive/Colab Notebooks/156 Project/Rice_Image_Dataset'
ROOT = '/content/drive/My Drive/156 Project/Rice_Image_Dataset' #this is the root for Farrel


FULL_DATASET = ImageFolder(root=ROOT)
NUM_CLASSES = len(FULL_DATASET.classes)

# --- DEVICE ---
DEVICE = torch.device(
    "mps" if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)
print("Using device:", DEVICE)


# --- Train and Validation Set Proportions ---
P_TRAIN = 0.6 #0.6
P_VAL = 0.2 #0.2
P_TUNE = 0.3

# --- Training Parameters ---
CRITERION = nn.CrossEntropyLoss()
LEARNING_RATE = 0.001
NUM_EPOCHS = 20 #20
HIDDEN_SIZE = 32
OPTIMIZER = optim.Adam

# --- Validation Parameters ---
NUM_TRIALS = 20

# Exploratory Data Analysis

In this dataset, each of the five classes contributes exactly 15,000 images, so class representation is perfectly balanced across the full distribution. Furthermore, image sizing is consistent where each image is 250x250 pixels.

In [None]:
# Generate indices for sampling images from the dataset
idx = [np.arange(5)* 15000] * 5 + np.random.choice(15000,(5,5))
idx = idx.T

# Create a 5x5 grid of subplots
fig, ax = plt.subplots(5,5, figsize=(15,20),sharex=True,sharey=True)

for r, l in enumerate(idx):
    for c, i in enumerate(l):
        im,lbl = FULL_DATASET[i]

        ax[r,c].imshow(im)
        
        # Label first column and title third column
        if c == 0:
            ax[r,c].set_ylabel(f'{FULL_DATASET.classes[lbl]}')
        if c == 2:
            ax[r,c].set_title(f'{FULL_DATASET.classes[lbl]}')

# Preprocessing

We compute the mean and standard deviation over the entire dataset and use these values to normalize the inputs. Since mean and standard deviation are low-level, label-agnostic statistics (they do not encode target information directly), estimating them on the full dataset does not introduce any meaningful form of data leakage in this context. Practically, using the full dataset also avoids the extra overhead of recomputing these statistics on the training subset only, while still providing stable, representative normalization parameters.

In [None]:
'''
# Define preprocessing transforms: resize images and convert to tensor
pre_transform = transforms.Compose([
    transforms.Resize((250,250)), # Resize all images to 250x250
    transforms.ToTensor() # Convert images to PyTorch tensors (0-1 range)
])

# Load dataset and create DataLoader for batch processing
pre_dataset = ImageFolder(root=ROOT, transform=pre_transform)
pre_loader = DataLoader(pre_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

# Initialize variables to accumulate mean and std
mean = 0.0
std = 0.0
nb_samples = 0.0

# Loop over batches in the DataLoader
for d, _ in pre_loader:
    batch_samples = d.size(0) # Number of images in current batch
    d = d.view(batch_samples, d.size(1), -1) # Flatten HxW into one dimension per channel
    mean += d.mean(2).sum(0) # Sum per-channel means
    std += d.std(2).sum(0) # Sum per-channel standard deviations
    nb_samples += batch_samples # Keep track of total number of samples

# Compute final mean and std by dividing by total number of images
mean /= nb_samples
std /= nb_samples

# Print per-channel mean and standard deviation
print(f"mean:{mean}")
print(f'std: {std}')
'''

mean = [0.1179, 0.1189, 0.1229]
std = [0.2851, 0.2875, 0.2989]

## Transforms

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((250,250)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

val_test_transform = transforms.Compose([
    transforms.Resize((250,250)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean,std=std)
])

## Train, Validation, and Test Set Division

In [None]:
# Compute sizes for train, validation, and test splits
total_size = len(FULL_DATASET)
train_size = int(P_TRAIN * total_size)
val_size = int(P_VAL * total_size)
test_size = total_size - train_size - val_size

# Randomly split the dataset into train, val, and test subsets
train_sub, val_sub, test_sub = random_split(
    FULL_DATASET, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42) # Ensure reproducibility
)

# Get the indices of each subset
train_idx = train_sub.indices
val_idx   = val_sub.indices
test_idx  = test_sub.indices

# Create Subset objects with appropriate transforms
train_set = Subset(ImageFolder(ROOT, transform=train_transform), train_idx)
val_set   = Subset(ImageFolder(ROOT, transform=val_test_transform), val_idx)
test_set  = Subset(ImageFolder(ROOT, transform=val_test_transform), test_idx)

# Create DataLoaders for batching
train_loader = DataLoader(train_set, BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS)
val_loader = DataLoader(val_set,     BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
test_loader = DataLoader(test_set,   BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

In [None]:
from sklearn.decomposition import PCA
import pandas as pd
import seaborn as sns

images_flat = [] # List to store flattened images
labels_list = [] # List to store corresponding labels

print("Gathering data for PCA...")

# Loop over validation loader to collect a small subset of images
for i, (imgs, lbls) in enumerate(val_loader):
    if i > 5: break # Limit to first 6 batches for speed
    flat = imgs.view(imgs.size(0), -1).numpy()  # Flatten images to 1D vectors
    images_flat.extend(flat)
    labels_list.extend(lbls.numpy())

# Perform PCA to reduce image dimensions to 2D
pca = PCA(n_components=2)
pca_result = pca.fit_transform(images_flat)

plt.figure(figsize=(10, 8))

# Create a DataFrame for visualization
df_pca = pd.DataFrame(data=pca_result, columns=['PC1', 'PC2'])
df_pca['Rice Type'] = [FULL_DATASET.classes[l] for l in labels_list]

# Plot the 2D PCA result with Seaborn
sns.scatterplot(x='PC1', y='PC2', hue='Rice Type', data=df_pca, palette='tab10', alpha=0.7)
plt.title('PCA of Raw Rice Images (2D Projection)')
plt.show()

# Creating the CNN

In [None]:
class CNN(nn.Module):
    def __init__(self, num_classes, hidden_size):
        super(CNN,self).__init__()
        hidden_size2 = hidden_size*2  # Double channels in second conv layer

         # Feature extractor: 2 conv layers with ReLU and MaxPool
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3,  out_channels=hidden_size, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2),

            nn.Conv2d(in_channels=hidden_size, out_channels=hidden_size2, kernel_size= 5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        # Classifier: global pooling + flatten + linear layer
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(hidden_size2,num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        out = self.classifier(x)
        return out

## Summary

In [None]:
# Instantiate the CNN and move it to the device 
cnn = CNN(NUM_CLASSES,HIDDEN_SIZE)
cnn = cnn.to(DEVICE)

# Print a summary of the model (layers, output shapes, params)
summary(cnn,(3,250,250))

# Set up the optimizer for training
optimizer = OPTIMIZER(cnn.parameters(),lr=LEARNING_RATE)

## Evaluate Function

In [None]:
def evaluate(model, loader, device, leave=True):
    model.eval() # Set model to evaluation mode
    total_loss, correct, total = 0.0, 0, 0

    with torch.no_grad(): # Disable gradient computation
        for images, labels in tqdm(loader, desc='Evaluation: ', leave=leave):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = CRITERION(outputs, labels) # Compute loss
            total_loss += loss.item() * images.size(0)

            # Get predicted class and count correct predictions
            _, preds = torch.max(outputs,1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            
    avg_loss = total_loss / total
    accuracy = correct / total

    return avg_loss, accuracy

# Training the CNN

In [None]:
train_loss_list = []
val_loss_list = []
val_accuracy_list = []

cnn = cnn.to(DEVICE)

for epoch in range(NUM_EPOCHS):
    cnn.train() # Set model to training mode
    running_loss = 0.0

    loop = tqdm(train_loader, desc=f'Epoch [{epoch+1}/{NUM_EPOCHS}]', leave=True)

    for images, labels in loop:
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad() # Clear previous gradients
        outputs = cnn(images) # Forward pass
        loss = CRITERION(outputs, labels)  # Compute loss
        loss.backward() # Backpropagation
        optimizer.step() # Update weights
        running_loss += loss.item() * images.size(0)

        loop.set_postfix(batch_loss=loss.item())  # Display batch loss in tqdm

    # Compute average training loss for the epoch
    train_loss = running_loss / len(train_loader.dataset)
    # Evaluate on validation set
    val_loss, val_acc = evaluate(cnn, val_loader, DEVICE)

    # Store metrics for plotting or analysis
    train_loss_list.append(train_loss)
    val_loss_list.append(val_loss)
    val_accuracy_list.append(val_acc)

    print(f'Epoch [{epoch+1}/{NUM_EPOCHS}]: Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}' )


In [None]:
# Save the trained model's weights to Google Drive
drive_path = '/content/drive/My Drive/156 Project/rice_cnn_best_model.pth'
torch.save(final_model.state_dict(), drive_path)
print(f"Model saved permanently to {drive_path}")

In [None]:
# Download the saved model file to your local machine
from google.colab import files
files.download('rice_cnn_best_model.pth')

In [None]:
# Instantiate the same CNN architecture and move to device
loaded_model = CNN(num_classes=NUM_CLASSES, hidden_size=32)
loaded_model = loaded_model.to(DEVICE)

# Path to the saved model weights
load_path = '/content/drive/My Drive/156 Project/rice_cnn_best_model.pth'

# Load the saved weights into the model
loaded_model.load_state_dict(torch.load(load_path, map_location=DEVICE))

# Set model to evaluation mode
loaded_model.eval()

print("Model loaded successfully!")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

train_loss_list = [
    0.2678, 0.2573, 0.2396, 0.2255, 0.2166,
    0.1958, 0.1873, 0.1752, 0.1631, 0.1517,
    0.1550, 0.1347, 0.1353, 0.1287, 0.1206,
    0.1212, 0.1113, 0.1220, 0.1076, 0.1004
]

val_loss_list = [
    0.4059, 0.1638, 0.1423, 0.1969, 0.0901,
    0.0941, 0.1325, 0.1939, 0.0527, 0.0917,
    0.0432, 0.0497, 0.0899, 0.0448, 0.0713,
    0.1269, 0.0615, 0.0663, 0.0371, 0.0616
]

val_accuracy_list = [
    0.8170, 0.9378, 0.9503, 0.9231, 0.9746,
    0.9795, 0.9497, 0.9334, 0.9860, 0.9691,
    0.9880, 0.9860, 0.9726, 0.9910, 0.9829,
    0.9637, 0.9902, 0.9876, 0.9905, 0.9847
]


NUM_EPOCHS = 20
fig, ax = plt.subplots(1, 2, sharex=True, figsize=(12, 5))

# Plot Training vs Validation Loss
ax[0].plot(range(1, NUM_EPOCHS + 1), train_loss_list, label="Training Loss")
ax[0].plot(range(1, NUM_EPOCHS + 1), val_loss_list, label="Validation Loss")
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Loss')
ax[0].set_title('Training vs Validation Loss')
ax[0].legend()
ax[0].grid(True)

# Plot Validation Accuracy
ax[1].plot(range(1, NUM_EPOCHS + 1), val_accuracy_list, label="Validation Accuracy", color='green')
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Accuracy')
ax[1].set_title('Validation Accuracy')
ax[1].legend()
ax[1].grid(True)

plt.tight_layout()
plt.show()

## Visuals

In [None]:
fig, ax = plt.subplots(1,2, sharex=True)

# Plot training and validation loss
ax[0].plot(range(NUM_EPOCHS), train_loss_list,   label = "Training Loss")
ax[0].plot(range(NUM_EPOCHS), val_loss_list,     label = 'Validation Loss')

# Plot validation accuracy
ax[1].plot(range(NUM_EPOCHS), val_accuracy_list, label = 'Validation Accuracy')

# Set axis labels and titles
ax[0].set_xlabel('Epoch')
ax[1].set_xlabel('Epoch')
ax[0].set_ylabel('Loss')
ax[1].set_ylabel('Accuracy')
ax[0].set_title('Loss')
ax[1].set_title('Accuracy')

# Add legends
ax[0].legend()
ax[1].legend()

# Adjust layout to prevent overlap
fig.tight_layout()

# Print rounded metrics
print(np.round(train_loss_list,3))
print(np.round(val_loss_list,3))
print(np.round(val_accuracy_list,3))

# Validation and Hyperparameter Tuning

In [None]:
# Select a random subset of the training set for tuning 
indices = np.random.choice(len(train_set), int(len(train_set)*P_TUNE), replace=False) 

# Create a Subset and corresponding DataLoader for tuning
tune_train_set = Subset(train_set, indices)
tune_train_loader = DataLoader(tune_train_set, BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

In [None]:
def objective(trial):
    # Suggest hyperparameters for this trial
    lr = trial.suggest_float('lr', 1e-4, 1e-2, log=True)
    hidden_size = trial.suggest_categorical('hidden_size', [16,32,64])
    weight_decay = trial.suggest_float('weight_decay', 1e-6, 1e-3, log=True)

    # Create and move CNN to device
    cnn = CNN(NUM_CLASSES, hidden_size)
    cnn = cnn.to(DEVICE)
    optimizer = OPTIMIZER(cnn.parameters(), lr=lr, weight_decay = weight_decay)

    # Train for a few epochs on the tuning subset
    for epoch in range(3):
        cnn.train()
        for images, labels in tune_train_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = cnn(images)
            loss = CRITERION(outputs,labels)
            loss.backward()
            optimizer.step()

        # Evaluate on validation set and report to Optuna
        val_loss, val_acc = evaluate(cnn, val_loader, DEVICE)
        trial.report(val_loss, epoch)
        if trial.should_prune(): # Early stopping for unpromising trials
            raise optuna.TrialPruned()

    # Final evaluation on validation set
    val_loss, val_acc = evaluate(cnn, val_loader, DEVICE, leave=False)
    return val_loss # Objective: minimize validation loss

# Set up Optuna study with pruning
pruner = optuna.pruners.MedianPruner()
study = optuna.create_study(direction='minimize',pruner=pruner)
study.optimize(objective, n_trials=NUM_TRIALS, show_progress_bar=True)

# Print and save best hyperparameters
print("Best trial:", study.best_trial.params)
best = study.best_trial.params

# Testing the CNN

## Retraining the Entire Model

In [None]:
# Best hyperparameters from Optuna
best = {'lr': 0.001, 'hidden_size': 32, 'weight_decay': 1e-5}
best_hidden = best['hidden_size']
best_lr = best['lr']
best_wd = best['weight_decay']

# Initialize final model with best hyperparameters
final_model = CNN(NUM_CLASSES, best_hidden)
final_model = final_model.to(DEVICE)

# Optimizer for final training
final_optimizer = OPTIMIZER(final_model.parameters(), lr = best_lr, weight_decay = best_wd)

# Combine training and validation sets for final training
train_val_set = ConcatDataset([train_set,val_set])
train_val_loader = DataLoader(train_val_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

# Train final model on combined dataset
for epoch in range(NUM_EPOCHS):
    final_model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_val_loader, desc=f'Final Training Epoch [{epoch+1}/{NUM_EPOCHS}]'):
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        final_optimizer.zero_grad()
        outputs = final_model(images)
        loss = CRITERION(outputs, labels)
        loss.backward()
        final_optimizer.step()
        running_loss += loss.item() * images.size(0)
    epoch_loss = running_loss / len(train_val_loader.dataset)
    print(f"Epoch {epoch+1}, Combined Train Loss: {epoch_loss:.4f}")

Save the model


In [None]:
# Save the trained final model's weights
model_save_path = 'rice_cnn_best_model.pth'
torch.save(final_model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

## Testing Results

In [None]:
# Evaluate the final trained model on the test set
test_loss, test_acc = evaluate(final_model, test_loader, DEVICE)
print(f"Final Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

##Confusion Matrix

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
all_preds = []
all_labels = []

# Collect predictions and true labels from the test set
for images, labels in test_loader:
  images, labels = images.to(DEVICE), labels.to(DEVICE)
  outputs = final_model(images)
  _, preds = torch.max(outputs, 1)
  all_preds.extend(preds.cpu().numpy())
  all_labels.extend(labels.cpu().numpy())

# Compute confusion matrix
cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=FULL_DATASET.classes)

# Plot confusion matrix
disp.plot(cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.show()


In [None]:
# Extract the best hidden size from the hyperparameter tuning results
best = {'lr': 0.001, 'hidden_size': 32, 'weight_decay': 1e-5}
best_hidden = best['hidden_size']

In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import torch
import numpy as np
y_pred = []
y_true = []
final_model.eval() # Set model to evaluation mode
print("Generating predictions...")

with torch.no_grad(): # Disable gradient computation
    for images, labels in test_loader:

        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        outputs = final_model(images)

        _, predicted = torch.max(outputs, 1) # Get predicted class indices

        y_pred.extend(predicted.cpu().numpy())
        y_true.extend(labels.cpu().numpy())
print("Predictions generated successfully.")

# Print detailed classification metrics
print("\n--- Classification Report ---")
print(classification_report(y_true, y_pred, target_names=FULL_DATASET.classes))

The "Basmati vs. Jasmine" Confusion is Quantified
Basmati Recall (0.80): This is our lowest score. It means 20% of real Basmati rice is being missed. Based on our Confusion Matrix, we know they are being mislabeled as Jasmine.

Jasmine Precision (0.77): This is also low. It means that when our model claims "This is Jasmine," it is only right 77% of the time. Why? Because it's "hallucinating" Jasmine when it actually sees Basmati.


While the model achieved an overall accuracy of 92%, performance was non-uniform across classes. Specifically, the model struggled to distinguish Basmati from Jasmine, resulting in a lower recall for Basmati (0.80) and lower precision for Jasmine (0.77). This suggests the model relies heavily on features (likely grain length or color) that are shared between these two specific rice varieties.

Ipsala is the "Control Group"
Ipsala (Precision 1.00, Recall 0.96): The model is basically perfect at identifying Ipsala.

This proves our model architecture works well. The failure on Basmati isn't because our CNN is broken; it's because the data (the images of those specific grains) is harder to distinguish.

In [None]:
# Visualize Data Augmentation Effects
def visualize_augmentation(dataset, idx=0):
    raw_img, label = dataset[idx] # Original image and its label

    fig, axes = plt.subplots(1, 6, figsize=(15, 3))

    # Placeholder for original image
    axes[0].text(0.5, 0.5, "Original\n(See Raw)", ha='center')
    axes[0].axis('off')

    # Generate and display augmented images
    for i in range(1, 6):
        img_aug, _ = train_set[idx] # Apply training transforms

        # Undo normalization for visualization
        img_display = img_aug.clone()
        mean = [0.1179, 0.1189, 0.1229]
        std = [0.2851, 0.2875, 0.2989]
        for c in range(3):
            img_display[c] = img_display[c] * std[c] + mean[c]

        # Convert to HxWxC and clip values
        img_display = img_display.permute(1, 2, 0).numpy()
        img_display = np.clip(img_display, 0, 1)

        axes[i].imshow(img_display)
        axes[i].set_title(f"Augmentation {i}")
        axes[i].axis('off')

    plt.suptitle(f"Effect of Data Augmentation on {FULL_DATASET.classes[label]}")
    plt.show()

# Visualize augmentations on the 101st image
visualize_augmentation(FULL_DATASET, idx=100)

In [None]:
# Transform for test images: resize, tensor, normalize
test_transform = transforms.Compose([
    transforms.Resize((250,250)),
    transforms.ToTensor(),

    transforms.Normalize(mean=[0.1179, 0.1189, 0.1229], std=[0.2851, 0.2875, 0.2989])
])

# Class indices
basmati_idx = FULL_DATASET.class_to_idx['Basmati']
jasmine_idx = FULL_DATASET.class_to_idx['Jasmine']

print("Scanning for Basmati -> Jasmine errors... (This might take a moment)")
found_images = []
final_model.eval()

with torch.no_grad():
    for i in range(len(FULL_DATASET)):
        if len(found_images) >= 5: break


        path, label = FULL_DATASET.samples[i]

        # Only check Basmati images
        if label == basmati_idx:
            img_tensor, _ = FULL_DATASET[i]

            # Load raw PIL image and apply test transforms
            img_raw = FULL_DATASET.loader(path) # Load raw PIL
            img_input = test_transform(img_raw).unsqueeze(0).to(DEVICE)

            # Predict class
            output = final_model(img_input)
            _, pred = torch.max(output, 1)

            # If misclassified as Jasmine, save the image
            if pred.item() == jasmine_idx:
                found_images.append(img_input.cpu().squeeze(0))

print(f"Found {len(found_images)} misclassified examples.")

# Visualize misclassified images
if len(found_images) > 0:
    plt.figure(figsize=(15, 4))
    for i, img_tensor in enumerate(found_images):
        ax = plt.subplot(1, 5, i+1)

        # Un-normalize
        img_display = img_tensor.clone()
        mean = [0.1179, 0.1189, 0.1229]
        std = [0.2851, 0.2875, 0.2989]
        for c in range(3):
            img_display[c] = img_display[c] * std[c] + mean[c]

        img_display = img_display.permute(1, 2, 0).numpy()
        img_display = np.clip(img_display, 0, 1)

        ax.imshow(img_display)
        ax.set_title("True: Basmati\nPred: Jasmine")
        ax.axis('off')

    plt.suptitle("Error Analysis: Basmati grains misclassified as Jasmine")
    plt.show()
else:
    print("No errors found in the scanned samples")

In [None]:
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.metrics import adjusted_rand_score

feature_list = []
label_list = []

final_model.eval()
print("Extracting CNN features...")
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(DEVICE)

        # Pass through the convolutional layers
        x = final_model.features(images)
        x = final_model.classifier[0](x) # AdaptiveAvgPool
        x = final_model.classifier[1](x) # Flatten

        feature_list.extend(x.cpu().numpy())
        label_list.extend(labels.numpy())

features = np.array(feature_list)
labels = np.array(label_list)

# Perform K-Means clustering on extracted features
kmeans = KMeans(n_clusters=5, random_state=42)
clusters = kmeans.fit_predict(features)

# Evaluate clustering quality with Adjusted Rand Index
ari = adjusted_rand_score(labels, clusters)
print(f"K-Means Adjusted Rand Index (clustering quality): {ari:.4f}")

# Reduce features to 2D for visualization
pca_features = PCA(n_components=2).fit_transform(features)

# Plot 2D PCA scatter colored by cluster assignment
plt.figure(figsize=(10, 8))
scatter = plt.scatter(pca_features[:, 0], pca_features[:, 1], c=clusters, cmap='tab10', alpha=0.6)
plt.title(f'K-Means Clustering of CNN Features (Bishop Sec 9.1)\nARI Score: {ari:.3f}')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.colorbar(scatter, label='Cluster ID')
plt.show()