In [3]:
import os
import hashlib
from collections import defaultdict

# Thư mục dataset
# data_root = "/kaggle/working/data"
data_root = "/kaggle/input/full-data-snake/kaggle/working/data"

def file_hash(path):
    h = hashlib.md5()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            h.update(chunk)
    return h.hexdigest()

# 1. Tạo hash cho train
train_hashes = {}
for root, dirs, files in os.walk(os.path.join(data_root, "train")):
    for f in files:
        if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")):
            path = os.path.join(root, f)
            h = file_hash(path)
            train_hashes[h] = path  # chỉ giữ 1 path đại diện trong train

# 2. So sánh val/test với train
leak_dict = defaultdict(list)
for split in ["val", "test"]:
    split_dir = os.path.join(data_root, split)
    for root, dirs, files in os.walk(split_dir):
        for f in files:
            if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")):
                path = os.path.join(root, f)
                h = file_hash(path)
                if h in train_hashes:  # nếu hash trùng với ảnh train
                    leak_dict[h].append((train_hashes[h], path))

# 3. Báo cáo leak
print(f"🔎 Tổng số nhóm leak (val/test vs train): {len(leak_dict)}\n")

for h, pairs in leak_dict.items():
    print(f"--- Hash: {h} ---")
    for p in pairs:
        print(f"Train: {p[0]}")
        print(f"Leak : {p[1]}")
        print()

# 4. Xoá ảnh leak trong val/test
delete_count = 0
for h, pairs in leak_dict.items():
    for train_path, leak_path in pairs:
        try:
            os.remove(leak_path)
            delete_count += 1
            print(f"🗑️ Deleted leak: {leak_path}")
        except Exception as e:
            print(f"❌ Error deleting {leak_path}: {e}")

print(f"\n✅ Đã xoá {delete_count} ảnh leak khỏi val/test.\n")

# 5. Check lại sau khi xoá
leak_dict_check = defaultdict(list)
for split in ["val", "test"]:
    split_dir = os.path.join(data_root, split)
    for root, dirs, files in os.walk(split_dir):
        for f in files:
            if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp")):
                path = os.path.join(root, f)
                h = file_hash(path)
                if h in train_hashes:
                    leak_dict_check[h].append((train_hashes[h], path))

print(f"🔎 Sau khi xoá, số nhóm leak còn lại: {len(leak_dict_check)}")


🔎 Tổng số nhóm leak (val/test vs train): 0


✅ Đã xoá 0 ảnh leak khỏi val/test.

🔎 Sau khi xoá, số nhóm leak còn lại: 0


In [5]:
import os

# data_dir = "/kaggle/working/data/train"
data_dir = "/kaggle/input/full-data-snake/kaggle/working/data/train"

folders = [f for f in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, f))]
print(f"Số thư mục trong '{data_dir}': {len(folders)}")
print("Danh sách thư mục:", folders)


Số thư mục trong '/kaggle/input/full-data-snake/kaggle/working/data/train': 124
Danh sách thư mục: ['Elaphe_taeniura', 'Oreocryptophis_porphyraceus', 'Gonyosoma_oxycephalum', 'Hebius_khasiensis', 'Cerberus_schneiderii', 'Ptyas_dhumnades', 'Laticauda_colubrina', 'Oligodon_albocinctus', 'Indotyphlops_braminus', 'Calloselasma_rhodostoma', 'Pareas_carinatus', 'Boiga_multomaculata', 'Acrochordus_granulatus', 'Calliophis_maculiceps', 'Bungarus_candidus', 'Sinomicrurus_annularis', 'Gonyosoma_boulengeri', 'Trimerodytes_percarinatus', 'Boiga_kraepelini', 'Ahaetulla_rufusoculara', 'Subsessor_bocourti', 'Ovophis_monticola', 'Myrrophis_chinensis', 'Cylindrophis_jodiae', 'Bungarus_fasciatus', 'Chrysopelea_ornata', 'Pareas_berdmorei', 'Oligodon_cinereus', 'Deinagkistrodon_acutus', 'Blue-lipped_sea_krait', 'Lycodon_futsingensis', 'Protobothrops_maolanensis', 'Naja_siamensis', 'Lycodon_rufozonatus', 'Coelognathus_radiatus', 'Ovophis_tonkinensis', 'Dendrelaphis_subocularis', 'Plagiopholis_nuchalis', 'B

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import time
import copy
import os
from tqdm import tqdm 

# =========================
# 1. Config
# =========================
# data_dir = "/kaggle/working/data"   # Cấu trúc thư mục: data/train, data/val
data_dir = "/kaggle/input/full-data-snake/kaggle/working/data"  # Cấu trúc thư mục: data/train, data/val
num_classes = 124
batch_size = 32
num_epochs = 25
lr = 1e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =========================
# 2. Data Augmentation
# =========================
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])


