In [None]:
from fastMONAI.vision_all import *

from monai.apps import DecathlonDataset
from sklearn.model_selection import train_test_split

In [None]:
path = Path('../data/processed')
path.mkdir(exist_ok=True)

In [None]:
path = Path('../data/processed/IXI_2d_slices/')
fnames = get_image_files(path / "Axial")
masks = get_image_files(path / "AxialMask")

In [None]:
training_data = []
i = 0
for x in fnames:
    training_data.append([x, masks[0]])
    i = i+1

In [None]:
df = pd.DataFrame(training_data)
df.shape

In [None]:
train_df, test_df = train_test_split(df, test_size=0.1, random_state=24)
train_df.shape, test_df.shape

In [None]:
med_dataset = MedDataset(img_list=masks, dtype=MedMask, max_workers=12)

In [None]:
med_dataset.df.head()

In [None]:
summary_df = med_dataset.summary()

In [None]:
summary_df.head()

In [None]:
resample, reorder = med_dataset.suggestion()
resample, reorder

In [None]:
img_size = med_dataset.get_largest_img_size(resample=resample)
img_size

In [None]:
bs=1
size=[256,256,256]

In [None]:
item_tfms = [ZNormalization(), PadOrCrop(size), RandomAffine(scales=0, degrees=5, isotropic=True)] 

In [None]:
dblock = MedDataBlock(blocks=(ImageBlock(cls=MedImage), MedMaskBlock), 
                      splitter=RandomSplitter(seed=24),
                      get_x=ColReader(0),
                      get_y=ColReader(1),
                      item_tfms=item_tfms,
                      reorder=reorder,
                      resample=resample) 

In [None]:
dls = dblock.dataloaders(train_df, bs=bs)

In [None]:
# training and validation
len(dls.train_ds.items), len(dls.valid_ds.items)

In [None]:
dls.show_batch(anatomical_plane=0) 

### Create and train a 3D model

In [None]:
from monai.losses import DiceCELoss
from monai.networks.nets import UNet

In [None]:
codes = np.loadtxt(path/'code-kopi_kuttet.txt', dtype=str)
n_classes = len(codes)
codes, n_classes

In [None]:
model = UNet(spatial_dims=3, in_channels=4, out_channels=n_classes, channels=(16, 32, 64, 128, 256),strides=(2, 2, 2, 2), num_res_units=2)
model = model

In [None]:
loss_func = CustomLoss(loss_func=DiceCELoss(to_onehot_y=True, include_background=True, softmax=True))

In [None]:
learn = Learner(dls, model, loss_func=loss_func, opt_func=ranger, metrics=multi_dice_score)#.to_fp16()

In [None]:
learn.lr_find()

In [None]:
lr = 1e-1

In [None]:
learn.fit_flat_cos(20 ,lr)

In [None]:
learn.save('fastMONAI-model')

In [None]:
learn.show_results(anatomical_plane=0, ds_idx=1)