#start


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define CNN
class TomatoCNN(nn.Module):
    def __init__(self, num_classes):
        super(TomatoCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(32 * 56 * 56, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# Data transforms
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load datasets
output_dir = '/content/drive/MyDrive/Tomato_dataset/cnn_crops'
datasets = {
    'ripe': {
        'train': datasets.ImageFolder(f'{output_dir}/ripe/train', transform=train_transform),
        'val': datasets.ImageFolder(f'{output_dir}/ripe/val', transform=val_transform)
    },
    'green': {
        'train': datasets.ImageFolder(f'{output_dir}/green/train', transform=train_transform),
        'val': datasets.ImageFolder(f'{output_dir}/green/val', transform=val_transform)
    }
}

# Training function
def train_cnn(model, train_loader, val_loader, save_path, epochs=20):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    best_acc = 0
    patience_counter = 0
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        acc = correct / total
        print(f'Epoch {epoch+1}, Loss: {train_loss/len(train_loader):.4f}, Val Acc: {acc:.4f}')
        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), save_path)
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter > 5:
                print("Early stopping triggered")
                break
    print(f'Best Val Acc: {best_acc:.4f}')

# Train and save models
for ripeness in ['ripe', 'green']:
    train_loader = DataLoader(datasets[ripeness]['train'], batch_size=16, shuffle=True)
    val_loader = DataLoader(datasets[ripeness]['val'], batch_size=16)
    model = TomatoCNN(num_classes=len(datasets[ripeness]['train'].classes))
    save_path = f'/content/drive/MyDrive/Tomato_dataset/models/{ripeness}_cnn.pt'
    train_cnn(model, train_loader, val_loader, save_path)
    print(f"{ripeness.capitalize()} CNN trained and saved to {save_path}")