image_datasets = {
    'train': datasets.ImageFolder(os.path.join(data_dir, 'train'), train_transform),
    'val': datasets.ImageFolder(os.path.join(data_dir, 'val'), val_transform)
}

dataloaders = {
    x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4)
    for x in ['train', 'val']
}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

In [7]:
# =========================
# 3. Model
# =========================
# model = models.swin_t(weights="IMAGENET1K_V1")  # pretrained Swin-Tiny
# model.head = nn.Linear(model.head.in_features, num_classes)
# model = model.to(device)

weights = models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1
model = models.convnext_tiny(weights=weights)
# Sửa lớp phân loại
model.classifier[2] = nn.Linear(model.classifier[2].in_features, num_classes)
model = model.to(device)

# =========================
# 4. Loss, Optimizer, Scheduler
# =========================
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

# =========================
# 5. Training Loop
# =========================
def train_model(model, criterion, optimizer, scheduler, num_epochs=10):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    # Thêm history để vẽ sau này
    history = {
        "train_loss": [],
        "val_loss": [],
        "train_acc": [],
        "val_acc": []
    }
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 20)
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            
            running_loss = 0.0
            running_corrects = 0
            
            # for inputs, labels in dataloaders[phase]:
            for inputs, labels in tqdm(dataloaders[phase], desc=f"{phase} Epoch {epoch+1}"):
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            
            if phase == 'train':
                scheduler.step()
            
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double().item() / dataset_sizes[phase]
            
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            # Lưu history
            history[f"{phase}_loss"].append(epoch_loss)
            history[f"{phase}_acc"].append(epoch_acc)
            
            # Save best
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        
        # ================================
        # 🔹 Lưu model sau mỗi 5 epoch
        # ================================
        if (epoch + 1) % 5 == 0:
            save_path = f"/kaggle/working/convnext_tiny_epoch{epoch+1}.pth"
            torch.save(model.state_dict(), save_path)
            print(f"✅ Saved model at {save_path}")
        
        print()
    
    print(f'Best val Acc: {best_acc:.4f}')
    model.load_state_dict(best_model_wts)
    return model, history

# =========================
# 6. Run Training
# =========================
model, history = train_model(model, criterion, optimizer, scheduler, num_epochs=num_epochs)

# Save final best model
torch.save(model.state_dict(), "/kaggle/working/convnext_tiny_best.pth")



Downloading: "https://download.pytorch.org/models/convnext_tiny-983f1562.pth" to /root/.cache/torch/hub/checkpoints/convnext_tiny-983f1562.pth
100%|██████████| 109M/109M [00:00<00:00, 207MB/s] 


Epoch 1/25
--------------------


train Epoch 1: 100%|██████████| 676/676 [05:38<00:00,  2.00it/s]


train Loss: 2.6553 Acc: 0.4953


val Epoch 1: 100%|██████████| 143/143 [00:20<00:00,  7.05it/s]


val Loss: 1.6359 Acc: 0.7492

Epoch 2/25
--------------------


train Epoch 2: 100%|██████████| 676/676 [05:44<00:00,  1.96it/s]


