In [None]:
# Defining transformations
# Mean and Standard Deviation values for RGB color channels have been identified separately and plugged in here.
train_transform = transforms.Compose([transforms.Resize((299,299)),
                            transforms.RandomRotation(25),
                            transforms.RandomHorizontalFlip(p=0.5),
                            transforms.RandomVerticalFlip(p=0.5),  
                            transforms.ToTensor(),
                            transforms.Normalize(mean=[0.4478926, 0.41914284 , 0.36154622], std=[0.24954137,0.23996224, 0.23252055])
                           ])

valid_transform = transforms.Compose([transforms.Resize((299,299)), 
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.4478926, 0.41914284 , 0.36154622], std=[0.24954137,0.23996224, 0.23252055])
                                    ])

In [None]:
# Extends class Dataset
# Problem: Applying transformation when loading the dataset will apply the same transformation to both the train and validation datasets.
# This is not desired and therefore a separate class has been defined which can apply different transformations to one single dataset.
class MyLazyDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, index):
        if self.transform:
            x = self.transform(dataset[index][0])
        else:
            x = dataset[index][0]
        y = dataset[index][1]
        return x, y
    
    def __len__(self):
        return len(dataset)

In [None]:
data_dir = '/content/sample_data/train_data/train'
dataset = datasets.ImageFolder(data_dir)
print(f'Length of dataset {len(dataset)}')

In [None]:
traindataset = MyLazyDataset(dataset,train_transform)
valdataset = MyLazyDataset(dataset,valid_transform)

In [None]:
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0)
indices=list(range(len(dataset)))
y_train0=[y for _,y in dataset]

In [None]:
for train_index, valid_index in sss.split(indices, y_train0):
    print("train:", train_index, "val:", valid_index)
    print(len(train_index),len(valid_index))

In [None]:
train_dataset = Subset(traindataset, indices=train_index)
valid_dataset = Subset(valdataset, indices=valid_index)

In [None]:
# Count of each target class in train and valid
y_train=[y for _,y in train_dataset]
y_valid=[y for _,y in valid_dataset]

counter_train=collections.Counter(y_train)
counter_val=collections.Counter(y_valid)
sorted_train_counter = sorted(counter_train.items())
sorted_valid_counter = sorted(counter_val.items())

In [None]:
print(f'Train : {sorted_train_counter}')
print(f'Valid : {sorted_valid_counter}')

In [None]:
plt.bar(*zip(*sorted_train_counter))

In [None]:
total = len(train_dataset)
weights = []
for item in sorted_train_counter:
    weight = total/(10*item[1])
    weights.append(weight)
print(weights)

In [None]:
# Due to imbalance in data set in class 4 we assign weights to the classes 
weights_tensor = torch.tensor(weights).to(device)

In [None]:
torch.manual_seed(0)

train_idx = indices=list(range(len(train_dataset)))
valid_idx = indices=list(range(len(valid_dataset)))

# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

In [None]:
batch_size=16
train_loader= torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=train_sampler)
valid_loader= torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=batch_size, sampler=valid_sampler)

In [None]:
# Shape of training batch
dataiter = iter(train_loader)
train_images, train_labels = dataiter.next()

print('Training data:')
print('X shape:',train_images.shape)
print('y shape:',train_labels.shape)

In [None]:
# Shape of validation batch
dataiter = iter(valid_loader)
valid_images, valid_labels = dataiter.next()

print('Validation data:')
print('X shape:',valid_images.shape)
print('y shape:',valid_labels.shape)

In [None]:
# Sample train images
imgshow(torchvision.utils.make_grid(train_images, padding=1, pad_value=0.85))
print(' '.join('%5s' % train_labels[j] for j in range(batch_size)))

In [None]:
# Sample valid images
imgshow(torchvision.utils.make_grid(valid_images, padding=1, pad_value=0.85))
print(' '.join('%5s' % valid_labels[j] for j in range(batch_size)))