## Data

In [None]:
from data import get_data_paths, celeb2mask, CustomDataset
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import time
import torch

from warnings import filterwarnings
filterwarnings("ignore")

%reload_ext autoreload
%autoreload 2

In [None]:
%%time
# get data paths
data_paths = get_data_paths(celeb_img_path="../data/dataset_celebs/imgs_256/*.jpg",
                            celeb_mask_path="../data/dataset_celebs/masks_256/")
# data_paths = get_data_paths()
# data_paths = [i for i in data_paths if i[0] == "community_dataset"]
# data_paths = [i for i in data_paths if i[0] != "community_dataset"]
np.random.seed(42)
np.random.shuffle(data_paths)
data_paths = data_paths[:20000]

# split data
train_paths, test_paths = train_test_split(data_paths, test_size=0.2, random_state=42)
test_paths, val_paths = train_test_split(test_paths, test_size=0.5, random_state=42)

print(f"Train: {len(train_paths)}, Val: {len(val_paths)}, Test: {len(test_paths)}")

# Datasets
train_dataset = CustomDataset(train_paths)
val_dataset = CustomDataset(val_paths)
test_dataset = CustomDataset(test_paths)

# Dataloaders
torch.manual_seed(42)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, drop_last=True)

del train_dataset, val_dataset, test_dataset

## DDRNET-23-slim

In [None]:
from ddrnet23 import DualResNet, BasicBlock, train_loop
import pytorch_lightning as pl

In [None]:
# utils
test_metrics = pd.DataFrame()
model_type = "ddrnet_23_slim"
use_pretrained = False
epochs = 20

In [None]:
# adam + equal weights
params = {"loss_name":"BCEWeighted", "optimizer_name":"Adam", "weight":[1,1,1,1,1,1,1,1,1,1,1],
          "lr": 1e-3, "weight_decay":0, "automatic_optimization":True, "scheduler_type":"Plateau",
         "scheduler_patience":3, "num_classes":11}

# callbacks utils
logger_name = f"ddrnet_23_slim_base"
logger_save_path = "../data/logs/ddrnet_23_slim/"
callback_name = f"ddrnet_23_slim_base"
callback_save_path = "../data/models/ddrnet_23_slim/"
# train loop
test_metrics = train_loop(params, test_metrics, logger_name, logger_save_path, callback_name,
                          callback_save_path, epochs, train_dataloader, val_dataloader, test_dataloader,
                         model_type=model_type)

In [None]:
# adam + proportional weights
params = {"loss_name":"BCEWeighted", "optimizer_name":"Adam",
          "weight":[0.0403, 1.8895, 4.3349, 0.9494, 0.0718, 2.6287, 7.0844, 0.0643, 1.7070, 5.7402, 3.7832],
          "lr": 1e-3, "weight_decay":0, "automatic_optimization":True, "scheduler_type":"Plateau",
         "scheduler_patience":3, "num_classes":11}

# callbacks utils
logger_name = f"ddrnet_23_slim_weights"
logger_save_path = "../data/logs/ddrnet_23_slim/"
callback_name = f"ddrnet_23_slim_weights"
callback_save_path = "../data/models/ddrnet_23_slim/"
# train loop
test_metrics = train_loop(params, test_metrics, logger_name, logger_save_path, callback_name,
                          callback_save_path, epochs, train_dataloader, val_dataloader, test_dataloader,
                         model_type=model_type)

In [None]:
# focalloss + proportional weights
params = {"loss_name":"FocalLoss", "optimizer_name":"Adam",
          "weight":[0.0403, 1.8895, 4.3349, 0.9494, 0.0718, 2.6287, 7.0844, 0.0643, 1.7070, 5.7402, 3.7832],
          "lr": 1e-3, "weight_decay":0, "automatic_optimization":True, "scheduler_type":"Plateau",
         "scheduler_patience":3, "num_classes":11}

# callbacks utils
logger_name = f"ddrnet_23_slim_weights_focalloss"
logger_save_path = "../data/logs/ddrnet_23_slim/"
callback_name = f"ddrnet_23_slim_weights_focalloss"
callback_save_path = "../data/models/ddrnet_23_slim/"
# train loop
test_metrics = train_loop(params, test_metrics, logger_name, logger_save_path, callback_name,
                          callback_save_path, epochs, train_dataloader, val_dataloader, test_dataloader,
                         model_type=model_type)

In [None]:
test_metrics