In [None]:
import pandas as pd
import os
from tqdm import tqdm
import numpy as np
from matplotlib import pyplot as plt
import cv2
import warnings
import glob
from torch.utils.data import Dataset, DataLoader

from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation

import gc
import shutil
import zipfile
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.cuda import amp
from torchinfo import summary
from torchmetrics import MeanMetric, MulticlassAccuracy, AveragePrecision

import albumentations as A
from albumentations.pytorch import ToTensorV2

from PIL import Image
from tqdm.notebook import tqdm
import time
import base64
import typing as t
import zlib

In [None]:

# Create a directory to store the large image tiles.
os.mkdir("big_ones")

# Load the tile metadata from the CSV file.
tile_meta_path = "/kaggle/input/hubmap-hacking-the-human-vasculature/tile_meta.csv"
tile_metadata = pd.read_csv(tile_meta_path)

# Group the metadata by the source whole slide image (wsi).
for wsi, group in tile_metadata.groupby(['source_wsi']):
    # Sort the tiles based on their coordinates.
    group = group.sort_values(by=['i', 'j'])
    group['i'] = group['i'] - min(group['i'])
    group['j'] = group['j'] - min(group['j'])
    max_x = group['i'].max()
    max_y = group['j'].max()

    # Create a large canvas to combine the tiles.
    big_one = np.zeros((max_y + 512, max_x + 512, 3))

    # Iterate through each tile and paste it onto the large canvas.
    for _, tile_row in enumerate(tqdm(group.iterrows())):
        tile = tile_row[1]
        path = '/kaggle/input/hubmap-hacking-the-human-vasculature/test/' + tile['id'] + '.tif'

        # Check if the tile image file exists.
        if os.path.isfile(path):
            img = cv2.imread(path)
            x, y = tile['i'], tile['j']
            big_one[y:y + 512, x:x + 512, :] = img

    # Save the combined large image if it contains non-zero data.
    if np.sum(big_one) > 0:
        large_image_path = os.path.join("big_ones", str(wsi) + '.jpg')
        cv2.imwrite(large_image_path, big_one)

# Load the tile metadata again for further processing.
meta = pd.read_csv(tile_meta_path)

In [None]:
def preprocess_tile_data(data):
    """
    Preprocesses the tile data to adjust coordinates and create new features.

    Args:
        data (pd.DataFrame): DataFrame containing tile metadata.

    Returns:
        pd.DataFrame: Preprocessed tile data.
    """
    # Adjusting coordinates to be relative to minimum values.
    data['i'] = data['i'] - data['i'].min()
    data['j'] = data['j'] - data['j'].min()

    # Calculate end coordinates and other features.
    data['i_end'] = data['i'].max()
    data['j_end'] = data['j'].max()
    data['big_i_start'] = data['i'] - 512
    data['big_j_start'] = data['j'] - 512
    data['i_start'] = 512
    data['j_start'] = 512
    data['big_i_end'] = data['i'] + 1024
    data['big_j_end'] = data['j'] + 1024

    # Adjust coordinates for tiles touching image borders.
    data.loc[data['big_i_start'] < 0, 'big_i_end'] += 512
    data.loc[data['big_i_start'] < 0, 'i_start'] -= 512
    data.loc[data['big_i_start'] < 0, 'big_i_start'] += 512

    data.loc[data['big_j_start'] < 0, 'big_j_end'] += 512
    data.loc[data['big_j_start'] < 0, 'j_start'] -= 512
    data.loc[data['big_j_start'] < 0, 'big_j_start'] += 512

    data.loc[(data['big_i_end'] - data['i_end']) > 512, 'big_i_start'] -= 512
    data.loc[(data['big_i_end'] - data['i_end']) > 512, 'i_start'] += 512
    data.loc[(data['big_i_end'] - data['i_end']) > 512, 'big_i_end'] -= 512

    data.loc[(data['big_j_end'] - data['j_end']) > 512, 'big_j_start'] -= 512
    data.loc[(data['big_j_end'] - data['j_end']) > 512, 'j_start'] += 512
    data.loc[(data['big_j_end'] - data['j_end']) > 512, 'big_j_end'] -= 512

    return data