train Loss: 1.7071 Acc: 0.7362


val Epoch 2: 100%|██████████| 143/143 [00:20<00:00,  7.07it/s]


val Loss: 1.3729 Acc: 0.8214

Epoch 3/25
--------------------


train Epoch 3: 100%|██████████| 676/676 [05:45<00:00,  1.96it/s]


train Loss: 1.4797 Acc: 0.8028


val Epoch 3: 100%|██████████| 143/143 [00:20<00:00,  7.06it/s]


val Loss: 1.3149 Acc: 0.8425

Epoch 4/25
--------------------


train Epoch 4: 100%|██████████| 676/676 [05:45<00:00,  1.96it/s]


train Loss: 1.3750 Acc: 0.8331


val Epoch 4: 100%|██████████| 143/143 [00:20<00:00,  7.11it/s]


val Loss: 1.2309 Acc: 0.8699

Epoch 5/25
--------------------


train Epoch 5: 100%|██████████| 676/676 [05:44<00:00,  1.96it/s]


train Loss: 1.2735 Acc: 0.8666


val Epoch 5: 100%|██████████| 143/143 [00:20<00:00,  7.07it/s]


val Loss: 1.2228 Acc: 0.8692
✅ Saved model at /kaggle/working/convnext_tiny_epoch5.pth

Epoch 6/25
--------------------


train Epoch 6: 100%|██████████| 676/676 [05:45<00:00,  1.96it/s]


train Loss: 1.2243 Acc: 0.8822


val Epoch 6: 100%|██████████| 143/143 [00:20<00:00,  7.07it/s]


val Loss: 1.1662 Acc: 0.8892

Epoch 7/25
--------------------


train Epoch 7: 100%|██████████| 676/676 [05:45<00:00,  1.96it/s]


train Loss: 1.1821 Acc: 0.8939


val Epoch 7: 100%|██████████| 143/143 [00:20<00:00,  7.05it/s]


val Loss: 1.1431 Acc: 0.8982

Epoch 8/25
--------------------


train Epoch 8: 100%|██████████| 676/676 [05:45<00:00,  1.96it/s]


train Loss: 1.1507 Acc: 0.9048


val Epoch 8: 100%|██████████| 143/143 [00:20<00:00,  7.06it/s]


val Loss: 1.1492 Acc: 0.8927

Epoch 9/25
--------------------


train Epoch 9: 100%|██████████| 676/676 [05:45<00:00,  1.96it/s]


train Loss: 1.1253 Acc: 0.9113


val Epoch 9: 100%|██████████| 143/143 [00:20<00:00,  7.07it/s]


val Loss: 1.1429 Acc: 0.8958

Epoch 10/25
--------------------


train Epoch 10: 100%|██████████| 676/676 [05:45<00:00,  1.96it/s]


train Loss: 1.0940 Acc: 0.9227


val Epoch 10: 100%|██████████| 143/143 [00:20<00:00,  7.11it/s]


val Loss: 1.1230 Acc: 0.8995
✅ Saved model at /kaggle/working/convnext_tiny_epoch10.pth

Epoch 11/25
--------------------


train Epoch 11: 100%|██████████| 676/676 [05:44<00:00,  1.96it/s]


train Loss: 1.0721 Acc: 0.9273


val Epoch 11: 100%|██████████| 143/143 [00:20<00:00,  7.03it/s]


val Loss: 1.1260 Acc: 0.9028

Epoch 12/25
--------------------


train Epoch 12: 100%|██████████| 676/676 [05:45<00:00,  1.95it/s]


train Loss: 1.0495 Acc: 0.9342


val Epoch 12: 100%|██████████| 143/143 [00:20<00:00,  7.04it/s]


val Loss: 1.1038 Acc: 0.9068

Epoch 13/25
--------------------


train Epoch 13: 100%|██████████| 676/676 [05:44<00:00,  1.96it/s]


train Loss: 1.0298 Acc: 0.9405


