# Vesuvius Challenge - Ink Detection: Submission

In [65]:
import wandb
from glob import glob
from os import sep
from os.path import join, abspath
import os

import cv2
import numpy as np
import pandas as pd
import torch

MODELS_DIR = join(os.pardir, 'models')
KAGGLE_INPUT_DIR = join(abspath(sep), 'kaggle', 'input')
TEST_FRAGMENTS_PATH = join(KAGGLE_INPUT_DIR, 'vesuvius-challenge-ink-detection', 'test')

## Vesuvius WandB

In [23]:
df = pd.read_csv(join(os.pardir, 'data', 'raw', 'wandb', 'wandb_export.csv'))
df.set_index('ID', inplace=True)

## Utils

In [66]:
def reconstruct_output(tiles, bboxes, fragment_id, fragment_shape, tile_size):
    reconstructed_output = torch.zeros(fragment_shape).to(device=tiles.device)
    count_map = torch.zeros(fragment_shape).to(device=tiles.device)

    for i in range(tiles.shape[0]):
        x0, y0, x1, y1 = bboxes[i]
        reconstructed_output[y0:y1, x0:x1] += tiles[i, :, :]
        count_map[y0:y1, x0:x1] += 1

    reconstructed_output /= count_map
    reconstructed_output = torch.nan_to_num(reconstructed_output, nan=0)

    mask_path = os.path.join(TEST_FRAGMENTS_PATH, fragment_id, 'mask.png')
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    padding = get_padding(mask.shape, tile_size)

    shape = reconstructed_output.shape
    x0, y0, x1, y1 = padding[1][0], padding[0][0], shape[1] - padding[1][1], shape[0] - padding[0][1]
    reconstructed_output = reconstructed_output[y0:y1, x0:x1]

    return reconstructed_output


def get_fragment_shape(fragment_dir, fragment_id, tile_size):
    mask_path = os.path.join(fragment_dir, fragment_id, 'mask.png')
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    padding = get_padding(mask.shape, tile_size)
    mask_pad = np.pad(mask, padding)

    return mask_pad.shape


def get_padding(mask_shape, tile_size, overlap=0.5):
    pad_left = int(overlap * tile_size)
    pad_up = int(overlap * tile_size)
    pad_right = int(overlap * tile_size + tile_size - mask_shape[1] % tile_size)
    pad_down = int(overlap * tile_size + tile_size - mask_shape[0] % tile_size)
    padding = [(pad_up, pad_down), (pad_left, pad_right)]

    return padding


def get_device():
    device = 'cpu'
    if torch.cuda.is_available():
        device = 'cuda'
    elif torch.backends.mps.is_available():
        device = 'mps'

    return torch.device(device=device)

## Vesuvius Lightning

In [54]:
import cv2

import torch.nn as nn
from torch.optim import AdamW
import pytorch_lightning as pl

from src.models.losses import BCEDiceWithLogitsLoss
from src.models.metrics import F05Score
from src.models.unet3d import Unet3d
from src.models.efficienunetv2 import EfficientUNetV2_L, EfficientUNetV2_M, EfficientUNetV2_S

import segmentation_models_pytorch as smp


class LightningVesuvius(pl.LightningModule):
    def __init__(self, model_name, model_params, learning_rate, bce_weight, dice_threshold):
        super().__init__()

        # Model
        if model_name == 'UNet3D':
            self.model = Unet3d(**model_params)
        elif model_name == 'EfficientUNetV2_L':
            self.model = EfficientUNetV2_L(**model_params)
        elif model_name == 'EfficientUNetV2_M':
            self.model = EfficientUNetV2_M(**model_params)
        elif model_name == 'EfficientUNetV2_S':
            self.model = EfficientUNetV2_S(**model_params)
        elif model_name == 'efficientnet-b5':
            self.model = smp.Unet(**model_params)

        self.learning_rate = learning_rate
        self.criterion = BCEDiceWithLogitsLoss(bce_weight=bce_weight, dice_threshold=dice_threshold)
        self.metric = F05Score()
        self.sigmoid = nn.Sigmoid()

    def forward(self, inputs):
        x = self.model(inputs)

        return x

    def training_step(self, batch, batch_idx):
        _, _, masks, images = batch
        outputs = self.forward(images)
        loss = self.criterion(outputs, masks)
        self.log('train/loss', loss, on_step=True, on_epoch=True)

        return loss

    def validation_step(self, batch, batch_idx):
        _, _, masks, images = batch
        outputs = self.forward(images)
        # outputs = torch.squeeze(outputs, dim=1)
        loss = self.criterion(outputs, masks)
        self.log('val/loss', loss, on_epoch=True)
        outputs = self.sigmoid(outputs)
        self.metric.update(masks, outputs)

        return loss

    def on_validation_epoch_end(self) -> None:
        sub_f05_threshold, sub_f05_score = self.metric.compute()

        metrics = {
            # 'val/f05_threshold': f05_threshold,
            # 'val/f05_score': f05_score,
            'val/sub_f05_threshold': sub_f05_threshold,
            'val/sub_f05_score': sub_f05_score
        }

        self.log_dict(metrics, on_epoch=True)
        # self.logger.log_image(key="val/inklabels_prediction", images=[reconstructed_pred])

        self.metric.reset()

        return metrics

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.learning_rate)

        return optimizer