# Load the tile metadata from the CSV file.
tile_meta_path = "/kaggle/input/hubmap-hacking-the-human-vasculature/tile_meta.csv"
tile_metadata = pd.read_csv(tile_meta_path)

# Apply preprocessing function to each group of tile data.
df = tile_metadata.groupby('source_wsi').apply(preprocess_tile_data)

selected_columns = [
    'source_wsi', 'id', 'big_i_start', 'big_j_start', 'i_start', 'j_start',
    'big_i_end', 'big_j_end', 'i', 'j', 'i_end', 'j_end'
]

renamed_columns = [
    'source_wsi', 'id', 'big_image_i_start', 'big_image_j_start', 'image_i_start', 'image_j_start',
    'big_image_i_end', 'big_image_j_end', 'i', 'j', 'i_end', 'j_end'
]

# Reorder the DataFrame columns based on the selected columns and renamed columns.
df = df[selected_columns]
df.columns = renamed_columns


In [None]:

class DroneDataset(Dataset):
    """
    Custom dataset class for processing drone images.
    """

    def __init__(self, df, processor, preds):
        """
        Initialize the dataset.

        Args:
            df (pd.DataFrame): Dataframe containing information about images.
            processor (SegformerImageProcessor): Processor for image preprocessing.
            preds (list): List of file paths to predictions.
        """
        self.df = df
        self.path = preds
        self.processor = processor
        self.big_one = {}
        for path in glob.glob("/kaggle/working/big_ones/*"):
            ids = path.split("/")[-1].split(".")[0]
            self.big_one[ids] = cv2.imread("big_ones/" + str(ids) + ".jpg")

    def __len__(self):
        """
        Get the length of the dataset.

        Returns:
            int: Length of the dataset.
        """
        return len(self.path)

    def __getitem__(self, index):
        """
        Get an item from the dataset.

        Args:
            index (int): Index of the item.

        Returns:
            tuple: Tuple containing image ID, preprocessed image tensor, and coordinates.
        """
        path = self.path[index]
        ids = path.split("/")[-1].split(".")[0]
        info = self.df.loc[self.df['id'] == ids]

        x1, y1 = 0, 0
        if info.shape[0] > 0:
            val = info[["source_wsi", "big_image_i_start", "big_image_j_start", "big_image_i_end", "big_image_j_end", "image_i_start","image_j_start", 'i','j','i_end','j_end']].values
            wsi, big_image_y_start, big_image_x_start, big_image_y_end, big_image_x_end, start_y, start_x,_,_,_,_ = val[0]
            big_one = self.big_one[str(wsi)]
            big_one = big_one[big_image_x_start : big_image_x_end, big_image_y_start : big_image_y_end, : ]
            image = big_one
            x1, y1 = start_x, start_y
        else:
            image = cv2.imread(path)

        encoded_inputs = self.processor.preprocess(
                images=image,
                return_tensors="pt", # Return pytorch tensor.
                do_rescale=True,
                rescale_factor=1.0 / 255,
                do_normalize=True,
                image_mean=(0.485, 0.456, 0.406),
                image_std=(0.229, 0.224, 0.225),
                resample=2,
            )
        image = encoded_inputs["pixel_values"].squeeze_()
        return ids, image, x1, y1

# Instantiate the SegformerImageProcessor.
processor = SegformerImageProcessor(
    do_resize=False,
    do_rescale=False,
    do_normalize=False,
)

# Create the DroneDataset using the provided dataframe and file paths.
dataset = DroneDataset(df, processor, glob.glob("/kaggle/input/hubmap-hacking-the-human-vasculature/test/*"))

# Get the keys of the 'big_one' dictionary in the dataset.
keys = dataset.big_one.keys()
print(keys)

In [None]:

def get_default_device():
    """
    Get the default device for PyTorch.

    Returns:
        torch.device: Default device (cuda if available, else cpu).
        bool: Flag indicating if GPU is available.
    """
    gpu_available = torch.cuda.is_available()
    return torch.device('cuda' if gpu_available else 'cpu'), gpu_available

def seed_everything(seed_value):
    """
    Seed various random number generators for reproducibility.

    Args:
        seed_value (int): Seed value.
    """
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Define data configurations