val Epoch 13: 100%|██████████| 143/143 [00:20<00:00,  7.06it/s]


val Loss: 1.1030 Acc: 0.9072

Epoch 14/25
--------------------


train Epoch 14: 100%|██████████| 676/676 [05:45<00:00,  1.96it/s]


train Loss: 1.0215 Acc: 0.9430


val Epoch 14: 100%|██████████| 143/143 [00:20<00:00,  7.04it/s]


val Loss: 1.0967 Acc: 0.9087

Epoch 15/25
--------------------


train Epoch 15: 100%|██████████| 676/676 [05:44<00:00,  1.96it/s]


train Loss: 1.0092 Acc: 0.9464


val Epoch 15: 100%|██████████| 143/143 [00:20<00:00,  7.06it/s]


val Loss: 1.1075 Acc: 0.9074
✅ Saved model at /kaggle/working/convnext_tiny_epoch15.pth

Epoch 16/25
--------------------


train Epoch 16: 100%|██████████| 676/676 [05:44<00:00,  1.96it/s]


train Loss: 1.0032 Acc: 0.9469


val Epoch 16: 100%|██████████| 143/143 [00:20<00:00,  7.06it/s]


val Loss: 1.0842 Acc: 0.9151

Epoch 17/25
--------------------


train Epoch 17: 100%|██████████| 676/676 [05:45<00:00,  1.96it/s]


train Loss: 0.9902 Acc: 0.9517


val Epoch 17: 100%|██████████| 143/143 [00:20<00:00,  7.05it/s]


val Loss: 1.0725 Acc: 0.9212

Epoch 18/25
--------------------


train Epoch 18: 100%|██████████| 676/676 [05:44<00:00,  1.96it/s]


train Loss: 0.9827 Acc: 0.9542


val Epoch 18: 100%|██████████| 143/143 [00:20<00:00,  7.04it/s]


val Loss: 1.0644 Acc: 0.9217

Epoch 19/25
--------------------


train Epoch 19: 100%|██████████| 676/676 [05:45<00:00,  1.96it/s]


train Loss: 0.9740 Acc: 0.9567


val Epoch 19: 100%|██████████| 143/143 [00:20<00:00,  7.05it/s]


val Loss: 1.0644 Acc: 0.9247

Epoch 20/25
--------------------


train Epoch 20: 100%|██████████| 676/676 [05:45<00:00,  1.95it/s]


train Loss: 0.9596 Acc: 0.9599


val Epoch 20: 100%|██████████| 143/143 [00:20<00:00,  7.07it/s]


val Loss: 1.0610 Acc: 0.9252
✅ Saved model at /kaggle/working/convnext_tiny_epoch20.pth

Epoch 21/25
--------------------


train Epoch 21: 100%|██████████| 676/676 [05:44<00:00,  1.96it/s]


train Loss: 0.9594 Acc: 0.9601


val Epoch 21: 100%|██████████| 143/143 [00:20<00:00,  7.04it/s]


val Loss: 1.0589 Acc: 0.9267

Epoch 22/25
--------------------


train Epoch 22: 100%|██████████| 676/676 [05:44<00:00,  1.96it/s]


train Loss: 0.9585 Acc: 0.9599


val Epoch 22: 100%|██████████| 143/143 [00:20<00:00,  7.05it/s]


val Loss: 1.0520 Acc: 0.9289

Epoch 23/25
--------------------


train Epoch 23: 100%|██████████| 676/676 [05:45<00:00,  1.96it/s]


train Loss: 0.9554 Acc: 0.9602


val Epoch 23: 100%|██████████| 143/143 [00:20<00:00,  7.03it/s]


val Loss: 1.0479 Acc: 0.9298

Epoch 24/25
--------------------


train Epoch 24: 100%|██████████| 676/676 [05:45<00:00,  1.96it/s]


train Loss: 0.9485 Acc: 0.9632


val Epoch 24: 100%|██████████| 143/143 [00:20<00:00,  7.00it/s]