## Vesuvius Dataset

In [77]:
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import transforms as T

class DatasetVesuvius(Dataset):
    def __init__(self, fragments, tile_size, num_slices, slices_list, start_slice, reverse_slices, selection_thr, augmentation, device, overlap):
        self.fragments = fragments
        self.tile_size = tile_size
        self.num_slices = num_slices
        self.slices_list = slices_list
        self.start_slice = start_slice
        self.reverse_slices = reverse_slices
        self.selection_thr = selection_thr
        self.augmentation = augmentation
        self.device = device

        self.overlap = overlap
        self.set_path = TEST_FRAGMENTS_PATH
        self.slices = self.make_slices()
        self.data, self.items = self.make_data()

        self.transforms = T.RandomApply(
            nn.ModuleList([
                T.RandomRotation(180),
                T.RandomPerspective(),
                T.ElasticTransform(alpha=500.0, sigma=10.0),
                T.RandomHorizontalFlip(),
                T.RandomVerticalFlip()
            ]), p=0.5
        )

    def make_slices(self):
        total_slices = 65
        slices = [i for i in range(total_slices)]

        if self.slices_list:
            slices = self.slices_list
        else:
            slices = sorted(slices[self.start_slice:self.start_slice+self.num_slices], reverse=self.reverse_slices)

        return slices

    def make_mask(self, fragment_path):
        mask_path = os.path.join(fragment_path, 'mask.png')
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        shape = (self.num_slices, mask.shape[0], mask.shape[1])
        padding = get_padding(mask.shape, self.tile_size)
        mask_pad = np.pad(mask, padding)

        return mask_pad, shape, padding

    def make_image(self, fragment_path, shape, padding):
        image = np.zeros(shape=shape, dtype=np.uint8)
        slices_files = sorted(glob(os.path.join(fragment_path, 'surface_volume/*.tif')))
        slices_path = [slices_files[i] for i in self.slices]

        print(f'\nMake image from {fragment_path}')
        for i, slice_path in tqdm(enumerate(slices_path), total=len(slices_path)):
            image[i, ...] = cv2.imread(slice_path, cv2.IMREAD_GRAYSCALE)

        padding.insert(0, (0, 0))
        image_pad = np.pad(image, padding)

        return image_pad

    def create_items(self, fragment, mask_pad):
        items = []
        overlap_size = int(self.overlap * self.tile_size)
        x_list = np.arange(0, mask_pad.shape[1] - overlap_size, overlap_size).tolist()
        y_list = np.arange(0, mask_pad.shape[0] - overlap_size, overlap_size).tolist()

        for x in x_list:
            for y in y_list:
                bbox = torch.IntTensor([x, y, x + self.tile_size, y + self.tile_size])
                x0, y0, x1, y1 = bbox
                tile = mask_pad[y0:y1, x0:x1]

                if tile.sum() / (255 * self.tile_size ** 2) >= self.selection_thr:
                    items.append({'fragment': fragment, 'bbox': bbox})

        return items

    def make_data(self):
        data = {}
        items = []

        for fragment in self.fragments:
            fragment_path = os.path.join(self.set_path, str(fragment))
            mask_pad, shape, padding = self.make_mask(fragment_path)
            image_pad = self.make_image(fragment_path, shape, padding)
            items += self.create_items(fragment, mask_pad)

            data[fragment] = {
                'image': torch.from_numpy(image_pad).to(self.device)
            }

        return data, items

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        fragment, bbox = self.items[idx]['fragment'], self.items[idx]['bbox']
        x0, y0, x1, y1 = bbox
        image = self.data[fragment]['image'][:, y0:y1, x0:x1] / 255.0

        return fragment, bbox, image


## Vesuvius Prediction

Calcul weight for each kfold models

In [49]:
df_weight = df.groupby('start_slice')['val/sub_f05_score'].mean()
df_weight /= df_weight.sum()
display(df_weight)
df_weight.sum()

start_slice
0     0.153861
8     0.154643
16    0.172690
24    0.184504
32    0.180550
40    0.153751
Name: val/sub_f05_score, dtype: float64

1.0000000000000002

Get dataset parameters for a given cross validation

