In [14]:
# !pip install torch torchvision torchaudio pytorch-lightning
# !pip install segmentation-models-pytorch

In [15]:
import pytorch_lightning as pl
# from torch.nn import segmentation_models_pytorch
# import segmentation_models_pytorch as smp
from segmentation_models_pytorch import Unet
from torch.utils.data import Dataset
import os
import numpy as np

# STEPS
# unet -> segmentacion por zonas
# modelo que le aplica la mascara al a input image
# style transfer, aplicarles el estilo
# combinar 

In [16]:
drive_path = '/content/drive/MyDrive/dl-project/face-segmentation'
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms

class LoadData(Dataset):
  def __init__(self, path):
    super().__init__()

    self.path = path
    self.images = os.listdir(path + '/img')
    self.masks = os.listdir(path + '/masks')

  def __len__(self):
    return len(self.images)

  def __getitem__(self, index):
    image = Image.open(self.path + '/img/' + self.images[index])
    mask = Image.open(self.path + '/masks/' + self.masks[index])
    
    image = image.convert('RGB')
    mask = mask.convert('RGB')
    # reshape images and masks
    transform1 = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    image = transform1(image)
    mask = transform1(mask)

    # image = np.transpose(image, (2, 0, 1))
    # mask = np.transpose(image, (2, 0, 1))

    return image, mask


data = LoadData(drive_path)
img, mask = data[0]
# plt.imshow(img)
# plt.imshow(mask)

In [17]:
from torch.utils.data import DataLoader

class DataModule(pl.LightningDataModule):
  def __init__(self, data, batch_size):
    super().__init__()

    self.data = data
    self.batch_size = batch_size

    self.training_data = None
    self.test_data = None
    self.validation_data = None
  
  def setup(self, stage=None):
    # poner un if para test, validation o training

    train = 0.7
    validation = 0.2
    # validation = 0.1

    train = int(len(self.data) * train)
    validation = int(len(self.data) * validation)
    test = int(len(self.data) - train - validation)

    print('train', train)
    print('test', test)
    print('validation', validation)

    print('len data', len(self.data))

    self.training_data, self.validation_data, self.test_data = torch.utils.data.random_split(self.data, [train, validation, test])

  def train_dataloader(self):
    return DataLoader(self.training_data, self.batch_size, shuffle=False)

  def test_dataloader(self):
    return DataLoader(self.test_data, self.batch_size, shuffle=False)

  def val_dataloader(self):
    return DataLoader(self.validation_data, self.batch_size, shuffle=False)
  



In [18]:
from lightning_fabric.strategies.strategy import Module
import torch
from torchmetrics.classification import MulticlassF1Score

DATASET_CLASSES = 17

class ModelUnet(pl.LightningModule):
  # def __init__(self, lr=1e-3):
  #   super(ModelUnet, self).__init__()

  #   self.lr = lr
  #   self.save_hyperparameters()
  def __init__(self,  lr=1e-3, save_every_n_epoch=100,**kwargs):
    super(ModelUnet, self).__init__(**kwargs)

    self.lr = lr
    self.save_hyperparameters()
    self.loss = MulticlassF1Score(num_classes=DATASET_CLASSES)

    # channels RGB, 3
    self.unet = Unet(in_channels=3, classes=DATASET_CLASSES)
  
  def forward(self, x):
    return self.unet(x)

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
    return optimizer

  def training_step(self, batch, idx):
    # x -> input, y -> output (images)
    x, y = batch
    out = self(x)
    loss = self.loss(out, y)
    # plot
    self.log('training_loss', loss)
    return loss

  def validation_step(self, batch, idx):
    # x -> input, y -> output (images)
    x, y = batch
    out = self(x)
    loss = self.loss(out, y)
    # plot
    self.log('validation_loss', loss)
    return loss

  def test_step(self, batch, idx):
    # x -> input, y -> output (images)
    x, y = batch
    out = self(x)
    loss = self.loss(out, y)
    # plot
    self.log('test_loss', loss)
    return loss

In [19]:
data = LoadData(drive_path)
data = [x for x in data]
print(len(data))

20


In [20]:
data_modl = DataModule(data, batch_size=8)
data_modl

<__main__.DataModule at 0x7f2e78470460>

In [21]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(dirpath=drive_path, filename='{epoch}-{validation_loss:.2f}', save_top_k=1, every_n_epochs=3)

In [22]:
model = ModelUnet()
trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=100, log_every_n_steps=1, callbacks=[checkpoint_callback])

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [23]:
trainer.fit(model, data_modl)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name | Type              | Params
-------------------------------------------
0 | loss | MulticlassF1Score | 0     
1 | unet | Unet              | 24.4 M
-------------------------------------------
24.4 M    Trainable params
0         Non-trainable params
24.4 M    Total params
97.755    Total estimated model params size (MB)


train 14
test 2
validation 4
len data 20


Sanity Checking: 0it [00:00, ?it/s]

ValueError: ignored