In [None]:
import torch
import torch.nn as nn
from torchvision.models import resnet18
from torchvision.transforms import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pathlib import Path
from datetime import datetime

# Constants
IMG_SIZE = 224
NUM_CLASSES = 10
NUM_EPOCHS = 30
MODEL_FOLDER_PATH = Path.cwd()  # Use current working directory
DEVICE = torch.device("cpu")

# Load or create the model
model_path = MODEL_FOLDER_PATH / f"model_resnet18_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"
if model_path.exists():
    model = torch.load(model_path, map_location=DEVICE)
    model.eval()
else:
    model = resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, NUM_CLASSES)

# Define data preprocessing transformation
train_mean = [0.485, 0.456, 0.406]
train_std = [0.229, 0.224, 0.225]
img_normalize = transforms.Normalize(mean=train_mean, std=train_std)
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    img_normalize,
])

# Define data paths
train_data_path = "/mnt/d/FY2023/DataSets/milVehs/dataset/train"
valid_data_path = "/mnt/d/FY2023/DataSets/milVehs/dataset/validation"
test_data_path = "/mnt/d/FY2023/DataSets/milVehs/dataset/test"

# Define dataset and dataloader for training, validation, and testing
train_dataset = ImageFolder(train_data_path, transform=transform)
valid_dataset = ImageFolder(valid_data_path, transform=transform)
test_dataset = ImageFolder(test_data_path, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Define optimizer, loss function, and learning rate scheduler
optimizer = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

# Define training loop function
def train(model, train_loader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in train_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    train_loss = running_loss / len(train_loader)
    train_accuracy = 100 * correct / total
    return train_loss, train_accuracy

# Define validation loop function
def validate(model, valid_loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    valid_loss = running_loss / len(valid_loader)
    valid_accuracy = 100 * correct / total
    return valid_loss, valid_accuracy

# Training loop
best_valid_loss = float('inf')
patience = 5
counter = 0
for epoch in range(NUM_EPOCHS):
    train_loss, train_accuracy = train(model, train_loader, optimizer, criterion)
    valid_loss, valid_accuracy = validate(model, valid_loader, criterion)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        counter = 0
    else:
        counter += 1
    
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Valid Loss: {valid_loss:.4f}, Valid Accuracy: {valid_accuracy:.2f}%")
    
    if counter >= patience:
        print("Early stopping")
        break

# Test the model
def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    test_accuracy = 100 * correct / total
    return test_accuracy

test_accuracy = test(model, test_loader)
print(f"Test Accuracy: {test_accuracy:.2f}%")

# Save the model
torch.save(model, model_path)
