In [None]:
#from google.colab import drive
#drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn

import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torch.utils.data.dataset import Dataset, Subset
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import cv2
import copy


from torch._utils import _accumulate


In [None]:
class ObjectsDataset(Dataset):
  
  def __init__(self, dataset_dir, frames, labels_index, label_name, image_type, rgb=True):
    # image_type can be original, gray image or color-opponency. It is expected 
    # that images with these types are already available - gotten through preprocessing offline. 
    
    self.labels = labels_index
    self.dataset_dir = dataset_dir
    self.transforms = transforms
    self.image_type = image_type
    self.frames = frames
    self.label_name = label_name
    self.dataset_dir = dataset_dir
    self.rgb = rgb
    
  def __len__(self):
    return len(self.frames)
    
  def __getitem__(self, index):
    frame_path = self.frames[index]
    label_index = self.labels[index]
    #object_name = self.label_name[index]
    
    if self.rgb:
        image = cv2.imread(frame_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image/ 255.0

        image=torch.from_numpy(image.astype('float32')).permute(2, 0, 1)
    else:
        image = cv2.imread(frame_path, 0)
        image = image/ 255.0
        image=torch.from_numpy(image.astype('float32')).permute(0, 1)
    
    #image = cv2.resize(image, (1080, 720))  
    
    
    
    label=torch.from_numpy(np.asarray(label_index).astype('long'))
        
    return (image, label) 

In [None]:
import os
import re

def natural_key(string_):
    """See http://www.codinghorror.com/blog/archives/001018.html"""
    return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_)]

def get_frames(train_test_path, obj, image_type, skip, min_count):
        object_frames = []
        labels_index = []

        #Alvaro's Dataste
        if 'RODframes' in train_test_path:
            object_path = os.path.join(train_test_path, obj+'/light/'+image_type+'/')
        #My dataset
        else:
            object_path = os.path.join(train_test_path, obj+'/'+image_type+'/')
        count_index = 0

        for frame_index, frame in enumerate(os.listdir(object_path)):
            if count_index%skip == 0:
                frame_path = os.path.join(object_path, frame)
                object_frames.append(frame_path)
                labels_index.append(frame_index)
            count_index += 1
            
        if len(object_frames) < min_count:
            min_count = len(object_frames)
        sorted_object_frames = sorted(object_frames, key=natural_key)
        print('Number of images for class ',obj, ': ', len(sorted_object_frames))

        return sorted_object_frames, labels_index, min_count

def get_dataloader(image_type='original', dataset_type='Alvaro', get_test=False, skip=6, batch_size=64, shuffle=True, rgb=True, ignore=[]):

    if dataset_type=='Alvaro':
        root_path = '/content/drive/My Drive/RODframes2/RODframes/'
        classes = ['bag','beer','book','case','coffee','cup','deodorant','eraser',
                   'hole','mouse','mug','sleep','speaker','spray','stapler','tape',
                   'tea','tissues','umbrella','watch']
    else:
        root_path = '/content/drive/My Drive/Personal Dataset/Frames/'
        classes = ['Background', 'Charger', 'Coin', 'Earphones', 'Gargoyle', 'Glasses', 'Jellyfish',
                   'Key', 'Laptop', 'Pens', 'Remote', 'Wallet']
        
    for name in ignore:
        classes.remove(name)
    
    
    model_path = '/content/drive/My Drive/Colab Notebooks/RetinaSmartCamera/models/'

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    #image_path = os.path.join(root_path)
    train_path = os.path.join(root_path, 'train/')
    test_path = os.path.join(root_path, 'test/')
    
    master_frames = []
    master_labels = []
    master_frames_test = []
    master_labels_test = []
    
    frames = []
    labels = []
    frames_test = []
    labels_test = []

    errors = []
    min_count = 7200
    min_count_test = 7200
    
    for index, obj in enumerate(classes):
        print('Iterating over {}'.format(obj))
        try:
            frames_list, labels_list, min_count = get_frames(train_path, obj, image_type, skip, min_count)
            master_frames.append(frames_list)
            master_labels.append(labels_list)
            if get_test:
                frames_list, labels_list, min_count_test = get_frames(test_path, obj, image_type, skip, min_count_test)
                master_frames_test.append(frames_list)
                master_labels_test.append(labels_list)
               
        except OSError as err:
            print("OS error for object {}: {}".format(obj, err))
            errors.append((index, obj))
            continue
    


    print(errors)
    print(len(classes))
    for i in range(0, min_count):
        for object_index in range(len(classes)):
            frames.append(master_frames[object_index][i])
            #labels.append(master_labels[object_index][i])
            labels.append(object_index)
            #print(object_index)
    
    print(len(frames))
    
    dataset = ObjectsDataset(root_path, frames, labels, classes, image_type, rgb)
    
    train_dataset = dataset
    #train_size = int (0.8 * len(dataset))
    #validation_size = int ( (0.2 if get_test else 0.15) * len(dataset))
    #test_size = len(dataset) - train_size - validation_size
    #train_dataset, validation_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, validation_size, test_size],
                                                                                    #generator=torch.Generator().manual_seed(42))
    if get_test:
        for i in range(0, min_count_test):
            for object_index in range(len(classes)):
                frames_test.append(master_frames_test[object_index][i])
                labels_test.append(object_index)
        
        test_dataset = ObjectsDataset(root_path, frames_test, labels_test, classes, image_type, rgb)
        test_size = int (0.7*len(test_dataset))
        validation_size = len(test_dataset) - test_size
        test_dataset, validation_dataset = torch.utils.data.random_split(test_dataset,[test_size, validation_size],
                                                                         generator=torch.Generator().manual_seed(42))

    dataset_dict = {'train' : train_dataset, 'validation' : validation_dataset, 'test' : test_dataset}

    dataloader = {x : torch.utils.data.DataLoader(dataset_dict[x], batch_size = batch_size, shuffle = shuffle, 
                                                  num_workers = 0, drop_last=True) for x in ['train', 'test', 'validation']}
    
    dataset_sizes = {x: len(dataset_dict[x]) for x in ['train', 'test', 'validation']}
    print('Dataset size is ', dataset_sizes)

    """if get_test:
        for i in range(0, min_count_test):
            for object_index in range(len(classes)):
                frames_test.append(master_frames_test[object_index][i])
                labels_test.append(master_labels_test[object_index][i])
        
        test_dataset = ObjectsDataset(root_path, frames_test, labels_test, classes, image_type, rgb)

        test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle, 
                                                      num_workers=0, drop_last=True)
        print('Test Dataset size is ', len(test_dataset))
        return dataloader, test_dataloader"""

    return dataloader
    