## Data

In [None]:
from data import get_data_paths, celeb2npy, CustomDataset
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np

from warnings import filterwarnings
filterwarnings("ignore")

In [None]:
%%time
# get data paths
data_paths = get_data_paths()
# np.random.shuffle(data_paths)
# data_paths = data_paths[:10000]

# split data
train_paths, test_paths = train_test_split(data_paths, test_size=0.2, random_state=42)
test_paths, val_paths = train_test_split(test_paths, test_size=0.5, random_state=42)

print(f"Train: {len(train_paths)}, Val: {len(val_paths)}, Test: {len(test_paths)}")

# Datasets
train_dataset = CustomDataset(train_paths)
val_dataset = CustomDataset(val_paths)
test_dataset = CustomDataset(test_paths)

# Dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

## Training

In [None]:
from pretrained_models import PretrainedModel
import pytorch_lightning as pl
from shared import get_callbacks

In [None]:
# utils
epochs = 15
loss_list = ["BCE", "FocalLoss", "DiceFocal"]
backbones = ["efficientnet-b3", "timm-mobilenetv3_large_075"]

### DeepLabV3

In [None]:
arch = "deeplabv3plus"

for loss in loss_list:
    for encoder_name in backbones:
        # init model
        params = {"lr": 1e-4, "weight_decay": 0,
                  "loss_name": loss, "arch":arch, "encoder_name":encoder_name}
        callbacks, logger = get_callbacks(params)
        model = PretrainedModel(**params)
        
        # train model
        trainer = pl.Trainer(
        gpus=1, 
        max_epochs=epochs,
        logger=logger,
        callbacks=callbacks
        )

        trainer.fit(
            model, 
            train_dataloaders=train_dataloader, 
            val_dataloaders=val_dataloader,
        )
        
        # evaluate
        trainer.test(model, test_dataloader)
        
        del model

### Unet

In [None]:
arch = "Unet"

for loss in loss_list:
    for encoder_name in backbones:
        # init model
        params = {"lr": 1e-4, "weight_decay": 0
                  "loss_name": loss, "arch":arch, "encoder_name":encoder_name}
        callbacks, logger = get_callbacks(params)
        model = PretrainedModel(**params)
        
        # train model
        trainer = pl.Trainer(
        gpus=1, 
        max_epochs=epochs,
        logger=logger,
        callbacks=callbacks,
        )

        trainer.fit(
            model, 
            train_dataloaders=train_dataloader, 
            val_dataloaders=val_dataloader,
        )
        
        # evaluate
        trainer.test(model, test_dataloader)
        
        del model