In [None]:
#importing necessary libraries
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
import time

#Datasets

train_dir = '/content/drive/MyDrive/CropDiseaseDetection/Train'
valid_dir = '/content/drive/MyDrive/CropDiseaseDetection/Valid'
test_dir  = '/content/drive/MyDrive/CropDiseaseDetection/Test'

#preprocessing

train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.RandomResizedCrop(224),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

valid_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])


train_dataset = ImageFolder(train_dir, transform=train_transforms)
valid_dataset = ImageFolder(valid_dir, transform=valid_transforms)
test_dataset  = ImageFolder(test_dir, transform=test_transforms)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)

class_names = train_dataset.classes
print(f"Classes: {class_names}")
class_names1 = test_dataset.classes
print(f"Classes 2:{class_names1}")

# Model
model = models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False

num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_ftrs, 256),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(256, len(class_names))
)
#model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

# Train Function
def train_model(model, train_loader, valid_loader, criterion, optimizer, num_epochs=100):
    since = time.time()
    train_losses, valid_losses = [], []
    train_accs, valid_accs = [], []
    best_model_wts = model.state_dict()
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print('-' * 20)

        for phase in ['train', 'valid']:
            model.train() if phase == 'train' else model.eval()
            dataloader = train_loader if phase == 'train' else valid_loader
            running_loss, running_corrects = 0.0, 0

            for inputs, labels in dataloader:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_acc = running_corrects.double() / len(dataloader.dataset)

            print(f"{phase.capitalize()} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
            if phase == 'train':
                train_losses.append(epoch_loss)
                train_accs.append(epoch_acc.item())
            else:
                valid_losses.append(epoch_loss)
                valid_accs.append(epoch_acc.item())
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = model.state_dict()

    time_elapsed = time.time() - since
    print(f"\nTraining complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s")
    print(f"Best Validation Accuracy: {best_acc:.4f}")
    model.load_state_dict(best_model_wts)
    return model, train_losses, valid_losses, train_accs, valid_accs

model, train_losses, valid_losses, train_accs, valid_accs = train_model(
    model, train_loader, valid_loader, criterion, optimizer, num_epochs=100)




# Evaluation
def evaluate(model, test_loader):
    model.eval()
    preds, targets = [], []

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predictions = torch.max(outputs, 1)
            preds.extend(predictions.cpu().numpy())
            targets.extend(labels.cpu().numpy())

    print("\nClassification Report:")
    print(classification_report(targets, preds, target_names=class_names))
    cm = confusion_matrix(targets, preds)
    print("\nConfusion Matrix:")
    print(cm)
    return preds, targets, cm

preds, targets, cm = evaluate(model, test_loader)
accuracy = np.mean(np.array(preds) == np.array(targets))
print(f"\nAccuracy: {accuracy:.4f}")

# Confusion Matrix Heatmap
plt.figure(figsize=(8,6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.tight_layout()
plt.show()

# Precision, Recall, F1 Plot
precision, recall, f1, _ = precision_recall_fscore_support(targets, preds, average=None)
x = np.arange(len(class_names))
width = 0.20

plt.figure(figsize=(10,6))
plt.bar(x - width, precision, width, label='Precision')
plt.bar(x, recall, width, label='Recall')
plt.bar(x + width, f1, width, label='F1 Score')
plt.xticks(x, class_names, rotation=45)
plt.ylabel('Score')
plt.title('Precision, Recall, F1 Score per Class')
plt.legend()
plt.tight_layout()
plt.show()

