## Installing requirements

In [None]:
# Only when running in Colab
!pip install git+https://github.com/AdrianUrbanski/Cell_nuclei_segmentation.git

In [None]:
!pip install pytorch-lightning
!pip install pytorch-toolbelt
!pip install ternausnet
!pip install pretrained-backbones-unet
!pip install imagecodecs
!pip install wandb

# Imports

In [None]:
from __future__ import annotations

import pytorch_lightning as pl
import wandb
from google.colab import drive
from matplotlib import pyplot as plt
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader

from src.data import RandomPatchesDataset
from src.models import UNetLit

# Mounting Google Drive

In [None]:
drive.mount('/content/drive')

# Logging in to Wandb

In [None]:
wandb.login()

# Loading data

In [None]:
batch_size = 12
PATH = '/content/drive/MyDrive/Cell_segmentation'
train_dataset = RandomPatchesDataset(
    f'{PATH}/train/img',
    f'{PATH}/train/mask',
)
train_data_loader = DataLoader(train_dataset, batch_size = batch_size)

val_dataset = RandomPatchesDataset(
    f'{PATH}/val/img',
    f'{PATH}/val/mask',
)
val_data_loader = DataLoader(val_dataset, batch_size = batch_size)

test_dataset = RandomPatchesDataset(
    f'{PATH}/test/img',
    f'{PATH}/test/mask',
)
test_data_loader = DataLoader(test_dataset, batch_size = batch_size)

# Training

In [None]:
config = {
    "lr": 0.001,
    "eps": 1.0e-08,
    "step_size": 4,
    "gamma": 0.1
}
num_epochs = 100
checkpoints_dir_path = './models'
project = 'cell-nuclei-segmentation'
gpus = 0

model = UNetLit(config)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath=checkpoints_dir_path,
    filename='model-{epoch:02d}-{val_acc:.2f}',
    save_top_k=1,
    mode='max',
)
wandb_logger = WandbLogger(save_dir="logs/", project=project)
trainer = pl.Trainer(
    logger=wandb_logger,
    max_epochs=num_epochs,
    callbacks=[checkpoint_callback],
    )

trainer.fit(model, train_data_loader, val_data_loader)
wandb.finish() 

# Evaluation & analysing the predictions

In [None]:
MODEL_CHECKPOINT_PATH = "/home/maria/Downloads/model-epoch=81-val_acc=0.96.ckpt"
model = UNetLit.load_from_checkpoint(MODEL_CHECKPOINT_PATH)
for img, mask in test_data_loader:
    real_mask = mask[0].cuda()
    real_img = img[0].cuda()
    output = model(real_img.float().unsqueeze(dim=1))
    break

In [None]:
plt.imshow(real_img.cpu().squeeze(), cmap='gray')

In [None]:
plt.imshow(real_mask.cpu().squeeze(), cmap='gray')

In [None]:
plt.imshow((output.cpu() > 0.5).float().squeeze().detach(), cmap='gray')