In [5]:
import torch
from torch import nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import models, transforms

import optuna
import matplotlib.pyplot as plt
import numpy as np
import scipy.io
import os
import random
import cv2
import copy
import time
import math
from tqdm import tqdm
from functools import partial


plt.ion()

<matplotlib.pyplot._IonContext at 0x7f445af78640>

In [4]:
def seed_all(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    
    
def data_map(path):
    dmap = {}
    for root, dirs, files in os.walk(path):
        if files:
            k = os.path.split(root)[-1]
            k = int(k.split('_')[0])
            dmap[k] = np.array([os.path.join(root, f) for f in files])
            
    return dmap


def get_mean_std(dmap):
    resize = transforms.Resize((64, 64))
    n_samples = sum([len(v) for v in dmap.values()])
    images = torch.zeros((n_samples, 3, 64, 64), dtype=torch.float32)
    
    idx = 0
    for samples in dmap.values():
        for path in samples:
            img = cv2.imread(path)[:, :, ::-1].copy()
            img = torch.tensor(img)
            img = torch.permute(img, (2, 0, 1))
            img = resize(img)
            images[idx] = img
            idx += 1
    
    images = images.reshape(3, -1)
    return images.mean(dim=-1), images.std(dim=-1)


def kfold_splitter(dmap, *, k, shuffle=False):
    
    if shuffle:
        dmap_new = {}
        for cls, samples in dmap.items():
            temp = samples.copy()
            np.random.shuffle(temp)
            dmap_new[cls] = temp
            
        dmap = dmap_new
        
    num_samples_in_fold_per_cls = {cls: math.ceil(len(samples) / k) for cls, samples in dmap.items()}
    
    for fold in range(k):
        train_folds = {}
        test_fold = {}
        
        for cls, samples in dmap.items():
            num_in_fold = num_samples_in_fold_per_cls[cls]
            
            all_idx = np.arange(len(samples))
            idx_in = (fold * num_in_fold <= all_idx) & (all_idx < (fold + 1) * num_in_fold)
            idx_out = ~idx_in
            
            train_folds[cls] = samples[idx_out].copy()
            test_fold[cls] = samples[idx_in].copy()

        yield train_folds, test_fold

In [195]:
a = data_map('data/cifar-100_5/')

In [196]:
kf = kfold_splitter(a, k=5, shuffle=True)

In [3]:
class ImageDataSet(Dataset):
    def __init__(self, data_map, mean, std, data_transforms, preload=False):
        self.data = data_map
        self.transform = data_transforms
        self._len = sum([len(v) for v in self.data.values()])
        self.classes = list(sorted(self.data.keys()))
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.cache = None
        self.mean = mean
        self.std = std
        
        if preload:
            self.load_all()
        
    def __len__(self):
        return self._len
        
    def __getitem__(self, idx):
        if idx < self._len:
            if self.cache is None:
                image, label = self.load_item_from_disk(idx)
            else:
                image, label = self.load_item_from_ram(idx)
        
        else:
            raise IndexError(f"Index {idx:,} is out of range for dataset with length {self._len:,}")
            
        # Augment, to-tensor
        image = self.transform(image)
        image = self.normalize(image)
        label = torch.tensor(label, device=self.device)
        
        image = image.to(self.device)
        return image, label

        
    def load_all(self):
        self.cache = {}
        for idx in tqdm(range(self._len), desc='Loading data', ncols=100):
            item = self.load_item_from_disk(idx)
            self.cache[idx] = item
            
    def load_item_from_disk(self, idx):
        for cls, samples in self.data.items():
            if idx < len(samples):
                item = {'X': samples[idx], 'y': cls}
                break
                
            else:
                idx -= len(samples)
        
        image = cv2.imread(item['X'])[:, :, ::-1].copy()
        label = item['y']
        return image, label
    

    def load_item_from_ram(self, idx):
        temp, label = self.cache[idx]
        image = temp.copy()
        return image, label
    
    def normalize(self, image):
        return (image - self.mean) / (self.std + 1e-7)

In [None]:
def optimize(trial, dmap, transforms, scorer):
    
    # optuna suggestions..
    
    scores = []
    
    for idx, (train, test) in enumerate(kfold_splitter(dmap, k=3)):
        
        mean, std = get_mean_std(train)
        
        train_dset = ImageDataSet(train, mean, std, data_transforms=transforms['train'], preload=True)
        test_dset = ImageDataSet(test, mean, std, data_transforms=transforms['test'], preload=True)
        
        train_loader = DataLoader(train_dset, batch_size=?, shuffle=True, num_workers=8)
        test_loader = DataLoader(test_dset, batch_size=?, num_workers=8)
        
        model = ...  # get model
        
        fold_score = fit_and_evaluate(model, train_loader, test_loader, scorer, epochs=?)       
        scores.append(fold_score)
    
    return np.mean(scores)


def fit_and_evaluate(model, train, test, scorer, epochs):
    pass

In [10]:
20 * 10 * 3 * 50 * 2

60000

In [7]:
bad_data = []

for root, dirs, files in os.walk('data'):
    for f in files:
        fname = os.path.join(root, f)
        loaded = cv2.imread(fname)
        if loaded is None:
            print(fname)
            bad_data.append(fname)

data/vgg-cats/0_Abyssinian/34.jpg
data/vgg-cats/8_Egyptian_Mau/145.jpg
data/vgg-cats/8_Egyptian_Mau/191.jpg
data/vgg-cats/8_Egyptian_Mau/139.jpg
data/vgg-cats/8_Egyptian_Mau/177.jpg
data/vgg-cats/8_Egyptian_Mau/167.jpg


