In [None]:
from data import get_data_paths, celeb2mask, CustomDataset
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import time
import torch
from tqdm.notebook import tqdm

from warnings import filterwarnings
filterwarnings("ignore")

%reload_ext autoreload
%autoreload 2

In [None]:
epochs = 100
test_metrics = pd.DataFrame()

## All Classes

In [None]:
%%time
# get data paths
data_paths = get_data_paths(celeb_img_path="../data/dataset_celebs/imgs_256/*.jpg",
                            celeb_mask_path="../data/dataset_celebs/masks_256/")
# data_paths = [i for i in data_paths if i[0] == "community_dataset"]
# data_paths = [i for i in data_paths if i[0] != "community_dataset"]
# np.random.shuffle(data_paths)
# data_paths = data_paths[:100]

# split data
train_paths, test_paths = train_test_split(data_paths, test_size=0.2, random_state=42)
test_paths, val_paths = train_test_split(test_paths, test_size=0.5, random_state=42)

print(f"Train: {len(train_paths)}, Val: {len(val_paths)}, Test: {len(test_paths)}")

# Datasets
train_dataset = CustomDataset(train_paths, num_classes="all")
val_dataset = CustomDataset(val_paths, num_classes="all")
test_dataset = CustomDataset(test_paths, num_classes="all")

# Dataloaders
torch.manual_seed(42)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, drop_last=True)

del train_dataset, val_dataset, test_dataset

### RegSeg

In [None]:
from regseg import RegSeg, train_loop
import pytorch_lightning as pl

In [None]:
# final
params = {"loss_name":"BCEWeighted", "optimizer_name":"Adam",
          "weight": [0.0403, 1.8895, 4.3349, 0.9494, 0.0718, 2.6287, 7.0844, 0.0643, 1.7070, 5.7402, 3.7832],
          "lr": 1e-3, "weight_decay":0, "automatic_optimization":True, "scheduler_type":"Plateau",
          "scheduler_patience":3, "num_classes":11}

# callbacks utils
logger_name = f"regseg_final_16"
logger_save_path = "../data/logs/regseg/"
callback_name = f"regseg_final_16"
callback_save_path = "../data/models/regseg/"
# train loop
test_metrics = train_loop(params, test_metrics, logger_name, logger_save_path, callback_name,
                          callback_save_path, epochs, train_dataloader, val_dataloader, test_dataloader,)

In [None]:
test_metrics

In [None]:
# final + bias
params = {"loss_name":"BiasLoss", "optimizer_name":"Adam",
          "weight": [0.0403, 1.8895, 4.3349, 0.9494, 0.0718, 2.6287, 7.0844, 0.0643, 1.7070, 5.7402, 3.7832],
          "lr": 1e-3, "weight_decay":0, "automatic_optimization":True, "scheduler_type":"Plateau",
          "scheduler_patience":3, "num_classes":11}

# callbacks utils
logger_name = f"regseg_final_16_bias"
logger_save_path = "../data/logs/regseg/"
callback_name = f"regseg_final_16_bias"
callback_save_path = "../data/models/regseg/"
# train loop
test_metrics = train_loop(params, test_metrics, logger_name, logger_save_path, callback_name,
                          callback_save_path, epochs, train_dataloader, val_dataloader, test_dataloader,)

In [None]:
test_metrics

In [None]:
# final + focal
params = {"loss_name":"FocalLoss", "optimizer_name":"Adam",
          "weight": [0.0403, 1.8895, 4.3349, 0.9494, 0.0718, 2.6287, 7.0844, 0.0643, 1.7070, 5.7402, 3.7832],
          "lr": 1e-3, "weight_decay":0, "automatic_optimization":True, "scheduler_type":"Plateau",
          "scheduler_patience":3, "num_classes":11}

# callbacks utils
logger_name = f"regseg_final_16_focal"
logger_save_path = "../data/logs/regseg/"
callback_name = f"regseg_final_16_focal"
callback_save_path = "../data/models/regseg/"
# train loop
test_metrics = train_loop(params, test_metrics, logger_name, logger_save_path, callback_name,
                          callback_save_path, epochs, train_dataloader, val_dataloader, test_dataloader,)

In [None]:
test_metrics

### Confusion matrix

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix, f1_score, accuracy_score
import seaborn as sns

In [None]:
# final
params = {"loss_name":"BCEWeighted", "optimizer_name":"Adam",
          "weight": [ 0.0423,  1.2298,  4.5671,  1.4371,  0.4330,  2.7586,  7.4159,  0.0678, 0.0874, 13.7053,  4.0711],
          "lr": 1e-3, "weight_decay":0, "automatic_optimization":True, "scheduler_type":"Plateau",
         "scheduler_patience":3, "num_classes":11}

# callbacks utils
logger_name = f"regseg_final_16"
logger_save_path = "../data/logs/regseg/"
callback_name = f"regseg_final_16"
callback_save_path = "../data/models/regseg/"

In [None]:
model = RegSeg(
        loss_name=params["loss_name"],
        optimizer_name=params["optimizer_name"],
        weight=params["weight"],
        lr=params["lr"],
        weight_decay=params["weight_decay"],
        automatic_optimization=params["automatic_optimization"],
        scheduler_type=params["scheduler_type"],
        scheduler_patience=params["scheduler_patience"] 
    )
model = model.load_from_checkpoint(f"{callback_save_path}/{callback_name}.ckpt")
_ = model.eval()

In [None]:
score_check = []
score_check2 = []

model = model.to("cuda")
cfm = np.zeros((11, 11))

for x, mask in tqdm(test_dataloader):
    # get preds
    logits_mask = model(x.to("cuda")).to("cpu")
    # get pred_mask
    prob_mask = logits_mask.sigmoid()
    pred_mask = prob_mask > 0.5
    # get new data
    new_preds = []
    new_mask = []
    for i in range(11):
        temp_pred = pred_mask[0, i].view(-1, 1).long().numpy()
        temp_mask = mask[0, i].view(-1, 1).long().numpy()

        temp_pred[temp_pred] = i
        temp_mask[temp_mask] = i

        new_preds.extend(temp_pred.reshape(1, -1)[0])
        new_mask.extend(temp_mask.reshape(1, -1)[0])
        
        
        if i == 1:
            temp_score = f1_score(temp_pred.reshape(1, -1)[0], temp_mask.reshape(1, -1)[0], average="micro")
            temp_score2 = accuracy_score(temp_pred.reshape(1, -1)[0], temp_mask.reshape(1, -1)[0])
            
            score_check.append(temp_score)
            score_check2.append(temp_score2)
        
    cfm_batch = confusion_matrix(new_mask, new_preds)
    cfm += cfm_batch

In [None]:
cfm_final = np.zeros((11, 11))

for i in range(11):
    for j in range(11):
        new_val = cfm[i, j] / np.sum(cfm[i])
        cfm_final[i, j] = new_val
        
CLASSES = {"background": [0, 0, 0], "lips": [255, 0, 0], "eye": [0, 255, 0], "nose": [0, 0, 255],
           "hair": [255, 255, 0], "eyebrows": [255, 0, 255], "teeth": [255, 255, 255], "face": [128, 128, 128],
           "ears": [0, 255, 255], "glasses": [0, 128, 128], "beard": [255, 192, 192]}

plt.figure(figsize=(16, 14))
sns.heatmap(cfm_final, annot=True, annot_kws={"size": 11},
            xticklabels=CLASSES.keys(), yticklabels=CLASSES.keys()) # font size

plt.show()