In [31]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import os
from tqdm import tqdm

In [None]:
class_dict = {
    0: "airplane",
    1: "automobile",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck",
}

In [32]:
# การแปลงข้อมูล
data_transforms = {
    "train": transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
    "val": transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
}

In [33]:
# โหลดข้อมูล CIFAR-10
train_dataset = datasets.CIFAR10(
    root="./data", train=True, transform=data_transforms["train"], download=True
)
val_dataset = datasets.CIFAR10(
    root="./data", train=False, transform=data_transforms["val"], download=True
)

Files already downloaded and verified
Files already downloaded and verified


In [34]:
dataloaders = {
    "train": DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4),
    "val": DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4),
}

dataset_sizes = {"train": len(train_dataset), "val": len(val_dataset)}
class_names = train_dataset.classes

In [35]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [36]:
# สร้างโมเดล MobileNetV3 Large
model = models.mobilenet_v3_large(pretrained=True)
num_ftrs = model.classifier[3].in_features
model.classifier[3] = nn.Linear(num_ftrs, len(class_names))



In [42]:
# ตั้งค่า optimizer และ loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
model.to(device)

MobileNetV3(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (2): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bi

In [43]:
def validate_model(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss = running_loss / len(val_loader)
    val_accuracy = 100 * correct / total
    return val_loss, val_accuracy

In [44]:
def train_model(
    model, train_loader, val_loader, criterion, optimizer, device, num_epochs=10
):
    train_losses = []
    val_losses = []
    val_accuracies = []
    model.train()

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in tqdm(train_loader, "Training"):
            inputs, labels = inputs.to(device), labels.to(device)

            # ตั้งค่า gradient เป็นศูนย์
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # สะสม loss
            running_loss += loss.item()

        # ทำการ validate หลังจากจบแต่ละ epoch
        val_loss, val_accuracy = validate_model(model, val_loader, criterion, device)

        train_losses.append(running_loss / len(train_loader))
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)
        print(
            f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader)}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}%"
        )

    return train_losses, val_losses, val_accuracies

In [45]:
num_epochs = 5
train_losses, val_losses, val_accuracies = train_model(
    model, dataloaders["train"], dataloaders["val"], criterion, optimizer, device, num_epochs
)

Training: 100%|██████████| 1563/1563 [10:03<00:00,  2.59it/s]


Epoch [1/5], Loss: 0.47999694525852316, Val Loss: 0.38884497695742326, Val Accuracy: 87.28%


Training:  80%|████████  | 1254/1563 [08:41<02:08,  2.41it/s]


KeyboardInterrupt: 

In [None]:
# ทดสอบโมเดลกับข้อมูล validation
model.eval()
y_true = []
y_pred = []

for inputs, labels in dataloaders["val"]:
    outputs = model(inputs)
    _, preds = torch.max(outputs, 1)
    y_true.extend(labels.tolist())
    y_pred.extend(preds.tolist())

In [None]:
# คำนวณ F1 score, Precision, Recall และ Confusion Matrix
f1 = f1_score(y_true, y_pred, average="weighted")
precision = precision_score(y_true, y_pred, average="weighted")
recall = recall_score(y_true, y_pred, average="weighted")
cm = confusion_matrix(y_true, y_pred)

print(f"F1 Score: {f1:.2f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"Confusion Matrix:\n{cm}")

In [None]:
# แสดงผล Confusion Matrix
plt.figure(figsize=(8, 8))
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=class_names,
    yticklabels=class_names,
)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.show()