In [5]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
 #   transforms.Resize(256),             # they are already resized in processing
    transforms.CenterCrop(227),       # AlexNet input size is 227
    transforms.ToTensor(),             
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalization
])

dataset = datasets.ImageFolder('color', transform=transform)

train_size = int(0.8 * len(dataset))
validation_size = len(dataset) - train_size
train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=32, shuffle=False)


In [6]:
from torchvision import models
#pretrained alexnet model to test
model = models.alexnet(pretrained=True)

model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 38)


In [7]:
import torch.optim as optim

model = model.to(device)

criterion = torch.nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()

for i, (inputs, labels) in enumerate(train_loader):
    inputs, labels = inputs.to(device), labels.to(device)
    optimizer.zero_grad()

    outputs = model(inputs)
    loss = criterion(outputs, labels)

    loss.backward()
    optimizer.step()

    if (i + 1) % 10 == 0:
        print(f'Batch {i + 1}, Loss: {loss.item()}')

torch.save(model.state_dict(), 'alexnet_plantvillage.pth')
print('Training complete and model saved')

Batch 10, Loss: 4.610135555267334
Batch 20, Loss: 3.189448833465576
Batch 30, Loss: 5.520407199859619
Batch 40, Loss: 3.528944730758667
Batch 50, Loss: 3.413517475128174
Batch 60, Loss: 3.371650218963623
Batch 70, Loss: 3.4372010231018066
Batch 80, Loss: 3.3828155994415283
Batch 90, Loss: 3.4647464752197266
Batch 100, Loss: 3.5136120319366455
Batch 110, Loss: 3.2961344718933105
Batch 120, Loss: 3.467622995376587
Batch 130, Loss: 3.4193756580352783
Batch 140, Loss: 3.2898266315460205
Batch 150, Loss: 3.22263765335083
Batch 160, Loss: 3.402648448944092
Batch 170, Loss: 3.6819252967834473
Batch 180, Loss: 3.399836301803589
Batch 190, Loss: 3.2841310501098633
Batch 200, Loss: 3.5595831871032715
Batch 210, Loss: 3.2257673740386963
Batch 220, Loss: 3.3020200729370117


KeyboardInterrupt: 

In [None]:
def evaluate_model():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in validation_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Accuracy: {100 * correct / total}%')

evaluate_model()
