In [1]:
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
import torch.optim as optim
from torchvision.transforms import v2
from torchvision.io import read_image
import os
import glob
import random

In [2]:
BATCH_SIZE = 512
EPOCHS = 20
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
image_dirs = glob.glob("./data/Rice/*jpg")

In [4]:
def str_to_int(labels):
    # Change labels from str to int
    for index in range(len(labels)):
        if(labels[index] == 'Arborio'):
            labels[index] = 0
        elif(labels[index] == 'Basmati' or labels[index] == 'basmati'): # Some are mislabeled
            labels[index] = 1
        elif(labels[index] == 'Ipsala'):
            labels[index] = 2
        elif(labels[index] == 'Jasmine'):
            labels[index] = 3
        elif(labels[index] == 'Karacadag'):
            labels[index] = 4
        else:
            pass

    # Check that all samples are well labeled
    for label in labels:
        if label not in [0, 1, 2, 3, 4]:
            print("Some labels are not correct.")

In [5]:
# Split image_dirs
generator = torch.Generator().manual_seed(42)
train_test = random_split(image_dirs, [0.8, 0.2], generator=generator)
train_dir, test_dir = train_test[0], train_test[1]

In [6]:
# Get labels
train_labels = [os.path.split(path)[1].split(" ")[0] for path in train_dir]
test_labels = [os.path.split(path)[1].split(" ")[0] for path in test_dir]

# Process labels
str_to_int(train_labels)
str_to_int(test_labels)

In [7]:
class RiceImageDataset(Dataset):
    def __init__(self, img_dir, labels, transform=v2.Compose([v2.Resize((25, 25), antialias=True), v2.ToImage(), v2.ToDtype(torch.float32)]), target_transform=None):
        self.img_labels = labels
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        
    def __len__(self):
        return len(self.img_labels)
    
    def __getitem__(self, idx):
        img_path = self.img_dir[idx]
        image = read_image(img_path)
        label = self.img_labels[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [8]:
train_dataset = RiceImageDataset(train_dir, train_labels)
test_dataset = RiceImageDataset(test_dir, test_labels)

In [9]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [10]:
train_dataset[0]

(Image([[[0., 0., 0.,  ..., 0., 0., 0.],
         [1., 1., 0.,  ..., 0., 0., 0.],
         [2., 2., 1.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 2., 0., 0.]],
 
        [[0., 0., 0.,  ..., 0., 0., 0.],
         [1., 1., 0.,  ..., 0., 0., 0.],
         [2., 2., 1.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 2., 0., 0.]],
 
        [[0., 0., 0.,  ..., 0., 0., 0.],
         [1., 1., 0.,  ..., 0., 0., 0.],
         [2., 2., 1.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 2., 0., 0.]]], ),
 2)

In [11]:
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.LazyConv2d(6, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.LazyConv2d(16, kernel_size=5),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.LazyLinear(120),
            nn.ReLU(),
            nn.LazyLinear(84),
            nn.ReLU(),
            nn.LazyLinear(5),
        )

    def forward(self, x):
        x = self.net(x)
        x = F.log_softmax(x, dim=1)
        return x

In [12]:
model = ConvNet().to(DEVICE)
optimizer = optim.Adam(model.parameters())



In [13]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        if(batch_idx+1)%30 == 0: 
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

In [14]:
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [15]:
for epoch in range(1, EPOCHS + 1):
    train(model, DEVICE, train_loader, optimizer, epoch)
    test(model, DEVICE, test_loader)


Test set: Average loss: 0.1021, Accuracy: 14479/15000 (97%)


Test set: Average loss: 0.0792, Accuracy: 14564/15000 (97%)


Test set: Average loss: 0.0741, Accuracy: 14610/15000 (97%)


Test set: Average loss: 0.0545, Accuracy: 14705/15000 (98%)


Test set: Average loss: 0.0779, Accuracy: 14562/15000 (97%)


Test set: Average loss: 0.0512, Accuracy: 14726/15000 (98%)


Test set: Average loss: 0.6464, Accuracy: 12347/15000 (82%)


Test set: Average loss: 0.0354, Accuracy: 14819/15000 (99%)


Test set: Average loss: 0.0317, Accuracy: 14831/15000 (99%)


Test set: Average loss: 0.0310, Accuracy: 14854/15000 (99%)


Test set: Average loss: 0.0304, Accuracy: 14858/15000 (99%)


Test set: Average loss: 0.0298, Accuracy: 14859/15000 (99%)


Test set: Average loss: 0.0337, Accuracy: 14825/15000 (99%)


Test set: Average loss: 0.0275, Accuracy: 14870/15000 (99%)


Test set: Average loss: 0.0236, Accuracy: 14896/15000 (99%)


Test set: Average loss: 0.0284, Accuracy: 14857/15000 (99%)


Test se