In [1]:
import os
import time
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch import optim
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

from model import LRFCNN
from loss import FrequencyAwareAdaptiveFocalLoss
from evaluation_metrics import compute_metrics

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Data directories (modify these paths according to your environment)
train_dir = r'C:\Users\linglingyuan\Desktop\split_dataset2\train'
val_dir   = r'C:\Users\linglingyuan\Desktop\split_dataset2\val'
test_dir  = r'C:\Users\linglingyuan\Desktop\split_dataset2\test'

In [3]:
# Define data transformations (resize images to 299x299, convert to Tensor, and normalize)
data_transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5203, 0.3638, 0.6511],
                         std=[0.2339, 0.2535, 0.1455])
])

# Use ImageFolder to load the datasets
train_dataset = datasets.ImageFolder(train_dir, transform=data_transform)
val_dataset   = datasets.ImageFolder(val_dir, transform=data_transform)
test_dataset  = datasets.ImageFolder(test_dir, transform=data_transform)

batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

print(f'Training set size: {len(train_dataset)}')
print(f'Validation set size: {len(val_dataset)}')
print(f'Test set size: {len(test_dataset)}')

Training set size: 1059
Validation set size: 349
Test set size: 349


In [4]:
# Model, loss function, optimizer, and learning rate scheduler
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LRFCNN(num_classes=5, dropconnect_prob=0.5).to(device)

criterion = FrequencyAwareAdaptiveFocalLoss(num_classes=5, base_alpha=0.5, base_beta=3.0, margin_const=1.0, reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
scheduler = StepLR(optimizer, step_size=20, gamma=0.5)

In [None]:
num_epochs = 100
best_val_accuracy = 0.0 
global_train_start = time.time()
total_train_images = 0

for epoch in range(num_epochs):
    model.train()
    total_correct = 0
    total_samples = 0
    epoch_start_time = time.time()
    
    # Training loop (only accuracy is computed during training)
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} Training"):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # Calculate the batch accuracy
        _, predicted = torch.max(outputs, 1)
        total_correct += (predicted == labels).sum().item()
        total_samples += labels.size(0)
    
    # Accumulate total training images processed
    total_train_images += total_samples

    epoch_end_time = time.time()
    epoch_time = epoch_end_time - epoch_start_time
    train_accuracy = total_correct / total_samples
    
    print(f"Epoch {epoch+1}/{num_epochs} - Training Accuracy: {train_accuracy:.4f}")
    
    scheduler.step()
    
    # Validation process (compute evaluation metrics)
    model.eval()
    val_preds, val_labels = [], []
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc="Validation"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            val_preds.extend(predicted.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())
    
    metrics = compute_metrics(np.array(val_labels), np.array(val_preds))
    print("Validation Metrics:")
    print(f"  Accuracy  : {metrics['accuracy']:.4f}")
    print(f"  Precision : {metrics['precision']:.4f}")
    print(f"  Recall    : {metrics['recall']:.4f}")
    print(f"  F1-score  : {metrics['f1_score']:.4f}")
    print(f"  G-mean    : {metrics['gmean']:.4f}")
    print(f"  MCC       : {metrics['mcc']:.4f}")

    # Save best model if improved
    if metrics['accuracy'] > best_val_accuracy:
        best_val_accuracy = metrics['accuracy']
        torch.save(model.state_dict(), "best_model.pth")
        print(f"Saved best model with validation accuracy: {best_val_accuracy:.4f}")

    # Save a checkpoint for the current epoch
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_accuracy': train_accuracy,
        'val_accuracy': metrics['accuracy'],
        'loss': loss.item()
    }
    checkpoint_path = f'checkpoint_epoch_{epoch+1}.pth'
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved for epoch {epoch+1} to {checkpoint_path}")

# After training, compute overall training time per 100 images.
global_train_end = time.time()
total_train_time = global_train_end - global_train_start
train_time_per_100 = (total_train_time / total_train_images) * 100
print(f"Total Training Time per 100 images: {train_time_per_100:.4f} seconds")

Epoch 1/100 Training:   5%|██▋                                                         | 6/133 [00:04<01:31,  1.39it/s]

In [None]:
model.eval()
test_preds, test_labels = [], []
test_start_time = time.time()
with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="Testing"):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        test_preds.extend(predicted.cpu().numpy())
        test_labels.extend(labels.cpu().numpy())
test_end_time = time.time()

total_test_time = test_end_time - test_start_time
avg_test_time = total_test_time / len(test_dataset)

metrics = compute_metrics(np.array(test_labels), np.array(test_preds))
print("Test Metrics:")
print(f"  Accuracy  : {metrics['accuracy']:.4f}")
print(f"  Precision : {metrics['precision']:.4f}")
print(f"  Recall    : {metrics['recall']:.4f}")
print(f"  F1-score  : {metrics['f1_score']:.4f}")
print(f"  G-mean    : {metrics['gmean']:.4f}")
print(f"  MCC       : {metrics['mcc']:.4f}")
print(f"Average Test Time per image: {avg_test_time:.4f} seconds")