val Loss: 1.0486 Acc: 0.9285

Epoch 25/25
--------------------


train Epoch 25: 100%|██████████| 676/676 [05:45<00:00,  1.96it/s]


train Loss: 0.9458 Acc: 0.9632


val Epoch 25: 100%|██████████| 143/143 [00:20<00:00,  7.05it/s]


val Loss: 1.0482 Acc: 0.9291
✅ Saved model at /kaggle/working/convnext_tiny_epoch25.pth

Best val Acc: 0.9298


In [8]:
!zip -j /kaggle/working/convnext_tiny_best.zip /kaggle/working/convnext_tiny_best.pth


  adding: convnext_tiny_best.pth (deflated 7%)


In [7]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import numpy as np
from torchvision import models
import torch.nn as nn
import torch


# =========================
# 7. Evaluation on Test Set
# =========================
# Load lại test dataset
test_dataset = datasets.ImageFolder(os.path.join(data_dir, 'test'), val_transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# Lấy mapping index -> class name
class_names = test_dataset.classes
num_classes = len(class_names)

weights = None
model = models.convnext_tiny(weights=weights)
model.classifier[2] = nn.Linear(model.classifier[2].in_features, num_classes)
model = model.to(device)
# Load best model đã save
model.load_state_dict(torch.load("/kaggle/input/model-best/convnext_tiny_best.pth"))
model.eval()

all_labels = []
all_preds = []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())

# =========================
# 8. Tính accuracy
# =========================
overall_acc = accuracy_score(all_labels, all_preds)
print(f"\nTest Accuracy: {overall_acc:.4f}")

# =========================
# 9. Accuracy từng class
# =========================
all_labels = np.array(all_labels)
all_preds = np.array(all_preds)

print("\nAccuracy theo từng class:")
for i, class_name in enumerate(class_names):
    idx = (all_labels == i)
    if np.sum(idx) == 0:
        acc = 0.0
    else:
        acc = np.mean(all_preds[idx] == all_labels[idx])
    correct = np.sum(all_preds[idx] == all_labels[idx])
    total = np.sum(idx)
    print(f"{class_name:25s}  Acc: {acc:.4f}  ({correct}/{total})")

# =========================
# 10. Classification Report (optional)
# =========================
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=class_names, digits=4))



Test Accuracy: 0.9213

