In [None]:
import timm
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import os 
import gdown
from tqdm import tqdm
import copy
from sklearn.metrics import confusion_matrix
import seaborn as sns

In [None]:
url = "https://drive.google.com/file/d/1HgJT06Y3QtlejY9cPKWYnCi9wfr0UHId/view?usp=sharing"
output = "/kaggle/working/dataset.zip"
gdown.download(url=url, output=output, fuzzy=True)

In [None]:
from zipfile import ZipFile
  
with ZipFile("/kaggle/working/dataset.zip", 'r') as zObject:
      zObject.extractall(
        path='./')

# Preposesing Data

In [None]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), 
    transforms.RandomAffine(degrees=30, translate=(0.1, 0.1), shear=15),
    transforms.RandomPerspective(distortion_scale=0.3, p=0.3),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.3),
    transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

# Transform สำหรับ validation
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

train_dir = '/kaggle/working/datasets/Images/Train'
val_dir = '/kaggle/working/datasets/Images/Validation'
test_dir = '/kaggle/working/datasets/Images/Test'

train_dataset = ImageFolder(root=train_dir, transform=train_transform)
val_dataset = ImageFolder(root=val_dir, transform=val_transform)
test_dataset = ImageFolder(root=test_dir, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
import torchvision
def denormalize(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m) 
    return tensor
data_iter = iter(train_loader)
images, labels = next(data_iter)
denorm_images = torch.stack([denormalize(img.clone(), mean, std) for img in images])

# แปลงเป็น Grid และแสดงผล
img_grid = torchvision.utils.make_grid(denorm_images, nrow=8, padding=2, normalize=False)
plt.figure(figsize=(12, 6))
plt.imshow(np.transpose(img_grid.numpy(), (1, 2, 0))) 
plt.axis('off')
plt.show()

# CNN_MODEL

In [None]:
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class CrossStageAttention(nn.Module):
    def __init__(self, channels, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.qkv = nn.Linear(channels, channels * 3)
        self.proj = nn.Linear(channels, 64)
        self.scale = (channels // num_heads) ** -0.5

    def forward(self, x_list):
        B = x_list[0].shape[0]
        feats = [f.view(B, f.shape[1], -1).permute(0, 2, 1) for f in x_list]
        x = torch.cat(feats, dim=1)

        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.reshape(B, -1, t.shape[-1]).unsqueeze(1), qkv)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = (attn @ v).squeeze(1)
        return self.proj(x.mean(dim=1))

class DynamicFeatureReducer(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, 128, kernel_size=1)
        self.se = SEBlock(128)
        self.pool = nn.AdaptiveAvgPool2d((7, 7))  # ลด spatial size

    def forward(self, x):
        x = self.se(self.conv(x))
        return self.pool(x)

# โมเดลหลัก 
class CassavaDiseaseModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # 1. ใช้ ConvNeXtV2 Base (Pretrained)
        self.backbone = timm.create_model('convnextv2_base.fcmae_ft_in22k_in1k', pretrained=True, features_only=True)
        self.backbone.set_grad_checkpointing(True)  # ลด VRAM

        # 2. Freeze ชั้นล่าง
        for name, param in self.backbone.named_parameters():
            if 'stages.3' in name or 'norm' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False

        self.stage_indices = [2, 3] 
        self.feature_channels = [self.backbone.feature_info[i]['num_chs'] for i in self.stage_indices]

        self.reducers = nn.ModuleList([
            DynamicFeatureReducer(c) for c in self.feature_channels
        ])
        
        self.cross_attention = CrossStageAttention(channels=128)

        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(64, 128),
            nn.GELU(),
            nn.TransformerEncoderLayer(d_model=128, nhead=2, dim_feedforward=128),
            nn.Linear(128, num_classes)
        )

        self.temperature = nn.Parameter(torch.ones(1))

    def forward(self, x):
        features = self.backbone(x)
        reduced = [self.reducers[i](features[idx]) for i, idx in enumerate(self.stage_indices)]
        x = self.cross_attention(reduced)
        x = self.classifier(x)
        return x / self.temperature



# Loss Function

In [None]:
class LabelSmoothedCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, logits, target):
        log_probs = logits.log_softmax(dim=-1)
        nll = -log_probs.gather(dim=-1, index=target.unsqueeze(1)).squeeze(1)
        smooth_loss = -log_probs.mean(dim=-1)
        return (1.0 - self.smoothing) * nll + self.smoothing * smooth_loss

