In [None]:
import glob
import imageio.v2 as imageio
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
import segmentation_models_pytorch.utils
import torch
from datasets.heart_dataset import HeartDataset, HeartDatasetType
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from utils import natural_key

In [None]:
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(10, 2))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image.squeeze(), cmap='gray')
    plt.show()

In [None]:
dataset = HeartDataset(use_augmentation=True)

image, mask = dataset[70]

visualize(
    image=image,
    mask=mask,
)

image.min(), image.max(), image.mean(), mask.min(), mask.max()

In [None]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

ENCODER = 'se_resnext50_32x4d'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['heart']
ACTIVATION = 'sigmoid'
DEVICE = 'cuda:0'

# create segmentation model with pretrained encoder
model = smp.FPN(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
    in_channels=1
)

#model = smp.Unet(
#            encoder_name=ENCODER, 
#            encoder_weights=ENCODER_WEIGHTS, 
#            classes=len(CLASSES), 
#            activation=ACTIVATION,
#            in_channels=1,
#            decoder_channels=(128, 64, 32, 16, 8)
#        )

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
train_dataset = HeartDataset(
    dataset_type=HeartDatasetType.TRAIN,
)

valid_dataset = HeartDataset(
    dataset_type=HeartDatasetType.VALIDATION,
)

test_dataset = HeartDataset(
    dataset_type=HeartDatasetType.TEST,
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=12)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)

In [None]:
#loss = smp.utils.losses.DiceLoss()
loss = smp.utils.losses.BCELoss()
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
    smp.utils.metrics.Fscore(),
    smp.utils.metrics.Accuracy(),
    smp.utils.metrics.Recall(),
    smp.utils.metrics.Precision()
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=1e-4),
])

In [None]:
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
max_score = 0

for i in range(0, 10):
    
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    
    # do something (save model, change lr, etc.)
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, './best_model_interactive.pth')
        print('Model saved!')
        
    if i == 25:
        optimizer.param_groups[0]['lr'] = 1e-5
        print('Decrease decoder learning rate to 1e-5!')

In [None]:
best_model = torch.load('./results/heart_segmentation_2/best_model_fpn.pth', map_location=DEVICE)

In [None]:
test_epoch = smp.utils.train.ValidEpoch(
    best_model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
    verbose=True,
)

test_epoch.run(test_loader)

In [None]:
image, mask = valid_dataset[60]

prediction = best_model.predict(image.to(DEVICE).unsqueeze(0))

visualize(
    image=image,
    mask=mask,
    predict=prediction.to('cpu')
)

In [None]:
from ct import Ct
#ct = Ct('/data/tavi/MOL005/ct/')
ct = Ct('/data/calcium_processed/CS_011/ct/', file_pattern='IM-0001-*.dcm')

In [None]:
ct_image = (ct.img[160,:,:] / 1000)

prediction = best_model.predict(ct_image.to(DEVICE).unsqueeze(0).unsqueeze(0)).to('cpu')

visualize(
    image=ct_image,
    overlap=ct_image * (prediction<0.1) + (prediction*2),
    #mix=ct_image + (prediction*2),
    predict=prediction
)

ct_image.min(), ct_image.max(), prediction.min(), prediction.max()

In [None]:
masked = torch.empty(ct.img.shape)

for i in range(0, ct.img.shape[0]):
    image = ct.img[i] / 1000
    prediction = best_model.predict(image.to(DEVICE).unsqueeze(0).unsqueeze(0)).to('cpu')
    masked[i] = image * (prediction<0.1) + (prediction*2)


In [None]:
%matplotlib widget
from viz import VolumePlot, SliceDirection
plotter = VolumePlot(masked, figsize=(9,9))
plotter.direction = SliceDirection.SAGITTAL
plotter.plot_interactive()