In [2]:
## Load Data
import torch
from torch import nn
from torch.utils.data import DataLoader

In [3]:
import numpy as np

def load_dataset_divide(dataset_dir, rescaled_size, idx_splits, no_split=False):
    size_str = f"{rescaled_size[0]}{rescaled_size[1]}{rescaled_size[2]}"

    x_dataset_path = dataset_dir + "divided_volumes_" + size_str + ".npy"
    y_dataset_path = dataset_dir + "divided_points_" + size_str + ".npy"
    res_dataset_path = dataset_dir + "divided_res_" + size_str + ".npy"

    x_dataset = np.load(x_dataset_path).astype('float32')
    y_dataset = np.load(y_dataset_path).astype('float32')
    res_dataset = np.load(res_dataset_path).astype('float32')

    res_dataset_rep = np.repeat(res_dataset, 2, axis=1).reshape(2000, 1, 3)

    # without splitting to Train, Val and Test
    if no_split:
        return x_dataset, y_dataset, res_dataset_rep

    train_idx = idx_splits[0]
    x_train = x_dataset[train_idx]
    y_train = y_dataset[train_idx]
    res_train = res_dataset_rep[train_idx]

    val_idx = idx_splits[1]
    x_val = x_dataset[val_idx]
    y_val = y_dataset[val_idx]
    res_val = res_dataset_rep[val_idx]

    test_idx = idx_splits[2]
    x_test = x_dataset[test_idx]
    y_test = y_dataset[test_idx]
    res_test = res_dataset_rep[test_idx]

    return x_train, y_train, res_train, \
        x_val, y_val, res_val, \
        x_test, y_test, res_test

In [4]:
def get_data_splits(pat_splits, split=False, aug_num=50):

    if split:
        double_aug_num = aug_num * 2
        idx_splits = [[list(range(i * double_aug_num, i * double_aug_num + double_aug_num)) for i in j] for j in pat_splits]
        for i in range(0, 3):
            idx_splits[i] = [num for sublist in idx_splits[i] for num in sublist]
            idx_splits[i] = np.asarray(idx_splits[i])
    else:
        idx_splits = [[list(range(i * aug_num, i * aug_num + aug_num)) for i in j] for j in pat_splits]
        for i in range(0, 3):
            idx_splits[i] = [num for sublist in idx_splits[i] for num in sublist]
            idx_splits[i] = np.asarray(idx_splits[i])

    return idx_splits

In [5]:
pat_splits = [np.asarray([2, 4, 18, 17, 12, 10, 6, 0, 11, 16, 9, 14, 5, 19]), np.asarray([3, 13]), np.asarray([8, 7, 1, 15])]
dataset_dir = "/data/gpfs/projects/punim1836/Data/divided/17617648/"
rescaled_size = (176, 176, 48)
data_splits = get_data_splits(pat_splits, split=True, aug_num=50)

x_train, y_train, res_train, x_val, y_val, res_val, x_test, y_test, res_test = load_dataset_divide(dataset_dir, rescaled_size, data_splits)

# move channel forward
x_train = np.transpose(x_train, (0,4,1,2,3))
x_val = np.transpose(x_val, (0,4,1,2,3))
x_test = np.transpose(x_test, (0,4,1,2,3))

In [6]:
column_size, row_size, slice_size = 88, 176, 48

res_train = (res_train / [2/column_size, 2/row_size, 2/slice_size]).astype('float32')
res_val = (res_val / [2/column_size, 2/row_size, 2/slice_size]).astype('float32')
res_test = (res_test / [2/column_size, 2/row_size, 2/slice_size]).astype('float32')

In [7]:
# transfer the Y
image_size = [88, 176, 48]

y_train_t = (y_train * 2 + 1) / np.asarray(image_size) - 1
y_train_res = np.concatenate((y_train_t, res_train), axis=1)
y_val_t = (y_val * 2 + 1) / np.asarray(image_size) - 1
y_val_res = np.concatenate((y_val_t, res_val), axis=1)
y_test_t = (y_test * 2 + 1) / np.asarray(image_size) - 1
y_test_res = np.concatenate((y_test_t, res_test), axis=1)

In [8]:
from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
    def __init__(self, X, Y, transform=None, target_transform=None):
        self.img = X
        self.img_labels = Y
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        image = self.img[idx]
        label = self.img_labels[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [9]:
train_dataset = CustomImageDataset(x_train, y_train_res)
val_dataset = CustomImageDataset(x_val, y_val_res)
test_dataset = CustomImageDataset(x_test, y_test_res)

In [10]:
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=True)

In [11]:
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

In [12]:
## Prepare Model
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

