In [1]:
!pip install torch torchvision




In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader, Dataset
import json
import os
from PIL import Image

In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None, label_file=None):
        self.root_dir = root_dir
        self.transform = transform
        self.label_file = label_file
        self.image_files = []
        self.labels = []

        # Load label mapping if label_file is provided
        if self.label_file:
            with open(self.label_file, 'r') as f:
                self.class_to_idx = json.load(f)
                print(self.class_to_idx)

        # Create a list of labels based on class_to_idx mapping
        self.idx_to_class = {v: int(k) - 1 for k, v in self.class_to_idx.items()}  # Convert to zero-based index

        # Traverse through subfolders and collect image paths
        for class_name in os.listdir(self.root_dir):
            class_path = os.path.join(self.root_dir, class_name)
            if os.path.isdir(class_path):
                # Convert folder name to class label (1-based to 0-based)
                class_idx = int(class_name) - 1

                for file in os.listdir(class_path):
                    if file.lower().endswith(('.png', '.jpg', '.jpeg')):  # Add more valid image extensions if needed
                        self.image_files.append(os.path.join(class_path, file))
                        if label_file:
                            self.labels.append(class_idx)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        label = self.labels[idx]
        image = Image.open(img_path)

        if self.transform:
            image = self.transform(image)

        return image, label



train_dataset = CustomDataset(
    root_dir="/content/drive/MyDrive/DoanPhuongNam_HuynhHuuThinh/102 Category Flower Classification/dataset/train",
    transform=transform,
    label_file='/content/drive/MyDrive/DoanPhuongNam_HuynhHuuThinh/102 Category Flower Classification/dataset/cat_to_name.json'
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

{'21': 'fire lily', '3': 'canterbury bells', '45': 'bolero deep blue', '1': 'pink primrose', '34': 'mexican aster', '27': 'prince of wales feathers', '7': 'moon orchid', '16': 'globe-flower', '25': 'grape hyacinth', '26': 'corn poppy', '79': 'toad lily', '39': 'siam tulip', '24': 'red ginger', '67': 'spring crocus', '35': 'alpine sea holly', '32': 'garden phlox', '10': 'globe thistle', '6': 'tiger lily', '93': 'ball moss', '33': 'love in the mist', '9': 'monkshood', '102': 'blackberry lily', '14': 'spear thistle', '19': 'balloon flower', '100': 'blanket flower', '13': 'king protea', '49': 'oxeye daisy', '15': 'yellow iris', '61': 'cautleya spicata', '31': 'carnation', '64': 'silverbush', '68': 'bearded iris', '63': 'black-eyed susan', '69': 'windflower', '62': 'japanese anemone', '20': 'giant white arum lily', '38': 'great masterwort', '4': 'sweet pea', '86': 'tree mallow', '101': 'trumpet creeper', '42': 'daffodil', '22': 'pincushion flower', '2': 'hard-leaved pocket orchid', '54': 's

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.resnet18(pretrained=True)  # Load ResNet-18 đã được huấn luyện trước
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 102)  # Thay đổi số lớp đầu ra theo số class của bạn


model = model.to(device)


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 82.6MB/s]


In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs = 5
for epoch in range(num_epochs):
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

Epoch [1/5], Loss: 0.6564
Epoch [2/5], Loss: 0.5884
Epoch [3/5], Loss: 0.2305
Epoch [4/5], Loss: 0.7458
Epoch [5/5], Loss: 0.1071


In [7]:
valid_dataset = CustomDataset(root_dir="/content/drive/MyDrive/DoanPhuongNam_HuynhHuuThinh/102 Category Flower Classification/dataset/valid", transform=transform,label_file='/content/drive/MyDrive/DoanPhuongNam_HuynhHuuThinh/102 Category Flower Classification/dataset/cat_to_name.json')
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)