@dataclass
class DatasetConfig:
    """
    Data configuration class.
    """
    NUM_CLASSES: int = 4
    IMG_WIDTH: int = 512
    IMG_HEIGHT: int = 512
    DATA_TRAIN_IMAGES: list = [
        "/kaggle/input/final-dataset/big_ones/1.jpg",
        "/kaggle/input/final-dataset/big_ones/2.jpg",
        "/kaggle/input/final-dataset/big_ones/3.jpg"
    ]
    DATA_VALID_IMAGES: list = [
        '/kaggle/input/final-dataset/big_ones/4.jpg'
    ]
    MEAN: tuple = (0.485, 0.456, 0.406)
    STD: tuple = (0.229, 0.224, 0.225)
    BACKGROUND_CLS_ID: int = 0

# Define training configuration

@dataclass
class TrainingConfig:
    """
    Training configuration class.
    """
    BATCH_SIZE: int = 1
    NUM_EPOCHS: int = 1
    LEARNING_RATE: float = 1e-4
    NUM_WORKERS: int = 1
    WEIGHT_DECAY: float = 1e-4

# Define inference configuration

@dataclass
class InferenceConfig:
    """
    Inference configuration class.
    """
    BATCH_SIZE: int = 1
    NUM_BATCHES: int = 1

# Define model configuration

@dataclass
class ModelConfig:
    """
    Model configuration class.
    """
    MODEL_NAME: str = "nvidia/segformer-b3-finetuned-ade-512-512"

# Define custom functions for image and mask processing

def get_bounding_box(ground_truth_map):
    """
    Get the bounding box coordinates from a mask.

    Args:
        ground_truth_map (np.ndarray): Ground truth mask.

    Returns:
        list: Bounding box coordinates [x_min, y_min, x_max, y_max].
    """
    y_indices, x_indices = np.where(ground_truth_map > 0)
    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)
    return [x_min, y_min, x_max, y_max]

def rle_decode(mask_rle, shape, color=1):
    """
    Decode a run-length encoded mask.

    Args:
        mask_rle (str): Run-length encoded mask.
        shape (tuple): Shape of the output mask.
        color (int, optional): Color for the mask. Defaults to 1.

    Returns:
        np.ndarray: Decoded mask.
    """
    s = mask_rle.split()
    starts = list(map(lambda x: int(x) - 1, s[0::2]))
    lengths = list(map(int, s[1::2]))
    ends = [x + y for x, y in zip(starts, lengths)]
    if len(shape) == 3:
        img = np.zeros((shape[0] * shape[1], shape[2]), dtype=np.float32)
    else:
        img = np.zeros(shape[0] * shape[1], dtype=np.float32)
    for start, end in zip(starts, ends):
        img[start : end] = color
    return img.reshape(shape)

def get_box(a_mask):
    """
    Get the bounding box of a given mask.

    Args:
        a_mask (np.ndarray): Mask.

    Returns:
        list: Bounding box coordinates [x_min, y_min, x_max, y_max].
    """
    pos = np.where(a_mask)
    xmin = np.min(pos[1])
    xmax = np.max(pos[1])
    ymin = np.min(pos[0])
    ymax = np.max(pos[0])
    return [xmin, ymin, xmax, ymax]


# Create a mapping of class ID to RGB value.
id2color = {0:(0,0,0),1:(255,0,0),2:(0,255,0),3:(0,0,255)}

# del id2color[23] # To remove the 'conflicting' class.

DatasetConfig.NUM_CLASSES = len(id2color)
# Reverse id2color mapping.
# Used for converting RGB mask to a single channel (grayscale) representation.
rev_id2color = {value: key for key, value in id2color.items()}
def rgb_to_grayscale(rgb_arr, color_map=rev_id2color, background_cls_id=0):

    # Collapse H, W dimensions.
    reshaped_rgb_arr = rgb_arr.reshape((-1, 3))

    # Get an array of all unique pixels along with the "inverse" array
    # (of the same shape as the original array) filled with indices to the unique array.
    # Each value in the "inverse" array points to the unique pixel at that
    # location in the input array.
    unique_pixels, inverse = np.unique(reshaped_rgb_arr, axis=0, return_inverse=True)

    # If a unique pixel is not found in the color_map, class ID of background pixel is used.
    grayscale_map = np.array([color_map.get(tuple(pixel), background_cls_id) for pixel in unique_pixels])[inverse]

    return grayscale_map.reshape(rgb_arr.shape[:2])
