In [None]:
# Importing the Libraries

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models
import kagglehub
from tqdm import tqdm

In [None]:
# Set device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

In [None]:
# Download data
path = kagglehub.dataset_download("nafishamoin/new-bangladeshi-crop-disease")
data_dir = path

In [None]:
# Data augmentation and normalization
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0), ratio=(0.9, 1.1)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.2),
    transforms.RandomRotation(degrees=15),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.3, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


In [None]:
# Load dataset
dataset = datasets.ImageFolder(root=path, transform=train_transforms)

In [None]:
# Split into train and validation sets
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# Replace transform for validation set
val_dataset.dataset.transform = val_transforms

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4) #Added num_workers=4
test_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4) #Added num_workers=4