In [44]:
import numpy as np
import pandas as pd
import torch
import pickle, glob
from PIL import Image
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Compose, Normalize, PILToTensor
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import DataLoader
from rgb_stacking.utils.get_data import KEYS

In [97]:
class CustomDataset(Dataset):
    def __init__(self, examples, img_transform=None, target_transform=None):
        self.examples = examples
        self.transform = img_transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        fl, fr, bl, pose = self.examples[idx]
        images = torch.stack( [ self.transform(img) for img in [fl, fr, bl]] )
        label = torch.from_numpy( pose )
        return images, label

In [69]:
def load_data():
    dfs = glob.glob('../data/*csv')
    batch = []

    get_file = lambda pov, idx, rank : f'../data/images/IMG_{pov}_{idx}_{rank}.png'

    for df_file in dfs:
        df = pd.read_csv(df_file)
        rank = df_file.split('/')[-1][:-4].split('_')[-1]
        for i, id in enumerate(df['id']):
            images = [ torch.from_numpy( np.array( Image.open( get_file(pov, id, rank)) ).astype(float) ) for pov in ['fl', 'fr', 'bl']]
            batch.append( images )
            batch[-1].append( np.array([float(df[k][i]) for k in KEYS], float) )

    return batch


In [114]:
def view(batch, label):
    fl, fr, bl = batch

    fl = Image.fromarray( fl.cpu().to(torch.uint8).numpy() )
    fr = Image.fromarray( fr.cpu().to(torch.uint8).numpy())
    bl = Image.fromarray( bl.cpu().to(torch.uint8).numpy() )

    fl.show('fl')
    fr.show('fr')
    bl.show('bl')

    print(','.join( f'{k}={label[i]}' for i, k in enumerate(KEYS)))
    return fl, fr, bl

In [71]:
examples = load_data()

In [72]:
sz = len(examples)
train_sz, valid_sz = int(0.7*sz), int(0.9*sz)

In [73]:
train_dt, valid_dt, test_dt = examples[:train_sz], examples[train_sz:valid_sz], examples[valid_sz:]

In [74]:
img_transform = Normalize(0.1307, 0.3081)
target_transform = ToTensor()

In [100]:
train_ds, valid_ds, test_ds = CustomDataset(train_dt, img_transform, target_transform), \
                              CustomDataset(valid_dt, img_transform, target_transform), \
                              CustomDataset(test_dt, img_transform, target_transform)

In [101]:
train_dataloader = DataLoader( train_ds, batch_size=64, shuffle=True)
valid_dataloader = DataLoader(valid_ds, batch_size=64, shuffle=False)
test_dataloader = DataLoader(test_ds, batch_size=64, shuffle=True)

In [102]:
train_features, train_labels = next(iter(train_dataloader))

In [118]:
fl, fr, bl = view((train_features[0][0], train_features[0][1], train_features[0][2]) , train_labels[0])

rX=0.72289574,rY=-0.104483284,rZ=0.07934863,rQ1=0.11600672,rQ2=0.77320856,rQ3=-0.5900871,rQ4=-0.20121671,bX=0.70034885,bY=-0.013099843,bZ=0.07502708,bQ1=0.72762287,bQ2=-0.5715084,bQ3=0.34906504,bQ4=-0.14864959,gX=0.5397142,gY=-0.03359359,gZ=0.07382866,gQ1=1.5475927e-05,gQ2=0.99348843,gQ3=0.113933064,gQ4=-7.63775e-05