In [63]:
def get_dataset_parameters(start_slice, df):
    device = get_device()
    dataset_params = {
        "fragments": ['a', 'b'],
        "tile_size": df['tile_size'][0],
        "num_slices": df['num_slices'][0],
        "start_slice": start_slice,
        "reverse_slices": df['reverse_slices'][0],
        "selection_thr": 0,
        "augmentation": False,
        "device": device,
        "overlap": 0.5,
    }
    return dataset_params

get_dataset_parameters(16, df)

{'fragments': ['a', 'b'],
 'tile_size': 256,
 'num_slices': 16,
 'start_slice': 16,
 'reverse_slices': False,
 'selection_thr': 0,
 'augmentation': False,
 'device': device(type='mps'),
 'overlap': 0.5}

In [39]:
def find_checkpoint(id):
    return glob(join(os.pardir, MODELS_DIR, f'*{id}*.ckpt'))[0]

find_checkpoint('u21mnb2m')

'../models/vibrant-sweep-17-u21mnb2m-32-16-3.ckpt'

Get lightning model parameters

In [45]:
def get_lightning_parameters(id, df):
    lightning_params = {
        "model_name":  'efficientnet-b5',
        "model_params":  {
            "in_channels": df.loc[id, 'num_slices'],
            "encoder_weights": df.loc[id, 'encoder_weights'],
            "classes": 1
        },
        "learning_rate":  df.loc[id, 'learning_rate'],
        "bce_weight":  df.loc[id, 'bce_weight'],
        "dice_threshold":  df.loc[id, 'dice_threshold'],    
    }

    return lightning_params

get_lightning_parameters('u21mnb2m', df)

{'model_name': 'efficientnet-b5',
 'model_params': {'in_channels': 16,
  'encoder_weights': 'imagenet',
  'classes': 1},
 'learning_rate': 0.0001,
 'bce_weight': 0.5,
 'dice_threshold': 0.5}

In [None]:
def get_blank_masks(dataset: DatasetVesuvius):
    blank_masks = {fragment_id: {} for fragment_id in dataset.fragments}
    
    for fragment_id in dataset.fragments:
        inklabels_shape = get_fragment_shape(dataset.set_path, fragment_id, dataset.tile_size)
        blank_masks[fragment_id]['inklabels'] = torch.zeros(inklabels_shape).to(device=dataset.device)
        blank_masks[fragment_id]['count_map'] = torch.zeros(inklabels_shape).to(device=dataset.device)
    
    return blank_masks

In [None]:
def add_inklabels(masks, fragment_id, bbox, inklabels):
    x0, y0, x1, y1 = bbox
    masks[fragment_id]['inklabels'][y0:y1, x0:x1] += inklabels
    masks[fragment_id]['count_map'][y0:y1, x0:x1] += 1
    return masks

In [None]:
def make_sub_prediction(dataset: DatasetVesuvius, model, weight):
    masks = get_blank_masks(dataset.fragments)
    sigmoid = torch.nn.Sigmoid()
    model.eval()
    
    for fragment_id, bbox, image in dataset:
        sub_inklabels = model(image)
        sub_inklabels = sigmoid(sub_inklabels)
        masks = add_inklabels(masks, fragment_id, bbox, sub_inklabels)
    
    inklabels = {}
    for fragment_id in dataset.fragments:
        inklabels[fragment_id] = masks[fragment_id]['inklabels'] / masks[fragment_id]['count_map']
        inklabels[fragment_id] = torch.nan_to_num(inklabels[fragment_id], nan=0)
        inklabels[fragment_id] /= weight
    
    return inklabels

In [None]:
def get_blank_inklabels(fragments, fragment_path):
    blank_inklabels = {}
    
    for fragment_id in fragments:
        blank_inklabels[fragment_id] = cv2.imread(join(os.pardir, fragment_path, fragment_id, 'mask.png'))
    
    return blank_inklabels

masks = get_blank_inklabels(['a', 'b'], TEST_FRAGMENTS_PATH)
masks['']

In [46]:
model = None
model_temoins = None

inklabels_shape = {get_fragment_shape(TEST_FRAGMENTS_PATH, fragment_id, df['tile_size'][0])}
inklabels = {fragment_id: torch.zeros(inklabels_shape).to(device=dataset.device)}

for start_slice in df_weight.items():
    dataset_params = get_dataset_parameters(start_slice, df)
    dataset = DatasetVesuvius(**dataset_params)
    
    for run_id in df.index.to_numpy():
        torch.cuda.empty_cache()
        
        lightning_params = get_lightning_parameters(run_id, df)
        model_ligthning = LightningVesuvius(**lightning_params)
        
        model_path = find_checkpoint(run_id)
        model_ligthning.load_from_checkpoint(
            model_path, 
            map_location='cpu',
            **lightning_params
        )
        model_pytorch = model_ligthning.model
        
        sub_inklabels = make_sub_prediction(dataset, model_pytorch)
        
        for fragment_id in dataset.fragments:
            inklabels[fragment_id] += sub_inklabels[fragment_id]
        break
    break