In [1]:
import torch
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt

from app.models import ONet
from app.utils import seed_everything
from app.utils_dataset import get_data, get_data_loaders
from app.utils_model import train, model_inference, test_dlib
from app.utils_plot import plot_ced_auc_test
from app.settings import RANDOM_STATE, IMAGE_DIRS_300W_TRAIN, IMAGE_DIRS_300W_TEST, IMAGE_DIRS_MENPO_TRAIN, IMAGE_DIRS_MENPO_TEST, DEVICE

In [None]:
seed_everything(RANDOM_STATE)

sns.set_style('whitegrid')
plt.rcParams.update({'font.size': 15})

train_df, val_df, test_df  = get_data(IMAGE_DIRS_MENPO_TRAIN, IMAGE_DIRS_MENPO_TEST, use_val_dataset=True)
train_loader, val_loader, test_loader = get_data_loaders(train_df, val_df, test_df)

In [None]:
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.figure(figsize=(15, 12))
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)


# Получим 1 батч (картнки-метки) из обучающей выборки
image, real_landmarks, resized_landmarks, rect = next(iter(train_loader))
# cv2.rectangle(image, (rect[0],rect[1]), (rect[2],rect[3]), (255, 0, 0), 2)
# for i in real_landmarks:
#     image = cv2.circle(image, i, radius=0, color=(0, 0, 255), thickness=-1)
# Расположим картинки рядом
out = torchvision.utils.make_grid(image)

imshow(out)

Обучение модели( Пока что возьмем O-Net)

In [None]:
model = ONet()
model = model.to(DEVICE)

num_epochs = 500

# В качестве cost function используем кросс-энтропию
criterion = torch.nn.MSELoss()

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=0.001,
    # weight_decay=2e-05
)

# scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)

NAME = 'Menpo_dataset_model_O-Net'
train_losses, val_losses, common_train_auc_0_08, common_val_auc_0_08, common_train_RMSE, common_val_RMSE = train(
    model, optimizer, scheduler, criterion, train_loader, val_loader, num_epochs, NAME
)

#создаем словать из того что хотим сохранить
# NAME = '300W_dataset_model_O-Net'
state = {
    'state_dict': model.state_dict(),
    'optimizer' : optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'loss': train_losses[-1],
    'common_train_auc_0_08': common_train_auc_0_08[-1],
    'common_val_auc_0_08': common_val_auc_0_08[-1],
    'common_train_RMSE': common_train_RMSE[-1],
    'common_val_RMSE': common_val_RMSE[-1],
}
torch.save(state, f'/model_weights/checkpoints_{NAME}.pth')


Получим результаты для тестовой части

In [None]:
import os
import pickle

RMSE, test_auc_0_08, pred_landmarks_model = model_inference(model, test_loader, tqdm_desc='model_inference')

# save result
SAVE_PATH = '/results/predictions_300W_ONet.pkl'

data_result= {'RMSE': RMSE, 'test_auc_0_08': test_auc_0_08, 'pred_landmarks_model': pred_landmarks_model}

if os.path.exists(SAVE_PATH):
      os.remove(SAVE_PATH)
      
with open(SAVE_PATH, 'wb') as fp:
    pickle.dump(data_result, fp)


Построим график CED для test_300W

In [None]:
# with open(SAVE_PATH, 'rb') as fp:
#      data_result = pickle.load(fp)
RMSE_300W_test_ONet = data_result['RMSE']
auc_0_08_300W_test_ONet = data_result['test_auc_0_08']
plot_ced_auc_test(RMSE_300W_test_ONet, auc_0_08_300W_test_ONet, image_data = '300W')

Результаты DLIB model На тесте.

In [None]:
# ced_auc_0_08_300W_dlib, RMSE_300W_test_dlib, pred_landmarks_300W_dlib  = test_dlib(datasets_path = '/home/ann/projects/vision_lab/df_300W_test.pickle')

In [None]:
train_df, val_df, test_df  = get_data(IMAGE_DIRS_300W_TRAIN, IMAGE_DIRS_300W_TEST, use_val_dataset=True)
train_loader, val_loader, test_loader = get_data_loaders(train_df, val_df, test_df)