In [94]:
import re
import torch
from torch.utils.data import Dataset

import torchvision
from torchvision import datasets, models, transforms

### Attributes description

- a1 - face containing flag: (1-with face, 0-without face),
- a2 - image number in current class (person) beginning from 0,
- a3 - class (person) number beginning from 0,
- a4 - sex (0 - woman, 1 - man)
- a5 - race (0- white, 1 - negro, 2 - indian, ...)
- a6 - age (0 - baby, 1 - young, 2 - middle-age, 3 - old) 
- a7 - binokulars (0 - without, 1 - transparent, 2 - dark)
- a8 - emotional expression (not state!) (0 - sad, 1 - neutral, 2 - happy)

In [95]:
IMAGE_SIZE = 24

dataset_dict = {
  "pixel_values": list(),
  "face": list(),
  "image_number": list(),
  "class": list(),
  "sex": list(),
  "race": list(),
  "age": list(),
  "binoculars": list(),
  "emotion": list(),
}

In [96]:
def append_attribute(attributes, index, attr_name):
  dataset_dict[attr_name].append(int(attributes[index]))

def read_file(filename):
  with open(filename, 'r') as file:
    lines = file.readlines()
    
    # Remove newLines
    for i, line in enumerate(lines):
      lines[i] = line.replace('\n', '')
    
    # We assume these are integers
    EXAMPLES_NR = int(lines[0])
    PIXELS_NR = int(lines[1])
    
    examples_raw = lines[2:EXAMPLES_NR]
    
    for i, example_raw in enumerate(examples_raw):
      # Split by spaces (treats multiple as one)
      tokens = re.split('\s+', example_raw)
      
      pixel_values = tokens[0:PIXELS_NR]
      attributes = tokens[PIXELS_NR:]       
        
      for j, el in enumerate(pixel_values):
        pixel_values[j] = float(el)
        
      pixel_values = torch.Tensor(pixel_values).reshape([IMAGE_SIZE, IMAGE_SIZE])
      
      # Add everything to the dataset dictionary
      dataset_dict["pixel_values"].append(pixel_values)

      
      append_attribute(attributes, 0, "face")
      append_attribute(attributes, 1, "image_number")
      append_attribute(attributes, 2, "class")
      append_attribute(attributes, 3, "sex")
      append_attribute(attributes, 4, "race")
      append_attribute(attributes, 5, "age")
      append_attribute(attributes, 6, "binoculars")
      append_attribute(attributes, 7, "emotion")
    

### Building Dataset

In [97]:
class DataFromDict(Dataset):
  def __init__(self, input_dict):
    self.input_dict = input_dict
    self.input_keys = list(input_dict.keys())
    print(self.input_keys)

  def __len__(self):
    return len(self.input_keys)

  def __getitem__(self, idx):
    # print(self.input_dict['age'][idx])
    pixel_values = self.input_dict['pixel_values'][idx]
    face = self.input_dict['face'][idx]
    image_number = self.input_dict['image_number'][idx]
    sex = self.input_dict['sex'][idx]
    race = self.input_dict['race'][idx]
    age = self.input_dict['age'][idx]
    binoculars = self.input_dict['binoculars'][idx]
    emotion = self.input_dict['emotion'][idx]
    label = self.input_dict['class'][idx]
    
    # return pixel_values
    return pixel_values, face, image_number, sex, race, age, binoculars, emotion, label

In [98]:
read_file('./data/x24x24.txt')

In [99]:
dataset = DataFromDict(dataset_dict)

['pixel_values', 'face', 'image_number', 'class', 'sex', 'race', 'age', 'binoculars', 'emotion']


In [100]:
# next(iter(dataset))

### Define dataloaders

In [102]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [103]:
for i, el in enumerate(train_dataloader):
  print(i)
  
  if i < 5:
    print(el)

0
[tensor([[[-0.0721,  0.6407,  0.5921,  ...,  0.4187,  0.4654,  0.5425],
         [ 0.2845,  0.4655,  0.5008,  ...,  0.4860,  0.4353,  0.4294],
         [ 0.2460,  0.4757,  0.6555,  ...,  0.1881,  0.4122,  0.3811],
         ...,
         [ 0.3619,  0.5048,  0.0828,  ...,  0.2913,  0.2736,  0.1899],
         [ 0.5527,  0.7342,  0.6612,  ...,  0.2455,  0.2476,  0.3162],
         [ 0.7170,  0.6697,  0.7091,  ...,  0.3148,  0.1952,  0.3329]],

        [[ 1.7617,  1.1200,  0.6426,  ...,  0.2532,  0.2933,  0.2714],
         [ 1.2360,  0.8174,  0.7113,  ...,  0.3278,  0.3507,  0.3192],
         [ 0.9729,  0.8133,  0.7463,  ...,  0.2771,  0.3234,  0.3331],
         ...,
         [ 0.4834,  1.1649,  0.9740,  ...,  0.3002,  0.2615,  0.2853],
         [ 0.1675,  1.5068,  0.9740,  ...,  0.2870,  0.2554,  0.2547],
         [ 0.1522,  1.4989,  0.7292,  ...,  0.2906,  0.2328,  0.2409]],

        [[ 0.6180,  0.5224,  0.6187,  ...,  0.3010,  0.3646,  0.3789],
         [ 0.4910,  0.4919,  0.5625,  ...,