In [None]:
#from google.colab import drive
#drive.mount('/content/drive')

In [None]:
#Standard imports
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

import cv2
import os

In [None]:
#Dataset object used to load the images
class ObjectsDataset(Dataset):
  def __init__(self, frames, labels_index, rgb=True):
    #Store the frame paths and the corresponding labels
    self.labels = labels_index
    self.frames = frames
    self.rgb = rgb
    
  def __len__(self):
    return len(self.frames)
    
  def __getitem__(self, index):
    #Access the current frame and label
    frame_path = self.frames[index]
    label_index = self.labels[index]
    
    #Load images and turn to tensors
    if self.rgb:
        image = cv2.imread(frame_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image/ 255.0

        image=torch.from_numpy(image.astype('float32')).permute(2, 0, 1)
    else:
        image = cv2.imread(frame_path, 0)
        image = image/ 255.0
        image=torch.from_numpy(image.astype('float32')).permute(0, 1)

    #Return frame along with label
    label=torch.from_numpy(np.asarray(label_index).astype('long'))
    return (image, label) 

In [None]:
def get_dataloader(image_type='cortical', personal=True, skip=8, batch_size=8, shuffle=True, rgb=True):
    """
    Function used to get the dataset dataloader

    Parameters
    ----------
    image_type: str
        Used to determine the image type. Needs to be one of 'original', 'original_corticalimages', 
        'original_fixationcrop', or 'original_retinalimages'
    personal: bool
        Used to determine which dataset to use. 'True' will use my personal dataset, while any other will default to
        Alvaro's dataset
    skip: int
        Used to trim the dataset. The resulting length of the dataset is the full dataset length divided by skip
        For example, skip=8 will skip every 8 frames and save the 9th.
    batch_size: int
        The batch size of the final dataloader
    shuffle: bool
        Determine the dataloader will shuffle or not
    rgb: bool
        Determine whether to use RGB images or grayscale images
    """
    
    if personal:
        #Path to my dataset
        train = '/content/drive/My Drive/Personal Dataset/Frames/train/'
        test = '/content/drive/My Drive/Personal Dataset/Frames/test/'
        classes = ['Background', 'Charger', 'Coin', 'Gargoyle', 'Glasses', 
                   'Jellyfish', 'Key', 'Laptop', 'Pens', 'Remote', 'Wallet']
    else:
        #Path to Alvaro's dataset
        train = '/content/drive/My Drive/RODframes2/RODframes/train/'
        test = '/content/drive/My Drive/RODframes2/RODframes/test/'
        classes = ['bag','beer','book','case','coffee','cup','deodorant','eraser',
                   'hole','mouse','mug','sleep','speaker','spray','stapler','tape',
                   'tea','tissues','umbrella','watch']

    #Store the frames and labels here
    train_frames = []
    train_labels = []
    test_frames = []
    test_labels = []
    #Iterate over the classes
    for i, obj in enumerate(classes):
        #Determine the path
        if personal:
            obj_train = os.path.join(train, obj, image_type)
            obj_test = os.path.join(test, obj, image_type)
        else:
            obj_train = os.path.join(train, obj, 'light', image_type)
            obj_test = os.path.join(test, obj, 'light', image_type)
        #Get the absolute paths of the frames
        obj_frame_paths = [os.path.join(obj_train, frame) for frame in os.listdir(obj_train)[::skip]]
        #Save the current class's frame path and corresponding labels
        train_frames += obj_frame_paths
        train_labels += [i]*len(obj_frame_paths)

        #Do the same for the test set
        obj_frame_paths = [os.path.join(obj_test, frame) for frame in os.listdir(obj_test)[::skip]]
        test_frames += obj_frame_paths
        test_labels += [i]*len(obj_frame_paths)

    #Load the Dataset 
    train_dataset = ObjectsDataset(train_frames, train_labels, rgb)
    test_dataset = ObjectsDataset(test_frames, test_labels, rgb)

    #Turn to Dataloaders and return
    train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = shuffle, 
                                              num_workers = 0, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = shuffle, 
                                              num_workers = 0, drop_last=True)
    return {'train': train_loader, 'test': test_loader}
    