In [13]:
class FCN(nn.Module):
    def __init__(self):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm3d(16),
            nn.Conv3d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm3d(16),
            nn.Conv3d(16, 16, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return self.layers(x)

In [14]:
kernel_size = 5

class FCN1(nn.Module):
    def __init__(self):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Conv3d(1, 32, kernel_size=kernel_size, padding="same"),
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.Conv3d(32, 64, kernel_size=kernel_size, padding="same"),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.Conv3d(64, 128, kernel_size=kernel_size, padding="same"),
            nn.BatchNorm3d(128),
            nn.ReLU(),
            nn.Conv3d(128, 64, kernel_size=kernel_size, padding="same"),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.Conv3d(64, 128, kernel_size=kernel_size, padding="same"),
            nn.BatchNorm3d(128),
            nn.ReLU(),
            nn.Conv3d(128, 64, kernel_size=kernel_size, padding="same"),
        )

    def forward(self, x):
        return self.layers(x)

In [15]:
import dsntnn

class CoordRegressionNetwork(nn.Module):
    def __init__(self, n_locations):
        super().__init__()
        self.fcn = FCN1()
        self.hm_conv = nn.Conv3d(64, n_locations, kernel_size=1, bias=False)

    def forward(self, images):
        # 1. Run the images through our FCN
        fcn_out = self.fcn(images)
        # 2. Use a 1x1 conv to get one unnormalized heatmap per location
        unnormalized_heatmaps = self.hm_conv(fcn_out)
        # 3. Normalize the heatmaps
        heatmaps = dsntnn.flat_softmax(unnormalized_heatmaps)
        # 4. Calculate the coordinates
        coords = dsntnn.dsnt(heatmaps)

        return coords, heatmaps

In [16]:
## Loss Function
def mse_with_res(y_true, y_pred, res):
    """
    :param y_true: [batch_size, num_landmarks, dimension(column, row, slice)]
    :param y_pred: [batch_size, num_landmarks, dimension(column, row, slice)]
    :param res: Pixel distance in mm, [batch_size, 1, dimension(column, row, slice)]
    :return: mean square error along batch_size (mm^2)
    """
    err_diff = y_true - y_pred
    # repeat res to make a convenient calculation follow
    num_landmarks = err_diff.shape[1]
    rep_res = torch.repeat_interleave(res, num_landmarks, axis=1)
    # change pixel distance to mm (kind of normalization I think)
    losses = err_diff
    disses = err_diff * rep_res
    square_losses = torch.pow(losses, 2)
    square_disses = torch.pow(disses, 2 )
    #loss = torch.mean(torch.sum(square_losses, (1, 2)))
    loss = torch.sum(square_losses, (1, 2))
    diss = torch.sum(square_disses, (1, 2))
    return loss, diss

In [17]:
model = CoordRegressionNetwork(n_locations=2).to(device)

In [18]:
train_features_var = train_features.cuda()
coords, heatmaps = model(train_features_var)

In [19]:
y = train_labels[:, 0:2, :]
res = train_labels[:, 2:3, :]

In [20]:
y_var = y.cuda()
res_var = res.cuda()
mse_with_res(y_var, coords, res_var)

In [21]:
def loss_fn(y_pred, heatmaps, y_true_res):
    y_true = y_true_res[:, 0:2, :]
    res = y_true_res[:, 2:3, :]
    
    # Per-location euclidean losses
    # euc_losses = dsntnn.euclidean_losses(coords, pts_tensor_var)
    euc_losses, dist = mse_with_res(y_true, y_pred, res)
    # Per-location regularization losses
    reg_losses = dsntnn.js_reg_losses(heatmaps, y_true, sigma_t=1.0)
    # Combine losses into an overall loss
    loss = dsntnn.average_loss(euc_losses)
    euc_mean = dsntnn.average_loss(dist)
    reg_mean = dsntnn.average_loss(reg_losses)
    return loss, euc_mean, reg_mean

In [22]:
## Training
from torch import optim

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred, heat = model(X)
        loss, dis, reg = loss_fn(pred, heat, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}] ({dis:>7f} / {reg:>5f})")

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, dis_loss, reg_loss = 0, 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred, heat = model(X)
            loss, dis, reg = loss_fn(pred, heat, y)
            test_loss += loss.item()
            dis_loss += dis.item()
            reg_loss += reg.item()
            #test_loss += loss_fn(pred, heat, y).item()
    test_loss /= num_batches
    dis_loss /= num_batches
    reg_loss /= num_batches
    print(f"Test Error: \nAvg loss: {test_loss:>8f}  Avg dis: {dis_loss:>8f}, Avg reg: {reg_loss:>8f}\n")

In [24]:
model = CoordRegressionNetwork(n_locations=2).to(device)

opt_rms = optim.RMSprop(model.parameters(), lr=0.0001)
opt_ada = optim.Adam(model.parameters(), lr=0.0001)

epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, opt_ada)
    test(test_dataloader, model, loss_fn)
print("Done!")