# **Definizione dei Dataset**
Creiamo due classi differenti utilizzate per gestire i due dataset utilizzati dal modello: il dataset IAM a il nostro dataset con immagini etichettate (se rappresentano o meno la scrittura di un soggetto disgrafico)


In [2]:
import os
import random
from PIL import Image
import torch
# from tacobox import Taco
from torch.utils.data import Dataset
from torchvision.transforms import Compose, ConvertImageDtype, Pad, Resize, PILToTensor
%run utils.ipynb
%run path.ipynb

### **1. Dataset IAM**
Questa classe gestisce le immagini del dataset IAM per il pretraining dell'architettura, suddividendo le immagini in train, test e validation set. Viene calcolata inoltre una funzione di loss definita come 'Triplet loss'.

In [3]:
class IAMDL(Dataset):

    def __init__(self, model_set : str, path : str):
        assert model_set == 'test' or model_set == 'train' or model_set == 'validation'
        model_set += '.uttlist'
        self.path = path
        self.set = path + '/LWRT/' + model_set
        self.samples = []
        self.set_samples = self.__get_set_samples()
        self.max_width, self.max_height = get_IAM_statistics()
        self.device = 'cpu'
    
    def __len__(self):
        return len(self.set_samples)
    
    def __getitem__(self, index : int):
        return self.set_samples[index]
    
    def __get_set_samples(self):
        set_samples = []
        f = open(self.path + '/sentences.txt')

        for line in f:
            if not line or line[0]=='#':
                continue
            lineSplit = line.strip().split(' ')
            assert len(lineSplit) >= 9
            fileNameSplit = lineSplit[0].split('-')
            fileName = self.path + '/sentences/' + fileNameSplit[0] + '/' + fileNameSplit[0] + '-' + fileNameSplit[1] + '/' + lineSplit[0] + '.png'
            self.samples.append(fileName)
        
        folders = [x.strip("\n") for x in open(self.set).readlines()]

        for i in range(0, len(self.samples)):
            file = self.samples[i].split("/")[-1][:-4].strip(" ")
            folder = "-".join(file.split("-")[:-2])
            if (folder in folders): 
                set_samples.append(self.samples[i])
        return set_samples
    
    def get_triplet(self, sample):
        pos_aut = '/'.join(sample.split("/")[:-1])
        anc_img = sample.split("/")[-1]
        pos_img = random.choice([a for a in os.listdir(pos_aut)])
        while(pos_img == anc_img):
            pos_img = random.choice([a for a in os.listdir(pos_aut)])

        neg_aut = os.path.join(self.path + '/sentences', random.choice([a for a in os.listdir(self.path + '/sentences')]))
        neg_aut = os.path.join(neg_aut, random.choice([a for a in os.listdir(neg_aut)]))
        while(pos_aut == neg_aut):
            neg_aut = os.path.join(self.path + '/sentences', random.choice([a for a in os.listdir(self.path + '/sentences')]))
            neg_aut = os.path.join(neg_aut, random.choice([a for a in os.listdir(neg_aut)]))
        neg_img = random.choice([a for a in os.listdir(neg_aut)])

        anchor_img = Image.open(os.path.join(pos_aut, anc_img))
        anchor_w, anchor_h = anchor_img.size
        transform = Compose([
            PILToTensor(),
            ConvertImageDtype(torch.float),
            Pad((0, 0, self.max_width - anchor_w, self.max_height - anchor_h), fill=1.),
            Resize((128, 1024))
        ])
        anchor = transform(anchor_img)

        positive_img = Image.open(os.path.join(pos_aut, pos_img))
        positive_w, positive_h = positive_img.size
        transform = Compose([
            PILToTensor(),
            ConvertImageDtype(torch.float),
            Pad((0, 0, self.max_width - positive_w, self.max_height - positive_h), fill=1.),
            Resize((128, 1024))
        ])
        positive = transform(positive_img)

        negative_img = Image.open(os.path.join(neg_aut, neg_img))
        negative_w, negative_h = negative_img.size
        transform = Compose([
            PILToTensor(),
            ConvertImageDtype(torch.float),
            Pad((0, 0, self.max_width - negative_w, self.max_height - negative_h), fill=1.),
            Resize((128, 1024))
        ])
        negative = transform(negative_img)

        return anchor, positive, negative
    
    def batch_triplets(self, samples):
        
        batch_size = len(samples)
        anchors = torch.empty(size=(batch_size, 1, 128, 1024))
        positives = torch.empty(size=(batch_size, 1, 128, 1024))
        negatives = torch.empty(size=(batch_size, 1, 128, 1024))
        
        for batch, sample in enumerate(samples):
            anchors[batch], positives[batch], negatives[batch] = self.get_triplet(sample)
        
        return anchors.to(self.device), positives.to(self.device), negatives.to(self.device)