def num_to_rgb(num_arr, color_map=id2color):
    single_layer = np.squeeze(num_arr)
    output = np.zeros(num_arr.shape[:2] + (3,))

    for k in color_map.keys():
        output[single_layer == k] = color_map[k]

    return np.float32(output) / 255.0 # return a floating point array in range [0.0, 1.0]
def image_overlay(image, segmented_image):

    alpha = 1.0 # Transparency for the original image.
    beta  = 0.7 # Transparency for the segmentation map.
    gamma = 0.0 # Scalar added to each sum.

    segmented_image = cv2.cvtColor(segmented_image, cv2.COLOR_RGB2BGR)

    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

    image = cv2.addWeighted(image, alpha, segmented_image, beta, gamma, image)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    return np.clip(image, 0.0, 1.0)
def display_image_and_mask(*, images, masks, color_mask=False, color_map=id2color):
    title = ["GT Image", "GT Mask", "Color Mask", "Overlayed Mask"]

    for idx in range(images.shape[0]):
        image = images[idx]
        grayscale_gt_mask = masks[idx]

        fig = plt.figure(figsize=(15, 4))

        # Create RGB segmentation map from grayscale segmentation map.
        rgb_gt_mask = num_to_rgb(grayscale_gt_mask, color_map=color_map)

        # Create the overlayed image.
        overlayed_image = image_overlay(image, rgb_gt_mask)

        plt.subplot(1, 4, 1)
        plt.title(title[0])
        plt.imshow(image)
        plt.axis("off")

        plt.subplot(1, 4, 2)
        plt.title(title[1])
        plt.imshow(grayscale_gt_mask, cmap="gray")
        plt.axis("off")

        plt.subplot(1, 4, 3)
        plt.title(title[2])
        plt.imshow(rgb_gt_mask)
        plt.axis("off")

        plt.imshow(rgb_gt_mask)
        plt.subplot(1, 4, 4)
        plt.title(title[3])
        plt.imshow(overlayed_image)
        plt.axis("off")

        plt.show()

    return
def get_model():
    model = SegformerForSemanticSegmentation.from_pretrained(
        "/kaggle/input/refbsrb/abc",
        num_labels=4,
        ignore_mismatched_sizes=True,
    )
    return model
model = get_model()
model.load_state_dict(torch.load("/kaggle/input/weights-1560/1/best_cross_entropy.pth",map_location=torch.device('cpu')))


In [None]:

# Define utility functions

def save_model(name, mod):
    """
    Save the model's state dictionary to a file.

    Args:
        name (str): File name.
        mod (nn.Module): PyTorch model.
    """
    torch.save(mod.state_dict(), name)

def mapping(pred):
    """
    Map the predicted labels to RGB colors.

    Args:
        pred (np.ndarray): Predicted labels.

    Returns:
        np.ndarray: RGB image.
    """
    dk = {0: [0, 0, 0], 1: [255, 0, 0], 2: [0, 255, 0], 3: [0, 0, 255]}
    img = np.zeros((512, 512, 3))
    if len(pred.shape) == 3:
        pred = np.argmax(pred, 0)
    for i in range(512):
        for j in range(512):
            img[i, j, :] = dk[pred[i, j]]
    return img

