In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from PIL import Image
import pandas as pd
import numpy as np
from pathlib import Path
import sys
import time
from sklearn.model_selection import train_test_split

In [2]:
# Re-usable Components (from previous tasks)

class PestCNN(nn.Module):
    """Implements the 3-layer convolutional neural network (Architecture C)"""
    def __init__(self):
        super(PestCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 28 * 28, 128)
        self.dropout = nn.Dropout(p=0.5)
        self.output = nn.Linear(128, 17) # 17 classes

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = torch.flatten(x, 1) # Flatten
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.output(x) # Logits
        return x

class JutePestDataset(Dataset):
    """Custom Dataset for loading jute pest images on-the-fly."""
    def __init__(self, df, transform=None, class_to_idx=None):
        self.df = df
        self.transform = transform
        self.class_to_idx = class_to_idx

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]['filepath']
        label_str = self.df.iloc[idx]['label']
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Warning: Error loading {img_path}: {e}. Using a dummy image.")
            image = Image.new('RGB', (224, 224))
            
        label = self.class_to_idx[label_str]
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

def load_data_from_folders(dataset_dir):
    """Scans a directory and loads image paths/labels."""
    image_extensions = ['.jpg', '.jpeg', '.png']
    filepaths = []
    labels = []
    if not dataset_dir.exists():
        print(f"Error: Dataset path not found: {dataset_dir}", file=sys.stderr)
        return pd.DataFrame()
        
    for class_dir in dataset_dir.iterdir():
        if class_dir.is_dir():
            class_name = class_dir.name
            for img_path in class_dir.rglob('*'):
                if img_path.is_file() and img_path.suffix.lower() in image_extensions:
                    filepaths.append(str(img_path))
                    labels.append(class_name)

    if not filepaths:
        print(f"Error: No images found in {dataset_dir}.", file=sys.stderr)
        return pd.DataFrame()

    df = pd.DataFrame({'filepath': filepaths, 'label': labels})
    df = df.sample(frac=1, random_state=RANDOM_SEED).reset_index(drop=True)
    return df

In [None]:
# Main Training Script
if __name__ == '__main__':

    # Configuration
    DATASET_PATH = Path("Jute_Pest_Dataset/train") 
    RANDOM_SEED = 42

    # Design Choices
    BATCH_SIZE = 32      
    LEARNING_RATE = 1e-3 
    NUM_EPOCHS = 3   # Initial test

    # Setup device (use GPU if available, else CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Data Loading and Splitting
    full_df = load_data_from_folders(DATASET_PATH)
    
    if full_df.empty:
        sys.exit("Data loading failed. Exiting.")

    # Create the class-to-integer mapping from the *full* dataset
    unique_classes = sorted(full_df['label'].unique())
    class_to_idx = {cls_name: i for i, cls_name in enumerate(unique_classes)}
    
    # Split the DataFrame
    X = full_df['filepath']
    y = full_df['label']
    
    X_train, X_temp, y_train, y_temp = train_test_split(
        X, y, train_size=0.70, stratify=y, random_state=RANDOM_SEED
    )
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp, test_size=0.50, stratify=y_temp, random_state=RANDOM_SEED
    )
    
    # Create final DataFrames
    train_df = pd.DataFrame({'filepath': X_train, 'label': y_train})
    val_df = pd.DataFrame({'filepath': X_val, 'label': y_val})
    test_df = pd.DataFrame({'filepath': X_test, 'label': y_test})
    
    print(f"Data split: {len(train_df)} train, {len(val_df)} val, {len(test_df)} test")

    # Preprocessing 
    # Load the pre-calculated stats
    try:
        TRAIN_MEAN = torch.load("train_mean.pt")
        TRAIN_STD = torch.load("train_std.pt")
        print("Loaded pre-calculated normalization statistics.")
    except FileNotFoundError:
        print("Warning: Statistics files not found. Using ImageNet defaults.")
        # Fallback to ImageNet stats if not calculated
        TRAIN_MEAN = torch.tensor([0.485, 0.456, 0.406])
        TRAIN_STD = torch.tensor([0.229, 0.224, 0.225])

    data_transforms = {
        'train': T.Compose([
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize(TRAIN_MEAN, TRAIN_STD)
        ]),
        'val': T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(TRAIN_MEAN, TRAIN_STD)
        ])
    }
    
    # Component 1: DataLoaders
    train_dataset = JutePestDataset(
        train_df, transform=data_transforms['train'], class_to_idx=class_to_idx
    )
    val_dataset = JutePestDataset(
        val_df, transform=data_transforms['val'], class_to_idx=class_to_idx
    )
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        num_workers=4,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=False, 
        num_workers=4,
        pin_memory=True
    )

    # Component 2: Model, Loss, Optimizer
    model = PestCNN().to(device)
    
    # Loss function (combines LogSoftmax and NLLLoss)
    criterion = nn.CrossEntropyLoss()
    
    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    print("\nStarting training...")

    # Component 3 & 4: Training & Validation Loop 
    
    # Lists to store metrics
    train_loss_history = []
    val_loss_history = []
    val_acc_history = []

    for epoch in range(NUM_EPOCHS):
        epoch_start_time = time.time()
        
        # Training Phase 
        model.train() # Set model to training mode
        running_train_loss = 0.0
        
        for inputs, labels in train_loader:
            # Move data to the correct device
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Zero the gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs)
            
            # Compute loss
            loss = criterion(outputs, labels)
            
            # Backward pass
            loss.backward()
            
            # Optimizer step
            optimizer.step()
            
            # Accumulate loss (multiply by batch size for correct averaging)
            running_train_loss += loss.item() * inputs.size(0)
            
        epoch_train_loss = running_train_loss / len(train_loader.dataset)
        train_loss_history.append(epoch_train_loss)

        # Validation Phase 
        model.eval() # Set model to evaluation mode
        running_val_loss = 0.0
        correct_preds = 0
        
        with torch.no_grad(): # Disable gradient calculation
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                
                # Forward pass
                outputs = model(inputs)
                
                # Compute loss
                loss = criterion(outputs, labels)
                running_val_loss += loss.item() * inputs.size(0)
                
                # Calculate accuracy
                _, preds = torch.max(outputs, 1)
                correct_preds += torch.sum(preds == labels.data)
                
        epoch_val_loss = running_val_loss / len(val_loader.dataset)
        epoch_val_acc = correct_preds.double() / len(val_loader.dataset)
        
        val_loss_history.append(epoch_val_loss)
        val_acc_history.append(epoch_val_acc.item()) # Store as float
        
        # Print Epoch Results 
        epoch_end_time = time.time()
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} | "
              f"Time: {(epoch_end_time - epoch_start_time):.2f}s | "
              f"Train Loss: {epoch_train_loss:.4f} | "
              f"Val Loss: {epoch_val_loss:.4f} | "
              f"Val Acc: {epoch_val_acc:.4f}")

    print("\nFinished 3-epoch test run.")