Add project root to path

In [None]:
import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)


Load dataset

In [None]:
import torch
from utils import filename
from utils import filehandling

size = 'dev_s'
prob = 50

directory = filename.processed.dataset(size, prob)
dataset = filehandling.read_splitted_dataset(directory)

Split to train & test

In [None]:
import numpy as np
from utils.data import splitting

random_state = np.random.RandomState(5)
train, test = splitting.train_val_split(dataset, .8, random_state=random_state, train_filter=None)
print(len(train), len(test))

Load model

In [None]:
%%capture
from pipeline.convert2Dto3D import Conv3dConverter
import segmentation_models_pytorch as smp

model = smp.Unet(encoder_name='resnet18', in_channels=1, classes=1, encoder_weights='swsl')
# Convert pretrained 2D model to 3D
Conv3dConverter(model, -1, (32, 1, 32, 32, 32))

Convert pretrained 2D model to 3D

In [None]:
import numpy as np

seg = train.get_attribute('segmentmap')
threshold = 0.5

In [None]:
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

model_id = filename.models.new_id()
checkpoint_callback = ModelCheckpoint(monitor='sofia_dice', save_top_k=1, dirpath=filename.models.directory,
                                      filename=str(model_id) + '-{epoch:02d}-{sofia_dice:.2f}', mode='max', period=10)

In [None]:
from astropy.io import fits

from utils.data.generating import get_hi_shape
from utils import filename

hi_shape = get_hi_shape(filename.data.sky(size))
header = fits.getheader(filename.data.sky(size))

Create Lightning objects

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from datetime import datetime
from pytorch_toolbelt import losses
from pipeline.segmenter import BaseSegmenter
from training.train_segmenter import TrainSegmenter, get_random_vis_id

min_vis_voxels = 500
vis_id = get_random_vis_id(test, hi_shape, min_vis_voxels, random_state=np.random.RandomState(10))

loss = losses.JointLoss(losses.DiceLoss(mode='binary', log_loss=True), losses.SoftBCEWithLogitsLoss(), 1.0, 1.0)

version = datetime.now().strftime("%Y/%m/%d %H:%M:%S")
logger = TensorBoardLogger("tb_logs", name="resnet", version=version)
base_segmenter = BaseSegmenter(model, train.get_attribute('scale'), train.get_attribute('mean'),
                               train.get_attribute('std'))
segmenter = TrainSegmenter(base_segmenter, loss, train, test, header, vis_id=vis_id, threshold=threshold, lr=1e-2,
                           batch_size=128)

trainer = pl.Trainer(max_epochs=100000, gpus=0, logger=logger, callbacks=[checkpoint_callback],
                     check_val_every_n_epoch=10)

Train!

In [None]:
trainer.fit(segmenter)