def evaluate(model, loader, device, num_classes, epoch_idx, criteria, total_epochs, validation=True):
    """
    Evaluate the model on the validation/test dataset.

    Args:
        model (nn.Module): PyTorch model.
        loader (DataLoader): Data loader.
        device (torch.device): Device (cuda or cpu).
        num_classes (int): Number of classes.
        epoch_idx (int): Current epoch index.
        criteria: Loss function.
        total_epochs (int): Total number of epochs.
        validation (bool, optional): Flag for validation set. Defaults to True.

    Returns:
        dict: Model outputs.
    """
    # Change model mode.
    model.eval()

    loss_record = MeanMetric()
    metric_record = MeanMetric()
    acc_record = MulticlassAccuracy(num_classes=num_classes, average="micro")

    loader_len = len(loader)

    with tqdm(total=loader_len, ncols=122, ascii=True) as tq:
        tq.set_description(f"{'Valid' if validation else 'Test'} :: Epoch: {epoch_idx}/{total_epochs}")

        for en, (data, target) in enumerate(loader):
            tq.update(1)

            # Send data and target to GPU device if available.
            data, target = data.to(device), target.to(device)
            with torch.no_grad():
                # Perform Forward pass through the model. Output is a dictionary.
                outputs = model(pixel_values=data, return_dict=True)

            logits = outputs['logits']
            upsampled_logits = nn.functional.interpolate(logits, size=target.shape[-2:], mode="bilinear", align_corners=False)

            rgb_pred_mask = mapping(upsampled_logits.squeeze(0).detach().cpu().numpy())
            image = data.cpu().numpy()

            image = image[0].transpose(1, 2, 0)

            fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(20, 10))
            axes[0, 0].imshow(image)
            axes[0, 1].imshow(rgb_pred_mask)
            axes[1, 0].imshow(0.8 * image + 0.2 * rgb_pred_mask)
            axes[1, 1].imshow(mapping(target.squeeze(0).cpu().numpy()))
            fig.tight_layout()
            plt.savefig("/kaggle/working/" + str(en) + ".jpg")
            plt.show()

    return outputs

def main(*, model, optimizer, ckpt_path, configs=None, pin_memory=False, device="cpu"):

    # Create Dataloader.
    train_loader, _, valid_loader, _ = get_dataloader(configs=configs, pin_memory=pin_memory, num_workers=TrainingConfig.NUM_WORKERS)
    plotting={
                "loss": [],
                "val_loss": [],
                "accuracy": [],
                "val_accuracy": [],
                "IoU": [],
                "val_IoU": []
    }
    # Intialize learning rate scheduler.
    if configs["LR_SCHEDULER"]:
        milestones = [configs["NUM_EPOCHS"] // 2,]  # Decrease LR by 0.1 after 50% of traiining.
        configs["SCHLR_MILESTONES"] = milestones

        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=configs["SCHLR_MILESTONES"], gamma=0.1)


    # Creates a GradScaler once at the beginning of training
    # for Automatic Mixed-Precision training.
    scaler = amp.GradScaler()
    FL = FocalLoss()

    # Save the model if validation loss improves.
    best_valid_loss = float("inf")

    # Plot training and validation epoch logs.
    v_l = 100
    v_iou = 0
    v_a = 0
    # Training Loop.
    for epoch in range(configs["NUM_EPOCHS"]):
        # Memory Cleanup.
        gc.collect()

        output = evaluate(
            model=model,
            loader=valid_loader,
            device=device,
            criteria = FL,
            num_classes=configs["NUM_CLASSES"],
            epoch_idx=epoch + 1,
            total_epochs=configs["NUM_EPOCHS"],
        )

    return output
# For deterministic training
seed_everything(seed_value=41)

# Set default device to GPU if available.
DEVICE, GPU_AVAILABLE = get_default_device()
# Create a model.
model = get_model()

# Send model to the device (GPU/CPU)
model.to(DEVICE);
LR = TrainingConfig.LEARNING_RATE
WD = TrainingConfig.WEIGHT_DECAY
optimizer = getattr(torch.optim, "AdamW")(model.parameters(), lr=LR, weight_decay=WD, amsgrad=True)

HPARAMS={}
HPARAMS['IMG_SIZE']      = (DatasetConfig.IMG_HEIGHT, DatasetConfig.IMG_WIDTH)
HPARAMS['MODEL_NAME']    = ModelConfig.MODEL_NAME
HPARAMS['BATCH_SIZE']    = TrainingConfig.BATCH_SIZE
HPARAMS['NUM_EPOCHS']    = TrainingConfig.NUM_EPOCHS

