### Imports

In [17]:
import monai
from monai.data import DataLoader, Dataset
from monai.transforms import LoadImaged, Compose, ScaleIntensityd, Resized, EnsureTyped, ConcatItemsd, ToTensord, CenterSpatialCropd, EnsureChannelFirstd
from monai.metrics import ROCAUCMetric
from monai.engines import SupervisedTrainer
from monai.handlers import from_engine, CheckpointLoader
from sklearn.preprocessing import MinMaxScaler

import pandas as pd

import torch

from ignite.metrics import Accuracy

from model_helpers import Repeatd, prepare_batch, get_additional_metrics

### Data loading

In [2]:
df = pd.read_csv("data/labels_ts2024_imp.tsv", sep="\t")
df = df.assign(pet=lambda df: df['pseudo_id'].map(lambda pseudo_id: "data/cropped_nifti/" + pseudo_id + "_pet.nii.gz"))
df = df.assign(ct=lambda df: df['pseudo_id'].map(lambda pseudo_id: "data/cropped_nifti/" + pseudo_id + "_ct.nii.gz"))

df.head()

Unnamed: 0,pseudo_id,sex,staging,px,psa,label,pseudo_patid,set,unknown,age,pet,ct
0,T_33263,M,re,0,0.35,1,96256,test,False,68,data/cropped_nifti/T_33263_pet.nii.gz,data/cropped_nifti/T_33263_ct.nii.gz
1,T_71212,M,re,1,8.7,1,28134,test,True,74,data/cropped_nifti/T_71212_pet.nii.gz,data/cropped_nifti/T_71212_ct.nii.gz
2,T_82650,M,re,1,0.82,1,75859,test,True,70,data/cropped_nifti/T_82650_pet.nii.gz,data/cropped_nifti/T_82650_ct.nii.gz
3,T_23712,M,re,1,932.0,0,20584,test,False,64,data/cropped_nifti/T_23712_pet.nii.gz,data/cropped_nifti/T_23712_ct.nii.gz
4,T_44829,M,re,0,3.77,1,28035,test,True,81,data/cropped_nifti/T_44829_pet.nii.gz,data/cropped_nifti/T_44829_ct.nii.gz


In [3]:
df.shape

(200, 12)

In [4]:
original_df = pd.read_csv("data/labels.tsv", sep="\t")
scaler = MinMaxScaler()
scaler.fit(original_df[["psa"]])
psa_normalized = scaler.transform(df[["psa"]])
df["psa_norm"] = psa_normalized

### Create sets

In [7]:
test_data = df.to_dict('records')

### Defining the transforms

In [8]:
transforms = Compose(
    [
        LoadImaged(keys=["ct","pet"]),
        EnsureChannelFirstd(keys=["ct","pet"]),
        ScaleIntensityd(keys=["ct","pet"]),
        Resized(keys=["ct","pet"], spatial_size=(70, 70, 70)),
        Repeatd(keys=["psa_norm", "px"], target_size=(1, 65, 46, 69)),
        CenterSpatialCropd(keys=["ct", "pet"], roi_size = (65, 46, 69)),
        EnsureTyped(keys=["ct","pet", "psa_norm", "px"]),  
        ConcatItemsd(keys=["ct", "pet", "psa_norm", "px"], name="petct", dim=0),  
                                              
        ToTensord(keys=["petct", "ct", "pet"]),
    ]
) 

### Create data loaders

In [9]:
batchsize = 16

In [11]:
test_ds = Dataset(data=test_data, transform=transforms)
test_loader = DataLoader(test_ds, batch_size=batchsize, num_workers=1, pin_memory=torch.cuda.is_available())

### Create model

In [12]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=4, out_channels=2).to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
auc_metric = ROCAUCMetric()

### Use SupervisedTrainer

##### Create handlers + Trainer and Evaluator

