<a href="https://colab.research.google.com/github/AnandDaksh/Crop-Disease-Classification/blob/main/(ConvNeXt)_Tomato_Plant_Village_Dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
pip install timm

Collecting timm
  Downloading timm-0.9.16-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: timm
Successfully installed timm-0.9.16


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sns
from timm.models.convnext import convnext_small  # Import the ConvNeXt model

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load Tomato dataset (assuming the structure is similar to CIFAR-10)
train_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/Plant Village Datasets/Tomato/Train', transform=transform)
test_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/Plant Village Datasets/Tomato/Test', transform=transform)

# Define data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

# Define classes
classes = ["Tomato_Bacterisl_Spot","Tomato_Early_blight","Tomato_late_blight","Tomato_leaf_mold","Tomato_Septoria_leaf_spot","Tomato_Spider_Mites","Tomato_Target_Spot","Tomato_tomato_mosaic", "Tomato_healthy"]

# Define function to plot sample images
def plot_sample(X, y, index):
    plt.figure(figsize=(15, 2))
    plt.imshow(np.transpose(X[index], (1, 2, 0)))
    plt.xlabel(classes[y[index]])
    plt.show()

# Plot sample images
images, labels = next(iter(train_loader))
plot_sample(images.numpy(), labels.numpy(), 5)
plot_sample(images.numpy(), labels.numpy(), 2)
plot_sample(images.numpy(), labels.numpy(), 0)
plot_sample(images.numpy(), labels.numpy(), 8)

# Create ConvNeXt model
model = convnext_small(pretrained=False, num_classes=len(classes))  # Adjust num_classes accordingly

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train the model
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

# Test the model
model.eval()
correct = 0
total = 0
predictions = []
true_labels = []
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        predictions.extend(predicted.tolist())
        true_labels.extend(labels.tolist())
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Accuracy on test set: {accuracy:.2%}")

# Plot confusion matrix
conf_matrix = confusion_matrix(true_labels, predictions)
accuracy_percentage = accuracy * 100

plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues")
plt.xlabel(f'Predicted labels\nAccuracy: {accuracy_percentage:.2f}%')
plt.ylabel(f'True labels\nAccuracy: {accuracy_percentage:.2f}%')
plt.title(f'Confusion Matrix\nAccuracy: {accuracy_percentage:.2f}%')
plt.show()


FileNotFoundError: Found no valid file for the classes Tomato_Leaf_Mold, Tomato_Septoria_leaf_spot, Tomato_Spider_mites_Two_spotted_spider_mite, Tomato__Target_Spot, Tomato__Tomato_YellowLeaf__Curl_Virus, Tomato__Tomato_mosaic_virus. Supported extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp