## Train a sparse semantic unet on RGB image

In this notebook we work out the details for training a sparse semeantic unet to predict pixel classes for RGB images. 

Sparse means we don't have to label every pixel in an ROI.  We label each class (including a background class labeled with 1).  We then ignore the unlabeled pixels. 

Imports:  Note on Linux for some reason we have to import and show Napari before importing PyTorch. 

In [None]:
import napari
viewer = napari.Viewer()


from tnia.deeplearning.dl_helper import quantile_normalization
import numpy as np
import torch

from tnia.deeplearning.dl_helper import collect_training_data, divide_training_data
from tnia.plotting.plt_helper import imshow_multi2d, random_label_cmap
from semantic_dataset import SemanticDataset
import random
from unet import *
from torch.utils.tensorboard import SummaryWriter
import datetime
from torch.utils.data import DataLoader
from semantic_helper import train3
from monai.networks.nets import BasicUNet
from pathlib import Path
import os

## Check if Cuda is present

If cuda is not present training will be slow... 

In [None]:
cuda_present = torch.cuda.is_available()
ndevices = torch.cuda.device_count()
use_cuda = cuda_present and ndevices > 0
device = torch.device("cuda" if use_cuda else "cpu")  # "cuda:0" ... default device, "cuda:1" would be GPU index 1, "cuda:2" etc
print("number of devices:", ndevices, "\tchosen device:", device, "\tuse_cuda=", use_cuda)

## Set Parent Path

This is the path that contains the images we will work with and pre-existing patches that would have been created in notebook ```33_label_semantic_sparse_rgb```

In [None]:

tnia_images_path = Path(r"D:\images")
parent_path = r'C:\Users\bnort\Documents\...'
parent_path = r'/home/bnorthan/bekonbits/images/Columbia_Semantic_Sparse/'

train_path = os.path.join(parent_path, 'patches')

image_patch_path = train_path + '/ground truth0'
label_patch_path = train_path + '/input0'

model_path = os.path.join(parent_path,'models')

if not os.path.exists(model_path):
    os.makedirs(model_path)

if not os.path.exists(image_patch_path):
    print('image_patch_path does not exist')

if not os.path.exists(label_patch_path):
    print('label_patch_path does not exist')

## Collect training data

Collect the training data that would have been created in ```33_label_semantic_sparse_rgb.ipynb```

In [None]:
X, Y = collect_training_data(train_path, sub_sample=1, downsample=False, normalize_input=False, add_trivial_channel=False, relabel=False)

print('Number of input images', len(X))
print('Number of ground truth images ', len(Y))

print('Size of first input image', X[0].shape)
print('Size of first ground truth image ', Y[0].shape)

In [None]:
X_train, Y_train, X_val, Y_val = divide_training_data(X, Y, 2, to_numpy=False)

print('Number of training images', len(X_train))
print('Number of validaiton images ', len(X_val))

In [None]:
X_train = np.array(X_train)
Y_train = np.array(Y_train)

X_val = np.array(X_val)
Y_val = np.array(Y_val)

X_test = X_val
Y_test = Y_val

X_train.shape, Y_train.shape

In [None]:
Y_train = Y_train.astype(np.int16)-1
Y_val = Y_val.astype(np.int16)-1
Y_test = Y_test.astype(np.int16)-1

In [None]:
Y_train.min(), Y_train.max()

## Preview Training Data

Just make sure it looks right and labels correspond to objects properly

In [None]:
n=10
X_ = X_train[n]
Y_ = Y_train[n]
print(Y_.dtype)

print(X_.shape, Y_.shape)
print(X_.min(), X_.max())
print(X_.dtype, X_.shape, X_.min(), X_.max())
fig=imshow_multi2d([X_, Y_], ['input', 'label'], 1,2)

## Create Datasets

In [None]:
train_dataset = SemanticDataset(X_train, Y_train, crop_size=256)
test_dataset = SemanticDataset(X_val, Y_val, crop_size=256)
val_dataset = SemanticDataset(X_val, Y_val, crop_size=256)

# verify that the dataset is working
raw, mask = train_dataset[random.randrange(len(train_dataset))]
raw = np.transpose(raw, (1,2,0))
raw.shape,mask.shape

## Visualize a few datasets

In [None]:
raw, mask = train_dataset[random.randrange(len(train_dataset))]
fig = imshow_multi2d([np.transpose(raw, (1,2,0)), mask[0]], ['Image', 'Labels'], 1, 2, 10, 10,colormaps=['gray', random_label_cmap()])
print(mask.min(), mask.max())

raw, mask = train_dataset[random.randrange(len(train_dataset))]
fig = imshow_multi2d([np.transpose(raw, (1,2,0)), mask[0]], ['Image', 'Labels'], 1, 2, 10, 10,colormaps=['gray', random_label_cmap()])
print(mask.min(), mask.max())

raw, mask = train_dataset[1500]
fig = imshow_multi2d([np.transpose(raw, (1,2,0)), mask[0]], ['Image', 'Labels'], 1, 2, 10, 10,colormaps=['gray', random_label_cmap()])
print(mask.min(), mask.max())

## Set up unet

We use monai BasicUnet.

Since image is RGB in_channels are 3

Since the dataset we are working with has 3 classes, out_channels are 3

No activation function since we are using CrossEntropyLoss which applies softmax
Note: predictor will need to use a softmax activation function



In [None]:
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1)
dtype = torch.LongTensor

net = BasicUNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=3,
    #features=[16, 16, 32, 64, 128, 16],
    act=None,
    #norm="batch",
    #norm=None,
    #dropout=0.25,
)

net = torch.load( Path(model_path) / 'model_Jan27_batchfix3.pth')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = net.to(device)
loss_fn = loss_fn.to(device)

logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
writer = SummaryWriter(logdir)

## Start training process

In [None]:
train_batch_size =32 
test_batch_size = 1

learning_rate = 5e-5
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

# make dataloaders
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size)
val_loader = DataLoader(val_dataset, batch_size=test_batch_size)

training_steps = 5000

train3(train_loader, val_loader, net, loss_fn, optimizer, dtype, 150, device)
#test_data_loader(train_loader, val_loader, net2, loss_fn, optimizer, dtype, 1, 100, device, writer)
#train(train_loader, val_loader, net2, loss_fn, None, optimizer, dtype, 10, device, writer)
#

## save unet

In [None]:

torch.save(net, Path(model_path) / 'model_Jan27_batchfix4.pth')

## Test predictions

In [None]:
for i in range(20):
    #features, label = train_loader.dataset[random.randrange(len(train_loader.dataset))]
    features, label = train_loader.dataset[i]
    
    net.eval()
    features_tensor = torch.from_numpy(features).unsqueeze(0).to(device)       
    #features = features.todevice()
    predicted = net(features_tensor)

    print(predicted.shape, features.shape)

    features = np.transpose(features, (1,2,0))

    predicted.shape
    c1 = predicted[0,0,:,:].cpu().detach().numpy()
    c2 = predicted[0,1,:,:].cpu().detach().numpy()
    c3 = predicted[0,2,:,:].cpu().detach().numpy()
    fig = imshow_multi2d([features, c1, c2, c3], ['Image', 'Class 1', 'Class 2', 'Class 3'], 1, 4, 10, 10,colormaps=['gray', 'viridis', 'viridis', 'viridis'])

    features, label = train_loader.dataset[i]
    net.train()
    features_tensor = torch.from_numpy(features).unsqueeze(0).to(device)       
    #features = features.todevice()
    predicted = net(features_tensor)

    print(predicted.shape, features.shape)

    features = np.transpose(features, (1,2,0))

    predicted.shape
    c1 = predicted[0,0,:,:].cpu().detach().numpy()
    c2 = predicted[0,1,:,:].cpu().detach().numpy()
    c3 = predicted[0,2,:,:].cpu().detach().numpy()
    fig = imshow_multi2d([features, c1, c2, c3], ['Image', 'Class 1', 'Class 2', 'Class 3'], 1, 4, 10, 10,colormaps=['gray', 'viridis', 'viridis', 'viridis'])