In [1]:
# !mkdir -p ~/.kaggle
# !mv kaggle.json ~/.kaggle/
# !chmod 600 ~/.kaggle/kaggle.json

# !kaggle datasets download -d trumanrase/rice-leaf-diseases -p ./data --unzip

In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, roc_curve
import matplotlib.pyplot as plt
import numpy as np

In [3]:
class Cfg:
  data_root = "./data/rice_disease_val_test"
  img_size = 224
  batch_size = 32
  num_workers = 0
  epochs = 20
  lr = 1e-3
  device = "cuda" if torch.cuda.is_available() else "cpu"

cfg = Cfg()

In [4]:
def build_transforms(img_size):
    train_tf = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    val_tf = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    return train_tf, val_tf

In [5]:
def build_dataloaders(data_root, img_size, batch_size, workers=0):
    train_tf, val_tf = build_transforms(img_size)
    train_ds = datasets.ImageFolder(os.path.join(data_root,"train"), transform=train_tf)
    val_ds   = datasets.ImageFolder(os.path.join(data_root,"val"),   transform=val_tf)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=workers)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=workers)
    return train_loader, val_loader, train_ds.classes

train_loader, val_loader, class_names = build_dataloaders(cfg.data_root, cfg.img_size, cfg.batch_size, cfg.num_workers)
cfg.num_classes = len(class_names)
print(f"Classes : {class_names}")

Classes : ['bacterial_leaf_blight', 'bacterial_leaf_streak', 'bacterial_panicle_blight', 'blast', 'brown_spot', 'dead_heart', 'downy_mildew', 'hispa', 'normal', 'rice_sheath_blight', 'smut', 'tungro']


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import mobilenet_v3_small, mobilenet_v3_large
import math

class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class CBAM(nn.Module):
    def __init__(self, channels, reduction=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = SEBlock(channels, reduction)
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size//2, bias=False),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.channel_attention(x)
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        spatial_input = torch.cat([avg_out, max_out], dim=1)
        spatial_weight = self.spatial_attention(spatial_input)
        x = x * spatial_weight
        return x

class MultiScaleFeatureExtractor(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MultiScaleFeatureExtractor, self).__init__()
        self.scale1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels // 4, kernel_size=1),
            nn.BatchNorm2d(out_channels // 4),
            nn.ReLU(inplace=True)
        )
        self.scale2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels // 4, kernel_size=3, padding=1, dilation=1),
            nn.BatchNorm2d(out_channels // 4),
            nn.ReLU(inplace=True)
        )
        self.scale3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels // 4, kernel_size=3, padding=2, dilation=2),
            nn.BatchNorm2d(out_channels // 4),
            nn.ReLU(inplace=True)
        )
        self.scale4 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels // 4, kernel_size=3, padding=4, dilation=4),
            nn.BatchNorm2d(out_channels // 4),
            nn.ReLU(inplace=True)
        )
        self.fusion = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        scale1 = self.scale1(x)
        scale2 = self.scale2(x)
        scale3 = self.scale3(x)
        scale4 = self.scale4(x)
        multi_scale = torch.cat([scale1, scale2, scale3, scale4], dim=1)
        fused = self.fusion(multi_scale)
        return fused

class HybridPlantDiseaseModel(nn.Module):
    def __init__(self, num_classes=10, backbone='small'):
        super(HybridPlantDiseaseModel, self).__init__()
        
        if backbone == 'small':
            self.backbone = mobilenet_v3_small(pretrained=True)
            backbone_channels = [16, 24, 40, 96, 576]
        else:
            self.backbone = mobilenet_v3_large(pretrained=True)
            backbone_channels = [16, 24, 40, 112, 960]
        
        self.backbone.classifier = nn.Identity()
        self.feature_extractors = nn.ModuleList()
        
        self.ms_extractor_1 = MultiScaleFeatureExtractor(backbone_channels[1], 128)
        self.ms_extractor_2 = MultiScaleFeatureExtractor(backbone_channels[2], 256)
        self.ms_extractor_3 = MultiScaleFeatureExtractor(backbone_channels[3], 512)
        
        self.attention_1 = CBAM(128)
        self.attention_2 = CBAM(256)
        self.attention_3 = CBAM(512)
        
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        feature_dim = 128 + 256 + 512
        
        self.feature_fusion = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2)
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(256, num_classes)
        )
        

    def _extract_backbone_features(self, x):
        features = []
        x = self.backbone.features[0](x)
        features.append(x)
        
        for i in range(1, len(self.backbone.features)):
            x = self.backbone.features[i](x)
            if i in [2, 5, 11]:
                features.append(x)
        
        return features

    def forward(self, x):
        backbone_features = self._extract_backbone_features(x)
        
        ms_feat_1 = self.ms_extractor_1(backbone_features[1])
        ms_feat_2 = self.ms_extractor_2(backbone_features[2])
        ms_feat_3 = self.ms_extractor_3(backbone_features[3])
        
        att_feat_1 = self.attention_1(ms_feat_1)
        att_feat_2 = self.attention_2(ms_feat_2)
        att_feat_3 = self.attention_3(ms_feat_3)
        
        global_feat_1 = self.global_pool(att_feat_1).view(att_feat_1.size(0), -1)
        global_feat_2 = self.global_pool(att_feat_2).view(att_feat_2.size(0), -1)
        global_feat_3 = self.global_pool(att_feat_3).view(att_feat_3.size(0), -1)
        
        fused_features = torch.cat([global_feat_1, global_feat_2, global_feat_3], dim=1)
        fused_features = self.feature_fusion(fused_features)
        
        output = self.classifier(fused_features)
        return output

