Add project root to path

In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

Load dataset

In [None]:
import torch
from utils import filename
from utils import filehandling

size = 'dev_l'
cube_side = 32
transform = 'minmax'
prob = 50

directory = filename.processed.dataset(size, cube_side, transform, prob)
dataset = filehandling.read_splitted_dataset(directory)

Split to train & test

In [None]:
from utils.data import splitting

train, test = splitting.train_val_split(dataset, .8)
len(train), len(test)

Load pretrained 2D model

In [None]:
import segmentation_models_pytorch as smp

model = smp.Unet(encoder_name='resnet101', in_channels=1, classes=1, encoder_weights='imagenet')

Convert pretrained 2D model to 3D

In [None]:
%%capture
import torch
from training.convert2Dto3D import Conv3dConverter

Conv3dConverter(model, -1, torch.ones(1, 1, 32, 32, 32))

In [None]:
import numpy as np
seg = train.get_attribute('segmentmap')
threshold = 0.5#sum(map(torch.sum, seg)) / sum(map(lambda t: torch.prod(torch.tensor(t.shape)),seg))

In [None]:
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

model_id = filename.models.new_id()
checkpoint_callback = ModelCheckpoint(monitor='Jaccard',save_top_k=1, dirpath=filename.models.directory,
                                      filename=str(model_id) +'-{epoch:02d}-{Jaccard:.2f}', mode='max', period=3)

In [None]:
from utils.data.generating import get_hi_shape
from utils import filename
hi_shape = get_hi_shape(filename.data.sky(size))

Create Lightning objects

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.loggers import TensorBoardLogger
from datetime import datetime
from pytorch_toolbelt import losses

from training.segmentation import Segmenter, get_vis_id

min_vis_voxels = 300
vis_id = get_vis_id(test, hi_shape, min_vis_voxels)

loss = losses.JointLoss(losses.JaccardLoss(mode='binary', log_loss=True), losses.SoftBCEWithLogitsLoss())

version = datetime.now().strftime("%Y/%m/%d %H:%M:%S")
logger = TensorBoardLogger("tb_logs", name="segmenter", version=version)
segmenter = Segmenter(model, loss, train, test, vis_id=vis_id, threshold=threshold, lr=1e-2)

trainer = pl.Trainer(max_epochs=500, gpus=1, logger=logger, callbacks=[checkpoint_callback])

Train!

In [None]:
trainer.fit(segmenter)