Add project root to path

In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

Load raw data

In [None]:
from astropy.io import fits
import pandas as pd
from astropy.wcs import WCS
import dask.array as da

from utils.data.segmentmap import create_from_files
from utils import filename

name = 'dev_s'
segmentmap = create_from_files(name)
hi_data = da.from_array(fits.getdata(filename.data.sky(name)))
header = fits.getheader(filename.data.sky(name))
df = pd.read_csv(filename.data.true(name), sep=' ')

Create dataset objects

In [None]:
from utils.data.generating import create_data_set_dict

dataset = create_data_set_dict(df, hi_data, segmentmap, WCS(header), .5, side_length=32,
                               precuation=100, freq_band=32, spatial_points=1)

Create dataloader, split to train & test

In [None]:
from utils.data import splitting

train, test = splitting.splitted_loaders(dataset, .8)

Load pretrained 2D model

In [None]:
import segmentation_models_pytorch as smp

model = smp.Unet(encoder_name='resnet18', in_channels=1, classes=1)

Convert pretrained 2D model to 3D

In [None]:
import torch
from models.convert2Dto3D import Conv3dConverter

Conv3dConverter(model, -1, torch.ones(1, 1, 32, 32, 32))

Create Lightning objects

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_toolbelt.losses import SoftBCEWithLogitsLoss
from pytorch_lightning.loggers import TensorBoardLogger
from datetime import datetime

from models.segmentation import Segmenter

segmenter = Segmenter(model, SoftBCEWithLogitsLoss(), train, test)
version = datetime.now().strftime("%Y/%m/%d %H:%M:%S")
logger = TensorBoardLogger("tb_logs", name="segmenter", version=version)
trainer = pl.Trainer(max_epochs=100, gpus=1, logger=logger)

Train!

In [14]:
trainer.fit(segmenter)

1