# Installing requirements

In [None]:
# Only when running in Colab
!pip install git+https://github.com/AdrianUrbanski/Cell_nuclei_segmentation.git

In [None]:
!pip install pytorch-lightning
!pip install pytorch-toolbelt
!pip install imagecodecs
!pip install stardist
!pip install wandb

# Imports

In [None]:
from __future__ import annotations

from os import listdir

import imageio.v2 as imageio
import matplotlib.pyplot as plt
import wandb
from csbdeep.utils import normalize
from google.colab import drive
from stardist.models import StarDist2D
from stardist.plot import render_label

# Mounting Google Drive

In [None]:
drive.mount('/content/drive')

# Logging in to Wandb

In [None]:
wandb.login()

# Loading data

In [None]:
PATH = '/content/drive/MyDrive/Cell_segmentation'
files_names =  listdir(f'{PATH}/train/img')
train_imgs = [imageio.imread(f'{PATH}/train/img/{f}') for f in files_names]
train_masks = [imageio.imread(f'{PATH}/train/mask/{f}') for f in files_names]

train_imgs = [normalize(img) if len(img.shape) == 2 else normalize(img[:, :, 0]) for img in train_imgs]
train_masks = [img.astype(int) if len(img.shape) == 2 else img[:, :, 0].astype(int) for img in train_masks]

files_names =  listdir(f'{PATH}/val/img')
val_imgs = [imageio.imread(f'{PATH}/val/img/{f}') for f in files_names]
val_masks = [imageio.imread(f'{PATH}/val/mask/{f}') for f in files_names]

val_imgs = [normalize(img) if len(img.shape) == 2 else normalize(img[:, :, 0]) for img in val_imgs]
val_masks = [img.astype(int) if len(img.shape) == 2 else img[:, :, 0].astype(int) for img in val_masks]

# Loading pretrained model

In [None]:
model = StarDist2D.from_pretrained('2D_versatile_fluo')

In [None]:
example_pred = model.predict_instances(val_imgs[0])
plt.imshow(render_label(example_pred[0], val_imgs[0]))

# Finetuning pretrained model

In [None]:
model.train(train_imgs, train_masks, validation_data=(val_imgs, val_masks), augmenter=None, epochs=10, steps_per_epoch=30)

In [None]:
model.optimize_thresholds(val_imgs, val_masks)

In [None]:
example_pred = model.predict_instances(val_imgs[0])
plt.imshow(render_label(example_pred[0], val_imgs[0]))