### 1. Imports

In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models

from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt

### 2. Device

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


### 3. Data loading & preprocessing

In [4]:
data_dir = "dataset_10"  # <- change this

# ImageNet mean/std (works well with pretrained ResNet)
IMG_SIZE = 224
mean = [0.485, 0.456, 0.406]
std  = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

val_test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

# Full dataset (for splitting)
full_dataset = datasets.ImageFolder(root=data_dir, transform=train_transform)
class_names = full_dataset.classes
num_classes = len(class_names)
print("Classes:", class_names)
print("Total images:", len(full_dataset))

# Train/val/test split: 70/15/15
total_size = len(full_dataset)
train_size = int(0.7 * total_size)
val_size   = int(0.15 * total_size)
test_size  = total_size - train_size - val_size

generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset, test_dataset = random_split(
    full_dataset,
    [train_size, val_size, test_size],
    generator=generator
)

# Override transforms for val/test
val_dataset.dataset.transform = val_test_transform
test_dataset.dataset.transform = val_test_transform

batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


Classes: ['single_prediction', 'test_set', 'training_set', 'val']
Total images: 26181


### 4. Model: CNN with transfer learning(ResNet-18)

In [5]:
# Load pretrained ResNet-18
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

# Freeze all convolutional layers
for param in model.parameters():
    param.requires_grad = False

# Replace the final classification layer
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)

model = model.to(device)
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

### 5. Training loop (with validation)

In [6]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    total = 0

    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * images.size(0)
        running_corrects += torch.sum(preds == labels).item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = running_corrects / total
    return epoch_loss, epoch_acc


def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    total = 0

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * images.size(0)
            running_corrects += torch.sum(preds == labels).item()
            total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = running_corrects / total
    return epoch_loss, epoch_acc

In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.3, patience=3, verbose=True
)



In [8]:
num_epochs = 15

best_val_acc = 0.0
best_model_state = None

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)

    scheduler.step(val_loss)

    print(f"Epoch {epoch+1}/{num_epochs} "
          f"| Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} "
          f"| Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_state = model.state_dict().copy()

print("Best validation accuracy:", best_val_acc)
# Restore best model
if best_model_state is not None:
    model.load_state_dict(best_model_state)


Epoch 1/15 | Train Loss: 0.6847, Train Acc: 0.7935 | Val Loss: 0.6713, Val Acc: 0.8029
Epoch 2/15 | Train Loss: 0.6605, Train Acc: 0.7982 | Val Loss: 0.6547, Val Acc: 0.8029
Epoch 3/15 | Train Loss: 0.6558, Train Acc: 0.7983 | Val Loss: 0.6528, Val Acc: 0.8029
Epoch 4/15 | Train Loss: 0.6558, Train Acc: 0.7983 | Val Loss: 0.6558, Val Acc: 0.8029
Epoch 5/15 | Train Loss: 0.6504, Train Acc: 0.7983 | Val Loss: 0.6552, Val Acc: 0.8029
Epoch 6/15 | Train Loss: 0.6478, Train Acc: 0.7983 | Val Loss: 0.6579, Val Acc: 0.8029
Epoch 7/15 | Train Loss: 0.6482, Train Acc: 0.7982 | Val Loss: 0.6536, Val Acc: 0.8029
Epoch 8/15 | Train Loss: 0.6338, Train Acc: 0.7983 | Val Loss: 0.6522, Val Acc: 0.8029
Epoch 9/15 | Train Loss: 0.6323, Train Acc: 0.7983 | Val Loss: 0.6547, Val Acc: 0.8029
Epoch 10/15 | Train Loss: 0.6327, Train Acc: 0.7983 | Val Loss: 0.6576, Val Acc: 0.8026
Epoch 11/15 | Train Loss: 0.6332, Train Acc: 0.7983 | Val Loss: 0.6528, Val Acc: 0.8029
Epoch 12/15 | Train Loss: 0.6327, Train A

### 6. Evaluation on the test set

### 6.1 Overall accuracy

In [9]:
test_loss, test_acc = evaluate(model, test_loader, criterion, device)
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

Test Loss: 0.6466, Test Accuracy: 0.8035


### 6.2 Confusion matrix & classification report

In [10]:
all_labels = []
all_preds = []

model.eval()
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _, preds = torch.max(outputs, 1)

        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())

cm = confusion_matrix(all_labels, all_preds)
print("Classification report:")
print(classification_report(all_labels, all_preds, target_names=class_names))

fig, ax = plt.subplots(figsize=(8, 8))
im = ax.imshow(cm, interpolation='nearest')
ax.figure.colorbar(im, ax=ax)
ax.set(
    xticks=np.arange(len(class_names)),
    yticks=np.arange(len(class_names)),
    xticklabels=class_names,
    yticklabels=class_names,
    xlabel='Predicted label',
    ylabel='True label',
    title='Confusion Matrix'
)
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
plt.tight_layout()
plt.show()

Classification report:


ValueError: Number of classes, 3, does not match size of target_names, 4. Try specifying the labels parameter