class FocalLoss(nn.Module):
    def __init__(self, gamma=2, weight=None):
        super().__init__()
        self.gamma = gamma
        self.weight = weight

    def forward(self, inputs, targets):
        ce_loss = nn.CrossEntropyLoss(weight=self.weight, reduction='none')(inputs, targets)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
        return focal_loss

criterion = lambda x, y: 0.7 * LabelSmoothedCrossEntropy(smoothing=0.1)(x, y) + 0.3 * FocalLoss(gamma=2)(x, y)

#  Train

In [None]:
def train_model(model, train_loader, val_loader, num_epochs, device, save_path='/kaggle/working/best_model.pth'):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    scaler = GradScaler()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 30)

        # Training
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in tqdm(train_loader, desc="Training"):
            images, labels = images.to(device), labels.to(device)

            # ตรวจสอบ Label Shape
            if len(labels.shape) > 1 and labels.shape[1] > 1:
                _, labels = labels.max(dim=1)  # เปลี่ยน one-hot เป็น class index

            optimizer.zero_grad()
            with torch.amp.autocast('cuda'):  # ใช้รูปแบบใหม่
                outputs = model(images)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * images.size(0)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_loss = running_loss / total
        train_acc = correct / total
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")

        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc="Validation"):
                images, labels = images.to(device), labels.to(device)

                # ตรวจสอบ Label Shape
                if len(labels.shape) > 1 and labels.shape[1] > 1:
                    _, labels = labels.max(dim=1)

                outputs = model(images)
                loss = criterion(outputs, labels)

                _, preds = torch.max(outputs, 1)
                val_loss += loss.item() * images.size(0)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        val_loss /= total
        val_acc = correct / total
        print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(model.state_dict(), save_path)
            print(f"✅ Saved Best Model: {save_path} (Val Acc: {best_acc:.4f})")

    model.load_state_dict(best_model_wts)
    return model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CassavaDiseaseModel(num_classes=5).to(device)

In [None]:
num_epochs = 100
save_path = '/kaggle/working/best_model.pth'
model = train_model(model, train_loader, val_loader, num_epochs, device, save_path)

In [None]:
torch.save(model.state_dict(), '/kaggle/working/final_convnextv2_model.pth')
print("Final model saved to /kaggle/working/final_convnextv2_model.pth")

# EVALUATE Confusion Metric and F1 Score 

In [None]:
# ฟังก์ชันประเมินโมเดล
def evaluate_model(model, test_loader):
    model.eval()
    test_preds = []
    test_true = []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            test_preds.extend(preds.cpu().numpy())
            test_true.extend(labels.cpu().numpy())

    test_f1 = f1_score(test_true, test_preds, average='macro')
    print(f'Test F1 Score: {test_f1:.4f}')

    # สร้าง Confusion Matrix
    cm = confusion_matrix(test_true, test_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

    return test_f1, cm

In [None]:
# โหลดโมเดลและประเมิน
model.load_state_dict(torch.load("/kaggle/working/final_convnextv2_model.pth"))
test_f1, cm = evaluate_model(model, test_loader)

In [None]:
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)

        y_true.extend(labels.cpu().numpy())
        y_pred.extend(predicted.cpu().numpy())

# Accuracy
acc = accuracy_score(y_true, y_pred)
print(f"\ nAccuracy: {acc:.f}")

# Classification report
print("\n Classification Report:")
print(classification_report(y_true, y_pred, target_names=test_dataset.classes, digits=4))

เปรียบเทียบกับ https://www.kaggle.com/code/pradiptadatta/cassava-leaf-disease-best-quality
                                     precision    recall  f1-score   support

     Cassava Bacterial Blight (CBB)       0.71      0.65      0.68       311
     Cassava Brown Streak Disease (CBSD)  0.86      0.82      0.84       726
     Cassava Green Mottle (CGM)           0.83      0.79      0.81       632
     Cassava Mosaic Disease (CMD)         0.95      0.97      0.96      3163
                            Healthy       0.76      0.78      0.77       579

                           accuracy                           0.89      5411
                          macro avg       0.82      0.80      0.81      5411
                       weighted avg       0.89      0.89      0.89      5411