In [1]:
# Train a tiny convolutional neural network to rate the spline interpolations 
import torch
import torch.nn as nn
import torch.optim as optim

import sys

WORKING_DIR = "/home/daniel/Documents/Uni/practical-sose23/castellvi/3D-Castellvi-Prediction/"

sys.path.append(WORKING_DIR + "src/")

from dataset.Splines import Splines, ConvexHullDataset
from utils._prepare_data import DataHandler

dataset = [WORKING_DIR  + 'data/dataset-verse19',  WORKING_DIR + 'data/dataset-verse20', WORKING_DIR + 'data/dataset-tri']
data_types = ['rawdata',"derivatives"]
image_types = ["ct"]
master_list = WORKING_DIR + 'src/dataset/Castellvi_list_v3.xlsx'
processor = DataHandler(master_list=master_list ,dataset=dataset, data_types=data_types, image_types=image_types)

dataset = Splines(processor=processor)

# Split dataset into train and test
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42))

####################################
  from .autonotebook import tqdm as notebook_tqdm
  File "/home/daniel/anaconda3/envs/dev-castellvi/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/daniel/anaconda3/envs/dev-castellvi/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/daniel/.local/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/home/daniel/.local/lib/python3.10/site-packages/traitlets/config/application.py", line 1043, in launch_instance
    app.start()
  File "/home/daniel/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 725, in start
    self.io_loop.start()
  File "/home/daniel/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 195, in start
    self.asyncio_loop.run_forever()
  File "/home/daniel/anaconda3/envs/dev-castellvi/lib/python3.10/asyncio/base_events.py", line

[!] subreg is not in list of legal keys. This name 'sub-verse015_seg-subreg_subreg-castcorr_msk.nii.gz' is invalid. Legal keys are: ['sub', 'ses', 'sequ', 'acq', 'task', 'chunk', 'hemi', 'sample', 'ce', 'trc', 'stain', 'rec', 'proc', 'mod', 'recording', 'res', 'dir', 'echo', 'flip', 'inv', 'mt', 'part', 'space', 'seg', 'source', 'snapshot', 'ovl', 'run', 'label', 'split', 'den', 'desc', 'ct']. 
For use see https://bids-specification.readthedocs.io/en/stable/99-appendices/09-entities.html
[!] Unknown format seg-ano in file sub-verse602_dir-iso_seg-ano.nii.gz
[!] Unknown format iso-ctd in file sub-verse616_dir-iso_iso-ctd.json
[!] "verse549" is not a valid key/value pair. Expected "KEY-VALUE" in verse549_CT-iso_seg-ano.nii.gz
[!] "template" is not a valid key/value pair. Expected "KEY-VALUE" in sub-verse519_template_sacrum_msk.nii.gz
[!] "sacrum" is not a valid key/value pair. Expected "KEY-VALUE" in sub-verse519_template_sacrum_msk.nii.gz
[!] cortex is not in list of legal keys. This na

In [2]:
# Define a tiny convolutional neural network to rate the spline interpolations
class SplineRatingNet(nn.Module):
    def __init__(self):
        super(SplineRatingNet, self).__init__()
        # Input shape: (Batch_Size, 128, 3)
        self.conv1 = nn.Conv1d(3, 16, 3, padding=1)
        self.act1 = nn.ReLU()

        self.conv2 = nn.Conv1d(16, 32, 3, padding=1)
        self.act2 = nn.ReLU()

        self.flat = nn.Flatten()

        self.fc1 = nn.Linear(32 * 128, 128)
        self.act3 = nn.ReLU()

        self.fc2 = nn.Linear(128, 3)

    def forward(self, x):
        # Switch shape to (Batch_Size, 3, 128)
        x = x.permute(0, 2, 1)
        x = self.conv1(x)
        x = self.act1(x)

        x = self.conv2(x)
        x = self.act2(x)

        x = self.flat(x)

        x = self.fc1(x)
        x = self.act3(x)

        x = self.fc2(x)

        return x

In [3]:
# Get random batch from train dataset
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)
data, label = next(iter(train_loader))

data.shape, label.shape

(torch.Size([4, 128, 3]), torch.Size([4]))

In [4]:
# Train the network with the Adam optimizer and cross entropy loss
net = SplineRatingNet()
optimizer = optim.Adam(net.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Define a function to train the network
def train(net, optimizer, criterion, train_loader, test_loader, epochs=10):
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(train_loader):
            inputs, labels = data

            optimizer.zero_grad()

            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if i % 10 == 9:
                print(f'Epoch {epoch + 1}, batch {i + 1}: loss {running_loss / 10}')
                running_loss = 0.0

        correct = 0
        total = 0
        with torch.no_grad():
            for data in test_loader:
                inputs, labels = data
                outputs = net(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'Epoch {epoch + 1}: accuracy {100 * correct / total}')

# Train the network
train(net, optimizer, criterion, torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True), torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=True), epochs=10)