### **2. Dataset Dysgraphia**
Questa classe, invece, gestisce le immagini del nostro dataset etichettato per il training e test dell'architettura, suddividendo anch'esso le immagini in train, test e validation set. Inoltre, consideriamo due casi: dataset senza augmentation e dataset con augmentation (tecniche tradizionali o taco augmentation, più info nella relazione).

In [4]:
class DysgraphiaDL(Dataset):

    def __init__(self, aug :  str, model_set : str, device : str):
        assert model_set == 'train' or model_set == 'val' or model_set == 'test'
        create_simple_splits(DYSG, aug)

        self.BASE = DYSG
        self.set = model_set + f'_{aug}.txt'
        self.set_samples = self.__set_samples()
        self.max_width, self.max_height = get_base_statistics(aug)
        self.device = device
    
    def __len__(self):
        return len(self.set_samples)
    
    def __getitem__(self, index : int):
        img = Image.open(self.set_samples[index]).convert('L')
        transform = Compose([
            PILToTensor(),
            ConvertImageDtype(torch.float),
            Pad((0, 0, self.max_width - img.size[0], self.max_height - img.size[1]), fill=1.),
            Resize((192, 512))
        ])
        img = transform(img)
        

        if 'No_Dysgraphic' in self.set_samples[index].split("/")[1] : 
            label = torch.tensor(0)
        else: 
            label = torch.tensor(1)
        
        return img.to(self.device), label.to(self.device), torch.empty((1))
    
    def __set_samples(self):
        f = open(SPLIT + "/" + self.set)

        set_samples = [line.replace("\n","") for line in f]

        return set_samples
    
    def get_binary_weights(self):
        counter = [0, 0]
        for sample in self.set_samples:
            author = sample.split("/")[-2]
            label = self.labels_csv.filter(like=self.labels.upper()).loc[author].values[0]
            if label == 0: counter[0] += 1
            else: counter[1] += 1
        
        print(f"Samples per class: {counter}")
        print(f"Values: {[min(counter) / counter[0], min(counter) / counter[1]]}")
        
        return torch.tensor([min(counter) / counter[0], min(counter) / counter[1]]).to(self.device)
    
    # def preprocess(self, img, augment=True):
    #     if augment:
    #         img = self.apply_taco_augmentations(img)
            
    #     # scaling image [0, 1]
    #     img = img/255
    #     img = img.swapaxes(-2,-1)[...,::-1]
    #     target = np.ones((config.INPUT_WIDTH, config.INPUT_HEIGHT))
    #     new_x = config.INPUT_WIDTH/img.shape[0]
    #     new_y = config.INPUT_HEIGHT/img.shape[1]
    #     min_xy = min(new_x, new_y)
    #     new_x = int(img.shape[0]*min_xy)
    #     new_y = int(img.shape[1]*min_xy)
    #     img2 = cv2.resize(img, (new_y,new_x))
    #     target[:new_x,:new_y] = img2
    #     return 1 - (target)
    
    # def apply_taco_augmentations(self, input_img):
    #     random_value = random.random()
    #     if random_value <= config.TACO_AUGMENTAION_FRACTION:
    #         augmented_img = self.mytaco.apply_vertical_taco(
    #             input_img, 
    #             corruption_type='random'
    #         )
    #     else:
    #         augmented_img = input_img
    #     return augmented_img