In [3]:
import os
import re
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class ImageDataset(Dataset):
    def __init__(self, txt_loc, transform=None):
        print("Initializing image dataset.")
        self.image_paths = self.find_full_paths(txt_loc)
        self.transform = transform
        print(f"Total images found: {len(self.image_paths)}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        img_path = self.image_paths[index]
        try:
            image = Image.open(img_path)
            if self.transform:
                image = self.transform(image)
            print(f"Loaded image {index + 1}/{len(self.image_paths)}: {img_path}")
            return image
        except IOError:
            print(f"Error loading image: {img_path}")
            return None  # Or handle error differently

    def find_file_by_suffix(self, directory, filename_suffix):
        for root, dirs, files in os.walk(directory):
            for filename in files:
                if filename.endswith(filename_suffix):
                    return os.path.join(root, filename)
        return None

    def find_full_paths(self, txt_loc):
        data_dirs = os.listdir("color/")
        final_paths = []
        with open(txt_loc, 'r') as infile:
            lines = [line.strip() for line in infile.readlines()]
        
        for line in lines:
            parts = re.split(r'(?<=\.JPG)\s*', line)
            filename_suffix = parts[0]
            file_location = f'color/{data_dirs[int(parts[1]) - 1]}/'
            full_path = self.find_file_by_suffix(file_location, filename_suffix)
            if full_path:
                final_paths.append(full_path)
            else:
                print(f"File not found: {filename_suffix} in {file_location}")
        return final_paths

# Example use case
transform = transforms.Compose([
    transforms.ToTensor(),
])
train_dataset = ImageDataset(txt_loc='./train.txt', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)

print(len(train_dataset))

test_dataset = ImageDataset(txt_loc='./test.txt', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True, num_workers=4)

print(len(test_dataset))


Initializing image dataset.
File not found: UF.GRC_BS_Lab Leaf 0381.JPG in color/Tomato___Bacterial_spot/
Total images found: 8498
8498


In [2]:
# transform = transforms.Compose([
#     transforms.CenterCrop(227),
#     transforms.ToTensor(),             
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])


In [3]:
from torchvision import models

model = models.alexnet(pretrained=True)
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 38)

In [4]:
import torch.optim as optim

model = model.to(device)

criterion = torch.nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()

for i, (inputs, labels) in enumerate(train_loader):
    inputs, labels = inputs.to(device), labels.to(device)
    optimizer.zero_grad()

    outputs = model(inputs)
    loss = criterion(outputs, labels)

    loss.backward()
    optimizer.step()

    if (i + 1) % 10 == 0:
        print(f'Batch {i + 1}, Loss: {loss.item()}')

torch.save(model.state_dict(), 'alexnet_plantvillage.pth')
print('Training complete and model saved')

  img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))


Batch 10, Loss: 2.7607421875
Batch 20, Loss: 3.140488624572754


KeyboardInterrupt: 

In [5]:
def evaluate_model():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in validation_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Accuracy: {100 * correct / total}%')

evaluate_model()


Accuracy: 12.807292146211214%