Epoch 1, batch 10: loss 3.6678982734680177
Epoch 1, batch 20: loss 1.0038303315639496
Epoch 1: accuracy 69.23076923076923
Epoch 2, batch 10: loss 0.6232566058635711
Epoch 2, batch 20: loss 0.5422274380922317
Epoch 2: accuracy 80.21978021978022
Epoch 3, batch 10: loss 0.6121211528778077
Epoch 3, batch 20: loss 0.5299663960933685
Epoch 3: accuracy 79.67032967032966
Epoch 4, batch 10: loss 0.5451123654842377
Epoch 4, batch 20: loss 0.5193828165531158
Epoch 4: accuracy 80.76923076923077
Epoch 5, batch 10: loss 0.5898045897483826
Epoch 5, batch 20: loss 0.45612713545560835
Epoch 5: accuracy 81.86813186813187
Epoch 6, batch 10: loss 0.5320951730012894
Epoch 6, batch 20: loss 0.5148334175348281
Epoch 6: accuracy 80.21978021978022
Epoch 7, batch 10: loss 0.4740293949842453
Epoch 7, batch 20: loss 0.6287113785743713
Epoch 7: accuracy 79.67032967032966
Epoch 8, batch 10: loss 0.5690117746591568
Epoch 8, batch 20: loss 0.5069302946329117
Epoch 8: accuracy 79.12087912087912
Epoch 9, batch 10: loss

In [5]:
# Calculate confusion matrix, F1 score for test dataset
from sklearn.metrics import confusion_matrix, f1_score, matthews_corrcoef, cohen_kappa_score, classification_report
import numpy as np

# Get predictions for test dataset
predictions = []
with torch.no_grad():
    for data in torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=True):
        inputs, labels = data
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        predictions.append(predicted)

# Calculate confusion matrix
y_pred = np.concatenate(predictions)
y_true = np.array([label for _, label in test_dataset])

# Print metrics and report
print("Confusion Matrix: \n", confusion_matrix(y_true, y_pred))
print("F1 Score: ", f1_score(y_true, y_pred, average='macro'))
print("MCC: ", matthews_corrcoef(y_true, y_pred))
print("Cohens Kappa: ", cohen_kappa_score(y_true, y_pred))
print("Classification Report: \n", classification_report(y_true, y_pred))


Confusion Matrix: 
 [[141   0   3]
 [ 20   0   2]
 [ 16   0   0]]
F1 Score:  0.29283489096573206
MCC:  0.02070215421574309
Cohens Kappa:  0.012440444679724716
Classification Report: 
               precision    recall  f1-score   support

           0       0.80      0.98      0.88       144
           1       0.00      0.00      0.00        22
           2       0.00      0.00      0.00        16

    accuracy                           0.77       182
   macro avg       0.27      0.33      0.29       182
weighted avg       0.63      0.77      0.70       182



####################################
  _warn_prf(average, modifier, msg_start, len(result))
  File "/home/daniel/anaconda3/envs/dev-castellvi/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/daniel/anaconda3/envs/dev-castellvi/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/daniel/.local/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/home/daniel/.local/lib/python3.10/site-packages/traitlets/config/application.py", line 1043, in launch_instance
    app.start()
  File "/home/daniel/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 725, in start
    self.io_loop.start()
  File "/home/daniel/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 195, in start
    self.asyncio_loop.run_forever()
  File "/home/daniel/anaconda3/envs/dev-castellvi/lib/python3.10/asyncio/base_events.py",

That is not convincing at all. Let's see if the convex hull is more helpful.

In [6]:
dataset = ConvexHullDataset(processor=processor)

# Split dataset into train and test
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42))


In [14]:
train_dataset[0][0]["last_L"].vertices

array([    0,     2,     3,    13,    14,    29,    45,    57,    61,
          62,    73,    74,    93,   115,   116,   206,   232,   259,
         285,   389,   393,   394,   420,   450,   510,   688,   862,
        1140,  1141,  1148,  1246,  1805,  1820,  1821,  1830,  1831,
        1863,  1864,  1902,  2647,  2662,  3521,  3577,  3578,  3604,
        3622,  4595,  4626,  4727,  5805,  6117,  6977,  7013,  7044,
        7062,  7220,  8196,  8237,  8575, 10970, 11117, 11278, 11439,
       12588, 14119, 14222, 14278, 15700, 15708, 15944, 16226, 17347,
       17365, 17369, 17495, 17561, 17633, 17920, 18060, 19047, 19265,
       19559, 20734, 20780, 21571, 21640, 22491, 22563, 22637, 23154,
       24135, 24196, 24268, 24343, 24860, 25804, 26025, 26100, 26322,
       26396, 27438, 29197, 30691, 32296, 33867, 36926, 39889, 41334,
       42801, 42802, 42811, 42822, 44249, 44251, 44252, 44275, 45669,
       45701, 45759, 47031, 47177, 48413, 55013, 55041, 55046, 55047,
       55049, 55703,