In [2]:
from glob import glob
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from typing import List
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset
import torch

In [3]:
file_locations = glob('./Captchas/*')
captcha_names = [file.split('/')[-1].split('.')[0] for file in file_locations]
print( f'identified {len(file_locations)} images' )

identified 113062 images


In [4]:
# sample_image = torchvision.io.read_image(file_locations[0])
# torchvision.transforms.ToPILImage()(sample_image)
# Odd, my installation doesn't have io for some reason. 

In [5]:
class Captcha_Dataset(Dataset):
    def __init__(self, data, labels):
        self.X = data
        self.y = labels

    @classmethod
    def import_image(cls, location:str) -> np.ndarray:
        """
        Import a single image.

        Parameters: location (str) Location of image
        Returns: (np.ndarray) Image dimensions = Captchas 40 x 150 x 3 RGB channels
        """
        image = Image.open(location)
        image.load()
        #image.show()
        data = np.asarray(image, dtype='float32')
        return data
    @classmethod
    def stack_images(cls, file_locations:List[str]) -> np.ndarray:
        """
        Stack imageset from directory.

        Parameters: file_locations (List[str]) List of image locations
        Returns: (np.ndarray) len(file_locations) x image dimensions
        """
        return np.array([cls.import_image(location) for location in file_locations ])
    
    @classmethod
    def read_label_names(cls, file_locations:List[str]) -> List[str]:
        """
        Simply extracts labels from filenames.

        Parameters: file_locations (List[str]) List of image locations
        Returns: (List[str]) List of label names
        """
        return [file.split('/')[-1].split('.')[0] for file in file_locations]


    @classmethod
    def from_dir(cls, file_locations):
        """
        Instantiate from only a list of files.

        Parameters: file_locations (List[str]) List of image locations
        Returns: (Captcha_Dataset) object
        """
        return cls(
            cls.stack_images(file_locations),
            cls.read_label_names(file_locations)
        )

    def transform(self, image):
        """Dataset transform for loading."""
        return T.ToTensor()(image) # This is a hack for now.
        # Not sure why, but this transforming doesn't work. It's weird. Idk.
        # I originally tried using only PIL images and then resizing from there, but it didn't work.
        # Tried now going from PIL --> ndarray --> PIL --> Tensor; also doesn't work. 
        # Bit lost.
        # return  T.Compose([
        #     T.ToPILImage(),
        #     T.Resize([40, 150]),
        #     T.ToTensor()
        #     ])(image)

    def __getitem__(self, index):
        """Select one sample. DataLoader accesses samples through this function."""
        return self.transform(self.X[index]), self.y[index]
    
    def __len__(self):
        """Also needed for DataLoader."""
        return len(self.X)

In [6]:
sample = Captcha_Dataset.from_dir(file_locations[0:128])

print(f'{sample.X.shape}\nSample of 128 images in format 40px x 150px x 3 RGB channels, of type {type(sample.X[0][0][0][0])}')

(128, 40, 150, 3)
Sample of 128 images in format 40px x 150px x 3 RGB channels, of type <class 'numpy.float32'>


In [7]:
dl = DataLoader(sample, \
    64, # Fetch 64 samples per batch
    shuffle=True, num_workers=2)

In [8]:
dataiter = iter(dl)
images, labels = next(dataiter)

In [12]:
images.shape

torch.Size([64, 3, 40, 150])

In [14]:
len(labels)

64