HPARAMS['OPTIMIZER']     = "AdamW"
HPARAMS['LEARNING_RATE'] = TrainingConfig.LEARNING_RATE
HPARAMS['WEIGHT_DECAY']  = TrainingConfig.WEIGHT_DECAY
HPARAMS['LR_SCHEDULER']  = "MultiStepLR"
HPARAMS['NUM_CLASSES'] = 4
train_loader=dataset
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, ignore_index=None, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ignore_index = ignore_index
        self.reduction = reduction

    def forward(self, inputs, targets):
        # Compute the softmax over the inputs
        inputs_soft = F.softmax(inputs, dim=1)

        # Create the labels for the softmax using one-hot encoding
        target_one_hot = F.one_hot(targets, num_classes=inputs.shape[1]).permute(0, 3, 1, 2).float()

        # Compute the focal loss
        focal = -self.alpha * ((1 - inputs_soft)**self.gamma) * target_one_hot * torch.log(inputs_soft.clamp(min=1e-8))

        # Mask the pixels to ignore
        if self.ignore_index is not None:
            mask = targets != self.ignore_index
            focal = focal * mask.unsqueeze(1).float()

        # Reduce the loss
        if self.reduction == 'mean':
            focal = focal.mean()
        elif self.reduction == 'sum':
            focal = focal.sum()
        else:
            raise ValueError(f"Invalid reduction mode: {self.reduction}")

        return focal
model.load_state_dict(torch.load("/kaggle/input/weights-1560/3/best_miou.pth",map_location=torch.device('cpu')))

class CFG:
    data_path = '/kaggle/input/hubmap-hacking-the-human-vasculature/'
    batch_size = 1
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    th = 0.15
    chepoint_dir = '/kaggle/input/hubmap-checpoint/'
    model_types = ['UnetPlusPlus', 'UnetPlusPlus', 'UnetPlusPlus']
    encoder_name_list = ['se_resnext50_32x4d', 'se_resnext101_32x4d', 'vgg19_bn']
    is_tta = False
    size = 512
    org_size = 512
    encoder_depth = 4
    decoder_channels = [512, 256, 128, 64]

    test_aug = [
        A.Resize(size, size),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ]
model.to(CFG.device)
model=model.float()