def get_model_info(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Model size: {total_params * 4 / 1024 / 1024:.2f} MB")

    
model = HybridPlantDiseaseModel(num_classes=cfg.num_classes, backbone='small')
model.to(cfg.device)

get_model_info(model)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)



Total parameters: 2,352,216
Trainable parameters: 2,352,216
Model size: 8.97 MB


In [7]:
def train_model(model, loader, optimizer, criterion, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)
    return running_loss / total, correct / total

In [8]:
def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    all_labels, all_preds, all_probs = [], [], []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * imgs.size(0)
            probs = torch.softmax(outputs, dim=1)
            _, preds = probs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs[:,1].cpu().numpy() if probs.size(1) == 2 else np.max(probs.cpu().numpy(), axis=1))
    return (running_loss / total,
            correct / total,
            np.array(all_labels),
            np.array(all_preds),
            np.array(all_probs))

In [9]:
def compute_metrics(y_true, y_pred, y_prob, average="macro"):
    acc  = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, average=average, zero_division=0)
    rec  = recall_score(y_true, y_pred, average=average, zero_division=0)
    f1   = f1_score(y_true, y_pred, average=average, zero_division=0)

    cm = confusion_matrix(y_true, y_pred)

    specificity_per_class = []
    for i in range(len(cm)):
        tp = cm[i,i]
        fn = cm[i,:].sum() - tp
        fp = cm[:,i].sum() - tp
        tn = cm.sum() - (tp + fp + fn)
        specificity_per_class.append(tn / (tn + fp + 1e-6))
    specificity = np.mean(specificity_per_class)

    return {
        "accuracy": acc,
        "precision": prec,
        "recall (sensitivity)": rec,
        "specificity": specificity,
        "f1-score": f1,
    }


