In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from torch.optim import SGD
# from cnn_network import CNN # Lớp CNN sẽ được định nghĩa trực tiếp tại đây
from tqdm import tqdm
from multiprocessing import freeze_support

# --- MLP Model Definition (commented out as per request) ---
# class MLP(nn.Module):
#     def __init__(self):
#         super(MLP, self).__init__()
#         self.fc1 = nn.Linear(32*32*3, 1024)
#         self.bn1 = nn.BatchNorm1d(1024)
#         self.fc2 = nn.Linear(1024, 512)
#         self.bn2 = nn.BatchNorm1d(512)
#         self.fc3 = nn.Linear(512, 10)
#         self.dropout = nn.Dropout(0.5)
#         self.relu = nn.ReLU()
#     def forward(self, x):
#         x = x.view(x.size(0), -1)
#         x = self.fc1(x)
#         x = self.bn1(x)
#         x = self.relu(x)
#         x = self.dropout(x)
#         x = self.fc2(x)
#         x = self.bn2(x)
#         x = self.relu(x)
#         x = self.dropout(x)
#         x = self.fc3(x)
#         return x

# --- CNN Model Definition ---
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)
        self.dropout = nn.Dropout(0.2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = x.view(-1, 128 * 4 * 4)
        x = self.dropout(x)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Split train into train and validation
val_size = int(0.1 * len(train_set))
train_size = len(train_set) - val_size
train_set, val_set = torch.utils.data.random_split(train_set, [train_size, val_size])

print(f"Training set size: {len(train_set)}")
print(f"Validation set size: {len(val_set)}")
print(f"Test set size: {len(test_set)}")

# Data loaders
batch_size = 128 # Keep consistent with MLP if comparing
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

# Class names
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# --- evaluate_model function (common for both) ---
def evaluate_model(model, loader, criterion):
    model.eval()
    loss_cnt = 0.0
    accuracy = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            output = model(images)
            loss = criterion(output, labels)

            loss_cnt += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += labels.size(0)
            accuracy += (predicted == labels).sum().item()

    loss = loss_cnt / len(loader)
    acc = 100 * accuracy / total
    return loss, acc

# --- plot_confusion_matrix function (common for both) ---
def plot_confusion_matrix(all_labels, all_preds, classes, model_name="Model"):
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.title(f'Confusion Matrix for {model_name}')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.savefig(f'{model_name}_confusion_matrix.png')
    plt.show()

# --- train_model function (common for both) ---
def train_model(model, train_loader, val_loader, num_epochs=15, model_name='Model'):
    criterion = nn.CrossEntropyLoss()
    optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)

    # Data for custom learning curve plot
    train_cnt= []
    train_losses = []
    train_accs = [] # To store training accuracy for plot

    val_cnt = [] # X-axis for validation, same granularity as train if interval is same
    val_losses= []
    val_accs = [] # To store validation accuracy for plot

    # Frequency for recording train/val metrics
    record_metrics_every_batches = 10 # Record train loss/acc and evaluate val loss/acc every 10 batches

    print(f"\n--- Starting Training for {model_name} ---")
    for epoch in range(num_epochs): # epoch is 0-indexed here (0, 1, ..., num_epochs-1)
        model.train()
        running_loss_epoch = 0.0 # Accumulate loss for epoch summary
        correct_train_epoch = 0 # Accumulate correct predictions for epoch summary
        total_train_epoch = 0 # Accumulate total samples for epoch summary

        for batch_idx, (images, labels) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [{model_name}]')):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss_val = criterion(outputs, labels)
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

            running_loss_epoch += loss_val.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train_epoch += labels.size(0)
            correct_train_epoch += (predicted == labels).sum().item()

            # --- Calculate current examples seen based on your logic, corrected for 0-indexed epoch ---
            current_trained_nums = (epoch * len(train_loader.dataset)) + (batch_idx * batch_size)

            # --- Collect data for continuous Train Loss & Accuracy plot ---
            if batch_idx % record_metrics_every_batches == 0:
                train_losses.append(loss_val.item())
                train_cnt.append(current_trained_nums)

                # Calculate current batch accuracy for plotting
                current_batch_acc = 100 * (predicted == labels).sum().item() / labels.size(0)
                train_accs.append(current_batch_acc)

                # --- Evaluate Validation Loss & Accuracy PERIODICALLY for line plot ---
                val_loss, val_acc = evaluate_model(model, val_loader, criterion)
                val_losses.append(val_loss)
                val_accs.append(val_acc)
                val_cnt.append(current_trained_nums) # Same counter as train

        # Calculate and print epoch summary (using accumulated values)
        epoch_train_loss = running_loss_epoch / len(train_loader)
        epoch_train_acc = 100 * correct_train_epoch / total_train_epoch

        # For print, use the last recorded periodic validation.
        final_epoch_val_loss = val_losses[-1] if val_losses else float('nan')
        final_epoch_val_acc = val_accs[-1] if val_accs else float('nan')

        print(f'Epoch {epoch+1}: Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.2f}%, Val Loss: {final_epoch_val_loss:.4f}, Val Acc: {final_epoch_val_acc:.2f}%')

    print(f"--- Finished Training for {model_name} ---")

    # --- Plotting Learning Curves (Loss) ---
    plt.figure(figsize=(12, 6))

    plt.subplot(1, 2, 1)
    plt.plot(train_cnt, train_losses, color='blue', label='Training Loss')
    plt.plot(val_cnt, val_losses, color='red', label='Validation Loss')
    plt.legend(loc='upper right')
    plt.xlabel('Number of training examples seen')
    plt.ylabel('Loss (Cross-Entropy Loss)')
    plt.title(f'Learning Curve: {model_name} (Train vs Validation Loss)')
    plt.grid(True)

    # --- Plotting Learning Curves (Accuracy) ---
    plt.subplot(1, 2, 2)
    plt.plot(train_cnt, train_accs, color='blue', label='Training Accuracy')
    plt.plot(val_cnt, val_accs, color='red', label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.xlabel('Number of training examples seen')
    plt.ylabel('Accuracy (%)')
    plt.title(f'Learning Curve: {model_name} (Train vs Validation Accuracy)')
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(f'{model_name}_learning_curves_loss_acc.png')
    plt.show()

    return model

# --- Main execution block for CNN ---
if __name__ == '__main__':
    freeze_support() # Essential for multiprocessing on Windows

    cnn_model = CNN().to(device) # Initialize CNN model

    num_epochs = 15 # Number of epochs for CNN training

    # --- Train CNN and plot its custom learning curve ---
    print("\nTraining CNN and plotting its custom learning curve...")
    cnn_model = train_model(cnn_model, train_loader, val_loader, num_epochs=num_epochs, model_name='CNN')

    # --- Final Evaluation on Test Set and Confusion Matrix Plotting for CNN ---
    final_criterion = nn.CrossEntropyLoss()

    print("\nFinal Evaluation for CNN on test set...")
    all_cnn_preds_test = []
    all_cnn_labels_test = []
    cnn_model.eval()
    cnn_test_running_loss = 0.0
    cnn_test_correct = 0
    cnn_test_total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = cnn_model(images)
            loss = final_criterion(outputs, labels)
            cnn_test_running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            cnn_test_correct += (predicted == labels).sum().item()
            cnn_test_total += labels.size(0)
            all_cnn_preds_test.extend(predicted.cpu().numpy())
            all_cnn_labels_test.extend(labels.cpu().numpy())
    cnn_final_test_loss = cnn_test_running_loss / len(test_loader)
    cnn_final_test_acc = 100 * cnn_test_correct / cnn_test_total
    print(f"CNN Test Loss: {cnn_final_test_loss:.4f}, Test Accuracy: {cnn_final_test_acc:.2f}%")
    plot_confusion_matrix(all_cnn_labels_test, all_cnn_preds_test, classes, model_name='CNN_Test')

    print("\n--- Model Comparison (for CNN only) ---")
    print(f"CNN Final Test Accuracy: {cnn_final_test_acc:.2f}%")
    print("\nDiscussion points (Analyze after running this CNN code):")
    print("1. Compare the final test accuracies of MLP and CNN (requires running train_mlp.py separately).")
    print("2. Analyze the learning curves (Train vs Validation Loss AND Accuracy) for this CNN model. Look for signs of overfitting or underfitting.")
    print("3. Examine the confusion matrix for this CNN model to understand class-wise performance.")
    print("4. Discuss why CNN generally performs better than MLP for image classification tasks on CIFAR-10.")