In [None]:
"""
Program Title: train_alexnet_model.ipynb
Programmer/s: Idan Josh Bosi 

Where the program fits in the general system designs: 
This script is a part of a larger project that involves facial skin disease classification using AlexNet. 
It focuses on training AlexNet from scratch for skin disease detection.

Date Written: October 4, 2024
Date Revised: November 4, 2024

Purpose: 
The primary purpose of this script is to train a  AlexNet model to classify skin diseases into multiple classes. 
The script loads image datasets, applies augmentation techniques, performs training, and validates the model.

Data Structures, Algorithms, and Control Flow:
- The script primarily relies on PyTorch tensors and DataLoader for image data handling.
- Uses K-means clustering and CLAHE transforms for preprocessing and data augmentation.
- Utilizes a custom AlexNet model defined in PyTorch, with 8 output classes.
- The training loop includes an early stopping mechanism and learning rate scheduler for optimization.
- Implements visualization techniques for dataset samples and training/validation loss.
"""

import torch
import os
import random
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import cv2
import numpy as np
from PIL import Image

# Set the seed for reproducibility
SEED = 123
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(SEED)  # Apply the seed


In [None]:

BATCH_SIZE = 32
EPOCHS = 100
INPUT_SIZE = (227, 227)  


def count_files_in_directory(directory):
    total_files = 0
    for root, dirs, files in os.walk(directory):
        total_files += len(files)
    return total_files

main_data_dir = r""
train_dir = os.path.join(main_data_dir, "train")
val_dir = os.path.join(main_data_dir, "val")
test_dir = os.path.join(main_data_dir, "test")

train_files = count_files_in_directory(train_dir)
val_files = count_files_in_directory(val_dir)
test_files = count_files_in_directory(test_dir)

print(f"Training Dataset: {train_files}")
print(f"Validation Dataset: {val_files}")
print(f"Test Dataset: {test_files}")

In [None]:

# Load the training dataset to calculate mean and std, and get class labels
train_dataset = datasets.ImageFolder(root=train_dir)
class_n = list(train_dataset.class_to_idx.keys())  # Automatically retrieves class names from folders
print("Class to label mapping:", train_dataset.class_to_idx)

In [None]:
# # Function to calculate mean and std for the dataset
# def calculate_mean_std(loader):
#     mean = 0.0
#     std = 0.0
#     total_images_count = 0
#     for images, _ in loader:
#         batch_samples = images.size(0)  # batch size (the last batch can have smaller size!)
#         images = images.view(batch_samples, images.size(1), -1)  # reshape to (batch_size, channels, height * width)
#         mean += images.mean(2).sum(0)
#         std += images.std(2).sum(0)
#         total_images_count += batch_samples

#     mean /= total_images_count
#     std /= total_images_count
#     return mean, std

# # Temporary transform to load the dataset without normalization for mean and std calculation
# transform_temp = transforms.Compose([
#     transforms.Resize(INPUT_SIZE),
#     transforms.ToTensor()
# ])

# # Load the training dataset without normalization
# train_dataset_temp = datasets.ImageFolder(root=train_dir, transform=transform_temp)
# train_loader_temp = DataLoader(train_dataset_temp, batch_size=BATCH_SIZE, shuffle=False)

# # Calculate mean and std
# mean, std = calculate_mean_std(train_loader_temp)
# print(f"Calculated Mean: {mean}")
# print(f"Calculated Std: {std}")

In [None]:
from sklearn.cluster import KMeans
import numpy as np
from PIL import Image

class KMeansSegmentation:
    def __init__(self, n_clusters=12, overlay_alpha=0.5):
        self.n_clusters = n_clusters
        self.overlay_alpha = overlay_alpha

    def __call__(self, img):
        # Convert PIL image to NumPy array
        img_np = np.array(img)

        # Reshape the image to a 2D array of pixels
        h, w, c = img_np.shape
        pixel_values = img_np.reshape((-1, 3))

        # Apply K-Means clustering
        kmeans = KMeans(n_clusters=self.n_clusters, random_state=42)
        kmeans.fit(pixel_values)
        centers = np.uint8(kmeans.cluster_centers_)
        labels = kmeans.labels_

        # Create the segmented image
        segmented_image = centers[labels.flatten()]
        segmented_image = segmented_image.reshape((h, w, c))

        # Blend the segmented regions back onto the original image
        img_segmented = cv2.addWeighted(img_np, 1 - self.overlay_alpha, segmented_image, self.overlay_alpha, 0)

        # Convert back to PIL Image
        return Image.fromarray(img_segmented)