In [10]:
from tqdm import tqdm
for epoch in tqdm(range(cfg.epochs)):
    train_loss, train_acc = train_model(model, train_loader, optimizer, criterion, cfg.device)
    val_loss, val_acc, y_true, y_pred, y_prob = evaluate(model, val_loader, criterion, cfg.device)
    metrics = compute_metrics(y_true, y_pred, y_prob, average="macro")
    print(f"Epoch {epoch+1}/{cfg.epochs}")
    print(f" Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f" Val   Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
    print(" Metrics:", metrics)

  5%|▌         | 1/20 [01:16<24:17, 76.70s/it]

Epoch 1/20
 Train Loss: 1.4947 | Train Acc: 0.5902
 Val   Loss: 1.4664 | Val Acc: 0.6402
 Metrics: {'accuracy': 0.6401952085181899, 'precision': 0.701335596949097, 'recall (sensitivity)': 0.5908299621224499, 'specificity': 0.9661652734044734, 'f1-score': 0.6059201346496089}


 10%|█         | 2/20 [02:29<22:24, 74.70s/it]

Epoch 2/20
 Train Loss: 1.0815 | Train Acc: 0.7829
 Val   Loss: 1.0306 | Val Acc: 0.8119
 Metrics: {'accuracy': 0.8118899733806566, 'precision': 0.8321393953954318, 'recall (sensitivity)': 0.7670881206811476, 'specificity': 0.9823644636034526, 'f1-score': 0.777353188947458}


 15%|█▌        | 3/20 [03:42<20:54, 73.79s/it]

Epoch 3/20
 Train Loss: 0.9404 | Train Acc: 0.8395
 Val   Loss: 0.9104 | Val Acc: 0.8509
 Metrics: {'accuracy': 0.8509316770186336, 'precision': 0.8449846681543165, 'recall (sensitivity)': 0.833595578145445, 'specificity': 0.9861812935083206, 'f1-score': 0.8367252278764852}


 20%|██        | 4/20 [04:55<19:34, 73.38s/it]

Epoch 4/20
 Train Loss: 0.8455 | Train Acc: 0.8796
 Val   Loss: 0.8111 | Val Acc: 0.8949
 Metrics: {'accuracy': 0.8948535936113576, 'precision': 0.8771624720535014, 'recall (sensitivity)': 0.8872569969144002, 'specificity': 0.9903510523001011, 'f1-score': 0.8810028992253592}


 25%|██▌       | 5/20 [06:08<18:16, 73.12s/it]

Epoch 5/20
 Train Loss: 0.7905 | Train Acc: 0.9017
 Val   Loss: 0.8279 | Val Acc: 0.8824
 Metrics: {'accuracy': 0.8824312333629104, 'precision': 0.8866331710601935, 'recall (sensitivity)': 0.8684745543347253, 'specificity': 0.9889555089195833, 'f1-score': 0.8742647986721188}


 30%|███       | 6/20 [07:21<17:02, 73.04s/it]

Epoch 6/20
 Train Loss: 0.7718 | Train Acc: 0.9112
 Val   Loss: 0.8355 | Val Acc: 0.8793
 Metrics: {'accuracy': 0.8793256433007985, 'precision': 0.8831712642384607, 'recall (sensitivity)': 0.8509071884487168, 'specificity': 0.9886572666752262, 'f1-score': 0.8638747777861963}


 35%|███▌      | 7/20 [08:33<15:48, 73.00s/it]

Epoch 7/20
 Train Loss: 0.7235 | Train Acc: 0.9298
 Val   Loss: 0.7421 | Val Acc: 0.9197
 Metrics: {'accuracy': 0.919698314108252, 'precision': 0.9147534689633833, 'recall (sensitivity)': 0.906279823559179, 'specificity': 0.9925656235066391, 'f1-score': 0.9075208677924013}


 40%|████      | 8/20 [09:48<14:43, 73.66s/it]

Epoch 8/20
 Train Loss: 0.7028 | Train Acc: 0.9394
 Val   Loss: 0.7348 | Val Acc: 0.9219
 Metrics: {'accuracy': 0.9219165927240461, 'precision': 0.9235382010265365, 'recall (sensitivity)': 0.9152089083723225, 'specificity': 0.9927112853954906, 'f1-score': 0.9181948444169571}


 45%|████▌     | 9/20 [11:01<13:27, 73.40s/it]

Epoch 9/20
 Train Loss: 0.6929 | Train Acc: 0.9401
 Val   Loss: 0.7187 | Val Acc: 0.9321
 Metrics: {'accuracy': 0.9321206743566992, 'precision': 0.9317290873174922, 'recall (sensitivity)': 0.9162947625344878, 'specificity': 0.9935986912641054, 'f1-score': 0.9229282126668884}


 50%|█████     | 10/20 [12:14<12:13, 73.31s/it]

Epoch 10/20
 Train Loss: 0.6713 | Train Acc: 0.9520
 Val   Loss: 0.7409 | Val Acc: 0.9232
 Metrics: {'accuracy': 0.9232475598935226, 'precision': 0.9156431947554292, 'recall (sensitivity)': 0.9072188729899858, 'specificity': 0.9928630856407129, 'f1-score': 0.9095601130789955}


 55%|█████▌    | 11/20 [13:27<10:58, 73.16s/it]

Epoch 11/20
 Train Loss: 0.6543 | Train Acc: 0.9568
 Val   Loss: 0.7239 | Val Acc: 0.9246
 Metrics: {'accuracy': 0.9245785270629991, 'precision': 0.9281600889627418, 'recall (sensitivity)': 0.914664044785523, 'specificity': 0.9929649177850534, 'f1-score': 0.9191223701085044}


 60%|██████    | 12/20 [14:40<09:43, 72.95s/it]

Epoch 12/20
 Train Loss: 0.6533 | Train Acc: 0.9590
 Val   Loss: 0.7805 | Val Acc: 0.9064
 Metrics: {'accuracy': 0.9063886424134872, 'precision': 0.9151317188274954, 'recall (sensitivity)': 0.8968925203213834, 'specificity': 0.9911404978797401, 'f1-score': 0.9042193350194814}


 65%|██████▌   | 13/20 [15:53<08:31, 73.10s/it]

Epoch 13/20
 Train Loss: 0.6455 | Train Acc: 0.9613
 Val   Loss: 0.7317 | Val Acc: 0.9281
 Metrics: {'accuracy': 0.9281277728482697, 'precision': 0.9317610594215981, 'recall (sensitivity)': 0.9086757199035488, 'specificity': 0.9932285865968037, 'f1-score': 0.918504216515556}


 70%|███████   | 14/20 [17:06<07:18, 73.03s/it]

Epoch 14/20
 Train Loss: 0.6492 | Train Acc: 0.9590
 Val   Loss: 0.7224 | Val Acc: 0.9317
 Metrics: {'accuracy': 0.9316770186335404, 'precision': 0.9366302482788226, 'recall (sensitivity)': 0.9118990148705722, 'specificity': 0.9935496769699173, 'f1-score': 0.9229929153882813}


 75%|███████▌  | 15/20 [18:19<06:04, 72.92s/it]

Epoch 15/20
 Train Loss: 0.6309 | Train Acc: 0.9664
 Val   Loss: 0.6913 | Val Acc: 0.9450
 Metrics: {'accuracy': 0.9449866903283053, 'precision': 0.9374851132460598, 'recall (sensitivity)': 0.928944582834221, 'specificity': 0.9948512568837257, 'f1-score': 0.9327491648476626}


 80%|████████  | 16/20 [19:32<04:52, 73.18s/it]

Epoch 16/20
 Train Loss: 0.6393 | Train Acc: 0.9619
 Val   Loss: 0.7126 | Val Acc: 0.9330
 Metrics: {'accuracy': 0.9330079858030168, 'precision': 0.9218534787858448, 'recall (sensitivity)': 0.9272816237214836, 'specificity': 0.9937779796284986, 'f1-score': 0.9233105480961776}


 85%|████████▌ | 17/20 [20:46<03:39, 73.18s/it]

Epoch 17/20
 Train Loss: 0.6244 | Train Acc: 0.9671
 Val   Loss: 0.7030 | Val Acc: 0.9348
 Metrics: {'accuracy': 0.9347826086956522, 'precision': 0.930012562434479, 'recall (sensitivity)': 0.926686295983643, 'specificity': 0.9938770455269729, 'f1-score': 0.9277464665835122}


 90%|█████████ | 18/20 [22:00<02:27, 73.53s/it]

Epoch 18/20
 Train Loss: 0.6222 | Train Acc: 0.9696
 Val   Loss: 0.6875 | Val Acc: 0.9383
 Metrics: {'accuracy': 0.9383318544809228, 'precision': 0.9384025548333733, 'recall (sensitivity)': 0.9200151555973243, 'specificity': 0.9942422343966055, 'f1-score': 0.927164823757119}


 95%|█████████▌| 19/20 [23:12<01:13, 73.07s/it]

Epoch 19/20
 Train Loss: 0.6200 | Train Acc: 0.9690
 Val   Loss: 0.6853 | Val Acc: 0.9441
 Metrics: {'accuracy': 0.9440993788819876, 'precision': 0.9342609956554283, 'recall (sensitivity)': 0.9296467899851546, 'specificity': 0.9947775359830415, 'f1-score': 0.9308684607649184}


100%|██████████| 20/20 [24:25<00:00, 73.28s/it]

Epoch 20/20
 Train Loss: 0.6112 | Train Acc: 0.9733
 Val   Loss: 0.6885 | Val Acc: 0.9401
 Metrics: {'accuracy': 0.9401064773735581, 'precision': 0.9282956590286501, 'recall (sensitivity)': 0.9311637281224724, 'specificity': 0.994429929792325, 'f1-score': 0.9291670428656603}



