In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import os
import pickle
import matplotlib.pyplot as plt
import copy
import numpy as np

TRAIN_DIR = '/kaggle/input/mp-data/MP_Data/train' 
VAL_DIR =   '/kaggle/input/mp-data/MP_Data/val'

MODEL_PATH = 'model_mobilenet_stable.pth'
LABEL_PATH = 'label_map.pkl'
IMG_SIZE = 224

BATCH_SIZE = 256
LEARNING_RATE = 0.0003
EPOCHS = 20 
PATIENCE = 5

# Thiết bị
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_model():
    print(f"Đang khởi động trên thiết bị: {device}")
    
    # Kiểm tra đường dẫn
    if not os.path.exists(TRAIN_DIR) or not os.path.exists(VAL_DIR):
        print(f"LỖI: Không tìm thấy đường dẫn!")
        print(f"Train: {TRAIN_DIR}")
        print(f"Val: {VAL_DIR}")
        return

    #DATA AUGMENTATION
    train_transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomRotation(15),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    val_transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    #LOAD DATASET TỪ FOLDER RIÊNG
    print("Đang load dữ liệu từ folder...")
    train_dataset = datasets.ImageFolder(root=TRAIN_DIR, transform=train_transform)
    val_dataset = datasets.ImageFolder(root=VAL_DIR, transform=val_transform)

    # Kiểm tra số lượng class
    class_names = train_dataset.classes
    print(f"Tìm thấy {len(class_names)} lớp: {class_names}")
    print(f"Số lượng ảnh - Train: {len(train_dataset)} | Val: {len(val_dataset)}")

    # Tạo DataLoader
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    # Lưu tên class
    with open(LABEL_PATH, 'wb') as f:
        pickle.dump(class_names, f)

    # MODEL SETUP
    class MobileNetSignLanguage(nn.Module):
        def __init__(self, num_classes):
            super(MobileNetSignLanguage, self).__init__()
            # Load pre-trained
            self.model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT)
            # Sửa input layer
            self.model.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)
            # Sửa classifier
            self.model.classifier = nn.Sequential(
                nn.Dropout(p=0.3),
                nn.Linear(1280, num_classes)
            )

        def forward(self, x):
            return self.model(x)

    model = MobileNetSignLanguage(len(class_names))
    model.to(device)

    #KAGGLE MULTI-GPU SUPPORT
    if torch.cuda.device_count() > 1:
        print(f"Đã kích hoạt chế độ 2 GPU ({torch.cuda.device_count()} cards)!")
        model = nn.DataParallel(model)

    #OPTIMIZER & LOSS
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

    # Biến lưu trữ
    min_val_loss = np.inf
    patience_counter = 0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    print("Bắt đầu Train...")

    for epoch in range(EPOCHS):
        # TRAIN LOOP
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images) 
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_dataset)
        epoch_acc = 100 * correct / total
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc)

        #VAL LOOP
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        epoch_val_loss = val_loss / len(val_dataset)
        epoch_val_acc = 100 * val_correct / val_total
        history['val_loss'].append(epoch_val_loss)
        history['val_acc'].append(epoch_val_acc)

        # LR Scheduling
        scheduler.step(epoch_val_loss)
        curr_lr = optimizer.param_groups[0]['lr']

        print(f"Epoch [{epoch + 1}/{EPOCHS}] LR:{curr_lr:.6f} | "
              f"Train Loss:{epoch_loss:.4f} Acc:{epoch_acc:.1f}% | "
              f"Val Loss:{epoch_val_loss:.4f} Acc:{epoch_val_acc:.1f}%")

        #SAVE BEST MODEL
        if epoch_val_loss < min_val_loss:
            min_val_loss = epoch_val_loss
            patience_counter = 0
            
            # Kiểm tra xem có đang dùng DataParallel không để lưu cho đúng
            if isinstance(model, nn.DataParallel):
                torch.save(model.module.state_dict(), MODEL_PATH)
            else:
                torch.save(model.state_dict(), MODEL_PATH)
                
            print("Đã lưu model tốt nhất!")
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print("Early Stopping: Model không cải thiện nữa.")
                break

    # Vẽ biểu đồ
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Acc')
    plt.plot(history['val_acc'], label='Val Acc')
    plt.title('Accuracy')
    plt.legend()

    plt.savefig('training_result.png')
    plt.show()

if __name__ == '__main__':
    train_model()

Test lại model

In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix
from collections import OrderedDict

#CẤU HÌNH
TEST_DIR = '/kaggle/input/bw-trainval/BW/Val'
MODEL_PATH = 'model_mobilenet_stable.pth'
LABEL_PATH = 'label_map.pkl'
IMG_SIZE = 224
BATCH_SIZE = 64
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def evaluate_test_set():
    print(f"Đang chạy kiểm tra trên thiết bị: {DEVICE}")
    
    #Load Labels
    if not os.path.exists(LABEL_PATH):
        print("Lỗi: Không tìm thấy file nhãn (label_map.pkl)")
        return
        
    with open(LABEL_PATH, 'rb') as f:
        class_names = pickle.load(f)
    print(f"Class Names: {class_names}")

    #Transform 
    test_transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    #Load Data
    test_dataset = datasets.ImageFolder(root=TEST_DIR, transform=test_transform)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    #Load Model
    class MobileNetSignLanguage(nn.Module):
        def __init__(self, num_classes):
            super(MobileNetSignLanguage, self).__init__()
            self.model = models.mobilenet_v2(weights=None)
            self.model.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)
            self.model.classifier = nn.Sequential(
                nn.Dropout(p=0.3),
                nn.Linear(1280, num_classes)
            )
        def forward(self, x):
            return self.model(x)

    model = MobileNetSignLanguage(len(class_names))
    
    # Xử lý key 'module.'
    state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace("module.", "")
        new_state_dict[name] = v
        
    model.load_state_dict(new_state_dict)
    model.to(DEVICE)
    model.eval()

    #Chạy dự đoán toàn bộ
    all_preds = []
    all_labels = []
    
    print("Đang tính toán Confusion Matrix...")
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            # Lưu lại kết quả để vẽ biểu đồ
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Tính độ chính xác tổng
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    acc = np.mean(all_preds == all_labels) * 100
    print(f"\nĐỘ CHÍNH XÁC: {acc:.2f}%")

    #VẼ CONFUSION MATRIX
    plot_confusion_matrix(all_labels, all_preds, class_names)

def plot_confusion_matrix(y_true, y_pred, classes):
    # Tính ma trận
    cm = confusion_matrix(y_true, y_pred)
    
    # Vẽ biểu đồ
    plt.figure(figsize=(20, 15)) 
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=classes, yticklabels=classes)
    
    plt.title('Confusion Matrix (Ma trận nhầm lẫn)', fontsize=20)
    plt.ylabel('Nhãn thực tế (True Label)', fontsize=15)
    plt.xlabel('Nhãn dự đoán (Predicted Label)', fontsize=15)
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.show()

if __name__ == '__main__':
    evaluate_test_set()