In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os

In [14]:
class ModifiedCNN(nn.Module):
    def __init__(self, num_classes):
        super(ModifiedCNN, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)

        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))

        # Giữ lại 3 lớp fully connected
        self.fc1 = nn.Linear(128, 256)
        self.dropout1 = nn.Dropout(p=0.7)  # Tăng cường dropout để giảm overfitting
        self.fc2 = nn.Linear(256, 128)
        self.dropout2 = nn.Dropout(p=0.7)  # Tăng cường dropout để giảm overfitting
        self.fc3 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = nn.ReLU()(self.bn1(self.conv1(x)))
        x = nn.MaxPool2d(kernel_size=2)(x)

        x = nn.ReLU()(self.bn2(self.conv2(x)))
        x = nn.MaxPool2d(kernel_size=2)(x)

        x = nn.ReLU()(self.bn3(self.conv3(x)))
        x = nn.MaxPool2d(kernel_size=2)(x)

        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1)

        x = self.dropout1(nn.ReLU()(self.fc1(x)))
        x = self.dropout2(nn.ReLU()(self.fc2(x)))
        x = self.fc3(x)

        return x

In [16]:
def train_model(data_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    batch_size = 32
    num_epochs = 100
    learning_rate = 0.00002
    weight_decay = 1e-3
    min_delta = 0.1
    patience = 3 # Early stopping

    transform_train = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Chuẩn hóa
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(10),  # Giảm mức độ xoay từ 20 xuống 10
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Giảm độ nhiễu màu sắc
        transforms.RandomResizedCrop(size=32, scale=(0.8, 1.0)),  # Giảm độ biến đổi kích thước crop
        transforms.RandomAffine(degrees=5, translate=(0.1, 0.1)),  # Giảm mức độ dịch chuyển và xoay
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.2, 1.0))], p=0.5),
        # Giảm độ áp dụng Gaussian Blur
        transforms.RandomErasing(p=0.4),  # Giảm tần suất xóa ngẫu nhiên
    ])

    transform_test = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Thêm chuẩn hóa vào test
    ])

    try:
        dataset = datasets.ImageFolder(data_path, transform=transform_train)
        train_size = int(0.8 * len(dataset))
        val_size = len(dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

        print(f"Train dataset size: {len(train_dataset)}")
        print(f"Validation dataset size: {len(val_dataset)}")

        num_classes = len(dataset.classes)
        model = ModifiedCNN(num_classes).to(device)
        print(f"Model architecture:\n{model}")

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3)

        best_val_loss = float('inf')
        epochs_without_improvement = 0

        train_losses = []
        val_losses = []
        val_accuracies = []

        # Directory to save model checkpoints
        checkpoint_dir = './checkpoints'
        os.makedirs(checkpoint_dir, exist_ok=True)

        for epoch in range(num_epochs):
            print(f"Starting epoch {epoch + 1}/{num_epochs}")
            model.train()
            epoch_loss = 0
            for batch_idx, (images, labels) in enumerate(train_loader):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                epoch_loss += loss.item()

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)

                optimizer.step()

            model.eval()
            val_loss = 0.0
            val_accuracy = 0.0
            with torch.no_grad():
                for images, labels in val_loader:
                    images, labels = images.to(device), labels.to(device)
                    outputs = model(images)
                    val_loss += criterion(outputs, labels).item()
                    _, predicted = torch.max(outputs, 1)
                    val_accuracy += (predicted == labels).sum().item()

            val_loss /= len(val_loader)
            val_accuracy /= len(val_dataset)

            train_losses.append(epoch_loss / len(train_loader))
            val_losses.append(val_loss)
            val_accuracies.append(val_accuracy * 100)

            print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss / len(train_loader):.4f}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy * 100:.2f}%')

            scheduler.step(val_loss)

            # Save the best model checkpoint
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'best_model.pth'))
                epochs_without_improvement = 0
            else:
                epochs_without_improvement += 1

            if epochs_without_improvement >= patience:
                print("Early stopping triggered due to lack of improvement.")
                break

        print(f'Final Validation Accuracy: {val_accuracy * 100:.2f}%')

        plt.figure(figsize=(12, 5))
        plt.subplot(1, 2, 1)
        plt.plot(train_losses, label='Training Loss')
        plt.plot(val_losses, label='Validation Loss')
        plt.title('Loss vs Epochs')
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(val_accuracies, label='Validation Accuracy')
        plt.title('Validation Accuracy vs Epochs')
        plt.legend()

        plt.show()

    except Exception as e:
        print(f"An error occurred: {e}")
        data_path = 'C:/Users/admin/Desktop/video/anhsinhvien'
        train_model(data_path)
