In [None]:
import random

import numpy as np
import pandas as pd
import torch
import torchvision
import torchvision.models as models

from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch import nn, optim
from tqdm import tqdm

from sklearn.metrics import mean_absolute_error, mean_squared_error
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
class Image_(Dataset):
    def __init__(self, images, labels, transform):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [None]:
transform = transforms.Compose([
    transforms.ToTensor()
])

In [None]:
train_queen = np.load('queen_train.npy')
train_nonqueen = np.load('nonqueen_train.npy')

val_queen = np.load('queen_val.npy')
val_nonqueen = np.load('nonqueen_val.npy')

test_queen = np.load('queen_test.npy')
test_nonqueen = np.load('nonqueen_test.npy')

In [None]:
x_train = np.vstack((train_queen, train_nonqueen))
ones = np.ones(len(train_queen))
zeros = np.zeros(len(train_nonqueen))
y_train = np.concatenate((ones, zeros))

x_val = np.vstack((val_queen, val_nonqueen))
ones = np.ones(len(val_queen))
zeros = np.zeros(len(val_nonqueen))
y_val = np.concatenate((ones, zeros))

x_test = np.vstack((test_queen, test_nonqueen))
ones = np.ones(len(test_queen))
zeros = np.zeros(len(test_nonqueen))
y_test = np.concatenate((ones, zeros))

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
seed = 42
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
train_set = Image_(x_train, y_train, transform=transform)
val_set = Image_(x_val, y_val, transform=transform)
test_set = Image_(x_test, y_test,  transform=transform)

train_dataloader = DataLoader(train_set, batch_size=256, shuffle=True) 
val_dataloader = DataLoader(val_set, batch_size=256, shuffle=False)
test_dataloader = DataLoader(test_set, batch_size=256, shuffle=False)

model = models.alexnet(weights='IMAGENET1K_V1')

model.features[0] = nn.Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2), bias=False)

num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, 2)


model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.004)

In [None]:
# Training and validation loop
num_epochs = 300
train_accs = []
val_accs = []
train_losses = []
val_losses = []

In [None]:
patience = 10
best_val_loss = float('inf')
early_stop_counter = 0

best_model_wts = model.state_dict()

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_corrects = 0

    for train_images, train_labels in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):
        train_images = train_images.to(device)
        train_labels = train_labels.long()
        train_labels = train_labels.to(device)

        optimizer.zero_grad()

        outputs = model(train_images)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, train_labels)

        loss.backward()
        optimizer.step()

        train_loss += loss.item() * train_images.size(0)
        train_corrects += torch.sum(preds == train_labels.data)

    train_loss = train_loss / len(y_train)
    train_losses.append(train_loss)
    train_acc = train_corrects.double() / len(y_train)
    train_accs.append(train_acc)

    model.eval()
    val_loss = 0.0
    val_corrects = 0
    with torch.no_grad():
        for val_images, val_labels in tqdm(val_dataloader, desc=f"Epoch {epoch+1} Validation"):
            val_images = val_images.to(device)
            val_labels = val_labels.long()
            val_labels = val_labels.to(device)

            val_outputs = model(val_images)
            _, preds = torch.max(val_outputs, 1)
            loss = criterion(val_outputs, val_labels)

            val_loss += loss.item() * val_images.size(0)
            val_corrects += torch.sum(preds == val_labels.data)

    val_loss = val_loss / len(y_val)
    val_losses.append(val_loss)
    val_acc = val_corrects.double() / len(y_val)
    val_accs.append(val_acc)

    print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_wts = model.state_dict() 
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print("Early stopping triggered.")
            model.load_state_dict(best_model_wts)  
            break

In [None]:
train_losses = [loss.cpu().item() if isinstance(loss, torch.Tensor) else loss for loss in train_losses]
val_losses = [loss.cpu().item() if isinstance(loss, torch.Tensor) else loss for loss in val_losses]
train_accs = [acc.cpu().item() if isinstance(acc, torch.Tensor) else acc for acc in train_accs]
val_accs = [acc.cpu().item() if isinstance(acc, torch.Tensor) else acc for acc in val_accs]

# Vẽ biểu đồ
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Train vs Validation Loss')

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Accuracy')
plt.plot(val_accs, label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Train vs Validation Accuracy')

plt.show()

In [None]:
model.eval()
test_loss = 0.0
test_corrects = 0
with torch.no_grad():
    for test_images, test_labels in tqdm(test_dataloader, desc="Testing"):
        test_images = test_images.to(device)
        test_labels = test_labels.long().to(device)

        test_outputs = model(test_images)
        _, preds = torch.max(test_outputs, 1)
        loss = criterion(test_outputs, test_labels)

        test_loss += loss.item() * test_images.size(0)
        test_corrects += torch.sum(preds == test_labels.data)

test_loss /= len(y_test)
test_acc = test_corrects.double() / len(y_test)

print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")