In [1]:
#Import libraries needed for data cleaning & dataset construction:
import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import warnings
from tqdm import tqdm
warnings.filterwarnings('ignore')
plt.ion()

In [2]:
#Assign filepaths as strings objects for convenience:
cat_dir = '/Users/JackLi/Downloads/Cats_and_Dogs_Data/PetImages/Cat'
dog_dir = '/Users/JackLi/Downloads/Cats_and_Dogs_Data/PetImages/Dog'

In [3]:
#Remove non-image files:
c = 0
d = 0
for file in tqdm(os.listdir(cat_dir)):
    path = os.path.join(cat_dir, file)
    try: 
        io.imread(path)
    except Exception as e:
        os.remove(path)
        c += 1
print('# of non-images removed from cats: ', c)

for file in tqdm(os.listdir(dog_dir)):
    path = os.path.join(dog_dir, file)
    try: 
        io.imread(path)
    except Exception as e:
        os.remove(path)
        d += 1
print('# of non-images removed from dogs: ', d)

100%|██████████| 12469/12469 [01:34<00:00, 132.13it/s]
  0%|          | 0/12464 [00:00<?, ?it/s]

# of non-images removed from cats:  0


100%|██████████| 12464/12464 [01:51<00:00, 111.97it/s]

# of non-images removed from dogs:  0





In [4]:
#Remove grayscale images:
g = 0
for file in tqdm(os.listdir(cat_dir)):
    path = os.path.join(cat_dir, file)
    image = io.imread(path)
    if image.shape[2] == 3:
        pass
    else:
        os.remove(path)
        g += 1
        
for file in tqdm(os.listdir(dog_dir)):
    path = os.path.join(dog_dir, file)
    image = io.imread(path)
    if image.shape[2] == 3:
        pass
    else:
        os.remove(path)
        g += 1

print('# of grayscale images removed from both: ', g)

100%|██████████| 12469/12469 [01:27<00:00, 141.91it/s]
100%|██████████| 12464/12464 [01:00<00:00, 205.56it/s]

# of grayscale images removed from both:  0





In [5]:
#Create dataset class & overwrite __init__, __len__, and __getitem__ functions:
class CVD(Dataset):
    def __init__(self, cat_dir, dog_dir, transform = None):
        self.cat_dir = cat_dir
        self.dog_dir = dog_dir
        self.transform = transform
        self.cat_img = os.listdir(cat_dir)
        self.dog_img = os.listdir(dog_dir)
        
    def __len__(self):
        length = len(self.cat_img) + len(self.dog_img)
        return length
    
    def __getitem__(self, idx):
        try:
            cat = False
            dog = False
            
            if torch.is_tensor(idx):
                idx = idx.to_list()
            if idx <= len(self.cat_img) - 1:
                idx = idx
                cat = True
                img_path = os.path.join(self.cat_dir, self.cat_img[idx])
            elif idx > len(self.cat_img) - 1:
                idx = idx - len(self.cat_img)
                dog = True
                img_path = os.path.join(self.dog_dir, self.dog_img[idx])

            image = io.imread(img_path)

            if cat == True:
                classification = 1
            elif dog == True:
                classification = 0

            sample = {'image': image, 'classification': classification}

            if self.transform:
                sample = self.transform(sample)

            return sample
        except Exception as e:
            pass

In [6]:
#Create Rescale class to rescale images
class Rescale(object):
    
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size
        
    def __call__(self, sample):
        image = sample['image']
        classification = sample['classification']
        
        ori_h, ori_w = image.shape[:2]
        
        if isinstance(self.output_size, tuple):
            new_h, new_w = self.output_size
        elif isinstance(self.output_size, int):
            if ori_h > ori_w:
                new_h = self.output_size * ori_h / ori_w
                new_w = self.output_size
            else:
                new_h = self.output_size
                new_w = self.output_size * ori_w / ori_h
                
        new_h, new_w = int(new_h), int(new_w)
                
        image = transform.resize(image, (new_h, new_w))
        
        return {'image': image, 'classification': classification}

In [7]:
#Create ToTensor class to convert samples into tensors
class ToTensor(object):
    
    def __call__(self, sample):
        image = sample['image']
        classification = sample['classification']
        
        image = image.transpose((2, 0, 1))
        
        return {'image' : torch.from_numpy(image), 'classification': classification}

In [8]:
#Create transformed_dataset
transformed_dataset = CVD(cat_dir, dog_dir, transform = transforms.Compose([Rescale((128, 128)), ToTensor()]))

In [None]:
#Making sure transformed_dataset works
count = {'True' : 0, 'False': 0}

for i in tqdm(range(len(transformed_dataset))):
        if transformed_dataset[i]:
            count['True'] = count['True'] + 1
        else:
            count['False'] = count['False'] + 1
print('functional: ', count['True'])
print('dysfunctional: ', count['False'])

 48%|████▊     | 11868/24933 [05:12<05:40, 38.34it/s]

In [None]:
#Load data
loaded = DataLoader(transformed_dataset, batch_size = 16, shuffle = True)