<a href="https://www.kaggle.com/code/aaalexlit/inference-with-yolo-trained-on-small-dataset?scriptVersionId=159015478" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
!pip install --no-index -f /kaggle/input/download-ultralytics /kaggle/input/download-ultralytics/ultralytics-8.1.0-py3-none-any.whl

In [None]:
from ultralytics import YOLO
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from ultralytics.engine.results import Results

In [None]:
!mkdir trained_model
!cp /kaggle/input/prepare-small-dataset-for-yolo/hacking_human_vasculature_small/small_set/weights/* /kaggle/working/trained_model

In [None]:
model = YOLO('/kaggle/working/trained_model/best.pt')


In [None]:
def add_masks(masks):
    result = 255*(np.sum(masks, axis=0))
    result = result.clip(0, 255).astype("uint8")
    return result

In [None]:
def rle_encode(mask):
    pixel = mask.flatten()
    pixel = np.concatenate([[0], pixel, [0]])
    run = np.where(pixel[1:] != pixel[:-1])[0] + 1
    run[1::2] -= run[::2]
    rle = ' '.join(str(r) for r in run)
    if rle == '':
        rle = '1 0'
    return rle

In [None]:
def extract_id_from_result(result: Results):
    dataset_name = result.path.split('/')[-3]
    file_name = result.path.split('/')[-1].split('.')[0]
    return f'{dataset_name}_{file_name}'

In [None]:
def get_rle_from_result(result: Results):
    if not result.masks:
        return '1 0'
    else:
        masks_array = result.masks.data.cpu().numpy()
        combined_mask = add_masks(masks_array)
        return rle_encode(combined_mask)

In [None]:
source = '/kaggle/input/blood-vessel-segmentation/test/**/*.tif'
results = model.predict(source, stream=True, device=[0,1], retina_masks=True, conf=0.5)

In [None]:
submission_list = []
for result in results:
    img_id = extract_id_from_result(result)
    rle = get_rle_from_result(result)
    submission_list.append({
        'id': img_id,
        'rle': rle
    })

df = pd.DataFrame(submission_list, columns=['id', 'rle'])

df.to_csv('submission.csv', index=False)