In [None]:
class CLAHETransform:
    def __init__(self, clip_limit=2.0, tile_grid_size=(8, 8)):
        self.clip_limit = clip_limit
        self.tile_grid_size = tile_grid_size

    def __call__(self, img):
        img_np = np.array(img)

        img_lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(img_lab)

        clahe = cv2.createCLAHE(clipLimit=self.clip_limit, tileGridSize=self.tile_grid_size)
        l_clahe = clahe.apply(l)

        img_clahe = cv2.merge((l_clahe, a, b))
        img_clahe = cv2.cvtColor(img_clahe, cv2.COLOR_LAB2RGB)

        return Image.fromarray(img_clahe)

In [None]:
MEAN = (0.5960, 0.4489, 0.4046)
STD = (0.2102, 0.1782, 0.1719)

class CustomRotation:
    def __init__(self, degrees, border_mode=cv2.BORDER_REPLICATE):
        self.degrees = degrees
        self.border_mode = border_mode

    def __call__(self, img):
        img_array = np.array(img)
        h, w = img_array.shape[:2]
        center = (w // 2, h // 2)

        angle = np.random.uniform(-self.degrees, self.degrees)
        rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)

        rotated_img = cv2.warpAffine(
            img_array,
            rotation_matrix,
            (w, h),
            flags=cv2.INTER_LINEAR,
            borderMode=self.border_mode
        )
        return transforms.functional.to_pil_image(rotated_img)


class CustomAffine:
    def __init__(self, degrees, translate, shear, border_mode=cv2.BORDER_REPLICATE):
        self.degrees = degrees
        self.translate = translate
        self.shear = shear
        self.border_mode = border_mode

    def __call__(self, img):

        img_array = np.array(img)
        

        h, w = img_array.shape[:2]
        
        tx = np.random.uniform(-self.translate[0] * w, self.translate[0] * w)
        ty = np.random.uniform(-self.translate[1] * h, self.translate[1] * h)
        shear_x = np.random.uniform(-self.shear, self.shear)
        shear_y = np.random.uniform(-self.shear, self.shear)

        src_pts = np.float32([[0, 0], [w, 0], [0, h]])
        dst_pts = np.float32([
            [tx, ty],
            [w + shear_x, ty],
            [shear_x, h + shear_y]
        ])
        affine_matrix = cv2.getAffineTransform(src_pts, dst_pts)


        affine_img = cv2.warpAffine(
            img_array,
            affine_matrix,
            (w, h),
            flags=cv2.INTER_LINEAR,
            borderMode=self.border_mode
        )
        return transforms.functional.to_pil_image(affine_img)


transform_train = transforms.Compose([
    transforms.Resize(INPUT_SIZE),  
    transforms.RandomHorizontalFlip(p=0.5), 
    transforms.RandomApply([CustomRotation(degrees=20)], p=0.5), 
    transforms.RandomChoice([ 
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
        transforms.RandomGrayscale(p=0.1)
    ]),
    CLAHETransform(clip_limit=2.0, tile_grid_size=(8, 8)), 
    transforms.RandomApply([KMeansSegmentation(n_clusters=12, overlay_alpha=0.5)], p=0.1),  
    transforms.ToTensor(), 
    transforms.Normalize(mean=MEAN, std=STD)  
])


transform_val_test = transforms.Compose([
    transforms.Resize(INPUT_SIZE),
    CLAHETransform(clip_limit=2.0, tile_grid_size=(8, 8)),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD)
])

# Load datasets
train_dataset = datasets.ImageFolder(root=train_dir, transform=transform_train)
val_dataset = datasets.ImageFolder(root=val_dir, transform=transform_val_test)
test_dataset = datasets.ImageFolder(root=test_dir, transform=transform_val_test)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
images, labels = next(iter(train_loader))
print(images.shape)  

In [None]:

# Function to unnormalize the image for visualization
def unnormalize(image, mean, std):
    image = image.numpy().transpose((1, 2, 0))  
    image = (image * std) + mean  
    image = np.clip(image, 0, 1)  
    return image


# Visualize a batch of images from the train_loader
def visualize_loader(loader, mean, std, class_names, num_images=6):
    data_iter = iter(loader)
    images, labels = next(data_iter)  

    plt.figure(figsize=(12, 8))
    for i in range(num_images):
        plt.subplot(2, 3, i+1)
        image = unnormalize(images[i], mean, std)  
        plt.imshow(image)
        plt.title(f"Class: {class_names[labels[i]]}")
        plt.axis('off')

    plt.tight_layout()
    plt.show()