Accuracy theo từng class:
Achalinus_rufescens        Acc: 1.0000  (13/13)
Achalinus_spinalis         Acc: 0.9286  (13/14)
Acrochordus_granulatus     Acc: 0.9535  (41/43)
Acrochordus_javanicus      Acc: 0.8750  (7/8)
Ahaetulla_fusca            Acc: 0.8148  (22/27)
Ahaetulla_prasina          Acc: 0.9500  (19/20)
Ahaetulla_rufusoculara     Acc: 1.0000  (9/9)
Amphiesma_stolatum         Acc: 0.9474  (72/76)
Blue-lipped_sea_krait      Acc: 0.9756  (80/82)
Boiga_cyanea               Acc: 0.9747  (77/79)
Boiga_guangxiensis         Acc: 0.6250  (5/8)
Boiga_jaspidea             Acc: 1.0000  (41/41)
Boiga_kraepelini           Acc: 0.9104  (61/67)
Boiga_multomaculata        Acc: 0.9800  (98/100)
Boiga_siamensis            Acc: 0.9355  (29/31)
Bungarus_candidus          Acc: 0.9296  (66/71)
Bungarus_fasciatus         Acc: 1.0000  (86/86)
Calamaria_septentrionalis  Acc: 0.8182  (18/22)
Calliophis_maculiceps      Acc: 0.9444  (17/18)
Calloselasma_rhodostoma    Acc: 0.9510  (97

In [None]:
import matplotlib.pyplot as plt

epochs = range(1, num_epochs+1)

plt.figure(figsize=(12,5))

# Loss
plt.subplot(1,2,1)
plt.plot(epochs, history["train_loss"], 'o-', label="Train Loss")
plt.plot(epochs, history["val_loss"], 's-', label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Train vs Val Loss")
plt.legend()
plt.grid(True)

# Accuracy
plt.subplot(1,2,2)
plt.plot(epochs, history["train_acc"], 'o-', label="Train Acc")
plt.plot(epochs, history["val_acc"], 's-', label="Val Acc")
plt.axhline(y=overall_acc, color='r', linestyle='--', label=f"Test Acc = {overall_acc:.4f}")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Train vs Val Accuracy")
plt.legend()
plt.grid(True)

plt.show()


In [8]:
print("\n⚠️ Các class dưới 70% accuracy:")
for i, class_name in enumerate(class_names):
    idx = (all_labels == i)
    if np.sum(idx) == 0:
        acc = 0.0
        correct, total = 0, 0
    else:
        acc = np.mean(all_preds[idx] == all_labels[idx])
        correct = np.sum(all_preds[idx] == all_labels[idx])
        total = np.sum(idx)

    if acc < 0.7:  # dưới 70%
        print(f"{class_name:25s}  Acc: {acc:.4f}  ({correct}/{total})")



⚠️ Các class dưới 70% accuracy:
Boiga_guangxiensis         Acc: 0.6250  (5/8)
Erpeton_tentaculatum       Acc: 0.4444  (4/9)
Hebius_khasiensis          Acc: 0.6000  (3/5)
Lycodon_chapaensis         Acc: 0.5000  (5/10)
Lycodon_truongi            Acc: 0.2500  (2/8)
Naja_fuxi                  Acc: 0.6000  (6/10)
Oligodon_cyclurus          Acc: 0.6000  (6/10)
Pareas_berdmorei           Acc: 0.6429  (18/28)
Pareas_hamptoni            Acc: 0.4286  (3/7)
Pareas_monticola           Acc: 0.4000  (2/5)
Ptyas_mucosa               Acc: 0.6667  (10/15)
Ptyas_nigromarginata       Acc: 0.6667  (4/6)
Sinomicrurus_peinani       Acc: 0.3750  (3/8)
Trimeresurus_albolabris    Acc: 0.6875  (11/16)


In [9]:
import torch
from torchvision import transforms
from PIL import Image
import requests
from io import BytesIO
from torchvision.models import convnext_tiny
import torch.nn as nn
# =====================
# Load model
# =====================
num_classes = 124
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# convnext_tiny

model = convnext_tiny(weights=None)  # không load pretrained vì đã có weight
model.classifier[2] = nn.Linear(model.classifier[2].in_features, num_classes)
model.load_state_dict(torch.load("/kaggle/input/model-best/convnext_tiny_best.pth", map_location=device))
model = model.to(device)
model.eval()

# =====================
# Transform cho ảnh test
# =====================
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # resize giống input khi train
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# =====================
# Predict từ link ảnh
# =====================
def predict_from_url(url, class_names):
    response = requests.get(url)
    img = Image.open(BytesIO(response.content)).convert("RGB")
    img_tensor = transform(img).unsqueeze(0).to(device)  # thêm batch dimension

    with torch.no_grad():
        outputs = model(img_tensor)
        probs = torch.softmax(outputs, dim=1)
        pred_idx = torch.argmax(probs, dim=1).item()
    
    print(f"Predicted class: {class_names[pred_idx]} (prob = {probs[0][pred_idx]:.4f})")
    return class_names[pred_idx]

# =====================
# Ví dụ chạy thử
# =====================
# Lấy tên class từ ImageFolder
class_names = image_datasets['train'].classes  

url = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSFuI1z0kcbSqY4aipR3sB0NoiGcbfhirWKoc4cSn7hCxP1pdY70jOmFenB74wUD6ZSYFRjzd38FCzkeogaji8JvEjQeiL_4-KeTbKVe1E"
predict_from_url(url, class_names)


Predicted class: Lycodon_chapaensis (prob = 0.9232)


'Lycodon_chapaensis'