In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
from torchvision import datasets, models, transforms

import numpy as np
import matplotlib.pyplot as plt
import os
import time
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import pandas as pd
from collections import Counter
import pickle

torch.manual_seed(42)
np.random.seed(42)

In [2]:
DATA_DIR = '/kaggle/input/nutrientdeficiencysymptomsinrice/rice_plant_lacks_nutrients'
MODEL_SAVE_PATH = 'rice_mobilenet_from_scratch.pth'
CLASS_MAPPING_PATH = 'class_mapping.pkl'

IMG_SIZE = 224
NUM_CLASSES = 3 

BATCH_SIZE = 64
EPOCHS = 50
LEARNING_RATE = 0.001
EARLY_STOPPING_PATIENCE = 10 
LR_SCHEDULER_PATIENCE = 5  

In [None]:
class RiceMobileNetV2(nn.Module):
    def __init__(self, num_classes=3, from_scratch=True):
        super(RiceMobileNetV2, self).__init__()
        
        if from_scratch:
            print("Initializing model with random weights (training from scratch).")
            weights = None
        else:
            print("Initializing model with pre-trained ImageNet weights.")
            weights = models.MobileNet_V2_Weights.IMAGENET1K_V1
            
        self.mobilenet_v2 = models.mobilenet_v2(weights=weights)
        
        num_features = self.mobilenet_v2.classifier[1].in_features
        
        self.mobilenet_v2.classifier = nn.Sequential(
            nn.Dropout(p=0.3),
            nn.Linear(num_features, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(p=0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        return self.mobilenet_v2(x)

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(IMG_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Scale to [-1, 1]
    ]),
    'val': transforms.Compose([
        transforms.Resize(IMG_SIZE + 32),
        transforms.CenterCrop(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ]),
}

print("Loading and splitting data...")
full_dataset = datasets.ImageFolder(DATA_DIR)
class_names = full_dataset.classes
print(f"Classes found: {class_names}")

train_val_indices, test_indices = train_test_split(
    list(range(len(full_dataset.targets))),
    test_size=0.15,
    stratify=full_dataset.targets,
    random_state=42
)
train_indices, val_indices = train_test_split(
    train_val_indices,
    test_size=0.176, # Approx. 15% of the original dataset (0.15 / 0.85)
    stratify=[full_dataset.targets[i] for i in train_val_indices],
    random_state=42
)

train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
val_dataset = torch.utils.data.Subset(full_dataset, val_indices)
test_dataset = torch.utils.data.Subset(full_dataset, test_indices)

train_dataset.dataset.transform = data_transforms['train']
# For validation and test, we need to clone the dataset to apply the correct transform
val_dataset_transformed = torch.utils.data.Subset(datasets.ImageFolder(DATA_DIR, transform=data_transforms['val']), val_indices)
test_dataset_transformed = torch.utils.data.Subset(datasets.ImageFolder(DATA_DIR, transform=data_transforms['val']), test_indices)

dataloaders = {
    'train': torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2),
    'val': torch.utils.data.DataLoader(val_dataset_transformed, batch_size=BATCH_SIZE, shuffle=False, num_workers=2),
    'test': torch.utils.data.DataLoader(test_dataset_transformed, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
}

dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset), 'test': len(test_dataset)}
print(f"Dataset sizes: {dataset_sizes}")

In [None]:
print("Calculating class weights for handling imbalance...")
train_labels = [full_dataset.targets[i] for i in train_indices]
class_counts = Counter(train_labels)
total_samples = len(train_labels)

class_weights = torch.tensor(
    [total_samples / class_counts[i] for i in range(len(class_names))],
    dtype=torch.float32
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class_weights = class_weights.to(device)

print(f"Calculated class weights: {class_weights}")

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

    best_acc = 0.0
    epochs_no_improve = 0
    
    model.to(device)
    print(f"Training on device: {device}")
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'train':
                history['train_loss'].append(epoch_loss)
                history['train_acc'].append(epoch_acc.cpu())
            else:
                history['val_loss'].append(epoch_loss)
                history['val_acc'].append(epoch_acc.cpu())
                
                scheduler.step(epoch_loss)

                if epoch_acc > best_acc:
                    print(f"Validation accuracy improved ({best_acc:.4f} --> {epoch_acc:.4f}). Saving model...")
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), MODEL_SAVE_PATH)
                    epochs_no_improve = 0
                else:
                    epochs_no_improve += 1
        
        if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
            print(f"\nEarly stopping triggered after {epoch+1} epochs.")
            break
        print()

    time_elapsed = time.time() - since
    print(f'\nTraining complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')
    
    model.load_state_dict(torch.load(MODEL_SAVE_PATH))
    return model, history

In [None]:
model_ft = RiceMobileNetV2(num_classes=NUM_CLASSES, from_scratch=True)

criterion = nn.CrossEntropyLoss(weight=class_weights)

optimizer = optim.Adam(model_ft.parameters(), lr=LEARNING_RATE)

scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=LR_SCHEDULER_PATIENCE, verbose=True)

best_model, history = train_model(model_ft, criterion, optimizer, scheduler, num_epochs=EPOCHS)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
fig.suptitle('PyTorch Model Training History')

ax1.plot(history['train_loss'], label='Train Loss')
ax1.plot(history['val_loss'], label='Validation Loss')
ax1.set_title('Loss vs. Epochs')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True)

ax2.plot(history['train_acc'], label='Train Accuracy')
ax2.plot(history['val_acc'], label='Validation Accuracy')
ax2.set_title('Accuracy vs. Epochs')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy')
ax2.legend()
ax2.grid(True)

plt.savefig('training_curves_pytorch.png')
plt.show()

In [None]:
print("\n--- Evaluating on Test Set ---")
best_model.eval() 
y_true = []
y_pred = []

with torch.no_grad():
    for inputs, labels in dataloaders['test']:
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        outputs = best_model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(predicted.cpu().numpy())

print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=class_names))

cm = confusion_matrix(y_true, y_pred)
df_cm = pd.DataFrame(cm, index=class_names, columns=class_names)
plt.figure(figsize=(8, 6))
sns.heatmap(df_cm, annot=True, fmt="d", cmap='Blues')
plt.title('Confusion Matrix on Test Set')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.savefig('confusion_matrix_pytorch.png')
plt.show()

In [None]:
class_mapping = {i: name for i, name in enumerate(class_names)}

with open(CLASS_MAPPING_PATH, 'wb') as f:
    pickle.dump(class_mapping, f)
    
print(f"✅ Class mapping saved to: {CLASS_MAPPING_PATH}")
print("Mapping details:", class_mapping)