visualize_loader(train_loader, mean=MEAN, std=STD, class_names=class_n)

In [None]:
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary

class AlexNet(nn.Module):
    def __init__(self, num_classes=len(class_n)):  
        super(AlexNet, self).__init__()
        
        self.model = models.alexnet(weights=None)
        self.model.classifier[6] = nn.Linear(4096, num_classes)  

    def forward(self, x):
        return self.model(x)

alexnet = AlexNet(num_classes=len(class_n))
device = torch.device("cuda") 
alexnet.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(alexnet.parameters(), lr=1e-4, weight_decay=1e-4) 

summary(alexnet, input_size=(3, 227, 227))
print(alexnet)


In [None]:
import os
from tqdm import tqdm
import torch

PATIENCE = 5
MODEL_PATH = 'alexnet.pth'
LOG_FILE = 'training_logs.txt'  # Path to save the logs

# Training Loop
def train_model(model, train_loader, val_loader, epochs, criterion, optimizer, device, patience, save_path=MODEL_PATH, log_file=LOG_FILE):
    # Initialize tracking metrics
    train_metrics = {'loss': [], 'accuracy': []}
    val_metrics = {'loss': [], 'accuracy': []}

    best_val_accuracy = 0.0
    patience_counter = 0

    # Ensure the log file is empty before starting
    if os.path.exists(log_file):
        open(log_file, 'w').close()

    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")

        # Train phase
        train_loss, train_accuracy = run_epoch(
            model, train_loader, criterion, optimizer, device, is_training=True
        )
        train_metrics['loss'].append(train_loss)
        train_metrics['accuracy'].append(train_accuracy)

        # Validation phase
        val_loss, val_accuracy = run_epoch(
            model, val_loader, criterion, optimizer, device, is_training=False
        )
        val_metrics['loss'].append(val_loss)
        val_metrics['accuracy'].append(val_accuracy)

        # Logging metrics to console
        print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}%")
        print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}%")

        # Save logs to file
        with open(log_file, 'a') as f:
            f.write(f"Epoch {epoch + 1}/{epochs}\n")
            f.write(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}%\n")
            f.write(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}%\n")
            f.write("\n")

        # Early stopping
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            patience_counter = 0
            print(f"New best validation accuracy: {best_val_accuracy:.4f}%. Saving model...")
            with open(log_file, 'a') as f:
                f.write(f"New best validation accuracy: {best_val_accuracy:.4f}% - Model saved.\n")
            torch.save(model.state_dict(), save_path)
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch + 1}.")
                with open(log_file, 'a') as f:
                    f.write(f"Early stopping triggered at epoch {epoch + 1}.\n")
                break

    print(f"\nTraining complete. Best validation accuracy: {best_val_accuracy:.4f}%")
    with open(log_file, 'a') as f:
        f.write(f"Training complete. Best validation accuracy: {best_val_accuracy:.4f}%\n")
    return train_metrics, val_metrics

# Function to run a single epoch
def run_epoch(model, data_loader, criterion, optimizer, device, is_training):
    phase = "Training" if is_training else "Validation"
    model.train() if is_training else model.eval()

    running_loss = 0.0
    correct = 0
    total = 0

    with tqdm(data_loader, unit="batch") as tepoch:
        tepoch.set_description(f"{phase} Phase")
        for inputs, labels in tepoch:
            inputs, labels = inputs.to(device), labels.to(device)

            if is_training:
                optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            if is_training:
                loss.backward()
                optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            tepoch.set_postfix(loss=f"{running_loss / len(data_loader):.4f}", accuracy=f"{100 * correct / total:.4f}")

    epoch_loss = running_loss / len(data_loader)
    epoch_accuracy = 100 * correct / total
    return epoch_loss, epoch_accuracy


train_metrics, val_metrics = train_model(
    alexnet, train_loader, val_loader, EPOCHS, criterion, optimizer, device, patience=PATIENCE, save_path=MODEL_PATH, log_file=LOG_FILE
)

In [None]:
import matplotlib.pyplot as plt

# Extract loss values for plotting
train_losses = train_metrics['loss']
val_losses = val_metrics['loss']

def plot_losses(train_losses, val_losses):
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss Over Epochs')
    plt.legend()
    plt.grid(True)
    plt.show()


plot_losses(train_losses, val_losses)