In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import random
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms as T
from torch.utils.data import DataLoader, Dataset, random_split
# Imagenes
import PIL
from PIl import Image


# GPU

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

Dataset and Dataloader

In [None]:
PATH = r''
TRAIN_PATH = r''
TRAIN_MASKS_PATH = r''
TEST_PATH = r''

Dataset

In [None]:
class Car_Dataset(Dataset):
    def __init__(self, data_paths_list, masks_paths_list=None, img_transforms=None, mask_transforms=None):
        '''
        data_paths_list - train data paths
        masks_paths_list - train masks paths
        '''
        self.train_data = data_paths_list # X
        self.train_masks = masks_paths_list # y

        self.img_transforms = img_transforms
        self.mask_transforms = mask_transforms

        # Lista de imagenes a entrenar
        # ordenada para que esten en el mismo orden que las mascaras
        self.images = []
        for dir in self.train_data:
            self.images.extend(sorted(os.listdir(dir)))

        self.masks = []
        for dir in self.train_masks:
            self.masks.extend(sorted(os.listdir(dir)))

    def __len__(self):
        # Asegurarnos que sea el mismo numero de imagenes que de mascaras, habiendo mascaras
        if self.train_masks is not None:
            assert len(self.images) == len(self.masks), f'El numero de imagenes en {self.train_masks} y {self.train_data} no coincide'
        return len(self.images)


    def __getitem__(self, idx):
        '''
        Regresa un elemento del dataset en la posición de index
        Con iter se pueden iterar todos los elementos del dataset

        '''
        image_name = os.path.join(self.train_data, self.images[idx])
        img = Image.open(image_name)

        # Aplicar transformaciones 
        trans = T.ToTensor() # Defaul transformation
        if self.img_transforms is None:
            img = self.img_transforms(img)
        else:
            img = trans(img)
        
        if self.train_masks is None:
            return img
        else: 
            mask_name = os.path.join(self.train_masks, self.masks[idx])
            mask = Image.open(mask_name)
            # Aplicar transformaciones 
            if self.mask_transforms is not None:
                mask = self.mask_transforms(mask)
            else:
                mask = trans(mask)
            # Normalizamos la mascara para que sean zeros y unos
            mask_max = mask.max().item()
            mask /= mask_max
            
        # La mascara si existe
        return img, mask

# Transformaciones

In [None]:
transform_data = T.Compose([
    T.Resize([224,224]),
    T.ToTensor()
])

# Dataloaders

In [None]:
full_dataset = Car_Dataset(TRAIN_PATH, TRAIN_MASKS_PATH,
                           img_transforms=transform_data,
                           mask_transforms=transform_data)

In [None]:
len(full_dataset)

In [None]:
BATCH_SIZE = 32
TRAIN_SIZE = int(len(full_dataset)*0.8)
VAL_SIZE = len(full_dataset) - TRAIN_SIZE
print(TRAIN_SIZE, VAL_SIZE)

In [None]:
train_dataset, val_dataset = random_split(full_dataset, [TRAIN_SIZE, VAL_SIZE])
print(len(train_dataset), len(val_dataset))

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, suffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, suffle=True)

In [None]:
imgs, masks_paths_list = next(iter(train_loader))
print(imgs.shape, masks_paths_list.shape)

# Lets see the data

In [None]:
for i, (x,y) in enumerate(train_loader):
    print(i, x.shape, y.shape)

In [None]:
imgs, masks_paths_list = next(iter(train_loader))
def plot_mini_batch(imgs, masks):
    plt.figure(figsize=(20,10))
    for i in range(BATCH_SIZE):
        # Filas, columnas y posición
        plt.subplot(4, 8, i+1)
        # De la posición i, pasa todos los canales
        # tensor = imgs[i,...]
        # Canales en la ultima posición (alto, ancho, canal)
        img=imgs[i,...].permute(1,2,0).numpy() 
        mask = masks[i,...].permute(1,2,0).numpy()
        # Muestra la imagen utilizando matplotlib.
        plt.imshow(img)
        # Muestra la máscara encima de la imagen con transparencia (alpha=0.5),
        # permitiendo visualizar tanto la imagen como la máscara simultáneamente.
        plt.imshow(mask, alpha=0.5)
        # Desactiva los ejes para una visualización más limpia.
        plt.axis('Off')
    # Ajusta el layout de la figura para evitar solapamientos entre las subfiguras.
    plt.tight_layout()
    plt.show()

plot_mini_batch(imgs, masks_paths_list)