{'21': 'fire lily', '3': 'canterbury bells', '45': 'bolero deep blue', '1': 'pink primrose', '34': 'mexican aster', '27': 'prince of wales feathers', '7': 'moon orchid', '16': 'globe-flower', '25': 'grape hyacinth', '26': 'corn poppy', '79': 'toad lily', '39': 'siam tulip', '24': 'red ginger', '67': 'spring crocus', '35': 'alpine sea holly', '32': 'garden phlox', '10': 'globe thistle', '6': 'tiger lily', '93': 'ball moss', '33': 'love in the mist', '9': 'monkshood', '102': 'blackberry lily', '14': 'spear thistle', '19': 'balloon flower', '100': 'blanket flower', '13': 'king protea', '49': 'oxeye daisy', '15': 'yellow iris', '61': 'cautleya spicata', '31': 'carnation', '64': 'silverbush', '68': 'bearded iris', '63': 'black-eyed susan', '69': 'windflower', '62': 'japanese anemone', '20': 'giant white arum lily', '38': 'great masterwort', '4': 'sweet pea', '86': 'tree mallow', '101': 'trumpet creeper', '42': 'daffodil', '22': 'pincushion flower', '2': 'hard-leaved pocket orchid', '54': 's

In [8]:
import torch
import torch.nn.functional as F

model.eval()  # Set the model to evaluation mode

correct = 0
total = 0
total_loss = 0

# Define the loss function (CrossEntropyLoss for classification)
criterion = torch.nn.CrossEntropyLoss()

# Get the image files and labels from the dataset
image_files = valid_loader.dataset.image_files
labels = valid_loader.dataset.labels  # Assuming labels are provided for evaluation

with torch.no_grad():  # No need to calculate gradients
    for idx, (images, true_labels) in enumerate(valid_loader):  # Now you get both images and labels
        images = images.to(device)
        true_labels = true_labels.to(device)
        outputs = model(images)

        # Calculate loss
        loss = criterion(outputs, true_labels)  # Compare with true labels

        total_loss += loss.item()  # Accumulate the loss

        _, predicted = torch.max(outputs, 1)  # Get the class with the highest probability

        # Track the number of correct predictions and total images
        total += true_labels.size(0)
        correct += (predicted == true_labels).sum().item()  # Compare predicted with true labels

# Calculate accuracy
accuracy = 100 * correct / total if total > 0 else 0
average_loss = total_loss / len(valid_loader) if len(valid_loader) > 0 else 0

# Display results
print(f'Accuracy on valid set: {accuracy:.2f}%')
print(f'Average Loss on valid set: {average_loss:.4f}')


Accuracy on valid set: 84.11%
Average Loss on valid set: 0.6553


In [9]:
class CustomTestDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = []

        for file in os.listdir(root_dir):
            file_path = os.path.join(root_dir, file)
            if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                self.image_files.append(file_path)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        return image  # Only return the image


In [10]:
test_dataset = CustomTestDataset(root_dir="/content/drive/MyDrive/DoanPhuongNam_HuynhHuuThinh/102 Category Flower Classification/dataset/test", transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [11]:
model.eval()  # Set the model to evaluation mode

predictions = []  # To store the predictions and corresponding image file names
image_files = test_loader.dataset.image_files  # Get image paths from the dataset

with torch.no_grad():  # No need to calculate gradients
    for idx, images in enumerate(test_loader):  # Only images, no labels in the test dataset
        images = images.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)  # Get the class with the highest probability

        # Collect predictions with corresponding image file names
        for i in range(len(predicted)):
            image_path = image_files[idx * test_loader.batch_size + i]  # Get the corresponding image path
            predicted_label = predicted[i].item()  # Get the predicted label (class index)
            predictions.append((image_path, predicted_label))

# Display the predictions
for image_path, label in predictions:  # Display first 10 predictions
    print(f'Image: {image_path} - Predicted Label: {label}')


Image: /content/drive/MyDrive/dataset/test/image_03506.jpg - Predicted Label: 29
Image: /content/drive/MyDrive/dataset/test/image_04432.jpg - Predicted Label: 89
Image: /content/drive/MyDrive/dataset/test/image_03414.jpg - Predicted Label: 66
Image: /content/drive/MyDrive/dataset/test/image_07370.jpg - Predicted Label: 93
Image: /content/drive/MyDrive/dataset/test/image_03454.jpg - Predicted Label: 22
Image: /content/drive/MyDrive/dataset/test/image_05592.jpg - Predicted Label: 31
Image: /content/drive/MyDrive/dataset/test/image_03718.jpg - Predicted Label: 52
Image: /content/drive/MyDrive/dataset/test/image_03926.jpg - Predicted Label: 50
Image: /content/drive/MyDrive/dataset/test/image_03906.jpg - Predicted Label: 16
Image: /content/drive/MyDrive/dataset/test/image_01672.jpg - Predicted Label: 81
Image: /content/drive/MyDrive/dataset/test/image_01996.jpg - Predicted Label: 79
Image: /content/drive/MyDrive/dataset/test/image_07018.jpg - Predicted Label: 30
Image: /content/drive/MyDriv