## Data

In [None]:
from data import get_data_paths, celeb2mask, CustomDataset
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import time
import torch

from warnings import filterwarnings
filterwarnings("ignore")

%reload_ext autoreload
%autoreload 2

In [None]:
%%time
# get data paths
data_paths = get_data_paths(celeb_img_path="../data/dataset_celebs/imgs_256/*.jpg",
                            celeb_mask_path="../data/dataset_celebs/masks_256/")
data_paths = [i for i in data_paths if i[0] == "community_dataset"]
# data_paths = [i for i in data_paths if i[0] != "community_dataset"]
# np.random.shuffle(data_paths)
# data_paths = data_paths[:100]

# split data
train_paths, test_paths = train_test_split(data_paths, test_size=0.2, random_state=42)
test_paths, val_paths = train_test_split(test_paths, test_size=0.5, random_state=42)

print(f"Train: {len(train_paths)}, Val: {len(val_paths)}, Test: {len(test_paths)}")

# Datasets
train_dataset = CustomDataset(train_paths)
val_dataset = CustomDataset(val_paths)
test_dataset = CustomDataset(test_paths)

# Dataloaders
torch.manual_seed(0)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, drop_last=True)

del train_dataset, val_dataset, test_dataset

## RegSeg

In [None]:
from regseg import RegSeg, train_loop
import pytorch_lightning as pl

In [None]:
epochs = 15
test_metrics = pd.DataFrame()

In [None]:
# adam + equal weights
params = {"loss_name":"BCEWeighted", "optimizer_name":"Adam", "weight":[1,1,1,1,1,1,1,1,1,1,1],
          "lr": 1e-3, "weight_decay":0, "automatic_optimization":True, "scheduler_type":"Plateau",
         "scheduler_patience":2, "num_classes":11}

# callbacks utils
logger_name = f"regseg_base"
logger_save_path = "../data/logs/regseg/"
callback_name = f"regseg_base"
callback_save_path = "../data/models/regseg/"
# train loop
test_metrics = train_loop(params, test_metrics, logger_name, logger_save_path, callback_name,
                          callback_save_path, epochs, train_dataloader, val_dataloader, test_dataloader,)

In [None]:
# adam + proportional weights
params = {"loss_name":"BCEWeighted", "optimizer_name":"Adam",
          "weight":[0.02335332, 2.29093098, 4.60850608, 0.95243978, 0.16392607, 3.69650905, 10.48914825, 0.05858025, 2.10530202, 4.27744651, 1.27002274],
          "lr": 1e-3, "weight_decay":0, "automatic_optimization":True, "scheduler_type":"Plateau",
         "scheduler_patience":2, "num_classes":11}

# callbacks utils
logger_name = f"regseg_weights"
logger_save_path = "../data/logs/regseg/"
callback_name = f"regseg_weights"
callback_save_path = "../data/models/regseg/"
# train loop
test_metrics = train_loop(params, test_metrics, logger_name, logger_save_path, callback_name,
                          callback_save_path, epochs, train_dataloader, val_dataloader, test_dataloader,)

In [None]:
# adamp + proportional weights
params = {"loss_name":"BCEWeighted", "optimizer_name":"AdamP",
          "weight":[0.02335332, 2.29093098, 4.60850608, 0.95243978, 0.16392607, 3.69650905, 10.48914825, 0.05858025, 2.10530202, 4.27744651, 1.27002274],
          "lr": 1e-3, "weight_decay":0, "automatic_optimization":True, "scheduler_type":"Plateau",
         "scheduler_patience":2, "num_classes":11}

# callbacks utils
logger_name = f"regseg_weights_adamp"
logger_save_path = "../data/logs/regseg/"
callback_name = f"regseg_weights_adamp"
callback_save_path = "../data/models/regseg/"
# train loop
test_metrics = train_loop(params, test_metrics, logger_name, logger_save_path, callback_name,
                          callback_save_path, epochs, train_dataloader, val_dataloader, test_dataloader,)

In [None]:
test_metrics