In [None]:
%%capture
!mkdir /kaggle/working/packages
!cp -r /kaggle/input/pycocotools/* /kaggle/working/packages
os.chdir("/kaggle/working/packages/pycocotools-2.0.6/")
!python setup.py install
!pip install . --no-index --find-links /kaggle/working/packages/
os.chdir("/kaggle/working")

In [None]:
from pycocotools import _mask as coco_mask

def TTA(x: torch.Tensor, model: nn.Module):
    """
    Apply Test Time Augmentation (TTA) to input tensor x using the given model.

    Args:
        x (torch.Tensor): Input tensor.
        model (nn.Module): Model used for predictions.

    Returns:
        torch.Tensor: TTA-enhanced predictions.
    """
    shape = x.shape
    x = [model(torch.rot90(x, k=i, dims=(-2, -1)))['logits'] for i in range(4)]
    x = [torch.rot90(x[i], k=-i, dims=(-2, -1)) for i in range(4)]
    x = torch.stack(x, dim=0)
    return torch.max(x, 0).values

class Test:
    """
    Class for testing and evaluating a model on a dataset.
    """

    def encode_binary_mask(self, mask: np.ndarray) -> t.Text:
        """
        Convert a binary mask into OID challenge encoding ascii text.

        Args:
            mask (np.ndarray): Binary mask to encode.

        Returns:
            t.Text: Encoded binary mask as ascii text.
        """
        # check input mask --
        if mask.dtype != np.bool:
            raise ValueError(
                "encode_binary_mask expects a binary mask, received dtype == %s" %
                mask.dtype)

        mask = np.squeeze(mask)
        if len(mask.shape) != 2:
            raise ValueError(
                "encode_binary_mask expects a 2d mask, received shape == %s" %
                mask.shape)

        # convert input mask to expected COCO API input --
        mask_to_encode = mask.reshape(mask.shape[0], mask.shape[1], 1)
        mask_to_encode = mask_to_encode.astype(np.uint8)
        mask_to_encode = np.asfortranarray(mask_to_encode)

        # RLE encode mask --
        encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"]

        # compress and base64 encoding --
        binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION)
        base64_str = base64.b64encode(binary_str)
        return base64_str

    def encode_output(self, outputs, idx):
        """
        Encode the outputs into challenge-specific format.

        Args:
            outputs: Model predictions.
            idx: Index of the prediction.

        Returns:
            dict: Encoded outputs in challenge format.
        """
        blood_vessel = torch.argmax(outputs, 0)
        blood_vessel = blood_vessel == 2
        blood_vessel = blood_vessel * 1

        blood_vessel = blood_vessel.cpu().numpy()
        all_encode = {}
        for i in range(blood_vessel.shape[0]):
            list_encode = []
            sliceImage = blood_vessel[i,:,:]
            binarized = sliceImage > 0
            coded_len = self.encode_binary_mask(binarized)
            list_encode.append(coded_len)
            all_encode[idx[i]] = list_encode
        return all_encode

    def get_test_transforms(self):
        """
        Get the test data augmentation transforms.

        Returns:
            A.Compose: Augmentation transforms.
        """
        return A.Compose(CFG.test_aug)

    def test_dataloader(self, image_folder):
        """
        Get the dataloaders for testing.

        Args:
            image_folder: Folder containing test images.

        Returns:
            DataLoader: Test dataloader.
        """
        tl, td, vl, vd = get_dataloader(shuffle_validation=True)
        return vl, vd

    def evaluate(self, model, test_dataloader, weights):
        """
        Evaluate the model on the test dataset.

        Args:
            model: Model to evaluate.
            test_dataloader: Dataloader for the test dataset.
            weights: List of model weights for TTA.

        Returns:
            Tuple: Evaluation results.
        """
        ids = []
        heights = []
        widths = []
        prediction_strings = []
        sample = None
        with torch.no_grad():
            bar = tqdm(enumerate(test_dataloader), total=len(test_dataloader))

            for step, (idn, images, x, y) in bar:
                images = images.to(CFG.device)
                images = torch.unsqueeze(images, 0)
                ls = []
                for weight in weights:
                    model.load_state_dict(torch.load(weight))
                    pred = TTA(images, model)
                    ls.append(pred)
                pred = torch.max(torch.stack(ls, dim=0), 0).values
                _, _, h, w = images.shape
                pred = F.interpolate(pred, size=[h, w], mode='bilinear', align_corners=False)
                pred = pred[:, :, x:x+512, y:y+512]
                pred_scored = torch.softmax(pred, 1)
                if sample is None:
                    sample = pred
                pred_string = ''
                pred = (pred_scored > 0.4).float().cpu().numpy()
                pred_scored = pred_scored.cpu().numpy()
                for m in range(len(pred)):
                    kernel = np.ones(shape=(3, 3), dtype=np.uint8)
                    x = ((pred[m][2] > 0.6) * 255.0)
                    binary_mask = x
                    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask.astype(np.uint8))
                    for i in range(1, num_labels):
                        mask_i = np.zeros_like(binary_mask)
                        mask_i[labels == i] = 1
                        mask = mask_i[:, :, np.newaxis].astype(np.bool)
                        x_min, y_min, x_max, y_max = get_bounding_box(mask[:,:,0])
                        if (x_min < 10) or (x_max > 500) or (y_min < 10) or (y_max > 500):
                            mask_pred = mask_i * pred_scored[0, 2, :, :]
                            score = np.sum(mask_pred) / np.sum(mask_i)
                            encoded = self.encode_binary_mask(cv2.dilate(mask * 255.0, kernel, 3) > 0)
                            if i == 0:
                                pred_string += f"0 {score} {encoded.decode('utf-8')}"
                            else:
                                pred_string += f" 0 {score} {encoded.decode('utf-8')}"
                b, c, h, w = images.shape
                ids.append(idn)
                heights.append(h)
                widths.append(w)
                prediction_strings.append(pred_string)
        return ids, heights, widths, prediction_strings, sample, binary_mask

# Create an instance of the Test class
test = Test()

# Call the evaluate method to perform inference and obtain results
ids, heights, widths, prediction_strings, sample, binary_mask = test.evaluate(
    model.float(),
    dataset,
    weights=["/kaggle/input/weights-1560/1/best_cross_entropy.pth",
             "/kaggle/input/weights-1560/2/best_cross_entropy.pth"] + glob.glob("/kaggle/input/weights-1560/*/
             ))