# Import Libraly

In [None]:
import timm
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
import numpy as np
import seaborn as sns
import shutil
import matplotlib.pyplot as plt
import os 
import gdown
from tqdm import tqdm
import copy
from sklearn.metrics import confusion_matrix
import seaborn as sns
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score

In [None]:
url1 = "https://drive.google.com/file/d/19u7gzEMboCQHjU6umzsrym5ax9O2JvhM/view?usp=sharing"
output1 = "/kaggle/working/dataset.zip"
gdown.download(url=url1, output=output1, fuzzy=True)

In [None]:
from zipfile import ZipFile
  
with ZipFile("/kaggle/working/dataset.zip", 'r') as zObject:
      zObject.extractall(
        path='./')

# Data preprocessing

In [None]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((448, 448)), 
    transforms.RandomAffine(degrees=30, translate=(0.1, 0.1), shear=15),
    transforms.RandomPerspective(distortion_scale=0.3, p=0.3),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

# Transform สำหรับ validation
val_transform = transforms.Compose([
    transforms.Resize((448, 448)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

test_transform = transforms.Compose([
    transforms.Resize((448, 448)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

train_dir = '/kaggle/working/datasets_and_pseudo_data/Train'
val_dir = '/kaggle/working/datasets_and_pseudo_data/Validation'
test_dir = '/kaggle/working/datasets_and_pseudo_data/Test'

train_dataset = ImageFolder(root=train_dir, transform=train_transform)
val_dataset = ImageFolder(root=val_dir, transform=val_transform)
test_dataset = ImageFolder(root=test_dir, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns

def count_images_per_class(dataset, dataset_name):
    class_names = dataset.classes
    class_counts = Counter(dataset.targets)
    counts = {class_names[idx]: count for idx, count in class_counts.items()}
    
    print(f"\n{dataset_name} Class Distribution:")
    for class_name, count in counts.items():
        print(f"Class {class_name}: {count} images")
    
    return counts, class_names

train_counts, train_class_names = count_images_per_class(train_dataset, "Train")
val_counts, val_class_names = count_images_per_class(val_dataset, "Validation")
test_counts, test_class_names = count_images_per_class(test_dataset, "Test")

In [None]:
# Plot class
def plot_class_distribution(counts, class_names, dataset_name):
    plt.figure(figsize=(10, 6))
    sns.barplot(x=list(counts.keys()), y=list(counts.values()))
    plt.xticks(rotation=45)
    plt.xlabel('Class')
    plt.ylabel('Number of Images')
    plt.title(f'{dataset_name} Class Distribution')
    plt.tight_layout()
    plt.savefig(f'/kaggle/working/{dataset_name.lower()}_class_distribution.png')
    plt.show()

plot_class_distribution(train_counts, train_class_names, "Train")
plot_class_distribution(val_counts, val_class_names, "Validation")
plot_class_distribution(test_counts, test_class_names, "Test")

In [None]:
import torchvision
def denormalize(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m) 
    return tensor
data_iter = iter(train_loader)
images, labels = next(data_iter)
denorm_images = torch.stack([denormalize(img.clone(), mean, std) for img in images])

# แปลงเป็น Grid และแสดงผล
img_grid = torchvision.utils.make_grid(denorm_images, nrow=8, padding=2, normalize=False)
plt.figure(figsize=(12, 6))
plt.imshow(np.transpose(img_grid.numpy(), (1, 2, 0))) 
plt.axis('off')
plt.show()

# CNN_MODEL

In [None]:
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b, seq_len, c = x.size()
        y = x.mean(dim=1) 
        y = self.fc(y)    
        return x * y.unsqueeze(1).expand_as(x)

# Cross-Stage Attention
class CrossStageAttention(nn.Module):
    def __init__(self, channels, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.qkv = nn.Linear(channels, channels * 3)
        self.proj = nn.Linear(channels, 64)
        self.scale = (channels // num_heads) ** -0.5
    
    def forward(self, x_list):
        B = x_list[0].shape[0]
        x = torch.cat(x_list, dim=1)  
        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.reshape(B, -1, self.num_heads, t.shape[-1]//self.num_heads).permute(0, 2, 1, 3), qkv)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, -1, x.shape[-1])
        
        return self.proj(x.mean(dim=1))

# Dynamic Feature Reducer
class DynamicFeatureReducer(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.proj = nn.Linear(in_channels, 128)
        self.se = SEBlock(128)
        self.norm = nn.LayerNorm(128)
    
    def forward(self, x):
        x = self.proj(x)     
        x = self.norm(x)
        x = self.se(x)
        return x

# Main Model 
class vit_base_patch32_model(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        
        self.backbone = timm.create_model(
            'vit_base_patch32_clip_448.laion2b_ft_in12k_in1k', 
            pretrained=True,
            num_classes=0  
        )
                if hasattr(self.backbone, 'set_grad_checkpointing'):
            self.backbone.set_grad_checkpointing(True)
        

        for name, param in self.backbone.named_parameters():
            if any(layer in name for layer in ['blocks.10', 'blocks.11', 'norm', 'head']):
                param.requires_grad = True
            else:
                param.requires_grad = False
        
        self.hidden_dim = 768
        self.feature_layers = [9, 11]
        self.reducers = nn.ModuleList([
            DynamicFeatureReducer(self.hidden_dim) for _ in self.feature_layers
        ])
        
        self.cross_attention = CrossStageAttention(channels=128)
        self.classifier = nn.Sequential(
            nn.Linear(64, 128),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )
        
        self.temperature = nn.Parameter(torch.ones(1))
    
    def forward_features(self, x):
        x = self.backbone.patch_embed(x)
        x = self.backbone._pos_embed(x)
        x = self.backbone.patch_drop(x)
        x = self.backbone.norm_pre(x)
        
        intermediate_features = {}
    
        for i, block in enumerate(self.backbone.blocks):
            x = block(x)
            if i in self.feature_layers:
                intermediate_features[i] = x
        
        return intermediate_features
    
    def forward(self, x, return_features=False):
        intermediate_features = self.forward_features(x)
        
        reduced_features = []
        for i, layer_idx in enumerate(self.feature_layers):
            feat = intermediate_features[layer_idx]
            reduced = self.reducers[i](feat)
            reduced_features.append(reduced)
        
        x = self.cross_attention(reduced_features)
        
        features = self.classifier[:3](x)  
        logits = self.classifier[3:](features)  
        logits = logits / self.temperature
        
        if return_features:
            return features
        return logits

# Loss Function

In [None]:
class LabelSmoothedCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1, reduction='mean'):
        super().__init__()
        self.smoothing = smoothing
        self.reduction = reduction

    def forward(self, logits, target):
        log_probs = logits.log_softmax(dim=-1)
        nll = -log_probs.gather(dim=-1, index=target.unsqueeze(1)).squeeze(1)
        smooth_loss = -log_probs.mean(dim=-1)
        loss = (1.0 - self.smoothing) * nll + self.smoothing * smooth_loss
        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

class FocalLoss(nn.Module):
    def __init__(self, gamma=2, reduction='mean'):
        super().__init__()
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss)
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

def combined_loss(logits, targets):
    ls = LabelSmoothedCrossEntropy(smoothing=0.1, reduction='mean')(logits, targets)
    fl = FocalLoss(gamma=2, reduction='mean')(logits, targets)
    return 0.7 * ls + 0.3 * fl

#  Train

In [None]:
from torch.amp import GradScaler, autocast
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report

# Early Stopping
class EarlyStopping:
    def __init__(self, patience=5, delta=0, path='checkpoint.pt', verbose=True, monitor='val_loss'):
        self.patience = patience
        self.delta = delta
        self.path = path
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.monitor = monitor
        
        if self.monitor == 'val_loss':
            self.val_score_min = float('inf')
            self.improvement_check = lambda score, best: score < best - self.delta
        else:  # 'val_acc'
            self.val_score_min = -float('inf')
            self.improvement_check = lambda score, best: score > best + self.delta
            
    def __call__(self, val_score, model):
        score = -val_score if self.monitor == 'val_loss' else val_score
        
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_score, model)
        elif not self.improvement_check(score, self.best_score):
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_score, model)
            self.counter = 0
            
    def save_checkpoint(self, val_score, model):
        if self.verbose:
            score_label = "loss" if self.monitor == "val_loss" else "accuracy"
            print(f'Validation {score_label} improved ({self.val_score_min:.6f} --> {val_score:.6f}). Saving model...')
        torch.save(model.state_dict(), self.path)
        self.val_score_min = val_score

# Validation Function
def validate(model, val_loader, device, criterion=None):
    if criterion is None:
        criterion = nn.CrossEntropyLoss()
        
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    val_pbar = tqdm(val_loader, desc="Validation")
    
    with torch.no_grad():
        for images, labels in val_pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            curr_loss = val_loss / total
            curr_acc = 100.0 * correct / total
            val_pbar.set_postfix(loss=f"{curr_loss:.4f}", acc=f"{curr_acc:.2f}%")
    
    return val_loss / total, correct / total

def train_model(model, train_loader, val_loader, device, 
                num_epochs=50, patience=7, delta=0.001, 
                checkpoint_path='best_model.pth', final_model_path='Cassava_convnext_large.pth',
                criterion=None):  
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    scaler = GradScaler() 
    
    if criterion is None:
        criterion = nn.CrossEntropyLoss()
    
    early_stopping = EarlyStopping(patience=patience, delta=delta, path=checkpoint_path, 
                                  verbose=True, monitor='val_acc')
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
        for images, labels in train_pbar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            with autocast('cuda'): 
                outputs = model(images)
                loss = criterion(outputs, labels)  # ใช้ criterion ที่ส่งเข้ามา

            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            running_loss += loss.item() * images.size(0)
            total += labels.size(0)
            
            # อัพเดท progress bar 
            curr_loss = running_loss / total
            curr_acc = 100.0 * correct / total
            train_pbar.set_postfix(loss=f"{curr_loss:.4f}", acc=f"{curr_acc:.2f}%")

        # คำนวณและบันทึกค่า train loss และ accuracy
        train_loss = running_loss / total
        train_acc = correct / total
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        
        # Validation phase
        val_loss, val_acc = validate(model, val_loader, device, criterion)  
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

        # ใช้ Early Stopping
        early_stopping(val_acc, model)
        
        if early_stopping.early_stop:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break

        scheduler.step()
    
    # โหลดโมเดลที่ดีที่สุด
    model.load_state_dict(torch.load(checkpoint_path))
    
    # บันทึกโมเดลสุดท้าย
    torch.save(model.state_dict(), final_model_path)
    print(f"Final model saved to {final_model_path}")
    
    actual_epochs = len(history['val_loss'])
    
    return model, history, actual_epochs

In [None]:
def plot_training_history(history, actual_epochs):
    epochs = range(1, actual_epochs + 1)

    plt.figure(figsize=(12, 5))
    
    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, history['train_loss'], 'b-', label='Training Loss')
    plt.plot(epochs, history['val_loss'], 'r-', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, history['train_acc'], 'b-', label='Training Accuracy')
    plt.plot(epochs, history['val_acc'], 'r-', label='Validation Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.show()
    plt.close()
    print(f"Training history graph saved as 'training_history.png' (trained for {actual_epochs} epochs)")


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 5
model = vit_base_patch32_model(num_classes=num_classes).to(device)
criterion = combined_loss

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model, device_ids=[0, 1])

In [None]:
model, history, actual_epochs = train_model(
    model, 
    train_loader, 
    val_loader, 
    device,
    criterion=criterion,
    num_epochs=100, 
    patience=5, 
    delta=0.001,
    checkpoint_path='best_model.pth', 
    final_model_path='vit_base_patch32.pth'
)

In [None]:
torch.save(model.state_dict(), 'model_state_dict.pth')

# Plot_Grap_Training

In [None]:
plot_training_history(history, actual_epochs)

# EVALUATE Confusion Metric and F1 Score

In [None]:
def final_evaluate(model, test_loader, device):
    model.eval()
    y_true = []
    y_pred = []
    
    test_pbar = tqdm(test_loader, desc="Testing")
    with torch.no_grad():
        for images, labels in test_pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average='macro')
    cm = confusion_matrix(y_true, y_pred)
    print(f"Test Accuracy: {acc:.4f}")
    print(f"Test F1 Score: {f1:.4f}")
    print("Classification Report:")
    print(classification_report(y_true, y_pred, digits=4))
    # Confusion Matrix
    class_names = ['CBB', 'CBSD', 'CGM', 'CMD', 'Healthy']
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig('confusion_matrix.png')
    plt.show()
    plt.close()
    print("Confusion matrix saved as 'confusion_matrix.png'")
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    print("\nError Analysis per Class:")
    for i, class_name in enumerate(class_names):
        true_count = sum(1 for label in all_labels if label == i)
        correct_count = sum(1 for true, pred in zip(all_labels, all_preds) if true == i and pred == i)
        
        if true_count > 0:
            accuracy = correct_count / true_count * 100
            error_rate = 100 - accuracy
            print(f"{class_name}: {correct_count}/{true_count} correct ({accuracy:.1f}% accuracy, {error_rate:.1f}% error)")


    return acc, f1, cm
    

In [None]:
test_acc, test_f1, test_cm = final_evaluate(model, test_loader, device)

เปรียบเทียบกับ https://www.kaggle.com/code/pradiptadatta/cassava-leaf-disease-best-quality
                                     precision    recall  f1-score   support

     Cassava Bacterial Blight (CBB)       0.71      0.65      0.68       311
     Cassava Brown Streak Disease (CBSD)  0.86      0.82      0.84       726
     Cassava Green Mottle (CGM)           0.83      0.79      0.81       632
     Cassava Mosaic Disease (CMD)         0.95      0.97      0.96      3163
                            Healthy       0.76      0.78      0.77       579

                           accuracy                           0.89      5411
                          macro avg       0.82      0.80      0.81      5411
                       weighted avg       0.89      0.89      0.89      5411

# Error Analysis

In [None]:
import os
from collections import defaultdict

def show_misclassified_by_class(model, test_loader, device, class_names, max_per_class=12, save_dir="misclassified_by_class"):
    
    misclassified_by_class = defaultdict(lambda: {
        'images': [],
        'true_labels': [],
        'pred_labels': []
    })
    
    os.makedirs(save_dir, exist_ok=True)
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            
            # หาตัวที่ผิด
            wrong_indices = (preds != labels).nonzero(as_tuple=True)[0]
            
            for idx in wrong_indices:
                true_class = labels[idx].item()
                pred_class = preds[idx].item()
                
                # เก็บข้อมูลตามคลาสจริง
                if len(misclassified_by_class[true_class]['images']) < max_per_class:
                    misclassified_by_class[true_class]['images'].append(images[idx].cpu())
                    misclassified_by_class[true_class]['true_labels'].append(true_class)
                    misclassified_by_class[true_class]['pred_labels'].append(pred_class)
    
    for class_idx, data in misclassified_by_class.items():
        if len(data['images']) == 0:
            continue
            
        class_name = class_names[class_idx]
        n_images = len(data['images'])
        
        cols = min(3, n_images)
        rows = (n_images + cols - 1) // cols
        
        plt.figure(figsize=(12, 4 * rows))
        plt.suptitle(f'True Class: {class_name}', fontsize=16, fontweight='bold')
        
        for i in range(n_images):
            image = data['images'][i]
            true_label = class_names[data['true_labels'][i]]
            pred_label = class_names[data['pred_labels'][i]]
            
            img_np = image.permute(1, 2, 0).numpy()
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            img_np = img_np * std + mean
            img_np = np.clip(img_np, 0, 1)  
            
            plt.subplot(rows, cols, i + 1)
            plt.imshow(img_np)
            plt.title(f"Predicted: {pred_label}", fontsize=12, color='red')
            plt.axis('off')
            
            save_path = os.path.join(save_dir, f"{class_name}_img_{i+1}_pred_{pred_label}.png")
            plt.imsave(save_path, img_np)
        
        plt.tight_layout()
        plt.show()
        
        print(f"Class '{class_name}': Found {n_images} misclassified images")
        
        pred_counts = defaultdict(int)
        for pred_idx in data['pred_labels']:
            pred_counts[class_names[pred_idx]] += 1
        
        print(f"  Most confused with:")
        for pred_class, count in sorted(pred_counts.items(), key=lambda x: x[1], reverse=True):
            print(f"    - {pred_class}: {count} times")
        print("-" * 50)


In [None]:
class_names = ['CBB', 'CBSD', 'CGM', 'CMD', 'Healthy']
show_misclassified_by_class(model, test_loader, device, class_names, max_per_class=12)