# Import libraries

In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from torchvision.datasets.folder import default_loader
from torch.optim.lr_scheduler import StepLR
from collections import Counter

(Onle applicable for google colab to fix some issues)

In [None]:
import shutil
checkpoint_path = os.path.join("train", ".ipynb_checkpoints")
if os.path.exists(checkpoint_path):
    shutil.rmtree(checkpoint_path)

checkpoint_path = os.path.join("test", ".ipynb_checkpoints")
if os.path.exists(checkpoint_path):
    shutil.rmtree(checkpoint_path)

checkpoint_path = os.path.join("validation", ".ipynb_checkpoints")
if os.path.exists(checkpoint_path):
    shutil.rmtree(checkpoint_path)

# Transforms

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

train_dir = "train"
test_dir = "test"

# Load Dataset

In [None]:
full_dataset = datasets.ImageFolder(root=train_dir, transform=train_transform)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
val_dataset.dataset.transform = val_test_transform

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

Class distribution: Counter({0: 700, 1: 700})


# Model

In [None]:
model = models.resnet50(weights='DEFAULT')
model.fc = nn.Linear(model.fc.in_features, 1)
model = model.to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)

early_stopping_patience = 5
best_val_loss = float('inf')
patience_counter = 0
num_epochs = 20


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 197MB/s]


Optmizers, loss and early stopping

In [7]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005, weight_decay=1e-4)
scheduler = StepLR(optimizer, step_size=7, gamma=0.1)

#Eearly stopping
early_stopping_patience = 5
best_val_loss = float('inf')
patience_counter = 0
num_epochs = 20

# Training

In [8]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device).float()

        optimizer.zero_grad()
        outputs = model(images).squeeze(1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device).float()

            outputs = model(images).squeeze(1)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            preds = (torch.sigmoid(outputs) > 0.5).long()
            correct += (preds.cpu() == labels.cpu().long()).sum().item()
            total += labels.size(0)

    val_loss /= len(val_loader)
    val_accuracy = correct / total
    print(f"Epoch {epoch+1}, Train Loss: {running_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}")

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "Model.pth")
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print("Early stopping triggered.")
            break

    scheduler.step()

Epoch 1, Train Loss: 9.5753, Val Loss: 0.3537, Val Acc: 0.8857
Epoch 2, Train Loss: 2.9231, Val Loss: 0.0682, Val Acc: 0.9821
Epoch 3, Train Loss: 2.7341, Val Loss: 0.1843, Val Acc: 0.9536
Epoch 4, Train Loss: 0.6082, Val Loss: 0.0243, Val Acc: 0.9964
Epoch 5, Train Loss: 0.2080, Val Loss: 0.0087, Val Acc: 0.9964
Epoch 6, Train Loss: 0.1514, Val Loss: 0.0196, Val Acc: 0.9929
Epoch 7, Train Loss: 0.8212, Val Loss: 0.0336, Val Acc: 0.9964
Epoch 8, Train Loss: 0.8710, Val Loss: 0.0320, Val Acc: 0.9929
Epoch 9, Train Loss: 0.1634, Val Loss: 0.0138, Val Acc: 0.9964
Epoch 10, Train Loss: 0.0546, Val Loss: 0.0167, Val Acc: 0.9964
Early stopping triggered.


# Testing

In [9]:
model.load_state_dict(torch.load("Model.pth"))
model.eval()
correct = 0
total = 0

test_images = [f for f in os.listdir(test_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
print(f"\nFound {len(test_images)} test images.")

with torch.no_grad():
    for filename in test_images:
        file_path = os.path.join(test_dir, filename)
        image = default_loader(file_path)
        image = val_test_transform(image).unsqueeze(0).to(device)

        output = model(image).squeeze()
        pred = (torch.sigmoid(output) > 0.5).long().item()
        actual_label = 1 if filename.lower().startswith("s_") else 0
        is_correct = (pred == actual_label)
        correct += is_correct
        total += 1

        pred_label_text = "Sick" if pred == 1 else "Healthy"
        actual_label_text = "Sick" if actual_label == 1 else "Healthy"
        print(f"File: {filename} | Predicted: {pred_label_text} | Actual: {actual_label_text} | {'✔' if is_correct else '✘'}")

accuracy = (correct / total) * 100
print(f"\nTest Accuracy: {accuracy:.2f}% ({correct}/{total} correct)")


Found 80 test images.
File: 40.jpg | Predicted: Healthy | Actual: Healthy | ✔
File: S_75.jpg | Predicted: Sick | Actual: Sick | ✔
File: 38.jpg | Predicted: Healthy | Actual: Healthy | ✔
File: S_67.jpg | Predicted: Sick | Actual: Sick | ✔
File: S_55.jpg | Predicted: Sick | Actual: Sick | ✔
File: 3.jpg | Predicted: Sick | Actual: Healthy | ✘
File: 34.jpg | Predicted: Sick | Actual: Healthy | ✘
File: 16.jpg | Predicted: Healthy | Actual: Healthy | ✔
File: S_36.jpg | Predicted: Sick | Actual: Sick | ✔
File: S_72.jpg | Predicted: Healthy | Actual: Sick | ✘
File: 18.jpg | Predicted: Healthy | Actual: Healthy | ✔
File: S_34.jpg | Predicted: Sick | Actual: Sick | ✔
File: 37.jpg | Predicted: Sick | Actual: Healthy | ✘
File: S_38.jpg | Predicted: Healthy | Actual: Sick | ✘
File: S_49.jpg | Predicted: Sick | Actual: Sick | ✔
File: 27.jpg | Predicted: Healthy | Actual: Healthy | ✔
File: 13.jpg | Predicted: Healthy | Actual: Healthy | ✔
File: 14.jpg | Predicted: Healthy | Actual: Healthy | ✔
File: