In [1]:
!conda config --add pkgs_dirs ./ # Set the location where conda package will be downloaded
!conda install --download-only -y pyvips # Download pyvips and dependencies

Collecting package metadata (current_repodata.json): - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - 

In [2]:
!conda install *.tar.bz2 


Downloading and Extracting Packages
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
###########################################################

In [3]:
# import pandas as pd
# from collections import Counter


# data = pd.read_csv('/kaggle/input/mayo-clinic-strip-ai/train.csv')
# res = dict(Counter(data['center_id'].tolist()))
# print(res)

In [4]:
# groups = [(11,), (4,), (7,), (1, 5,), (10, 3), (6, 2, 8, 9,)]
# for gr in groups:
#     print(gr, sum(res[i] for i in gr))

In [5]:
BAD_IMAGE_IDS = ['5adc4c_0', '7b9aaa_0', 'bb06a5_0', 'e26a04_0', '280c26_0'] + \
                ['4ae44b_0', '53e66f_0', '7c2c2f_0', '74a450_1']

BLOCK_SIZE = 28
BLOCKS_PER_CROP = 8
CROP_SIZE = BLOCK_SIZE * BLOCKS_PER_CROP
BLOCK_THR = 90
CROP_THR = 0.6
MAX_CROPS_PER_IMAGE = 20
IMAGES_PER_SAMPLE = 4
EPOCHS_NUM = 5
SCALE_FACTOR = 24

In [6]:
import gc
import os
from time import time
from typing import List, Tuple

import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm

import pyvips
import cv2


class DataPreparation:
    def __init__(self, visualize: bool = False, seed: int = 42):
        self.visualize = visualize
        self.seed = seed

        train_metadata = pd.read_csv('/kaggle/input/mayo-clinic-strip-ai/train.csv')
        train_metadata = list(zip(
            train_metadata['image_id'].tolist(),
            train_metadata['label'].tolist(),
            train_metadata['center_id'].tolist(),
        ))
        self.train = self._filter_bad_images(train_metadata)
        self.all_center_ids = sorted(list({center_id for _, _, center_id in self.train}))

        other_metadata = pd.read_csv('/kaggle/input/mayo-clinic-strip-ai/other.csv').query('label == \'Other\'')
        other_metadata = list(zip(
            other_metadata['image_id'].tolist(),
            ['LAA' for _ in range(other_metadata.shape[0])],
            [-1 for _ in range(other_metadata.shape[0])],
        ))
        self.other = self._filter_bad_images(other_metadata)

    @staticmethod
    def _filter_bad_images(data: List[Tuple]) -> List[Tuple]:
        return [
            (image_id, label, center_id)
            for image_id, label, center_id in data
            if image_id not in BAD_IMAGE_IDS
        ]

    @staticmethod
    def _add_rect_to_numpy(image: np.ndarray, x: int, y: int, size: int, thickness: int) -> None:
        image[x:x + size, y:y + thickness] = (0, 0, 0)
        image[x:x + thickness, y:y + size] = (0, 0, 0)
        image[x:x + size, y + size:y + size + thickness] = (0, 0, 0)
        image[x + size:x + size + thickness, y:y + size] = (0, 0, 0)

    @staticmethod
    def _get_blocks_map(image: np.ndarray) -> np.ndarray:
        pixels_diff = np.sum((image[:-1, :, :] - image[1:, :, :]) ** 2, axis=2)
        pixels_diff = np.cumsum(np.cumsum(pixels_diff, axis=0), axis=1)
        blocks_map = np.zeros((
            (image.shape[0] + BLOCK_SIZE - 1) // BLOCK_SIZE,
            (image.shape[1] + BLOCK_SIZE - 1) // BLOCK_SIZE,
        ))
        for x in range(0, pixels_diff.shape[0], BLOCK_SIZE):
            for y in range(0, pixels_diff.shape[1], BLOCK_SIZE):
                nx = min(x + BLOCK_SIZE, pixels_diff.shape[0])
                ny = min(y + BLOCK_SIZE, pixels_diff.shape[1])
                block_sum = int(pixels_diff[nx - 1, ny - 1])
                if x:
                    block_sum -= int(pixels_diff[x - 1, ny - 1])
                if y:
                    block_sum -= int(pixels_diff[nx - 1, y - 1])
                if x and y:
                    block_sum += int(pixels_diff[x - 1, y - 1])
                blocks_map[x // BLOCK_SIZE][y // BLOCK_SIZE] = \
                    (block_sum / BLOCK_SIZE / BLOCK_SIZE) > BLOCK_THR
        return blocks_map

    def _generate_crops_positions(
            self,
            image: np.ndarray,
            crop_thr: float,
    ) -> Tuple[List[Tuple[int, int]], np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        blocks_map = self._get_blocks_map(image)

        if self.visualize:
            for i in range(blocks_map.shape[0]):
                for j in range(blocks_map.shape[1]):
                    if blocks_map[i][j]:
                        self._add_rect_to_numpy(
                            image,
                            i * BLOCK_SIZE,
                            j * BLOCK_SIZE,
                            BLOCK_SIZE,
                            1,
                        )

        good_crops_starts = []
        for x in range(0, image.shape[0] - CROP_SIZE + 1, BLOCK_SIZE):
            for y in range(0, image.shape[1] - CROP_SIZE + 1, BLOCK_SIZE):
                _x, _y = x // BLOCK_SIZE, y // BLOCK_SIZE
                crop_sum = blocks_map[_x:_x + BLOCKS_PER_CROP, _y:_y + BLOCKS_PER_CROP].sum()
                if crop_sum > BLOCKS_PER_CROP * BLOCKS_PER_CROP * crop_thr:
                    good_crops_starts.append((x, y))

        if self.visualize:
            for x, y in good_crops_starts:
                self._add_rect_to_numpy(image, x, y, CROP_SIZE, 1)

        return good_crops_starts

    @staticmethod
    def _process_crop(crop: np.ndarray) -> np.ndarray:
        return crop

    def _create_crops(
        self,
        image: np.ndarray,
        crops_starts: List[Tuple[int]],
    ) -> List[np.ndarray]:
        return [
            Image.fromarray(
                self._process_crop(
                    image[x:x + CROP_SIZE, y:y + CROP_SIZE],
                )
            )
            for x, y in crops_starts
        ]

    @staticmethod
    def _get_unique_crops(crop_starts: List[Tuple[int, int]], order) -> List[Tuple[int, int]]:
        def inter_size_1d(a: int, b: int, c: int, d: int) -> int:
            return max(0, min(b, d) - max(a, c))

        def inter_size_2d(crop_start_1: Tuple[int, int], crop_start_2: Tuple[int, int]) -> int:
            return inter_size_1d(
                crop_start_1[0], crop_start_1[0] + CROP_SIZE,
                crop_start_2[0], crop_start_2[0] + CROP_SIZE,
            ) * inter_size_1d(
                crop_start_1[1], crop_start_1[1] + CROP_SIZE,
                crop_start_2[1], crop_start_2[1] + CROP_SIZE,
            )

        crop_starts_sorted = sorted(crop_starts, key=order)
        final_crop_starts = []
        for crop_start in crop_starts_sorted:
            if any(
                    inter_size_2d(crop_start, crop_start_prev) > CROP_SIZE * CROP_SIZE // 2
                    for crop_start_prev in final_crop_starts
            ):
                continue
            final_crop_starts.append(crop_start)
        return final_crop_starts
    
    @staticmethod
    def _read_and_resize_image(image_id: str, base_image_path: str) -> np.ndarray:       
        image_path = os.path.join(base_image_path, f'{image_id}.tif')
        image = pyvips.Image.new_from_file(image_path, access='sequential')
        return image.resize(1.0 / SCALE_FACTOR).numpy()
    
    def prepare_crops(
            self,
            image_ids: List[int],
            base_image_path: str,
    ) -> Tuple[List[List[np.ndarray]], List[List[Tuple[int]]], List[Tuple[np.ndarray, np.ndarray]]]:
        np.random.seed(self.seed)
        image_crops = []
        image_crops_indices = []
        for image_id in tqdm(image_ids):
            start_time = time()
            image = self._read_and_resize_image(image_id, base_image_path)
            gc.collect()
            print(f'Rescaling done in {time() - start_time} seconds. Image shape is {image.shape}')
            found_flag = False
            for crop_thr in np.arange(CROP_THR, -0.1, -0.1):
                good_crops_starts = self._generate_crops_positions(image, crop_thr)
                if len(good_crops_starts) < IMAGES_PER_SAMPLE:
                    print('Bad image', image_id, 'crop_thr', crop_thr, 'only', len(good_crops_starts))
                    continue

                good_crops_starts_unique = []
                for order in [
                    lambda x: (x[0], x[1]),
                    lambda x: (-x[0], -x[1]),
                ]:
                    good_crops_starts_unique.extend(self._get_unique_crops(good_crops_starts, order))
                good_crops_starts_unique = list(set(good_crops_starts_unique))

                if len(good_crops_starts_unique) < IMAGES_PER_SAMPLE:
                    print('Bad image', image_id, 'crop_thr', crop_thr, 'only', len(good_crops_starts_unique))
                    continue

                good_crops_starts_sample_ids = np.random.choice(
                    list(range(len(good_crops_starts_unique))),
                    min(len(good_crops_starts_unique), MAX_CROPS_PER_IMAGE),
                    replace=False,
                )
                good_crops_starts_sample = np.array(good_crops_starts_unique)[good_crops_starts_sample_ids]
                image_crops_indices.append(good_crops_starts_sample)
                image_crops.append(self._create_crops(image, good_crops_starts_sample))
                found_flag = True
                break
            if not found_flag:
                image_crops_indices.append([])
                image_crops.append([])
                print('No crops was found')
            print(f'Done {image_id} in {time() - start_time} seconds')
        gc.collect()
        return image_crops, image_crops_indices

    def process_train(
            self
    ) -> Tuple[List[List[np.ndarray]], List[List[Tuple[int]]]]:
        return self.prepare_crops(
            [image_id for image_id, _, _ in self.train],
            '/kaggle/input/mayo-clinic-strip-ai/train/',
        )

    def process_other(
            self
    ) -> Tuple[List[List[np.ndarray]], List[List[Tuple[int]]]]:
        return self.prepare_crops(
            [image_id for image_id, _, _ in self.other],
            '/kaggle/input/mayo-clinic-strip-ai/other/',
        )

In [7]:
import random
from collections import defaultdict
from typing import List

import numpy as np
import torch
from PIL import Image


class ClotImageDataset(torch.utils.data.Dataset):
    def __init__(
            self,
            image_ids: List[str],
            labels: List[str],
            image_crops: List[List[np.ndarray]],
            seed: int,
            is_test: bool,
            transformations,
    ):
        self.image_ids = image_ids
        self.labels = [float(label == 'CE') for label in labels]
        self.image_crops = image_crops
        self.seed = seed
        self.is_test = is_test
        self.transformations = transformations

        if not self.is_test:
            np.random.seed(self.seed)

            label_to_indices = defaultdict(list)
            for i, (label, crops) in enumerate(zip(self.labels, self.image_crops)):
                if len(crops) > 0:
                    label_to_indices[label].append(i)

            max_size = 4 * max(len(indices) for indices in label_to_indices.values())

            self.sample_ids = []
            for i, indices in enumerate(label_to_indices.values()):
                np.random.shuffle(indices)
                while len(self.sample_ids) < max_size * (i + 1):
                    req_size = min(len(indices), max_size * (i + 1) - len(self.sample_ids))
                    self.sample_ids += indices[:req_size]
        else:
            self.sample_ids = []
            for _ in range(20):
                self.sample_ids.extend(list(range(len(self.image_ids))))

        self.image_index_ids = []
        sample_id_to_image_index = defaultdict(int)
        for sample_id in self.sample_ids:
            self.image_index_ids.append(sample_id_to_image_index[sample_id])
            image_crops_cnt = len(self.image_crops[sample_id])
            if image_crops_cnt:
                sample_id_to_image_index[sample_id] = (sample_id_to_image_index[sample_id] + 1) % image_crops_cnt

    def __len__(self):
        return len(self.sample_ids)

    def __getitem__(self, idx):
        if self.is_test:
            np.random.seed(self.seed + idx)
            random.seed(self.seed + idx)
            torch.manual_seed(self.seed + idx)
        idx, image_index = self.sample_ids[idx], self.image_index_ids[idx]
        if len(self.image_crops[idx]) == 0:
            return (
                self.transformations(Image.fromarray(np.zeros((224, 224, 3)).astype(np.uint8))),
                torch.tensor(self.labels[idx]),
                self.image_ids[idx],
            )
        # image_index = np.random.randint(0, len(self.image_crops[idx]))
        return (
            self.transformations(self.image_crops[idx][image_index]),
            torch.tensor(self.labels[idx]),
            self.image_ids[idx],
        )


def get_loader(
        image_ids: List[str],
        labels: List[str],
        image_crops: List[List[np.ndarray]],
        seed: int,
        is_test: bool,
        transformations,
        shuffle: bool,
        batch_size: int,
        num_workers: int
):
    dataset = ClotImageDataset(
        image_ids, labels, image_crops, seed, is_test, transformations,
    )
    return torch.utils.data.DataLoader(dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers)

In [8]:
from collections import defaultdict

import numpy as np


def get_target_metric(y_true, y_pred, image_ids):
    patients = [image_id.split('_')[0] for image_id in image_ids]
    patient_to_y_true, patient_to_y_pred = defaultdict(list), defaultdict(list)
    for y, y_hat, patient in zip(y_true, y_pred, patients):
        patient_to_y_true[patient].append(y)
        patient_to_y_pred[patient].append(y_hat)
    patient_to_y_true = {
        patient: np.mean(y_true)
        for patient, y_true in patient_to_y_true.items()
    }
    patient_to_y_pred = {
        patient: np.mean(y_pred).tolist()
        for patient, y_pred in patient_to_y_pred.items()
    }
    y_true, y_pred = [], []
    for patient, y in patient_to_y_true.items():
        y_true.append(y)
        y_pred.append(patient_to_y_pred[patient])
    return _weighted_mc_log_loss(y_true, np.array([[1 - p, p] for p in y_pred]))


def _weighted_mc_log_loss(y_true, y_pred, epsilon=1e-15):
    class_cnt = [sum(int(val == cl) for val in y_true) for cl in range(2)]
    w = [0.5 for _ in range(2)]
    return -sum(
        w[cl] * sum(
            (y == cl) / class_cnt[cl] * np.log(max(min(y_hat, 1 - epsilon), epsilon))
            for y, y_hat in zip(y_true, y_pred[:, cl])
        )
        for cl in range(2)
    ) / sum(w[cl] for cl in range(2))

In [9]:
import torch
import torch.nn as nn
# import pretrainedmodels as pm
from torchvision import models


class ClotModelMIL(nn.Module):
    def __init__(self, num_crops=None):
        super().__init__()
        self.num_crops = num_crops

        base_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.model = nn.Sequential(*list(base_model.children())[:-2])
        in_features_cnt = list(base_model.children())[-1].in_features
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=1),
            nn.Flatten(),
            nn.Linear(in_features_cnt, 1),
            nn.Sigmoid(),
        )

    def freeze_encoder(self, flag):
        for param in self.model.parameters():
            param.requires_grad = not flag

    def forward(self, x):
        # x: bs x N x C x W x W
        bs, _, ch, w, h = x.shape
        x = x.view(bs * self.num_crops, ch, w, h)  # x: N bs x C x W x W
        x = self.model(x)  # x: N bs x C' x W' x W'

        # Concat and pool
        bs2, ch2, w2, h2 = x.shape
        x = x \
            .view(-1, self.num_crops, ch2, w2, h2) \
            .permute(0, 2, 1, 3, 4) \
            .contiguous() \
            .view(bs, ch2, self.num_crops * w2, h2)  # x: bs x C' x N W'' x W''
        return self.head(x)

    def save(self, model_path):
        weights = self.state_dict()
        torch.save(weights, model_path)

    def load(self, model_path):
        weights = torch.load(model_path, map_location='cpu')
        self.load_state_dict(weights)


class ClotModelSingle(nn.Module):
    def __init__(self, encoder_model):
        super().__init__()

        if encoder_model == 'effnet_b0':
            base_model = models.efficientnet_b0(pretrained=True) # weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
            self.model = base_model.features
            in_features_cnt = base_model.classifier[1].in_features
        elif encoder_model == 'resnet18':
            base_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
            self.model = nn.Sequential(*list(base_model.children())[:-2])
            in_features_cnt = list(base_model.children())[-1].in_features
        elif encoder_model == 'regnet_x_1_6gf':
            base_model = models.regnet_x_1_6gf(weights=models.RegNet_X_1_6GF_Weights.IMAGENET1K_V2)
            self.model = nn.Sequential(base_model.stem, base_model.trunk_output)
            in_features_cnt = base_model.fc.in_features
        else:
            raise Exception('Incorrect encoder name')

        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=1),
            nn.Flatten(),
            nn.Linear(in_features_cnt, 1),
            nn.Sigmoid(),
        )

    def freeze_encoder(self, flag):
        for param in self.model.parameters():
            param.requires_grad = not flag

    def forward(self, x):
        return self.head(self.model(x))

    def save(self, model_path):
        weights = self.state_dict()
        torch.save(weights, model_path)

    def load(self, model_path):
        weights = torch.load(model_path, map_location='cpu')
        self.load_state_dict(weights)

In [10]:
import os


os.makedirs('/kaggle/working/models/', exist_ok=True)

In [11]:
from __future__ import print_function, division

import os
import pickle
import sys
from collections import Counter

import cv2
import numpy as np
import ssl
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from PIL import Image
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
from torchvision import transforms


ssl._create_default_https_context = ssl._create_unverified_context
cudnn.benchmark = True


DUMPED_DATALOADER_PATH = '/kaggle/input/track-4-dataprep/data_loaders.pkl'
DUMPED_DATALOADER_OTHER_PATH = '/kaggle/input/track-4-dataprep/data_loaders_other.pkl'


def get_sub_data(data, image_crops, image_crops_indices, sample_ids):
    return [data[i][0] for i in sample_ids], \
        [data[i][1] for i in sample_ids], \
        [data[i][2] for i in sample_ids], \
        [image_crops[i] for i in sample_ids], \
        [image_crops_indices[i] for i in sample_ids]


data_prep = DataPreparation()

# image_crops, image_crops_indices = data_prep.process_train()
# with open(DUMPED_DATALOADER_PATH, 'wb') as file:
#     pickle.dump([image_crops, image_crops_indices], file)

# image_crops_other, image_crops_indices_other = data_prep.process_other()
# with open(DUMPED_DATALOADER_OTHER_PATH, 'wb') as file:
#     pickle.dump([image_crops_other, image_crops_indices_other], file)
    
with open(DUMPED_DATALOADER_PATH, 'rb') as file:
    image_crops, image_crops_indices = pickle.load(file)
with open(DUMPED_DATALOADER_OTHER_PATH, 'rb') as file:
    image_crops_other, image_crops_indices_other = pickle.load(file)    

In [12]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        # transforms.RandomResizedCrop((224, 224), scale=(0.5, 1.0), ratio=(1.0, 1.0)),
        transforms.RandomAdjustSharpness(sharpness_factor=2, p=1.0),
        transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.5),
        transforms.ColorJitter(brightness=0.2, saturation=0.5, hue=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        # transforms.RandomResizedCrop((224, 224), scale=(0.5, 1.0), ratio=(1.0, 1.0)),
        transforms.RandomAdjustSharpness(sharpness_factor=2, p=1.0),
        transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.5),
        transforms.ColorJitter(brightness=0.2, saturation=0.5, hue=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_data = data_prep.train + data_prep.other
image_crops += image_crops_other
image_crops_indices += image_crops_indices_other

all_metrics = []
all_y, all_y_hat, all_image_ids = [], [], []
#for test_center_id in {center_id for _, _, center_id in train_data if center_id != -1}:
for test_centers_group in [(11,), (4,), (7,), (1, 5,), (10, 3), (6, 2, 8, 9,)]:
    best_validation_metric = None
    best_y, best_y_hat, best_image_ids = [], [], []
    for iteration in range(3):
        for _ in range(3):
            print('-' * 80)
        test_centers_group_str = '.'.join(map(str, test_centers_group))
        print(f'CV with {test_centers_group_str} as test')
        train_sample_ids = [i for i, (_, _, center_id) in enumerate(train_data) if center_id not in test_centers_group]
        test_sample_ids = [i for i, (_, _, center_id) in enumerate(train_data) if center_id in test_centers_group]

        train_image_ids, train_labels, train_center_ids, train_crops, train_crop_indices = get_sub_data(
            train_data,
            image_crops,
            image_crops_indices,
            train_sample_ids
        )
        test_image_ids, test_labels, test_center_ids, test_crops, test_crop_indices = get_sub_data(
            train_data,
            image_crops,
            image_crops_indices,
            test_sample_ids
        )
        print(f'Train/Test sizes: {len(train_labels)}/{len(test_labels)}')
        print('Train/Test label distribution:')
        print({key: value / len(train_labels) for key, value in dict(Counter(train_labels)).items()})
        print({key: value / len(test_labels) for key, value in dict(Counter(test_labels)).items()})

        dataloaders = {
            'train': get_loader(
                train_image_ids,
                train_labels,
                train_crops,
                seed=42,
                is_test=False,
                transformations=data_transforms['train'],
                shuffle=True,
                batch_size=256,
                num_workers=2,
            ),
            'test': get_loader(
                test_image_ids,
                test_labels,
                test_crops,
                seed=42,
                is_test=True,
                transformations=data_transforms['test'],
                shuffle=False,
                batch_size=256,
                num_workers=2,
            ),
        }

        model = ClotModelSingle(encoder_model='effnet_b0').to(device)
        model.freeze_encoder(True)
        criterion = nn.BCELoss()
        optimizer = optim.Adam(model.head.parameters(), lr=0.05, weight_decay=0.0001)

        train_loss, val_loss = [], []
        for epoch in range(EPOCHS_NUM):
            np.random.seed(epoch)

            #if epoch == 10:
            #    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=5e-4)

            print('*' * 80)
            print("epoch {}/{}".format(epoch + 1, EPOCHS_NUM))

            model.train()
            running_loss, running_score = 0.0, 0.0
            y_hat, y, image_ids = [], [], []
            for image, label, image_id in tqdm(dataloaders['train']):
                image = image.to(device)
                label = label.to(device)
                optimizer.zero_grad()
                y_pred = model.forward(image).squeeze()
                loss = criterion(y_pred, label)
                running_loss += loss.item()
                loss.backward()
                optimizer.step()

                y_pred = y_pred.cpu().detach().numpy().tolist()
                label = label.cpu().detach().numpy().tolist()
                y_hat.extend(y_pred)
                y.extend(label)
                image_ids.extend(image_id)

                running_score += sum([int(int(y_hat > 0.5) == y) for y_hat, y in zip(y_pred, label)])

            print('Train:')
            print(Counter([int(p > 0.5) for p in y_hat]))
            print('ROC AUC metric:', roc_auc_score(y, y_hat))
            print('target metric:', get_target_metric(y, y_hat, image_ids))

            epoch_score = running_score / len(dataloaders['train'].dataset)
            epoch_loss = running_loss / len(dataloaders['train'])
            train_loss.append(epoch_loss)
            print("loss: {}, accuracy: {}".format(epoch_loss, epoch_score))

            with torch.no_grad():
                model.eval()
                running_loss, running_score = 0.0, 0.0
                y_hat, y, image_ids = [], [], []
                for image, label, image_id in tqdm(dataloaders['test']):
                    image = image.to(device)
                    label = label.to(device)
                    optimizer.zero_grad()
                    y_pred = model.forward(image).squeeze()
                    loss = criterion(y_pred, label)
                    running_loss += loss.item()

                    y_pred = y_pred.cpu().detach().numpy().tolist()
                    label = label.cpu().detach().numpy().tolist()
                    y_hat.extend(y_pred)
                    y.extend(label)
                    image_ids.extend(image_id)

                    running_score += sum([int(int(y_hat > 0.5) == y) for y_hat, y in zip(y_pred, label)])

                bad_image_ids = {
                    image_id
                    for image_id, crops in zip(test_image_ids, test_crops)
                    if len(crops) == 0 
                }
                y_hat_fixed = [
                    0.5 if image_id in bad_image_ids else p
                    for p, image_id in zip(y_hat, image_ids)
                ]

                print('Validation:')
                print(Counter([int(p > 0.5) for p in y_hat_fixed]))
                print('ROC AUC metric:', roc_auc_score(y, y_hat_fixed))

                target_metric = get_target_metric(y, y_hat, image_ids)
                print('target metric:', target_metric)
                target_metric = get_target_metric(y, y_hat_fixed, image_ids)
                print('target metric fixed:', target_metric)            
                if best_validation_metric is None or target_metric < best_validation_metric:
                    best_validation_metric = target_metric
                    best_y, best_y_hat, best_image_ids = y, y_hat_fixed, image_ids
                    torch.save(
                        model,
                        os.path.join(
                            '/kaggle/working/models',
                            f'center_id_{test_centers_group_str}_epoch_{epoch}_target_{round(target_metric, 3)}.h5',
                        ),
                    )

                epoch_score = running_score / len(dataloaders['test'].dataset)
                epoch_loss = running_loss / len(dataloaders['test'])
                val_loss.append(epoch_loss)
                print("loss: {}, accuracy: {}".format(epoch_loss, epoch_score))

    print(f'Best validation metric: {best_validation_metric}')
    all_metrics.append(best_validation_metric)
    all_y.extend(best_y)
    all_y_hat.extend(best_y_hat)
    all_image_ids.extend(best_image_ids)

print(all_metrics)
print(np.mean(all_metrics))
final_metric = get_target_metric(
    all_y,
    all_y_hat,
    all_image_ids,
)
np.save('/kaggle/working/all_y.npy', all_y)
np.save('/kaggle/working/all_y_hat.npy', all_y_hat)
np.save('/kaggle/working/all_image_ids.npy', all_image_ids)
print('Full validation metric:', final_metric)

--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CV with 11 as test
Train/Test sizes: 560/254
Train/Test label distribution:
{'CE': 0.6553571428571429, 'LAA': 0.34464285714285714}
{'CE': 0.6968503937007874, 'LAA': 0.3031496062992126}


Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-3dd342df.pth


  0%|          | 0.00/20.5M [00:00<?, ?B/s]

********************************************************************************
epoch 1/5


100%|██████████| 12/12 [00:24<00:00,  2.06s/it]


Train:
Counter({1: 1539, 0: 1309})
ROC AUC metric: 0.5870135557379119
target metric: 0.6651913418720463
loss: 0.9788656681776047, accuracy: 0.5635533707865169


100%|██████████| 20/20 [00:40<00:00,  2.05s/it]


Validation:
Counter({0: 4770, 1: 310})
ROC AUC metric: 0.5586383813926187
target metric: 1.201247539724809
target metric fixed: 1.2029961448937296
loss: 2.15089145898819, accuracy: 0.33937007874015745
********************************************************************************
epoch 2/5


100%|██████████| 12/12 [00:23<00:00,  1.96s/it]


Train:
Counter({0: 1448, 1: 1400})
ROC AUC metric: 0.6214463530488574
target metric: 0.6244971848514214
loss: 0.8448945184548696, accuracy: 0.5849719101123596


100%|██████████| 20/20 [00:39<00:00,  1.99s/it]


Validation:
Counter({0: 3525, 1: 1555})
ROC AUC metric: 0.5653267847971237
target metric: 0.7397692887468887
target metric fixed: 0.738399095036011
loss: 1.1209315001964568, accuracy: 0.4490157480314961
********************************************************************************
epoch 3/5


100%|██████████| 12/12 [00:23<00:00,  1.98s/it]


Train:
Counter({0: 1477, 1: 1371})
ROC AUC metric: 0.6540175048920591
target metric: 0.599741546611679
loss: 0.7511040419340134, accuracy: 0.6042837078651685


100%|██████████| 20/20 [00:40<00:00,  2.00s/it]


Validation:
Counter({1: 3195, 0: 1885})
ROC AUC metric: 0.5854558294812532
target metric: 0.6710026042326529
target metric fixed: 0.6722800301918153
loss: 0.7378010094165802, accuracy: 0.6045275590551181
********************************************************************************
epoch 4/5


100%|██████████| 12/12 [00:23<00:00,  1.95s/it]


Train:
Counter({0: 1427, 1: 1421})
ROC AUC metric: 0.6672324753818962
target metric: 0.5909698286739705
loss: 0.7235569705565771, accuracy: 0.6211376404494382


100%|██████████| 20/20 [00:39<00:00,  1.99s/it]


Validation:
Counter({0: 3342, 1: 1738})
ROC AUC metric: 0.5858364516839093
target metric: 0.7838443211811112
target metric fixed: 0.6854558641626494
loss: 1.281960552930832, accuracy: 0.481496062992126
********************************************************************************
epoch 5/5


100%|██████████| 12/12 [00:23<00:00,  1.95s/it]


Train:
Counter({1: 1446, 0: 1402})
ROC AUC metric: 0.6506433649476077
target metric: 0.5975244657995287
loss: 0.7622215896844864, accuracy: 0.6060393258426966


100%|██████████| 20/20 [00:39<00:00,  1.99s/it]


Validation:
Counter({1: 2682, 0: 2398})
ROC AUC metric: 0.594450069704307
target metric: 0.7499131824586767
target metric fixed: 0.6515247254402152
loss: 1.1916424214839936, accuracy: 0.5641732283464567
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CV with 11 as test
Train/Test sizes: 560/254
Train/Test label distribution:
{'CE': 0.6553571428571429, 'LAA': 0.34464285714285714}
{'CE': 0.6968503937007874, 'LAA': 0.3031496062992126}
********************************************************************************
epoch 1/5


100%|██████████| 12/12 [00:22<00:00,  1.90s/it]


Train:
Counter({0: 1536, 1: 1312})
ROC AUC metric: 0.5472152742709254
target metric: 0.6909690593546947
loss: 1.1445354570945103, accuracy: 0.5308988764044944


100%|██████████| 20/20 [00:39<00:00,  1.96s/it]


Validation:
Counter({1: 2952, 0: 2128})
ROC AUC metric: 0.6186665382639958
target metric: 0.6567562418934779
target metric fixed: 0.6567746030061511
loss: 0.7906636893749237, accuracy: 0.6019685039370078
********************************************************************************
epoch 2/5


100%|██████████| 12/12 [00:23<00:00,  1.92s/it]


Train:
Counter({0: 1443, 1: 1405})
ROC AUC metric: 0.6433506462252241
target metric: 0.6175791898610257
loss: 0.7448299129803976, accuracy: 0.5979634831460674


100%|██████████| 20/20 [00:38<00:00,  1.94s/it]


Validation:
Counter({1: 4276, 0: 804})
ROC AUC metric: 0.5996011262748551
target metric: 0.7633474301121308
target metric fixed: 0.7587721892970604
loss: 0.6816166520118714, accuracy: 0.6665354330708662
********************************************************************************
epoch 3/5


100%|██████████| 12/12 [00:22<00:00,  1.92s/it]


Train:
Counter({1: 1550, 0: 1298})
ROC AUC metric: 0.6089548352480747
target metric: 0.6327634231996833
loss: 0.8180160820484161, accuracy: 0.5786516853932584


100%|██████████| 20/20 [00:38<00:00,  1.92s/it]


Validation:
Counter({0: 2638, 1: 2442})
ROC AUC metric: 0.605682001614205
target metric: 0.641959281094876
target metric fixed: 0.642027776187675
loss: 0.781318411231041, accuracy: 0.5590551181102362
********************************************************************************
epoch 4/5


100%|██████████| 12/12 [00:23<00:00,  1.93s/it]


Train:
Counter({1: 1481, 0: 1367})
ROC AUC metric: 0.6653604737406894
target metric: 0.5962187758771414
loss: 0.6928450167179108, accuracy: 0.6148174157303371


100%|██████████| 20/20 [00:38<00:00,  1.93s/it]


Validation:
Counter({1: 3242, 0: 1838})
ROC AUC metric: 0.5843651405092083
target metric: 0.6754216541030547
target metric fixed: 0.6774366168372405
loss: 0.7122344642877578, accuracy: 0.6090551181102363
********************************************************************************
epoch 5/5


100%|██████████| 12/12 [00:22<00:00,  1.91s/it]


Train:
Counter({1: 1454, 0: 1394})
ROC AUC metric: 0.6810416929680596
target metric: 0.5853998470110247
loss: 0.6730612317721049, accuracy: 0.6306179775280899


100%|██████████| 20/20 [00:38<00:00,  1.92s/it]


Validation:
Counter({0: 3208, 1: 1872})
ROC AUC metric: 0.5905135189669088
target metric: 0.7828931347868961
target metric fixed: 0.6845046777684347
loss: 1.3187997758388519, accuracy: 0.502755905511811
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CV with 11 as test
Train/Test sizes: 560/254
Train/Test label distribution:
{'CE': 0.6553571428571429, 'LAA': 0.34464285714285714}
{'CE': 0.6968503937007874, 'LAA': 0.3031496062992126}
********************************************************************************
epoch 1/5


100%|██████████| 12/12 [00:23<00:00,  1.92s/it]


Train:
Counter({0: 1566, 1: 1282})
ROC AUC metric: 0.5427044210011363
target metric: 0.7131366388113037
loss: 1.2669840554396312, accuracy: 0.5273876404494382


100%|██████████| 20/20 [00:38<00:00,  1.91s/it]


Validation:
Counter({1: 4717, 0: 363})
ROC AUC metric: 0.5839495744368626
target metric: 1.0435307110662049
target metric fixed: 1.043303622934821
loss: 0.8934591799974442, accuracy: 0.6852362204724409
********************************************************************************
epoch 2/5


100%|██████████| 12/12 [00:22<00:00,  1.90s/it]


Train:
Counter({1: 1681, 0: 1167})
ROC AUC metric: 0.589617393637167
target metric: 0.6640563776665002
loss: 1.0098740508159, accuracy: 0.5677668539325843


100%|██████████| 20/20 [00:38<00:00,  1.92s/it]


Validation:
Counter({0: 4406, 1: 674})
ROC AUC metric: 0.567679580306699
target metric: 0.9117940827412848
target metric fixed: 0.908348439490772
loss: 1.4949710667133331, accuracy: 0.3799212598425197
********************************************************************************
epoch 3/5


100%|██████████| 12/12 [00:22<00:00,  1.89s/it]


Train:
Counter({0: 1552, 1: 1296})
ROC AUC metric: 0.6569692609045574
target metric: 0.5868872671227351
loss: 0.7669109056393305, accuracy: 0.6228932584269663


100%|██████████| 20/20 [00:38<00:00,  1.91s/it]


Validation:
Counter({1: 3518, 0: 1562})
ROC AUC metric: 0.5911363636363636
target metric: 0.6789815429344419
target metric fixed: 0.6799486028321804
loss: 0.7039013296365738, accuracy: 0.6248031496062992
********************************************************************************
epoch 4/5


100%|██████████| 12/12 [00:22<00:00,  1.88s/it]


Train:
Counter({0: 1506, 1: 1342})
ROC AUC metric: 0.656901452625931
target metric: 0.5915047123115773
loss: 0.7147474040587743, accuracy: 0.6179775280898876


100%|██████████| 20/20 [00:38<00:00,  1.92s/it]


Validation:
Counter({1: 3932, 0: 1148})
ROC AUC metric: 0.6015857729840781
target metric: 0.7030086775127492
target metric fixed: 0.7050236402469351
loss: 0.680913308262825, accuracy: 0.6535433070866141
********************************************************************************
epoch 5/5


100%|██████████| 12/12 [00:22<00:00,  1.86s/it]


Train:
Counter({1: 1581, 0: 1267})
ROC AUC metric: 0.6566785483209192
target metric: 0.5885279034238694
loss: 0.7527782569328944, accuracy: 0.6211376404494382


100%|██████████| 20/20 [00:37<00:00,  1.89s/it]


Validation:
Counter({0: 3078, 1: 2002})
ROC AUC metric: 0.585250568640399
target metric: 0.770229092190639
target metric fixed: 0.6718406351721772
loss: 1.2946497619152069, accuracy: 0.5066929133858268
Best validation metric: 0.642027776187675
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CV with 4 as test
Train/Test sizes: 700/114
Train/Test label distribution:
{'CE': 0.6514285714285715, 'LAA': 0.3485714285714286}
{'CE': 0.7719298245614035, 'LAA': 0.22807017543859648}
********************************************************************************
epoch 1/5


100%|██████████| 15/15 [00:28<00:00,  1.90s/it]


Train:
Counter({1: 1860, 0: 1740})
ROC AUC metric: 0.5584277777777777
target metric: 0.6793750802419605
loss: 1.159578029314677, accuracy: 0.5394444444444444


100%|██████████| 9/9 [00:17<00:00,  1.97s/it]


Validation:
Counter({0: 1492, 1: 788})
ROC AUC metric: 0.5984036276223776
target metric: 0.6755418721733735
target metric fixed: 0.6537235254129459
loss: 0.9569532937473721, accuracy: 0.49473684210526314
********************************************************************************
epoch 2/5


100%|██████████| 15/15 [00:27<00:00,  1.86s/it]


Train:
Counter({0: 1876, 1: 1724})
ROC AUC metric: 0.6059682098765432
target metric: 0.6401585038007558
loss: 0.7991904020309448, accuracy: 0.5766666666666667


100%|██████████| 9/9 [00:18<00:00,  2.02s/it]


Validation:
Counter({0: 1520, 1: 760})
ROC AUC metric: 0.5486746066433567
target metric: 0.6721038315590078
target metric fixed: 0.6789027391424745
loss: 0.9474267496003045, accuracy: 0.4631578947368421
********************************************************************************
epoch 3/5


100%|██████████| 15/15 [00:27<00:00,  1.84s/it]


Train:
Counter({0: 1828, 1: 1772})
ROC AUC metric: 0.6442716049382715
target metric: 0.6063843034203182
loss: 0.730549430847168, accuracy: 0.6033333333333334


100%|██████████| 9/9 [00:17<00:00,  1.96s/it]


Validation:
Counter({0: 1668, 1: 612})
ROC AUC metric: 0.5449278846153847
target metric: 0.7588417799122049
target metric fixed: 0.701953795939231
loss: 1.0620374547110663, accuracy: 0.4219298245614035
********************************************************************************
epoch 4/5


100%|██████████| 15/15 [00:27<00:00,  1.86s/it]


Train:
Counter({0: 1849, 1: 1751})
ROC AUC metric: 0.6342703703703704
target metric: 0.6144805865412837
loss: 0.737367820739746, accuracy: 0.5902777777777778


100%|██████████| 9/9 [00:18<00:00,  2.01s/it]


Validation:
Counter({0: 1574, 1: 706})
ROC AUC metric: 0.5397738199300699
target metric: 0.7773516490661811
target metric fixed: 0.696475137925803
loss: 1.0539565881093342, accuracy: 0.44473684210526315
********************************************************************************
epoch 5/5


100%|██████████| 15/15 [00:27<00:00,  1.84s/it]


Train:
Counter({0: 1825, 1: 1775})
ROC AUC metric: 0.6447706790123456
target metric: 0.6006860024252804
loss: 0.721319603919983, accuracy: 0.5997222222222223


100%|██████████| 9/9 [00:17<00:00,  1.92s/it]


Validation:
Counter({0: 1191, 1: 1089})
ROC AUC metric: 0.5562423513986015
target metric: 0.6785143140261762
target metric fixed: 0.6773014935823705
loss: 0.850490947564443, accuracy: 0.5469298245614035
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CV with 4 as test
Train/Test sizes: 700/114
Train/Test label distribution:
{'CE': 0.6514285714285715, 'LAA': 0.3485714285714286}
{'CE': 0.7719298245614035, 'LAA': 0.22807017543859648}
********************************************************************************
epoch 1/5


100%|██████████| 15/15 [00:27<00:00,  1.86s/it]


Train:
Counter({0: 1858, 1: 1742})
ROC AUC metric: 0.5783558641975309
target metric: 0.6640952256100674
loss: 0.9414571007092793, accuracy: 0.5633333333333334


100%|██████████| 9/9 [00:17<00:00,  1.93s/it]


Validation:
Counter({1: 1662, 0: 618})
ROC AUC metric: 0.6165329982517483
target metric: 0.7157112657271836
target metric fixed: 0.7209589527112826
loss: 0.597107277976142, accuracy: 0.6929824561403509
********************************************************************************
epoch 2/5


100%|██████████| 15/15 [00:27<00:00,  1.84s/it]


Train:
Counter({1: 1828, 0: 1772})
ROC AUC metric: 0.6157123456790123
target metric: 0.6267481188962234
loss: 0.7771254221598307, accuracy: 0.5822222222222222


100%|██████████| 9/9 [00:17<00:00,  1.95s/it]


Validation:
Counter({0: 1285, 1: 995})
ROC AUC metric: 0.5752376529720279
target metric: 0.6479090952723865
target metric fixed: 0.6537566673974857
loss: 0.8517755203776889, accuracy: 0.5346491228070176
********************************************************************************
epoch 3/5


100%|██████████| 15/15 [00:28<00:00,  1.88s/it]


Train:
Counter({1: 1837, 0: 1763})
ROC AUC metric: 0.6448682098765431
target metric: 0.6038019101585307
loss: 0.7560501257578532, accuracy: 0.6058333333333333


100%|██████████| 9/9 [00:16<00:00,  1.88s/it]


Validation:
Counter({0: 1329, 1: 951})
ROC AUC metric: 0.6222618006993007
target metric: 0.7938450627121605
target metric fixed: 0.6216517945334183
loss: 0.9388061232037015, accuracy: 0.5390350877192982
********************************************************************************
epoch 4/5


100%|██████████| 15/15 [00:27<00:00,  1.85s/it]


Train:
Counter({0: 1811, 1: 1789})
ROC AUC metric: 0.6409771604938272
target metric: 0.6007863048497761
loss: 0.7775380849838257, accuracy: 0.6091666666666666


100%|██████████| 9/9 [00:18<00:00,  2.01s/it]


Validation:
Counter({0: 1256, 1: 1024})
ROC AUC metric: 0.5802633304195804
target metric: 0.7498812447342434
target metric fixed: 0.6412759881590967
loss: 0.8829069336255392, accuracy: 0.5385964912280702
********************************************************************************
epoch 5/5


100%|██████████| 15/15 [00:27<00:00,  1.83s/it]


Train:
Counter({1: 1826, 0: 1774})
ROC AUC metric: 0.639512962962963
target metric: 0.6026139086208124
loss: 0.7982044657071431, accuracy: 0.5972222222222222


100%|██████████| 9/9 [00:17<00:00,  1.89s/it]


Validation:
Counter({0: 1575, 1: 705})
ROC AUC metric: 0.5653452797202797
target metric: 0.7177344497163027
target metric fixed: 0.6830508538496052
loss: 1.122702916463216, accuracy: 0.45921052631578946
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CV with 4 as test
Train/Test sizes: 700/114
Train/Test label distribution:
{'CE': 0.6514285714285715, 'LAA': 0.3485714285714286}
{'CE': 0.7719298245614035, 'LAA': 0.22807017543859648}
********************************************************************************
epoch 1/5


100%|██████████| 15/15 [00:27<00:00,  1.83s/it]


Train:
Counter({1: 1896, 0: 1704})
ROC AUC metric: 0.5493404320987655
target metric: 0.6980606715781343
loss: 1.19382088581721, accuracy: 0.5333333333333333


100%|██████████| 9/9 [00:17<00:00,  1.97s/it]


Validation:
Counter({0: 1755, 1: 525})
ROC AUC metric: 0.5587500000000001
target metric: 0.7497333485624473
target metric fixed: 0.7131655119083328
loss: 1.1332092947430081, accuracy: 0.4074561403508772
********************************************************************************
epoch 2/5


100%|██████████| 15/15 [00:27<00:00,  1.83s/it]


Train:
Counter({0: 1914, 1: 1686})
ROC AUC metric: 0.6019324074074075
target metric: 0.6436469848761192
loss: 0.8414316177368164, accuracy: 0.5733333333333334


100%|██████████| 9/9 [00:18<00:00,  2.01s/it]


Validation:
Counter({1: 1397, 0: 883})
ROC AUC metric: 0.6180026223776224
target metric: 0.6585403580545193
target metric fixed: 0.6444434790026806
loss: 0.7063594791624281, accuracy: 0.6346491228070176
********************************************************************************
epoch 3/5


100%|██████████| 15/15 [00:28<00:00,  1.87s/it]


Train:
Counter({1: 1828, 0: 1772})
ROC AUC metric: 0.6631132716049383
target metric: 0.5942208784293275
loss: 0.6815143545468648, accuracy: 0.6216666666666667


100%|██████████| 9/9 [00:16<00:00,  1.89s/it]


Validation:
Counter({1: 1422, 0: 858})
ROC AUC metric: 0.6285041520979021
target metric: 0.6286196476089742
target metric fixed: 0.6365418204981428
loss: 0.6308316588401794, accuracy: 0.6517543859649123
********************************************************************************
epoch 4/5


100%|██████████| 15/15 [00:27<00:00,  1.86s/it]


Train:
Counter({1: 1813, 0: 1787})
ROC AUC metric: 0.6389756172839507
target metric: 0.6056321588027354
loss: 0.733279804388682, accuracy: 0.5975


100%|██████████| 9/9 [00:17<00:00,  1.95s/it]


Validation:
Counter({0: 1419, 1: 861})
ROC AUC metric: 0.5992766608391609
target metric: 0.7114994409576965
target metric fixed: 0.6369773684461013
loss: 0.903089165687561, accuracy: 0.5083333333333333
********************************************************************************
epoch 5/5


100%|██████████| 15/15 [00:27<00:00,  1.86s/it]


Train:
Counter({0: 1815, 1: 1785})
ROC AUC metric: 0.6674682098765433
target metric: 0.5811118731988483
loss: 0.7212333599726359, accuracy: 0.6225


100%|██████████| 9/9 [00:16<00:00,  1.87s/it]


Validation:
Counter({0: 1154, 1: 1126})
ROC AUC metric: 0.6242679195804197
target metric: 0.7408023887535709
target metric fixed: 0.6156763463174704
loss: 0.8037275539504157, accuracy: 0.5789473684210527
Best validation metric: 0.6156763463174704
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CV with 7 as test
Train/Test sizes: 716/98
Train/Test label distribution:
{'CE': 0.6634078212290503, 'LAA': 0.3365921787709497}
{'CE': 0.7040816326530612, 'LAA': 0.29591836734693877}
********************************************************************************
epoch 1/5


100%|██████████| 15/15 [00:29<00:00,  1.95s/it]


Train:
Counter({0: 1905, 1: 1815})
ROC AUC metric: 0.5878678170886807
target metric: 0.6591471316166491
loss: 0.9548075675964356, accuracy: 0.5631720430107527


100%|██████████| 8/8 [00:16<00:00,  2.02s/it]


Validation:
Counter({0: 1346, 1: 614})
ROC AUC metric: 0.6410663418290854
target metric: 0.6411611600085682
target metric fixed: 0.637639828612466
loss: 0.8713346645236015, accuracy: 0.5010204081632653
********************************************************************************
epoch 2/5


100%|██████████| 15/15 [00:28<00:00,  1.93s/it]


Train:
Counter({1: 1915, 0: 1805})
ROC AUC metric: 0.6364909238062203
target metric: 0.615676282675611
loss: 0.7267415523529053, accuracy: 0.6008064516129032


100%|██████████| 8/8 [00:15<00:00,  1.96s/it]


Validation:
Counter({0: 1054, 1: 906})
ROC AUC metric: 0.6160107446276861
target metric: 0.6073899911732781
target metric fixed: 0.6043952312811802
loss: 0.7726947665214539, accuracy: 0.5520408163265306
********************************************************************************
epoch 3/5


100%|██████████| 15/15 [00:28<00:00,  1.91s/it]


Train:
Counter({0: 1942, 1: 1778})
ROC AUC metric: 0.6444684356573015
target metric: 0.6092479456588696
loss: 0.7040793816248576, accuracy: 0.6026881720430107


100%|██████████| 8/8 [00:16<00:00,  2.02s/it]


Validation:
Counter({1: 1125, 0: 835})
ROC AUC metric: 0.6379097951024488
target metric: 0.6102848386704557
target metric fixed: 0.6009506212225616
loss: 0.6873675286769867, accuracy: 0.6147959183673469
********************************************************************************
epoch 4/5


100%|██████████| 15/15 [00:28<00:00,  1.87s/it]


Train:
Counter({1: 1908, 0: 1812})
ROC AUC metric: 0.6664595617990519
target metric: 0.59223517019308
loss: 0.6827585577964783, accuracy: 0.6209677419354839


100%|██████████| 8/8 [00:16<00:00,  2.02s/it]


Validation:
Counter({0: 1003, 1: 957})
ROC AUC metric: 0.6273475762118941
target metric: 0.6088339666322513
target metric fixed: 0.6103425328918883
loss: 0.7699500024318695, accuracy: 0.585204081632653
********************************************************************************
epoch 5/5


100%|██████████| 15/15 [00:28<00:00,  1.89s/it]


Train:
Counter({0: 1905, 1: 1815})
ROC AUC metric: 0.6673375534743902
target metric: 0.5902695722666712
loss: 0.6813902934392293, accuracy: 0.6228494623655914


100%|██████████| 8/8 [00:15<00:00,  1.95s/it]


Validation:
Counter({1: 1206, 0: 754})
ROC AUC metric: 0.5990104947526237
target metric: 0.6336751054961338
target metric fixed: 0.6345957250743317
loss: 0.790761724114418, accuracy: 0.6142857142857143
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CV with 7 as test
Train/Test sizes: 716/98
Train/Test label distribution:
{'CE': 0.6634078212290503, 'LAA': 0.3365921787709497}
{'CE': 0.7040816326530612, 'LAA': 0.29591836734693877}
********************************************************************************
epoch 1/5


100%|██████████| 15/15 [00:28<00:00,  1.89s/it]


Train:
Counter({1: 1921, 0: 1799})
ROC AUC metric: 0.5603892068447219
target metric: 0.6721864954273137
loss: 1.02458651860555, accuracy: 0.5491935483870968


100%|██████████| 8/8 [00:15<00:00,  1.97s/it]


Validation:
Counter({0: 1403, 1: 557})
ROC AUC metric: 0.6482933533233384
target metric: 0.6757040771986755
target metric fixed: 0.6768243413574113
loss: 0.9852254018187523, accuracy: 0.49744897959183676
********************************************************************************
epoch 2/5


100%|██████████| 15/15 [00:28<00:00,  1.89s/it]


Train:
Counter({0: 1867, 1: 1853})
ROC AUC metric: 0.6364623077812464
target metric: 0.6233046048744704
loss: 0.713063649336497, accuracy: 0.5991935483870968


100%|██████████| 8/8 [00:15<00:00,  1.99s/it]


Validation:
Counter({0: 1101, 1: 859})
ROC AUC metric: 0.6312243878060969
target metric: 0.6338524723826094
target metric fixed: 0.6223371045268427
loss: 0.7686096951365471, accuracy: 0.551530612244898
********************************************************************************
epoch 3/5


100%|██████████| 15/15 [00:27<00:00,  1.86s/it]


Train:
Counter({0: 1903, 1: 1817})
ROC AUC metric: 0.6405934212047636
target metric: 0.610435127442264
loss: 0.7183332244555155, accuracy: 0.5943548387096774


100%|██████████| 8/8 [00:15<00:00,  1.99s/it]


Validation:
Counter({1: 1300, 0: 660})
ROC AUC metric: 0.6474812593703148
target metric: 0.6070295289023606
target metric fixed: 0.6087279330443925
loss: 0.646956242620945, accuracy: 0.6642857142857143
********************************************************************************
epoch 4/5


100%|██████████| 15/15 [00:27<00:00,  1.86s/it]


Train:
Counter({1: 1993, 0: 1727})
ROC AUC metric: 0.6552010348017111
target metric: 0.5950377856960894
loss: 0.7141734639803569, accuracy: 0.6158602150537634


100%|██████████| 8/8 [00:15<00:00,  1.97s/it]


Validation:
Counter({0: 1364, 1: 596})
ROC AUC metric: 0.6251174412793603
target metric: 0.6549932382339746
target metric fixed: 0.6585940372655766
loss: 0.9373350813984871, accuracy: 0.49081632653061225
********************************************************************************
epoch 5/5


100%|██████████| 15/15 [00:28<00:00,  1.89s/it]


Train:
Counter({0: 1937, 1: 1783})
ROC AUC metric: 0.6487420511041739
target metric: 0.5920474066773795
loss: 0.7317883729934692, accuracy: 0.6024193548387097


100%|██████████| 8/8 [00:15<00:00,  1.97s/it]


Validation:
Counter({1: 1480, 0: 480})
ROC AUC metric: 0.6542341329335333
target metric: 0.6479550733596533
target metric fixed: 0.6486915116279665
loss: 0.6267068162560463, accuracy: 0.6887755102040817
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CV with 7 as test
Train/Test sizes: 716/98
Train/Test label distribution:
{'CE': 0.6634078212290503, 'LAA': 0.3365921787709497}
{'CE': 0.7040816326530612, 'LAA': 0.29591836734693877}
********************************************************************************
epoch 1/5


100%|██████████| 15/15 [00:28<00:00,  1.89s/it]


Train:
Counter({0: 1950, 1: 1770})
ROC AUC metric: 0.5732746560295988
target metric: 0.6777829463534539
loss: 1.1791819254557292, accuracy: 0.5586021505376344


100%|██████████| 8/8 [00:15<00:00,  1.97s/it]


Validation:
Counter({1: 1425, 0: 535})
ROC AUC metric: 0.6420664667666167
target metric: 0.6609815674137965
target metric fixed: 0.6614155031818101
loss: 0.6795627176761627, accuracy: 0.6780612244897959
********************************************************************************
epoch 2/5


100%|██████████| 15/15 [00:28<00:00,  1.89s/it]


Train:
Counter({1: 1931, 0: 1789})
ROC AUC metric: 0.6100624349635796
target metric: 0.6309512140257152
loss: 0.8061712304751079, accuracy: 0.5879032258064516


100%|██████████| 8/8 [00:15<00:00,  1.92s/it]


Validation:
Counter({1: 1403, 0: 557})
ROC AUC metric: 0.6478373313343329
target metric: 0.6324525860694197
target metric fixed: 0.6326474021672652
loss: 0.6254339814186096, accuracy: 0.6709183673469388
********************************************************************************
epoch 3/5


100%|██████████| 15/15 [00:27<00:00,  1.86s/it]


Train:
Counter({1: 1910, 0: 1810})
ROC AUC metric: 0.6437731240605851
target metric: 0.6118981004981213
loss: 0.7075414856274923, accuracy: 0.5973118279569892


100%|██████████| 8/8 [00:15<00:00,  1.96s/it]


Validation:
Counter({0: 1470, 1: 490})
ROC AUC metric: 0.6490604697651174
target metric: 0.6916833512876235
target metric fixed: 0.654745685460075
loss: 0.9433862641453743, accuracy: 0.4642857142857143
********************************************************************************
epoch 4/5


100%|██████████| 15/15 [00:28<00:00,  1.88s/it]


Train:
Counter({0: 1943, 1: 1777})
ROC AUC metric: 0.6648430454387793
target metric: 0.592358595861028
loss: 0.6780757109324137, accuracy: 0.6174731182795699


100%|██████████| 8/8 [00:16<00:00,  2.01s/it]


Validation:
Counter({1: 1090, 0: 870})
ROC AUC metric: 0.6809770114942529
target metric: 0.5784363205627394
target metric fixed: 0.5805691482270264
loss: 0.6569951102137566, accuracy: 0.65
********************************************************************************
epoch 5/5


100%|██████████| 15/15 [00:28<00:00,  1.91s/it]


Train:
Counter({1: 1879, 0: 1841})
ROC AUC metric: 0.6816704243265118
target metric: 0.5810690289837572
loss: 0.6725374738375346, accuracy: 0.635752688172043


100%|██████████| 8/8 [00:15<00:00,  1.92s/it]


Validation:
Counter({1: 1087, 0: 873})
ROC AUC metric: 0.6381634182908545
target metric: 0.5902500548116256
target metric fixed: 0.5906903446378811
loss: 0.7048057615756989, accuracy: 0.6076530612244898
Best validation metric: 0.5805691482270264
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CV with 1.5 as test
Train/Test sizes: 722/92
Train/Test label distribution:
{'CE': 0.6523545706371191, 'LAA': 0.3476454293628809}
{'LAA': 0.20652173913043478, 'CE': 0.7934782608695652}
********************************************************************************
epoch 1/5


100%|██████████| 15/15 [00:28<00:00,  1.87s/it]


Train:
Counter({1: 1887, 0: 1793})
ROC AUC metric: 0.568953361294896
target metric: 0.6715948523759047
loss: 1.0114059726397195, accuracy: 0.5513586956521739


100%|██████████| 8/8 [00:15<00:00,  1.95s/it]


Validation:
Counter({1: 1364, 0: 476})
ROC AUC metric: 0.5632281903388608
target metric: 0.7678928504861998
target metric fixed: 0.7688262877941683
loss: 0.663352943956852, accuracy: 0.6739130434782609
********************************************************************************
epoch 2/5


100%|██████████| 15/15 [00:28<00:00,  1.90s/it]


Train:
Counter({1: 1946, 0: 1734})
ROC AUC metric: 0.6312550212665407
target metric: 0.6218448232193741
loss: 0.7777697801589966, accuracy: 0.5940217391304348


100%|██████████| 8/8 [00:14<00:00,  1.78s/it]


Validation:
Counter({0: 1105, 1: 735})
ROC AUC metric: 0.5930371304974765
target metric: 0.6787567390811532
target metric fixed: 0.6795469258159128
loss: 0.8983648344874382, accuracy: 0.48206521739130437
********************************************************************************
epoch 3/5


100%|██████████| 15/15 [00:28<00:00,  1.88s/it]


Train:
Counter({0: 1857, 1: 1823})
ROC AUC metric: 0.6622040406427221
target metric: 0.5926749626101657
loss: 0.719684112071991, accuracy: 0.6149456521739131


100%|██████████| 8/8 [00:14<00:00,  1.79s/it]


Validation:
Counter({0: 1416, 1: 424})
ROC AUC metric: 0.5916510454217736
target metric: 0.7440596850049729
target metric fixed: 0.7433545886457621
loss: 1.1425919979810715, accuracy: 0.3760869565217391
********************************************************************************
epoch 4/5


100%|██████████| 15/15 [00:28<00:00,  1.89s/it]


Train:
Counter({1: 1844, 0: 1836})
ROC AUC metric: 0.6650280600189036
target metric: 0.5914344732034567
loss: 0.6838773250579834, accuracy: 0.6195652173913043


100%|██████████| 8/8 [00:15<00:00,  1.90s/it]


Validation:
Counter({0: 1213, 1: 627})
ROC AUC metric: 0.5885346070656092
target metric: 0.6845915282090881
target metric fixed: 0.6899559863975033
loss: 0.9485280364751816, accuracy: 0.4483695652173913
********************************************************************************
epoch 5/5


100%|██████████| 15/15 [00:27<00:00,  1.86s/it]


Train:
Counter({0: 1924, 1: 1756})
ROC AUC metric: 0.6660580103969754
target metric: 0.5957309499361751
loss: 0.6838052074114481, accuracy: 0.6255434782608695


100%|██████████| 8/8 [00:13<00:00,  1.75s/it]


Validation:
Counter({1: 1162, 0: 678})
ROC AUC metric: 0.5838923936553714
target metric: 0.6743397410115937
target metric fixed: 0.679288418154272
loss: 0.6876282542943954, accuracy: 0.6076086956521739
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CV with 1.5 as test
Train/Test sizes: 722/92
Train/Test label distribution:
{'CE': 0.6523545706371191, 'LAA': 0.3476454293628809}
{'LAA': 0.20652173913043478, 'CE': 0.7934782608695652}
********************************************************************************
epoch 1/5


100%|██████████| 15/15 [00:28<00:00,  1.88s/it]


Train:
Counter({0: 2012, 1: 1668})
ROC AUC metric: 0.5226696892722117
target metric: 0.7500771698430564
loss: 1.6011921842892964, accuracy: 0.5125


100%|██████████| 8/8 [00:13<00:00,  1.75s/it]


Validation:
Counter({1: 1348, 0: 492})
ROC AUC metric: 0.6173071377072818
target metric: 0.7255719195946715
target metric fixed: 0.7248813478933246
loss: 0.6182315871119499, accuracy: 0.7010869565217391
********************************************************************************
epoch 2/5


100%|██████████| 15/15 [00:28<00:00,  1.88s/it]


Train:
Counter({1: 2029, 0: 1651})
ROC AUC metric: 0.6279900165406427
target metric: 0.6143163999112883
loss: 0.8146565794944763, accuracy: 0.597554347826087


100%|██████████| 8/8 [00:15<00:00,  1.91s/it]


Validation:
Counter({0: 1355, 1: 485})
ROC AUC metric: 0.5932065609228552
target metric: 0.7358192089670565
target metric fixed: 0.7332982634193379
loss: 1.114856243133545, accuracy: 0.38206521739130433
********************************************************************************
epoch 3/5


100%|██████████| 15/15 [00:28<00:00,  1.87s/it]


Train:
Counter({1: 1854, 0: 1826})
ROC AUC metric: 0.657116906899811
target metric: 0.5973946224957194
loss: 0.7075069824854533, accuracy: 0.6081521739130434


100%|██████████| 8/8 [00:14<00:00,  1.77s/it]


Validation:
Counter({0: 1002, 1: 838})
ROC AUC metric: 0.5817916366258111
target metric: 0.6788425644724883
target metric fixed: 0.6814682530171239
loss: 0.8365016207098961, accuracy: 0.5260869565217391
********************************************************************************
epoch 4/5


100%|██████████| 15/15 [00:28<00:00,  1.88s/it]


Train:
Counter({0: 1841, 1: 1839})
ROC AUC metric: 0.6688198546786389
target metric: 0.593643010133144
loss: 0.6692574779192607, accuracy: 0.6100543478260869


100%|██████████| 8/8 [00:13<00:00,  1.74s/it]


Validation:
Counter({1: 1085, 0: 755})
ROC AUC metric: 0.5720304614275415
target metric: 0.679172944638433
target metric fixed: 0.6849491711430993
loss: 0.718262791633606, accuracy: 0.6027173913043479
********************************************************************************
epoch 5/5


100%|██████████| 15/15 [00:27<00:00,  1.86s/it]


Train:
Counter({1: 1883, 0: 1797})
ROC AUC metric: 0.6887823428638942
target metric: 0.5744601343712443
loss: 0.6583022793134053, accuracy: 0.6421195652173913


100%|██████████| 8/8 [00:15<00:00,  1.89s/it]


Validation:
Counter({0: 1216, 1: 624})
ROC AUC metric: 0.5958056957462148
target metric: 0.69547990771807
target metric fixed: 0.6946462736636132
loss: 0.977883979678154, accuracy: 0.44021739130434784
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CV with 1.5 as test
Train/Test sizes: 722/92
Train/Test label distribution:
{'CE': 0.6523545706371191, 'LAA': 0.3476454293628809}
{'LAA': 0.20652173913043478, 'CE': 0.7934782608695652}
********************************************************************************
epoch 1/5


100%|██████████| 15/15 [00:28<00:00,  1.89s/it]


Train:
Counter({1: 1935, 0: 1745})
ROC AUC metric: 0.5555273806710775
target metric: 0.6834182605766301
loss: 1.0651598970095317, accuracy: 0.5339673913043478


100%|██████████| 8/8 [00:14<00:00,  1.79s/it]


Validation:
Counter({0: 1248, 1: 592})
ROC AUC metric: 0.6085472242249459
target metric: 0.7086139939252095
target metric fixed: 0.7088755428206717
loss: 1.0847695171833038, accuracy: 0.45108695652173914
********************************************************************************
epoch 2/5


100%|██████████| 15/15 [00:28<00:00,  1.87s/it]


Train:
Counter({0: 1926, 1: 1754})
ROC AUC metric: 0.6223570415879016
target metric: 0.6226480313568947
loss: 0.7610992471377055, accuracy: 0.5864130434782608


100%|██████████| 8/8 [00:14<00:00,  1.83s/it]


Validation:
Counter({1: 1552, 0: 288})
ROC AUC metric: 0.555329848594088
target metric: 0.8439915645681815
target metric fixed: 0.8388667261912488
loss: 0.6166034415364265, accuracy: 0.7130434782608696
********************************************************************************
epoch 3/5


100%|██████████| 15/15 [00:27<00:00,  1.85s/it]


Train:
Counter({1: 1906, 0: 1774})
ROC AUC metric: 0.6384564035916824
target metric: 0.6124297736248485
loss: 0.7678045630455017, accuracy: 0.5978260869565217


100%|██████████| 8/8 [00:14<00:00,  1.87s/it]


Validation:
Counter({1: 1145, 0: 695})
ROC AUC metric: 0.5643150684931506
target metric: 0.7059192117971007
target metric fixed: 0.7066258854783821
loss: 0.7245044708251953, accuracy: 0.6092391304347826
********************************************************************************
epoch 4/5


100%|██████████| 15/15 [00:28<00:00,  1.88s/it]


Train:
Counter({0: 1902, 1: 1778})
ROC AUC metric: 0.6638628308128545
target metric: 0.5929000127246573
loss: 0.6972147822380066, accuracy: 0.6070652173913044


100%|██████████| 8/8 [00:14<00:00,  1.76s/it]


Validation:
Counter({1: 1098, 0: 742})
ROC AUC metric: 0.6020872386445566
target metric: 0.9389078316764174
target metric fixed: 0.656860921556828
loss: 1.4003384560346603, accuracy: 0.6021739130434782
********************************************************************************
epoch 5/5


100%|██████████| 15/15 [00:28<00:00,  1.88s/it]


Train:
Counter({1: 1841, 0: 1839})
ROC AUC metric: 0.6789809782608696
target metric: 0.5833667242191524
loss: 0.6682579120000204, accuracy: 0.626358695652174


100%|██████████| 8/8 [00:14<00:00,  1.82s/it]


Validation:
Counter({1: 972, 0: 868})
ROC AUC metric: 0.6063031723143475
target metric: 0.6453045326795901
target metric fixed: 0.6510807591842562
loss: 0.7860750034451485, accuracy: 0.5836956521739131
Best validation metric: 0.6510807591842562
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CV with 10.3 as test
Train/Test sizes: 722/92
Train/Test label distribution:
{'CE': 0.6717451523545707, 'LAA': 0.32825484764542934}
{'CE': 0.6413043478260869, 'LAA': 0.358695652173913}
********************************************************************************
epoch 1/5


100%|██████████| 15/15 [00:29<00:00,  1.96s/it]


Train:
Counter({0: 1916, 1: 1876})
ROC AUC metric: 0.574134809681497
target metric: 0.6777291607479424
loss: 1.099273947874705, accuracy: 0.560126582278481


100%|██████████| 8/8 [00:14<00:00,  1.78s/it]


Validation:
Counter({1: 1328, 0: 512})
ROC AUC metric: 0.5336093990755008
target metric: 0.7972837511861172
target metric fixed: 0.7917227332908171
loss: 0.8877716884016991, accuracy: 0.575
********************************************************************************
epoch 2/5


100%|██████████| 15/15 [00:29<00:00,  1.93s/it]


Train:
Counter({1: 1979, 0: 1813})
ROC AUC metric: 0.6357426916982678
target metric: 0.6121931855529332
loss: 0.7398751775423685, accuracy: 0.6062763713080169


100%|██████████| 8/8 [00:14<00:00,  1.76s/it]


Validation:
Counter({0: 1316, 1: 524})
ROC AUC metric: 0.5091140215716488
target metric: 0.7164140457291
target metric fixed: 0.7173846077920887
loss: 0.9394715800881386, accuracy: 0.45217391304347826
********************************************************************************
epoch 3/5


100%|██████████| 15/15 [00:29<00:00,  1.96s/it]


Train:
Counter({0: 1976, 1: 1816})
ROC AUC metric: 0.671838836813901
target metric: 0.5864967453529168
loss: 0.6868024309476216, accuracy: 0.6271097046413502


100%|██████████| 8/8 [00:14<00:00,  1.85s/it]


Validation:
Counter({0: 1115, 1: 725})
ROC AUC metric: 0.5298382126348228
target metric: 0.6826151088820909
target metric fixed: 0.6792535391844639
loss: 0.8491982817649841, accuracy: 0.4951086956521739
********************************************************************************
epoch 4/5


100%|██████████| 15/15 [00:29<00:00,  1.97s/it]


Train:
Counter({1: 1945, 0: 1847})
ROC AUC metric: 0.6768157257562001
target metric: 0.5838704276064419
loss: 0.6750486294428507, accuracy: 0.622626582278481


100%|██████████| 8/8 [00:14<00:00,  1.87s/it]


Validation:
Counter({0: 1179, 1: 661})
ROC AUC metric: 0.5219722650231126
target metric: 0.7098749348994322
target metric fixed: 0.6983379161358578
loss: 0.9020295664668083, accuracy: 0.46793478260869564
********************************************************************************
epoch 5/5


100%|██████████| 15/15 [00:28<00:00,  1.88s/it]


Train:
Counter({0: 1951, 1: 1841})
ROC AUC metric: 0.6806718341077819
target metric: 0.577610711823312
loss: 0.6727490107218425, accuracy: 0.6294831223628692


100%|██████████| 8/8 [00:14<00:00,  1.82s/it]


Validation:
Counter({0: 1047, 1: 793})
ROC AUC metric: 0.5423394966615306
target metric: 1.0548509883903388
target metric fixed: 0.6702415654999896
loss: 1.5380618274211884, accuracy: 0.4951086956521739
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CV with 10.3 as test
Train/Test sizes: 722/92
Train/Test label distribution:
{'CE': 0.6717451523545707, 'LAA': 0.32825484764542934}
{'CE': 0.6413043478260869, 'LAA': 0.358695652173913}
********************************************************************************
epoch 1/5


100%|██████████| 15/15 [00:28<00:00,  1.93s/it]


Train:
Counter({1: 1917, 0: 1875})
ROC AUC metric: 0.5737133694742651
target metric: 0.663642563370795
loss: 1.0654353698094685, accuracy: 0.5545886075949367


100%|██████████| 8/8 [00:14<00:00,  1.86s/it]


Validation:
Counter({0: 1736, 1: 104})
ROC AUC metric: 0.5138507960965588
target metric: 1.1283668045706674
target metric fixed: 1.127194355629936
loss: 1.8232195675373077, accuracy: 0.3717391304347826
********************************************************************************
epoch 2/5


100%|██████████| 15/15 [00:29<00:00,  1.96s/it]


Train:
Counter({0: 1966, 1: 1826})
ROC AUC metric: 0.6373138430450961
target metric: 0.608419160122428
loss: 0.7868762771288554, accuracy: 0.6028481012658228


100%|██████████| 8/8 [00:14<00:00,  1.84s/it]


Validation:
Counter({0: 984, 1: 856})
ROC AUC metric: 0.5381753980482794
target metric: 0.667172831966329
target metric fixed: 0.667521449780279
loss: 0.823977567255497, accuracy: 0.5152173913043478
********************************************************************************
epoch 3/5


100%|██████████| 15/15 [00:29<00:00,  1.96s/it]


Train:
Counter({1: 1920, 0: 1872})
ROC AUC metric: 0.692357550428172
target metric: 0.5744680798509438
loss: 0.6623789151509603, accuracy: 0.6413502109704642


100%|██████████| 8/8 [00:14<00:00,  1.86s/it]


Validation:
Counter({1: 1250, 0: 590})
ROC AUC metric: 0.5149820236260914
target metric: 0.7192768476395673
target metric fixed: 0.7178795864789184
loss: 0.7918144017457962, accuracy: 0.5532608695652174
********************************************************************************
epoch 4/5


100%|██████████| 15/15 [00:29<00:00,  1.96s/it]


Train:
Counter({0: 1912, 1: 1880})
ROC AUC metric: 0.6889550118837794
target metric: 0.5755325376642483
loss: 0.6591034889221191, accuracy: 0.6360759493670886


100%|██████████| 8/8 [00:14<00:00,  1.78s/it]


Validation:
Counter({1: 1045, 0: 795})
ROC AUC metric: 0.5572714432460195
target metric: 0.6540364675963591
target metric fixed: 0.65766523668334
loss: 0.7704062312841415, accuracy: 0.5635869565217392
********************************************************************************
epoch 5/5


100%|██████████| 15/15 [00:29<00:00,  1.96s/it]


Train:
Counter({1: 1923, 0: 1869})
ROC AUC metric: 0.6943490292688138
target metric: 0.5755947469129727
loss: 0.654215141137441, accuracy: 0.6384493670886076


100%|██████████| 8/8 [00:14<00:00,  1.86s/it]


Validation:
Counter({0: 1132, 1: 708})
ROC AUC metric: 0.5022720852593734
target metric: 1.0943772475416442
target metric fixed: 0.7097678246512951
loss: 1.7473774552345276, accuracy: 0.45652173913043476
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CV with 10.3 as test
Train/Test sizes: 722/92
Train/Test label distribution:
{'CE': 0.6717451523545707, 'LAA': 0.32825484764542934}
{'CE': 0.6413043478260869, 'LAA': 0.358695652173913}
********************************************************************************
epoch 1/5


100%|██████████| 15/15 [00:28<00:00,  1.93s/it]


Train:
Counter({1: 1938, 0: 1854})
ROC AUC metric: 0.5684499568267194
target metric: 0.6734447789757338
loss: 1.0827801942825317, accuracy: 0.5522151898734177


100%|██████████| 8/8 [00:14<00:00,  1.84s/it]


Validation:
Counter({0: 1560, 1: 280})
ROC AUC metric: 0.5242295839753467
target metric: 0.8434772249444329
target metric fixed: 0.8462037253021811
loss: 1.3192589730024338, accuracy: 0.43369565217391304
********************************************************************************
epoch 2/5


100%|██████████| 15/15 [00:29<00:00,  1.96s/it]


Train:
Counter({0: 1957, 1: 1835})
ROC AUC metric: 0.6325172136765833
target metric: 0.6031757357983187
loss: 0.7962716976801555, accuracy: 0.5925632911392406


100%|██████████| 8/8 [00:14<00:00,  1.77s/it]


Validation:
Counter({0: 1539, 1: 301})
ROC AUC metric: 0.5576823317925013
target metric: 0.7709624031396901
target metric fixed: 0.7744646078769178
loss: 1.121067002415657, accuracy: 0.4331521739130435
********************************************************************************
epoch 3/5


100%|██████████| 15/15 [00:29<00:00,  1.98s/it]


Train:
Counter({0: 2023, 1: 1769})
ROC AUC metric: 0.6714079385426124
target metric: 0.5881679043302461
loss: 0.6902357260386149, accuracy: 0.6284282700421941


100%|██████████| 8/8 [00:15<00:00,  1.89s/it]


Validation:
Counter({1: 1020, 0: 820})
ROC AUC metric: 0.5546199280945043
target metric: 0.665172244984841
target metric fixed: 0.6715213113613285
loss: 0.7618812918663025, accuracy: 0.558695652173913
********************************************************************************
epoch 4/5


100%|██████████| 15/15 [00:29<00:00,  1.94s/it]


Train:
Counter({1: 1992, 0: 1800})
ROC AUC metric: 0.6824730389538713
target metric: 0.5795638030271579
loss: 0.673924724260966, accuracy: 0.6292194092827004


100%|██████████| 8/8 [00:14<00:00,  1.79s/it]


Validation:
Counter({0: 1098, 1: 742})
ROC AUC metric: 0.5301335387776065
target metric: 0.7075860261112243
target metric fixed: 0.6838535605592679
loss: 0.9157798960804939, accuracy: 0.483695652173913
********************************************************************************
epoch 5/5


100%|██████████| 15/15 [00:29<00:00,  1.96s/it]


Train:
Counter({0: 1913, 1: 1879})
ROC AUC metric: 0.6777515177411026
target metric: 0.5842282103454078
loss: 0.6762037555376689, accuracy: 0.6284282700421941


100%|██████████| 8/8 [00:14<00:00,  1.76s/it]


Validation:
Counter({1: 1334, 0: 506})
ROC AUC metric: 0.5525757575757577
target metric: 0.6923016172990543
target metric fixed: 0.7001782898054173
loss: 0.7498548701405525, accuracy: 0.6054347826086957
Best validation metric: 0.65766523668334
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CV with 6.2.8.9 as test
Train/Test sizes: 715/99
Train/Test label distribution:
{'CE': 0.6517482517482518, 'LAA': 0.34825174825174826}
{'LAA': 0.21212121212121213, 'CE': 0.7878787878787878}
********************************************************************************
epoch 1/5


100%|██████████| 15/15 [00:28<00:00,  1.88s/it]


Train:
Counter({0: 1960, 1: 1680})
ROC AUC metric: 0.5356572273879966
target metric: 0.7270080517024415
loss: 1.4810306390126546, accuracy: 0.5203296703296704


100%|██████████| 8/8 [00:16<00:00,  2.03s/it]


Validation:
Counter({1: 1544, 0: 436})
ROC AUC metric: 0.5877258852258852
target metric: 0.7763516629091461
target metric fixed: 0.7772059518385557
loss: 0.672618456184864, accuracy: 0.7222222222222222
********************************************************************************
epoch 2/5


100%|██████████| 15/15 [00:27<00:00,  1.86s/it]


Train:
Counter({1: 1887, 0: 1753})
ROC AUC metric: 0.6313488709093105
target metric: 0.6117120761273944
loss: 0.8109824021657308, accuracy: 0.5947802197802198


100%|██████████| 8/8 [00:14<00:00,  1.81s/it]


Validation:
Counter({1: 1362, 0: 618})
ROC AUC metric: 0.534284188034188
target metric: 0.7401066871208112
target metric fixed: 0.7382393543983466
loss: 0.6872045621275902, accuracy: 0.6212121212121212
********************************************************************************
epoch 3/5


100%|██████████| 15/15 [00:27<00:00,  1.87s/it]


Train:
Counter({1: 1836, 0: 1804})
ROC AUC metric: 0.6561852433281006
target metric: 0.5983479099002602
loss: 0.6917358915011088, accuracy: 0.6148351648351649


100%|██████████| 8/8 [00:15<00:00,  1.97s/it]


Validation:
Counter({1: 1106, 0: 874})
ROC AUC metric: 0.5582539682539682
target metric: 0.6749350752620682
target metric fixed: 0.6795813190860658
loss: 0.736482210457325, accuracy: 0.5696969696969697
********************************************************************************
epoch 4/5


100%|██████████| 15/15 [00:27<00:00,  1.85s/it]


Train:
Counter({0: 1895, 1: 1745})
ROC AUC metric: 0.6704815239705351
target metric: 0.5945885757478087
loss: 0.6867664615313213, accuracy: 0.6222527472527473


100%|██████████| 8/8 [00:16<00:00,  2.02s/it]


Validation:
Counter({1: 1312, 0: 668})
ROC AUC metric: 0.5683745421245421
target metric: 0.6800604704302702
target metric fixed: 0.6848230912544417
loss: 0.661997601389885, accuracy: 0.6313131313131313
********************************************************************************
epoch 5/5


100%|██████████| 15/15 [00:29<00:00,  1.95s/it]


Train:
Counter({0: 1845, 1: 1795})
ROC AUC metric: 0.6699683009298394
target metric: 0.5923449886835395
loss: 0.6687577525774638, accuracy: 0.6244505494505495


100%|██████████| 8/8 [00:14<00:00,  1.85s/it]


Validation:
Counter({1: 1291, 0: 689})
ROC AUC metric: 0.5648656898656899
target metric: 0.6845542414361947
target metric fixed: 0.6833626990419155
loss: 0.667955256998539, accuracy: 0.6196969696969697
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CV with 6.2.8.9 as test
Train/Test sizes: 715/99
Train/Test label distribution:
{'CE': 0.6517482517482518, 'LAA': 0.34825174825174826}
{'LAA': 0.21212121212121213, 'CE': 0.7878787878787878}
********************************************************************************
epoch 1/5


100%|██████████| 15/15 [00:28<00:00,  1.89s/it]


Train:
Counter({1: 1875, 0: 1765})
ROC AUC metric: 0.5563432858350441
target metric: 0.6901448485561213
loss: 1.08365908463796, accuracy: 0.542032967032967


100%|██████████| 8/8 [00:15<00:00,  1.97s/it]


Validation:
Counter({0: 1537, 1: 443})
ROC AUC metric: 0.5609401709401709
target metric: 0.7749011628380417
target metric fixed: 0.7726718693665133
loss: 1.1942193061113358, accuracy: 0.347979797979798
********************************************************************************
epoch 2/5


100%|██████████| 15/15 [00:27<00:00,  1.86s/it]


Train:
Counter({0: 1863, 1: 1777})
ROC AUC metric: 0.6325105663567202
target metric: 0.6118202667071742
loss: 0.7282506863276164, accuracy: 0.592032967032967


100%|██████████| 8/8 [00:15<00:00,  1.90s/it]


Validation:
Counter({1: 1181, 0: 799})
ROC AUC metric: 0.530813492063492
target metric: 0.7053122553622763
target metric fixed: 0.7073096541763036
loss: 0.7494534254074097, accuracy: 0.5823232323232324
********************************************************************************
epoch 3/5


100%|██████████| 15/15 [00:28<00:00,  1.89s/it]


Train:
Counter({1: 1831, 0: 1809})
ROC AUC metric: 0.65460844100954
target metric: 0.6001658492511568
loss: 0.6997409224510193, accuracy: 0.6101648351648352


100%|██████████| 8/8 [00:14<00:00,  1.86s/it]


Validation:
Counter({0: 1334, 1: 646})
ROC AUC metric: 0.5253144078144079
target metric: 0.7150859226472785
target metric fixed: 0.7184439981710412
loss: 0.9939270839095116, accuracy: 0.4212121212121212
********************************************************************************
epoch 4/5


100%|██████████| 15/15 [00:28<00:00,  1.89s/it]


Train:
Counter({1: 1829, 0: 1811})
ROC AUC metric: 0.6244867769593043
target metric: 0.61640890903487
loss: 0.7727385640144349, accuracy: 0.5898351648351648


100%|██████████| 8/8 [00:16<00:00,  2.03s/it]


Validation:
Counter({0: 1059, 1: 921})
ROC AUC metric: 0.5800778388278389
target metric: 0.6682434441300469
target metric fixed: 0.6727813062549651
loss: 0.8348987624049187, accuracy: 0.5267676767676768
********************************************************************************
epoch 5/5


100%|██████████| 15/15 [00:28<00:00,  1.87s/it]


Train:
Counter({0: 1928, 1: 1712})
ROC AUC metric: 0.6463171718391498
target metric: 0.5983215914214202
loss: 0.7745607376098633, accuracy: 0.6104395604395605


100%|██████████| 8/8 [00:14<00:00,  1.87s/it]


Validation:
Counter({1: 1136, 0: 844})
ROC AUC metric: 0.5561584249084248
target metric: 0.6773474112085467
target metric fixed: 0.6814329756540647
loss: 0.7680813670158386, accuracy: 0.5757575757575758
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
CV with 6.2.8.9 as test
Train/Test sizes: 715/99
Train/Test label distribution:
{'CE': 0.6517482517482518, 'LAA': 0.34825174825174826}
{'LAA': 0.21212121212121213, 'CE': 0.7878787878787878}
********************************************************************************
epoch 1/5


100%|██████████| 15/15 [00:28<00:00,  1.87s/it]


Train:
Counter({0: 1883, 1: 1757})
ROC AUC metric: 0.5916552952541962
target metric: 0.6463978778941835
loss: 0.9541222929954529, accuracy: 0.5673076923076923


100%|██████████| 8/8 [00:15<00:00,  1.96s/it]


Validation:
Counter({0: 1268, 1: 712})
ROC AUC metric: 0.5603983516483516
target metric: 0.714473762245718
target metric fixed: 0.713402596938832
loss: 1.0844343453645706, accuracy: 0.4484848484848485
********************************************************************************
epoch 2/5


100%|██████████| 15/15 [00:27<00:00,  1.86s/it]


Train:
Counter({1: 1896, 0: 1744})
ROC AUC metric: 0.6419813428329912
target metric: 0.6067371490637774
loss: 0.7413584629694621, accuracy: 0.6076923076923076


100%|██████████| 8/8 [00:15<00:00,  1.98s/it]


Validation:
Counter({1: 1131, 0: 849})
ROC AUC metric: 0.6048687423687424
target metric: 0.6430706157944789
target metric fixed: 0.6384588541986868
loss: 0.736591286957264, accuracy: 0.5954545454545455
********************************************************************************
epoch 3/5


100%|██████████| 15/15 [00:28<00:00,  1.87s/it]


Train:
Counter({0: 1855, 1: 1785})
ROC AUC metric: 0.6614483456104335
target metric: 0.5903627963835403
loss: 0.7001996994018554, accuracy: 0.6145604395604396


100%|██████████| 8/8 [00:14<00:00,  1.86s/it]


Validation:
Counter({0: 1147, 1: 833})
ROC AUC metric: 0.5459996947496948
target metric: 0.7079361131188633
target metric fixed: 0.7068193713963697
loss: 0.9407919272780418, accuracy: 0.4772727272727273
********************************************************************************
epoch 4/5


100%|██████████| 15/15 [00:27<00:00,  1.85s/it]


Train:
Counter({1: 1828, 0: 1812})
ROC AUC metric: 0.6644209636517329
target metric: 0.5934062566306606
loss: 0.7060884277025858, accuracy: 0.6175824175824176


100%|██████████| 8/8 [00:15<00:00,  1.95s/it]


Validation:
Counter({0: 1238, 1: 742})
ROC AUC metric: 0.5118406593406594
target metric: 0.7343959576593952
target metric fixed: 0.7381227353187751
loss: 0.9951399564743042, accuracy: 0.44343434343434346
********************************************************************************
epoch 5/5


100%|██████████| 15/15 [00:27<00:00,  1.86s/it]


Train:
Counter({0: 1821, 1: 1819})
ROC AUC metric: 0.6591725033208549
target metric: 0.5841782871692399
loss: 0.7223800778388977, accuracy: 0.6063186813186813


100%|██████████| 8/8 [00:15<00:00,  1.88s/it]


Validation:
Counter({1: 1261, 0: 719})
ROC AUC metric: 0.5915506715506715
target metric: 0.6558949888155032
target metric fixed: 0.6613030546518286
loss: 0.7192556038498878, accuracy: 0.6388888888888888
Best validation metric: 0.6384588541986868
[0.642027776187675, 0.6156763463174704, 0.5805691482270264, 0.6510807591842562, 0.65766523668334, 0.6384588541986868]
0.6309130201330758
Full validation metric: 0.6356833352527698


In [13]:
# [0.653830656159886, 0.6238455839150563, 0.605778830856576, 0.6647636130357559, 0.6544317172083116, 0.6532475492188892]
# 0.6426496583990792
# Full validation metric: 0.6421152903205289


# [0.6455449287371509, 0.6218588657257058, 0.5979860325098842, 0.6420611577083772, 0.675211714053326, 0.6503222829996482]
# 0.6388308302890154
# Full validation metric: 0.6404691832962497

#64 5e3
# [0.6476080393945753, 0.6079157320890402, 0.6050800143556536, 0.6378247315988896, 0.6661450085369098, 0.6504032066047876]
# 0.635829455429976
# Full validation metric: 0.6342736651767453

#64 1e3
# Best validation metric: 0.6557763951171189
# [0.651014691249416, 0.6347137845717545, 0.6147587091636446, 0.6665997338703076, 0.6587241317996946, 0.6557763951171189]
# 0.6469312409619894
# Full validation metric: 0.6467345027007692

In [14]:
import os


model_files = [file_name for file_name in list(os.listdir('/kaggle/working/models')) if file_name.endswith('.h5')]
print(model_files[:5])

['center_id_6.2.8.9_epoch_3_target_0.673.h5', 'center_id_7_epoch_0_target_0.638.h5', 'center_id_6.2.8.9_epoch_2_target_0.68.h5', 'center_id_4_epoch_4_target_0.616.h5', 'center_id_1.5_epoch_4_target_0.679.h5']


In [15]:
from collections import defaultdict


centers_to_models = defaultdict(list)
for model_file in model_files:
    center_id = model_file.split('_')[2]
    epoch = int(model_file.split('_')[4])
    metric = float(model_file.split('_')[6][:-3])
    
    centers_to_models[center_id].append((metric, epoch, model_file))
    
center_id_to_best_model_file_name = {
    center_id: sorted(model_files, key=lambda x: (x[0], x[1]))[0][2]
    for center_id, model_files in centers_to_models.items()
}
print(center_id_to_best_model_file_name)

{'6.2.8.9': 'center_id_6.2.8.9_epoch_1_target_0.638.h5', '7': 'center_id_7_epoch_3_target_0.581.h5', '4': 'center_id_4_epoch_4_target_0.616.h5', '1.5': 'center_id_1.5_epoch_4_target_0.651.h5', '10.3': 'center_id_10.3_epoch_3_target_0.658.h5', '11': 'center_id_11_epoch_2_target_0.642.h5'}


In [16]:
good_models = list(center_id_to_best_model_file_name.values())
print(good_models)

['center_id_6.2.8.9_epoch_1_target_0.638.h5', 'center_id_7_epoch_3_target_0.581.h5', 'center_id_4_epoch_4_target_0.616.h5', 'center_id_1.5_epoch_4_target_0.651.h5', 'center_id_10.3_epoch_3_target_0.658.h5', 'center_id_11_epoch_2_target_0.642.h5']


In [17]:
for model_file in model_files:
    if model_file not in good_models:
        os.remove(os.path.join('/kaggle/working/models', model_file))