In [1]:
!pip install timm --no-index --find-links=file:///kaggle/input/../input/sartoriussegmentationmydata/timm/timm

Looking in links: file:///kaggle/input/../input/sartoriussegmentationmydata/timm/timm
Processing /kaggle/input/sartoriussegmentationmydata/timm/timm/timm-0.4.12-py3-none-any.whl
Installing collected packages: timm
Successfully installed timm-0.4.12


In [2]:
from fastai.vision.all import *
import tqdm
import cv2

from timm import create_model
from fastai.vision.learner import _update_first_layer

import skimage.morphology

In [3]:
path = Path('../input/sartorius-cell-instance-segmentation')
train_path = path / 'train'
test_path = path / 'test'

In [4]:
def mask_label_func(fn):
    return mask_path / fn.name

In [5]:
learn = load_learner('../input/sartoriussegmentationmydata/mask_learner_final.pkl')

In [6]:
test_items = get_image_files(test_path)

In [7]:
## Stolen (& modified) from https://www.kaggle.com/ammarnassanalhajali/sartorius-segmentation-keras-u-net-inference
def post_process(mask,min_size=300):
    #mask = cv2.threshold(probability, threshold, 1, cv2.THRESH_BINARY)[1]
    num_component, component = cv2.connectedComponents(mask.astype(np.uint8))
    predictions = []
    for c in range(1, num_component):
        p = (component == c).astype(np.float32)
        if p.sum() > min_size:
            #a_prediction = np.zeros((520, 704), np.float32)
            #a_prediction[p] = 1
            #predictions.append(a_prediction)
            predictions.append(p)
    return predictions

In [8]:
# Stolen from: https://www.kaggle.com/arunamenon/cell-instance-segmentation-unet-eda
# Run-length encoding stolen from https://www.kaggle.com/rakhlin/fast-run-length-encoding-python
# Modified by someone
def rle_encoding(x):
    dots = np.where(x.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return ' '.join(map(str, run_lengths))

In [9]:
# Stolen from: https://www.kaggle.com/awsaf49/sartorius-fix-overlap
# Modified to sum over 1st axis
def check_overlap(msk):
    msk = msk.astype(np.bool).astype(np.uint8)
    return np.any(np.sum(msk, axis=-1)>1)

In [10]:
def my_post_process(mask, cutoff = 0.5, min_object_size = 100.):
    lab_mask = skimage.morphology.label(mask > cutoff)

    labels = set(lab_mask.flatten())
    labels.remove(0) #remove backround label 0
    predictions = []
    for l in labels:
        p = (lab_mask == l).astype(np.uint8)
        if p.sum() >= min_object_size:
            predictions.append(p)
    return predictions

In [11]:
submission = []
for fn in test_items:
    pred = learn.predict(fn)
    pred_mask = pred[0].numpy().astype(np.uint8)
    resized_mask = np.array(Image.fromarray(pred_mask).resize((704, 520), resample = Image.BILINEAR))#NEAREST))
    #instance_predictions = post_process(resized_mask)
    instance_predictions = my_post_process(resized_mask)
    ##& annoying
    if len(instance_predictions) == 0:
        submission.append((fn.stem, ''))
    else:
        stacked_predictions = np.stack(instance_predictions, axis = -1)
        if check_overlap(stacked_predictions):
            submission.append((fn.stem, ''))
        else:
    ##&
            for instance in listify(instance_predictions):
                submission.append((fn.stem, rle_encoding(instance)))

In [12]:
df = pd.DataFrame(submission, columns=['id', 'predicted'])
df.head()

Unnamed: 0,id,predicted
0,7ae19de7bc2a,49 2 60 14 753 4 762 16 1457 25 2160 27 2864 26 3569 25 4273 24 4976 25 5680 23 6384 21 7088 19 7791 18 8495 16 9198 15 9901 15 10605 13 11308 13 12011 13 12715 12 13418 12 14122 11 14825 11 15528 11 16232 11 16935 12 17639 12 18342 13 19045 15 19748 16 20452 16 21155 17 21858 17 22561 17 23265 15 23968 13 24673 8 25377 4
1,7ae19de7bc2a,137 20 842 19 1546 18 2251 16 2955 15 3660 9 4365 5 5070 3
2,7ae19de7bc2a,274 7 979 8 1684 10 2389 11 3094 13 3798 15 4502 16 5207 16 5911 17 6616 18 7320 19 8025 20 8730 19 9434 21 10139 20 10844 19 11549 18 12253 18 12958 17 13662 17 14367 15 15071 15 15776 14 16480 14 17185 13 17890 12 18594 12 19299 11 20004 10 20708 10 21413 8 22117 7 22822 5 23527 3
3,7ae19de7bc2a,446 7 1151 6 1854 8 2558 10 3263 10 3966 11 4670 12 5374 13 6078 13 6782 13 7486 13 8191 12 8895 12 9599 12 10303 12 11008 9 11713 5 12419 1
4,7ae19de7bc2a,490 17 515 18 1195 21 1219 18 1901 41 2606 19 2627 19 3310 19 3334 17 4014 19 4039 16 4717 20 4743 16 5421 20 5448 15 6125 19 6152 15 6829 19 6857 14 7534 19 7562 14 8239 18 8266 14 8943 18 8971 13 9648 17 9676 12 10353 16 10381 12 11057 16 11086 11 11761 16 11792 9 12466 14 12497 8 13176 7 13202 7 13881 4 13907 6 14612 5 15317 3


In [13]:
df.to_csv('submission.csv', index = False)