In [18]:
trainer = SupervisedTrainer(
    device = device,
    max_epochs = 15,
    train_data_loader = test_loader,
    network = model,
    optimizer = optimizer,
    loss_function = loss_function,
    prepare_batch = prepare_batch,
    key_train_metric = {"train_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
		additional_metrics=get_additional_metrics("test"),
    amp = False
)

### Prediction

In [19]:
handler = CheckpointLoader(f"runs/model_D/checkpoint_epoch=4.pt", load_dict={"net": model, "opt": optimizer})
handler(trainer)

  checkpoint = torch.load(self.load_path, map_location=self.map_location)


In [22]:
model.eval()
for batch in iter(test_loader):
    IDs = batch["pseudo_id"]
    Preds = model(batch["petct"].to(device)).argmax(dim=1)
    for ID, Pred in zip(IDs, Preds):
        df.loc[df.pseudo_id == ID, 'prediction'] = Pred.item()
        print(ID, Pred.item())
model.train();

T_33263 1
T_71212 0
T_82650 0
T_23712 0
T_44829 1
T_89795 0
T_43412 0
T_86419 0
T_28330 0
T_75117 1
T_02513 0
T_41025 0
T_24242 0
T_15014 1
T_86015 0
T_61820 0
T_62629 0
T_90149 0
T_08645 0
T_06614 0
T_86271 1
T_53582 0
T_56253 0
T_88700 1
T_13953 0
T_77949 0
T_88000 0
T_00321 0
T_14075 0
T_06700 0
T_59741 0
T_83904 0
T_73383 0
T_02449 1
T_40131 0
T_00202 0
T_37996 1
T_55935 0
T_69619 0
T_54173 1
T_62099 1
T_73742 0
T_95726 1
T_73200 0
T_05666 1
T_84574 0
T_91579 1
T_86649 0
T_22240 0
T_86404 1
T_07593 1
T_75877 0
T_41516 1
T_17737 0
T_67652 1
T_14316 0
T_37239 0
T_79624 1
T_52440 0
T_90251 0
T_00758 0
T_57485 1
T_46507 1
T_78867 1
T_30295 1
T_51110 1
T_01120 0
T_52986 0
T_04831 1
T_22941 1
T_27435 1
T_94201 0
T_10857 0
T_14757 1
T_48388 1
T_09628 1
T_09646 1
T_56919 0
T_97180 0
T_28375 0
T_67370 1
T_53915 1
T_04557 0
T_98987 0
T_34340 0
T_83804 0
T_33335 1
T_68767 0
T_73122 1
T_04911 0
T_15619 0
T_73503 0
T_20333 1
T_40054 0
T_61428 0
T_72936 0
T_50183 0
T_71687 1
T_37295 1
T_96039 0


In [23]:
df.to_csv(path_or_buf="analysis/testset_predictions.tsv", sep="\t", index=False)

In [36]:
cm = df.value_counts(["label", "prediction"])
cm

label  prediction
0      0.0           79
1      1.0           63
       0.0           48
0      1.0           10
Name: count, dtype: int64

In [47]:
tn = cm[0,0]
fn = cm[0,1]
tp = cm[1,1]
fp = cm[1,0]
tp, fp, fn, tn

(np.int64(63), np.int64(48), np.int64(10), np.int64(79))

In [25]:
from sklearn.metrics import balanced_accuracy_score, accuracy_score

In [26]:
accuracy_score(df.label, df.prediction)

0.71

In [30]:
balanced_accuracy_score(df.label, df.prediction).item()

0.7276040085028849

In [46]:
specificity = tn/(tn+fp)
sensitivity = tp/(tp+fn)
print(f"specificity = {specificity}")
print(f"sensitivity = {sensitivity}")

specificity = 0.6220472440944882
sensitivity = 0.863013698630137


In [32]:
dfu = df[df.unknown]
dfk = df[~df.unknown]

In [33]:
dfu.shape, dfk.shape

((116, 14), (84, 14))

In [34]:
accuracy_score(dfu.label, dfu.prediction)

0.7413793103448276

In [35]:
accuracy_score(dfk.label, dfk.prediction)

0.6666666666666666