# 📦 Imports

In [1]:
import os
import cv2
import numpy as np
from sklearn.preprocessing import LabelEncoder
import torch
from torch.utils.data import Dataset, DataLoader

# 🧾 Custom Dataset Class

In [2]:
class ImageDataset(Dataset):
    def __init__(self, root_folder):
        self.images = []
        self.labels = []
        self.label_encoder = LabelEncoder()
        self._load_data(root_folder)
        self.labels = self.label_encoder.fit_transform(self.labels)

    def _load_data(self, root_folder):
        print("Loading images from:", root_folder)
        for class_name in os.listdir(root_folder):
            class_path = os.path.join(root_folder, class_name)
            if not os.path.isdir(class_path):
                continue
            for image_name in os.listdir(class_path):
                image_path = os.path.join(class_path, image_name)
                img = cv2.imread(image_path)
                if img is None:
                    continue
                try:
                    img = cv2.resize(img, (128, 128))  # Smaller size for low RAM
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    img = img / 255.0
                    self.images.append(img.astype(np.float32))
                    self.labels.append(class_name)
                except:
                    continue

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = torch.tensor(self.images[idx]).permute(2, 0, 1)  # HWC to CHW
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return image, label

# 📁 Load dataset

In [3]:
data_path = r"C:\Users\DHRUV\Downloads\Telegram Desktop\time for cnn\split_data"
dataset = ImageDataset(data_path)

Loading images from: C:\Users\DHRUV\Downloads\Telegram Desktop\time for cnn\split_data


# 🧠 CNN Model

In [4]:
import torch.nn as nn

class MyCNN(nn.Module):
    def __init__(self):
        super(MyCNN, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 16 * 16, 128), nn.ReLU(),
            nn.Linear(128, 2)  # Output: 0 or 1
        )

    def forward(self, x):
        return self.model(x)

# ⚙️ Training Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MyCNN().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
batch_size = 16
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)


# 🏋️ Train the Model

In [5]:
epochs = 10
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_value_(model.parameters(), 1.0)
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")



Epoch 1/10, Loss: 0.4294
Epoch 2/10, Loss: 0.2418
Epoch 3/10, Loss: 0.1579
Epoch 4/10, Loss: 0.1075
Epoch 5/10, Loss: 0.0850
Epoch 6/10, Loss: 0.0617
Epoch 7/10, Loss: 0.0589
Epoch 8/10, Loss: 0.0467
Epoch 9/10, Loss: 0.0248
Epoch 10/10, Loss: 0.0301


# 💾 Save model

In [6]:
model_path = "cnn_disease_model.pth"
torch.save(model.state_dict(), model_path)
print(f"✅ Model saved at: {model_path}")

✅ Model saved at: cnn_disease_model.pth


# 🔍 Example: Predicting one image

In [8]:
model.eval()
with torch.no_grad():
    sample, _ = dataset[-1]
    sample = sample.unsqueeze(0).to(device)
    output = model(sample)
    prediction = torch.argmax(output, dim=1).item()
    print(f"Predicted: {prediction} ({'Diseased' if prediction == 1 else 'Healthy'})")

Predicted: 1 (Diseased)
