In [14]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

In [3]:
#can use data augmentation in this line of code

In [4]:
#read dataset, it combine all three folders into one
dataset = datasets.ImageFolder(root="F:/dataset/potato_disease", 
                               transform=ToTensor())

In [9]:
#we must split into 3 separate classes, 
#because if we just randomly split into train/dev/test, they might be very skewed
early_blight = []
healthy = []
late_blight = []
for ex in range(len(dataset)):
    if dataset[ex][1] == 0:
        early_blight.append(dataset[ex])
    elif dataset[ex][1] ==1: 
        healthy.append(dataset[ex])
    else:
        late_blight.append(dataset[ex])
len(early_blight), len(healthy), len(late_blight)

(1000, 152, 1000)

In [13]:
#custom datasets that are the same as above (list) but with extra Pytorch features
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        x, y = self.data[idx]
        #change label to tensor too
        y = torch.Tensor(y, dtype=torch.long)
        return x, y

In [23]:
#change healthy, early_blight, late_blight to dataset, then split into train and test set, and then concat all three of them
## -> finally turn them into train_dataloader, test_dataloader
healthy_data = CustomDataset(healthy)
train_size = int(0.8*len(healthy_data))
test_size = len(healthy_data) - train_size
healthy_train, healthy_test = random_split(healthy_data, [train_size, test_size])

early_data = CustomDataset(early_blight)
train_size = int(0.8*len(early_data))
test_size = len(early_data) - train_size
early_train, early_test = random_split(early_data, [train_size, test_size])

late_data = CustomDataset(late_blight)
train_size = int(0.8*len(late_data))
test_size = len(late_data) - train_size
late_train, late_test = random_split(late_data, [train_size, test_size])

#concat them using ConcatDataset
train_data = torch.utils.data.ConcatDataset([healthy_train, early_train, late_train])
test_data = torch.utils.data.ConcatDataset([healthy_test, early_test, late_test])

#turn them into DataLoader
train_dataloader = DataLoader(train_data, batch_size = 32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size = 32, shuffle=False)