In [7]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.optim import SGD
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

# Thiết lập thiết bị
device = "cuda" if torch.cuda.is_available() else "cpu"

# Đường dẫn tới thư mục dữ liệu
# data_folder = os.path.expanduser('~/data/dogs-vs-cats')
# train_folder = os.path.join(data_folder, 'train')
# test_folder = os.path.join(data_folder, 'test1')

# Các biến đổi cho dữ liệu
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

# Lớp dataset tùy chỉnh cho Dogs vs Cats
class DogsVSCatsDataset(Dataset):
    def __init__(self, folder, transform=None):
        self.folder = folder
        self.transform = transform
        self.images = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]
        self.labels = [1 if 'dog' in img else 0 for img in self.images]

    def __getitem__(self, idx):
        img_name = os.path.join(self.folder, self.images[idx])
        image = Image.open(img_name)
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        return image.to(device), torch.tensor(label, dtype=torch.long).to(device)

    def __len__(self):
        return len(self.images)

# Tạo datasets
train_dataset = DogsVSCatsDataset('dogs-vs-cats/train', transform=transform)
test_dataset = DogsVSCatsDataset('dogs-vs-cats/test1', transform=transform)

# Tạo DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Mô hình neural network
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 32 * 32, 512)
        self.fc2 = nn.Linear(512, 2)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 64 * 32 * 32)
        x = self.relu(self.fc1(x))
        x = self.softmax(self.fc2(x))
        return x

# Khởi tạo mô hình, hàm mất mát và tối ưu hóa
model = SimpleCNN().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=0.01)

# Hàm huấn luyện một batch
def train_batch(x, y, model, opt, loss_fn):
    model.train()
    prediction = model(x)
    batch_loss = loss_fn(prediction, y)
    batch_loss.backward()
    opt.step()
    opt.zero_grad()
    return batch_loss.item()

# Hàm tính độ chính xác
@torch.no_grad()
def accuracy(x, y, model):
    model.eval()
    prediction = model(x)
    max_values, argmaxes = prediction.max(-1)
    is_correct = argmaxes == y
    return is_correct.cpu().numpy().tolist()

# Huấn luyện mô hình
losses, accuracies = [], []
epochs = 250
for epoch in range(epochs):
    print(f'Epoch {epoch + 1}/{epochs}')
    epoch_losses, epoch_accuracies = [], []
    for ix, batch in enumerate(train_loader):
        x, y = batch
        batch_loss = train_batch(x, y, model, optimizer, loss_fn)
        epoch_losses.append(batch_loss)
    epoch_loss = np.mean(epoch_losses)
    for ix, batch in enumerate(train_loader):
        x, y = batch
        is_correct = accuracy(x, y, model)
        epoch_accuracies.extend(is_correct)
    epoch_accuracy = np.mean(epoch_accuracies)
    losses.append(epoch_loss)
    accuracies.append(epoch_accuracy)
    print(f'Loss: {epoch_loss}, Accuracy: {epoch_accuracy}')

# Vẽ biểu đồ
epochs_range = np.arange(1, epochs + 1)
plt.figure(figsize=(20, 5))
plt.subplot(121)
plt.title('Loss value over increasing epochs')
plt.plot(epochs_range, losses, label='Training Loss')
plt.legend()
plt.subplot(122)
plt.title('Accuracy value over increasing epochs')
plt.plot(epochs_range, accuracies, label='Training Accuracy')
plt.gca().set_yticklabels(['{:.0f}%'.format(x*100) for x in plt.gca().get_yticks()])
plt.legend()
plt.show()

Epoch 1/250
Loss: 0.3324408435728401, Accuracy: 1.0
Epoch 2/250
Loss: 0.31390142044983804, Accuracy: 1.0
Epoch 3/250
Loss: 0.3135893309954554, Accuracy: 1.0
Epoch 4/250
Loss: 0.31347793713212013, Accuracy: 1.0
Epoch 5/250
Loss: 0.3134214822202921, Accuracy: 1.0
Epoch 6/250
Loss: 0.31338755576871336, Accuracy: 1.0
Epoch 7/250
Loss: 0.3133653262630105, Accuracy: 1.0
Epoch 8/250
Loss: 0.3133493287023157, Accuracy: 1.0
Epoch 9/250
Loss: 0.3133374445606023, Accuracy: 1.0
Epoch 10/250
Loss: 0.31332837929949164, Accuracy: 1.0
Epoch 11/250
Loss: 0.3133211003150791, Accuracy: 1.0
Epoch 12/250
Loss: 0.31331534031778574, Accuracy: 1.0
Epoch 13/250
Loss: 0.31331063946709037, Accuracy: 1.0
Epoch 14/250
Loss: 0.3133063092827797, Accuracy: 1.0
Epoch 15/250
Loss: 0.31330288108438253, Accuracy: 1.0
Epoch 16/250
Loss: 0.3132997883949429, Accuracy: 1.0
Epoch 17/250
Loss: 0.3132972097955644, Accuracy: 1.0
Epoch 18/250
Loss: 0.3132949275895953, Accuracy: 1.0
Epoch 19/250
Loss: 0.31329293525777757, Accuracy