Add project root to path

In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

Load dataset

In [2]:
import torch
from utils import filename
from utils import filehandling

size = 'dev_l'
cube_side = 32
transform = 'minmax'
prob = 50

directory = filename.processed.dataset(size, cube_side, transform, prob)
dataset = filehandling.read_splitted_dataset(directory)

  0%|          | 0/167 [00:00<?, ?it/s]

Split to train & test

In [3]:
from utils.data import splitting

train, test = splitting.train_val_split(dataset, .8)

Load pretrained 2D model

In [4]:
import segmentation_models_pytorch as smp

model = smp.Unet(encoder_name='resnet101', in_channels=1, classes=1, encoder_weights='imagenet')

Convert pretrained 2D model to 3D

In [5]:
import torch
from models.convert2Dto3D import Conv3dConverter

Conv3dConverter(model, -1, torch.ones(1, 1, 32, 32, 32))

Conv3dConverter(
Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2), padding=(3, 3, 3), bias=False)
    (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
        (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
        (bn3): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace

In [6]:
import numpy as np
seg = train.get_attribute('segmentmap')
threshold = 0.5#sum(map(torch.sum, seg)) / sum(map(lambda t: torch.prod(torch.tensor(t.shape)),seg))

In [7]:
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

model_id = filename.models.new_id()
checkpoint_callback = ModelCheckpoint(monitor='Jaccard',save_top_k=1, dirpath=filename.models.directory,
                                      filename=str(model_id) +'-{epoch:02d}-{val_loss:.2f}', mode='max')

In [8]:
from utils.data.generating import get_hi_shape
hi_shape = get_hi_shape(filename.data.sky(size))

Create Lightning objects

In [9]:
import pytorch_lightning as pl
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.loggers import TensorBoardLogger
from datetime import datetime
from pytorch_toolbelt import losses

from models.segmentation import Segmenter, get_vis_id

min_vis_voxels = 1000
vis_id = get_vis_id(test, hi_shape, min_vis_voxels)

loss = losses.JointLoss(losses.JaccardLoss(mode='binary', log_loss=True), losses.SoftBCEWithLogitsLoss())

version = datetime.now().strftime("%Y/%m/%d %H:%M:%S")
logger = TensorBoardLogger("tb_logs", name="segmenter", version=version)
segmenter = Segmenter(model, loss, train, test, logger, vis_id=vis_id, threshold=threshold)

trainer = pl.Trainer(max_epochs=500, gpus=1, logger=logger, callbacks=[checkpoint_callback])

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Train!

In [None]:
trainer.fit(segmenter)

Set SLURM handle signals.

  | Name     | Type      | Params
---------------------------------------
0 | loss_fct | JointLoss | 0     
1 | model    | Unet      | 112 M 
---------------------------------------
112 M     Trainable params
0         Non-trainable params
112 M     Total params


Validation sanity check: 0it [00:00, ?it/s]

  value = torch.tensor(value, device=device, dtype=torch.float)


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]