# Change detection (invariant to angles)

In [36]:
import torch
print(torch.cuda.is_available()) # should be True
# t = torch.rand(10, 10).cuda()
# print(t.device) # should be CUDA

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
from PIL import Image
import os
import nibabel as nib


True


#
QA check reveals in /data/heatmaps/originals that theyre just really badly cropped
The idea was to either use bonestructure or mvoe into FA/MD and take more slices, i took the last approach

### Util functions

In [37]:
def remove_images_in_sub_dir(directory):

    for root, dirs, files in os.walk(directory):
        for filename in files:
            print(filename)
            if os.path.isfile(os.path.join(root, filename)) and filename.lower().endswith('.jpg'):
                os.remove(os.path.join(root,filename))

In [38]:
#from PIL import Image

def merge_images(image1_dir, image_id, output_path):
    # Open the images
    for root, dirs, files in os.walk(image1_dir):
        for filename in files:
            if filename.endswith(image_id):
                image1 = Image.open(os.path.join(root, filename))
                try:
                    if "preop" in root:
                        image2 = Image.open(os.path.join(root.replace("preop", "postop"), filename.replace("preop", "postop")))
                    else:
                        image2 = Image.open(os.path.join(root.replace("postop", "preop"), filename.replace("postop", "preop")))
                except FileNotFoundError:
                    print(f"Matching subject (pre and post) not found for {filename}")
                # Get the dimensions of the first image
                width1, height1 = image1.size

                # Get the dimensions of the second image
                width2, height2 = image2.size

                # Calculate the total width and height for the merged image
                total_width = width1 + width2
                max_height = max(height1, height2)

                # Create a new image with the calculated dimensions
                merged_image = Image.new('RGB', (total_width, max_height))

                # Paste the first image onto the merged image
                merged_image.paste(image1, (0, 0))

                # Paste the second image onto the merged image next to the first image
                merged_image.paste(image2, (width1, 0))

                # Save the merged image
                merged_image.save(output_path)

### Start with axial, coronal, saggital midpoint images

In [39]:
def convert_to_image_if_not_exist(directory, output_size=(256, 256), second_dir=None, image_id='mask.nii.gz', ext='nii.gz'):
    angle_map = {
        'saggital': 0,
        'coronal': 1,
        'axial': 2
    }
    for root, dirs, files in os.walk(directory):
        for filename in files:
            if filename.endswith(image_id):
                # Construct input and output paths
                input_path = os.path.join(root, filename)
                nifti_data = nib.load(input_path)
                image_data = nifti_data.get_fdata()
                for key, value in angle_map.items():

                    output_path = os.path.join(root, filename.replace(ext, "") + f"{key}" + ".jpg")
                    # Check if output image already exists
                    if os.path.exists(output_path):
                        print(f"Image {output_path} already exists, skipping...") 
                        continue
                    slicing_tuple = [slice(None)] * len(image_data.shape)
                    slice_index = image_data.shape[value] // 2
                    slicing_tuple[value] = slice_index

                    # Apply the slicing to image_data
                    image_slice = image_data[tuple(slicing_tuple)]

                    # Normalize intensity values
                    min_intensity = np.min(image_slice)
                    max_intensity = np.max(image_slice)
                    image_slice_normalized = (image_slice - min_intensity) / (max_intensity - min_intensity)

                    # Resize the slice to the specified output size
                    image_slice_resized = np.array(Image.fromarray((image_slice_normalized * 255).astype(np.uint8)).resize(output_size))

                    # Convert to image format
                    image = Image.fromarray(image_slice_resized)

                    # Resize the image to the specified output size
                    image = image.resize(output_size)
                    image.save(output_path)
                    if second_dir is not None:
                        output_dir = os.path.join(second_dir, filename.split('_')[0])
                        os.makedirs(output_dir, exist_ok=True)
                # Save the image
                        image.save(os.path.join(f'{second_dir}/{filename.split("_")[0]}', os.path.splitext(filename)[0] + f"{key}" + ".jpg"))
                    print(f"Converted image {input_path} to {output_path} with size {output_size}")

### Converting raw preop

In [42]:
dir = "./data/raw/preop/BTC-preop"
output_size = (256, 256)  # Specify the desired output size
#convert the masked images
convert_to_image_if_not_exist(dir, output_size, image_id='mask.nii.gz', ext='nii.gz')
#convert the original images
convert_to_image_if_not_exist(dir, output_size, image_id='T1w.nii.gz', ext='nii.gz')

Image ./data/raw/preop/BTC-preop/sub-PAT31/ses-preop/anat/sub-PAT31_ses-preop_T1w.nii_mask.saggital.jpg already exists, skipping...
Image ./data/raw/preop/BTC-preop/sub-PAT31/ses-preop/anat/sub-PAT31_ses-preop_T1w.nii_mask.coronal.jpg already exists, skipping...
Image ./data/raw/preop/BTC-preop/sub-PAT31/ses-preop/anat/sub-PAT31_ses-preop_T1w.nii_mask.axial.jpg already exists, skipping...
Image ./data/raw/preop/BTC-preop/sub-CON09/ses-preop/anat/sub-CON09_ses-preop_T1w.nii_mask.saggital.jpg already exists, skipping...
Image ./data/raw/preop/BTC-preop/sub-CON09/ses-preop/anat/sub-CON09_ses-preop_T1w.nii_mask.coronal.jpg already exists, skipping...
Image ./data/raw/preop/BTC-preop/sub-CON09/ses-preop/anat/sub-CON09_ses-preop_T1w.nii_mask.axial.jpg already exists, skipping...
Image ./data/raw/preop/BTC-preop/sub-CON03/ses-preop/anat/sub-CON03_ses-preop_T1w.nii_mask.saggital.jpg already exists, skipping...
Image ./data/raw/preop/BTC-preop/sub-CON03/ses-preop/anat/sub-CON03_ses-preop_T1w.ni

### Converting raw postop

In [41]:
dir = "./data/raw/postop/BTC-postop"
output_size = (256, 256)  # Specify the desired output size
#convert the masked images
convert_to_image_if_not_exist(dir, output_size, image_id='mask.nii.gz', ext='nii.gz')
#convert the original images
convert_to_image_if_not_exist(dir, output_size, image_id='T1w.nii.gz', ext='nii.gz')

Image ./data/raw/postop/BTC-postop/sub-CON09/ses-postop/anat/sub-CON09_ses-postop_T1w.nii_mask.saggital.jpg already exists, skipping...
Image ./data/raw/postop/BTC-postop/sub-CON09/ses-postop/anat/sub-CON09_ses-postop_T1w.nii_mask.coronal.jpg already exists, skipping...
Image ./data/raw/postop/BTC-postop/sub-CON09/ses-postop/anat/sub-CON09_ses-postop_T1w.nii_mask.axial.jpg already exists, skipping...
Image ./data/raw/postop/BTC-postop/sub-CON03/ses-postop/anat/sub-CON03_ses-postop_T1w.nii_mask.saggital.jpg already exists, skipping...
Image ./data/raw/postop/BTC-postop/sub-CON03/ses-postop/anat/sub-CON03_ses-postop_T1w.nii_mask.coronal.jpg already exists, skipping...
Image ./data/raw/postop/BTC-postop/sub-CON03/ses-postop/anat/sub-CON03_ses-postop_T1w.nii_mask.axial.jpg already exists, skipping...
Image ./data/raw/postop/BTC-postop/sub-PAT25/ses-postop/anat/sub-PAT25_ses-postop_T1w.nii_mask.saggital.jpg already exists, skipping...
Image ./data/raw/postop/BTC-postop/sub-PAT25/ses-postop/

### Converting processed preop

In [44]:
dir = "./data/processed/preop/BTC-preop"
output_size = (256, 256)  # Specify the desired output size
#convert the FA images
convert_to_image_if_not_exist(dir, output_size, image_id='fa.nii.gz', ext='nii.gz')
#convert the MD images
convert_to_image_if_not_exist(dir, output_size, image_id='md.nii.gz', ext='nii.gz')

Converted image ./data/processed/preop/BTC-preop/sub-PAT31/dti/fa.nii.gz to ./data/processed/preop/BTC-preop/sub-PAT31/dti/fa.saggital.jpg with size (256, 256)
Converted image ./data/processed/preop/BTC-preop/sub-PAT31/dti/fa.nii.gz to ./data/processed/preop/BTC-preop/sub-PAT31/dti/fa.coronal.jpg with size (256, 256)
Converted image ./data/processed/preop/BTC-preop/sub-PAT31/dti/fa.nii.gz to ./data/processed/preop/BTC-preop/sub-PAT31/dti/fa.axial.jpg with size (256, 256)
Converted image ./data/processed/preop/BTC-preop/sub-CON09/dti/fa.nii.gz to ./data/processed/preop/BTC-preop/sub-CON09/dti/fa.saggital.jpg with size (256, 256)
Converted image ./data/processed/preop/BTC-preop/sub-CON09/dti/fa.nii.gz to ./data/processed/preop/BTC-preop/sub-CON09/dti/fa.coronal.jpg with size (256, 256)
Converted image ./data/processed/preop/BTC-preop/sub-CON09/dti/fa.nii.gz to ./data/processed/preop/BTC-preop/sub-CON09/dti/fa.axial.jpg with size (256, 256)
Converted image ./data/processed/preop/BTC-preop

### Converting processed postop

In [45]:
dir = "./data/processed/postop/BTC-postop"
output_size = (256, 256)  # Specify the desired output size
#convert the FA images
convert_to_image_if_not_exist(dir, output_size, image_id='fa.nii.gz', ext='nii.gz')
#convert the MD images
convert_to_image_if_not_exist(dir, output_size, image_id='md.nii.gz', ext='nii.gz')

Converted image ./data/processed/postop/BTC-postop/sub-CON09/dti/fa.nii.gz to ./data/processed/postop/BTC-postop/sub-CON09/dti/fa.saggital.jpg with size (256, 256)
Converted image ./data/processed/postop/BTC-postop/sub-CON09/dti/fa.nii.gz to ./data/processed/postop/BTC-postop/sub-CON09/dti/fa.coronal.jpg with size (256, 256)
Converted image ./data/processed/postop/BTC-postop/sub-CON09/dti/fa.nii.gz to ./data/processed/postop/BTC-postop/sub-CON09/dti/fa.axial.jpg with size (256, 256)
Converted image ./data/processed/postop/BTC-postop/sub-CON03/dti/fa.nii.gz to ./data/processed/postop/BTC-postop/sub-CON03/dti/fa.saggital.jpg with size (256, 256)
Converted image ./data/processed/postop/BTC-postop/sub-CON03/dti/fa.nii.gz to ./data/processed/postop/BTC-postop/sub-CON03/dti/fa.coronal.jpg with size (256, 256)
Converted image ./data/processed/postop/BTC-postop/sub-CON03/dti/fa.nii.gz to ./data/processed/postop/BTC-postop/sub-CON03/dti/fa.axial.jpg with size (256, 256)
Converted image ./data/p

### Datasets and dataloaders

In [70]:
class imageSets(Dataset):
    """
    Image dataset for each subject in the dataset
    creating only 'correct' and 'incorrect' pairs for now

    Works by passing preop or postop directory to the class
    and finds the corresponding image in the other dir and labels
    """
    def __init__(self, root, image_ids, transform=None):
        self.root = root
        self.transform = transform
        self.data = []
        self.image_ids = image_ids
        for root, dirs, files in os.walk(self.root):
            for filename in files:
                for image_id in self.image_ids:
                    if filename.endswith(image_id):
                        img_1 = Image.open(os.path.join(root, filename))
                        ## finds the corresponding image in the other dir
                        try:
                            if "preop" in root:
                                img_2 = Image.open(os.path.join(root.replace("preop", "postop"), filename.replace("preop", "postop")))
                            else:
                                img_2 = Image.open(os.path.join(root.replace("postop", "preop"), filename.replace("postop", "preop")))
                            if "-CON" in filename or "-CON" in os.path.join(root, filename):
                                # print("control for ", filename)
                                self.data.append((img_1, img_2, 1, os.path.join(root, filename))) # Similar
                            elif "-PAT" in filename or "-PAT" in os.path.join(root, filename):
                                self.data.append((img_1, img_2, 0, os.path.join(root, filename))) # Dissimalar
                            else:
                                print(f"Invalid filename: {os.path.join(root, filename)}")
                        except FileNotFoundError:
                            print(f"Matching subject (pre and post) not found for {os.path.join(root, filename)}")
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        if self.transform:
            img1_file = self.transform(self.data[idx][0])
            img2_file = self.transform(self.data[idx][1])
        return (img1_file, img2_file, self.data[idx][2], self.data[idx][3])

In [90]:
# import torchvision.transforms.functional as TF
# import matplotlib.pyplot as plt
# default_transform = transforms.Compose([
# transforms.Resize(256),
# transforms.ToTensor(),
# ])
# BATCH_SIZE = 1
# train_data = imageSets("./data/processed/preop/BTC-preop", 
#                        image_ids=['fa.coronal.jpg', 'md.coronal.jpg'], 
#                        transform=default_transform)
# loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=False)
# print(len(loader))
# for i in loader:
#     print(i[0].shape)
#     print(i[1].shape)
#     print(i[3])
#     image = TF.to_pil_image(i[0].squeeze(0))
#     # Display the image
#     plt.imshow(image, cmap='gray')  # Specify grayscale colormap
#     plt.axis('off')
#     plt.show()
#     # Display the image
#     image.show()
#     break


### Network

In [47]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        # Define the architecture for the Siamese network
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.dropout = nn.Dropout(p=0.5)
        self.fc1 = nn.Linear(131072, 128)  # Adjust input size based on input dimensions

    def forward(self, input1, input2):
        # Forward pass through the Siamese network
        output1 = F.relu(self.bn1(self.conv1(input1)))
        output1 = F.max_pool2d(output1, kernel_size=2, stride=2)
        output1 = F.relu(self.bn2(self.conv2(output1)))
        output1 = F.max_pool2d(output1, kernel_size=2, stride=2)
        output1 = F.relu(self.bn3(self.conv3(output1)))
        output1 = F.max_pool2d(output1, kernel_size=2, stride=2)

        output2 = F.relu(self.bn1(self.conv1(input2)))
        output2 = F.max_pool2d(output2, kernel_size=2, stride=2)
        output2 = F.relu(self.bn2(self.conv2(output2)))
        output2 = F.max_pool2d(output2, kernel_size=2, stride=2)
        output2 = F.relu(self.bn3(self.conv3(output2)))
        output2 = F.max_pool2d(output2, kernel_size=2, stride=2)

        return output1, output2



## Loss Functions

In [48]:
class ConstractiveLoss(nn.Module):

    def __init__(self,margin =2.0,dist_flag='l2'):
        super(ConstractiveLoss, self).__init__()
        self.margin = margin
        self.dist_flag = dist_flag

    def various_distance(self,out_vec_t0,out_vec_t1):

        if self.dist_flag == 'l2': # Euclidean distance
            distance = F.pairwise_distance(out_vec_t0,out_vec_t1,p=2)
        if self.dist_flag == 'l1': # Manhattan distance
            distance = F.pairwise_distance(out_vec_t0,out_vec_t1,p=1)
        if self.dist_flag == 'cos':# Cosine similarity
            similarity = F.cosine_similarity(out_vec_t0,out_vec_t1)
            distance = 1 - 2 * similarity/np.pi
        return distance

    def forward(self,out_vec_t0,out_vec_t1,label):

        #distance = F.pairwise_distance(out_vec_t0,out_vec_t1,p=2)
        distance = self.various_distance(out_vec_t0,out_vec_t1)
        #distance = 1 - F.cosine_similarity(out_vec_t0,out_vec_t1)
        constractive_loss = torch.sum((1-label)*torch.pow(distance,2 ) + \
                                       label * torch.pow(torch.clamp(self.margin - distance, min=0.0),2))
        return constractive_loss
    
class ConstractiveThresholdHingeLoss(nn.Module):

    def __init__(self,hingethresh=0.0,margin=2.0):
        super(ConstractiveThresholdHingeLoss, self).__init__()
        self.threshold = hingethresh
        self.margin = margin

    def forward(self,out_vec_t0,out_vec_t1,label):

        distance = F.pairwise_distance(out_vec_t0,out_vec_t1,p=2)
        similar_pair = torch.clamp(distance - self.threshold,min=0.0)
        dissimilar_pair = torch.clamp(self.margin- distance,min=0.0)
        #dissimilar_pair = torch.clamp(self.margin-(distance-self.threshold),min=0.0)
        constractive_thresh_loss = torch.sum(
            (1-label)* torch.pow(similar_pair,2) + label * torch.pow(dissimilar_pair,2)
        )
        return constractive_thresh_loss

## Visualization

In [49]:
import cv2
def check_dir(dir):
    if not os.path.exists(dir):
        os.mkdir(dir)
        
def various_distance(out_vec_t0, out_vec_t1,dist_flag):
    if dist_flag == 'l2':
        distance = F.pairwise_distance(out_vec_t0, out_vec_t1, p=2)
    if dist_flag == 'l1':
        distance = F.pairwise_distance(out_vec_t0, out_vec_t1, p=1)
    if dist_flag == 'cos':
        distance = 1 - F.cosine_similarity(out_vec_t0, out_vec_t1)
    return distance

def single_layer_similar_heatmap_visual(output_t0,output_t1,save_change_map_dir,filename,dist_flag):

    interp = nn.Upsample(size=[512,512], mode='bilinear')
    n, c, h, w = output_t0.data.shape
    out_t0_rz = torch.transpose(output_t0.view(c, h * w), 1, 0)
    out_t1_rz = torch.transpose(output_t1.view(c, h * w), 1, 0)
    distance = various_distance(out_t0_rz,out_t1_rz,dist_flag=dist_flag)
    similar_distance_map = distance.view(h,w).data.cpu().numpy()
    similar_distance_map_rz = interp(torch.from_numpy(similar_distance_map[np.newaxis, np.newaxis, :]))
    similar_dis_map_colorize = cv2.applyColorMap(np.uint8(255 * similar_distance_map_rz.data.cpu().numpy()[0][0]), cv2.COLORMAP_JET)
    check_dir(save_change_map_dir)
    # save_change_map_dir_layer = os.path.join(save_change_map_dir,layer_flag)
    # check_dir(save_change_map_dir_layer)
    save_weight_fig_dir = os.path.join(save_change_map_dir, filename + '.jpg')
    cv2.imwrite(save_weight_fig_dir, similar_dis_map_colorize)
    return similar_distance_map_rz.data.cpu().numpy()


## Training

In [94]:
import matplotlib.pyplot as plt

def train(siamese_net,  optimizer, criterion, epochs=100, patience=3, save_dir='models', data_dir='./data/raw/preop/BTC-preop', image_ids=['nii_mask.saggital.jpg'], model_name='masked.pth'):
    default_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    ])
    BATCH_SIZE = 1
    train_data = imageSets(data_dir, image_ids=image_ids, transform=default_transform)
    loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=False)
    print(f"Number of samples: {len(loader)}")
    siamese_net.train()
    
    print("\nStarting training...")
    total_loss = 0
    best_loss = float('inf')
    consecutive_no_improvement = 0
    for epoch in range(epochs):
        epoch_loss = 0.0000
        for i, (img1_set, img2_set, label, filename) in enumerate(loader):
            # img1_set = img1_set.cuda()
            # img2_set = img2_set.cuda()
            # label = label.cuda()'
            optimizer.zero_grad()
            output1, output2 = siamese_net(img1_set, img2_set)
            loss = criterion(output1, output2, label)  # Add unsqueeze to match output shape
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print('Epoch [%d/%d], Loss: %.4f' % (epoch+1, epochs, epoch_loss))
                # Check for early stopping
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            consecutive_no_improvement = 0
            # Save the best model
            save_path = os.path.join(save_dir, model_name)
            torch.save(siamese_net.state_dict(), save_path)
            print(f'Saved best model to {save_path}')
        else:
            consecutive_no_improvement += 1
            if consecutive_no_improvement >= patience:
                print(f'Early stopping at epoch {epoch+1} as no improvement for {patience} consecutive epochs.')
                break


#### Masked model with 3 axis

In [97]:

# Initialize Siamese network
siamese_net = SiameseNetwork()
# siamese_net = siamese_net.cuda()  # Move the network to GPU
save_dir = './models'
if os.path.exists(os.path.join(save_dir, 'masked.pth')):
    siamese_net.load_state_dict(torch.load(os.path.join(save_dir, 'masked.pth')))
    print('Loaded the best model')
else:
    criterion = ConstractiveLoss(margin=0.0)
    optimizer = optim.Adam(siamese_net.parameters(), lr=0.001)
    # Train the Siamese network
    train(siamese_net,  optimizer, criterion, epochs=70, patience=5, save_dir='./models', data_dir='./data/raw/preop/BTC-preop', 
          image_ids=['nii_mask.saggital.jpg', 'nii_mask.coronal.jpg', 'nii_mask.axial.jpg'], model_name='masked.pth')

Matching subject (pre and post) not found for ./data/raw/preop/BTC-preop/sub-PAT31/ses-preop/anat/sub-PAT31_ses-preop_T1w.nii_mask.saggital.jpg
Matching subject (pre and post) not found for ./data/raw/preop/BTC-preop/sub-PAT31/ses-preop/anat/sub-PAT31_ses-preop_T1w.nii_mask.axial.jpg
Matching subject (pre and post) not found for ./data/raw/preop/BTC-preop/sub-PAT31/ses-preop/anat/sub-PAT31_ses-preop_T1w.nii_mask.coronal.jpg
Matching subject (pre and post) not found for ./data/raw/preop/BTC-preop/sub-PAT14/ses-preop/anat/sub-PAT14_ses-preop_T1w.nii_mask.saggital.jpg
Matching subject (pre and post) not found for ./data/raw/preop/BTC-preop/sub-PAT14/ses-preop/anat/sub-PAT14_ses-preop_T1w.nii_mask.coronal.jpg
Matching subject (pre and post) not found for ./data/raw/preop/BTC-preop/sub-PAT14/ses-preop/anat/sub-PAT14_ses-preop_T1w.nii_mask.axial.jpg
Matching subject (pre and post) not found for ./data/raw/preop/BTC-preop/sub-CON01/ses-preop/anat/sub-CON01_ses-preop_T1w.nii_mask.saggital.jpg


####  Raw unmasked model 3 axis

In [98]:
# Initialize Siamese network
siamese_net_unmasked = SiameseNetwork()
# siamese_net = siamese_net.cuda()  # Move the network to GPU
save_dir = './models'
if os.path.exists(os.path.join(save_dir, 'unmasked.pth')):
    siamese_net_unmasked.load_state_dict(torch.load(os.path.join(save_dir, 'unmasked.pth')))
    print('Loaded the best model')
else:
    criterion = ConstractiveLoss(margin=0.0)
    optimizer = optim.Adam(siamese_net_unmasked.parameters(), lr=0.001)
    # Train the Siamese network
    train(siamese_net_unmasked,  optimizer, criterion, epochs=200, patience=5, save_dir='./models', data_dir='./data/raw/preop/BTC-preop',
          image_ids=['T1w.saggital.jpg', 'T1w.axial.jpg', 'T1w.coronal.jpg'], model_name='unmasked.pth')

Matching subject (pre and post) not found for ./data/raw/preop/BTC-preop/sub-PAT31/ses-preop/anat/sub-PAT31_ses-preop_T1w.axial.jpg
Matching subject (pre and post) not found for ./data/raw/preop/BTC-preop/sub-PAT31/ses-preop/anat/sub-PAT31_ses-preop_T1w.saggital.jpg
Matching subject (pre and post) not found for ./data/raw/preop/BTC-preop/sub-PAT31/ses-preop/anat/sub-PAT31_ses-preop_T1w.coronal.jpg
Matching subject (pre and post) not found for ./data/raw/preop/BTC-preop/sub-PAT14/ses-preop/anat/sub-PAT14_ses-preop_T1w.saggital.jpg
Matching subject (pre and post) not found for ./data/raw/preop/BTC-preop/sub-PAT14/ses-preop/anat/sub-PAT14_ses-preop_T1w.axial.jpg
Matching subject (pre and post) not found for ./data/raw/preop/BTC-preop/sub-PAT14/ses-preop/anat/sub-PAT14_ses-preop_T1w.coronal.jpg
Matching subject (pre and post) not found for ./data/raw/preop/BTC-preop/sub-CON01/ses-preop/anat/sub-CON01_ses-preop_T1w.coronal.jpg
Matching subject (pre and post) not found for ./data/raw/preop/B

##

#### Processed FA model 3 axis

In [99]:
# Initialize Siamese network
siamese_net_fa = SiameseNetwork()
# siamese_net = siamese_net.cuda()  # Move the network to GPU
save_dir = './models'
if os.path.exists(os.path.join(save_dir, 'fa.pth')):
    siamese_net_fa.load_state_dict(torch.load(os.path.join(save_dir, 'fa.pth')))
    print('Loaded the best model')
else:
    criterion = ConstractiveLoss(margin=0.0)
    optimizer = optim.Adam(siamese_net_fa.parameters(), lr=0.001)
    # Train the Siamese network
    train(siamese_net_fa,  optimizer, criterion, epochs=200, patience=5, save_dir='./models', data_dir='./data/processed/preop/BTC-preop',
          image_ids=['fa.saggital.jpg', 'fa.axial.jpg', 'fa.coronal.jpg'], model_name='fa.pth')

Matching subject (pre and post) not found for ./data/processed/preop/BTC-preop/sub-PAT31/dti/fa.axial.jpg
Matching subject (pre and post) not found for ./data/processed/preop/BTC-preop/sub-PAT31/dti/fa.saggital.jpg
Matching subject (pre and post) not found for ./data/processed/preop/BTC-preop/sub-PAT31/dti/fa.coronal.jpg
Matching subject (pre and post) not found for ./data/processed/preop/BTC-preop/sub-PAT14/dti/fa.axial.jpg
Matching subject (pre and post) not found for ./data/processed/preop/BTC-preop/sub-PAT14/dti/fa.saggital.jpg
Matching subject (pre and post) not found for ./data/processed/preop/BTC-preop/sub-PAT14/dti/fa.coronal.jpg
Matching subject (pre and post) not found for ./data/processed/preop/BTC-preop/sub-CON01/dti/fa.axial.jpg
Matching subject (pre and post) not found for ./data/processed/preop/BTC-preop/sub-CON01/dti/fa.saggital.jpg
Matching subject (pre and post) not found for ./data/processed/preop/BTC-preop/sub-CON01/dti/fa.coronal.jpg
Matching subject (pre and post)

#### Processed MD model 3 axis

In [100]:
# Initialize Siamese network
siamese_net_md = SiameseNetwork()
# siamese_net = siamese_net.cuda()  # Move the network to GPU
save_dir = './models'
if os.path.exists(os.path.join(save_dir, 'md.pth')):
    siamese_net_md.load_state_dict(torch.load(os.path.join(save_dir, 'md.pth')))
    print('Loaded the best model')
else:
    criterion = ConstractiveLoss(margin=0.0)
    optimizer = optim.Adam(siamese_net_md.parameters(), lr=0.001)
    # Train the Siamese network
    train(siamese_net_md,  optimizer, criterion, epochs=200, patience=5, save_dir='./models', data_dir='./data/processed/preop/BTC-preop',
          image_ids=['md.saggital.jpg', 'md.axial.jpg', 'md.coronal.jpg'], model_name='md.pth')

Matching subject (pre and post) not found for ./data/processed/preop/BTC-preop/sub-PAT31/dti/md.coronal.jpg
Matching subject (pre and post) not found for ./data/processed/preop/BTC-preop/sub-PAT31/dti/md.saggital.jpg
Matching subject (pre and post) not found for ./data/processed/preop/BTC-preop/sub-PAT31/dti/md.axial.jpg
Matching subject (pre and post) not found for ./data/processed/preop/BTC-preop/sub-PAT14/dti/md.coronal.jpg
Matching subject (pre and post) not found for ./data/processed/preop/BTC-preop/sub-PAT14/dti/md.saggital.jpg
Matching subject (pre and post) not found for ./data/processed/preop/BTC-preop/sub-PAT14/dti/md.axial.jpg
Matching subject (pre and post) not found for ./data/processed/preop/BTC-preop/sub-CON01/dti/md.coronal.jpg
Matching subject (pre and post) not found for ./data/processed/preop/BTC-preop/sub-CON01/dti/md.saggital.jpg
Matching subject (pre and post) not found for ./data/processed/preop/BTC-preop/sub-CON01/dti/md.axial.jpg
Matching subject (pre and post)

## Prediction

### Load Models

In [101]:
# Initialize Siamese network
siamese_net_masked = SiameseNetwork()
siamese_net_unmasked = SiameseNetwork()
siamese_net_fa = SiameseNetwork()
siamese_net_md = SiameseNetwork()

# load weights
siamese_net_masked.load_state_dict(torch.load('./models/masked.pth'))
siamese_net_unmasked.load_state_dict(torch.load('./models/unmasked.pth'))
siamese_net_fa.load_state_dict(torch.load('./models/fa.pth'))
siamese_net_md.load_state_dict(torch.load('./models/md.pth'))

<All keys matched successfully>

In [12]:
# Prediction function
## TODO: val and test loader
## TODO: 
def predict(siamese_net, image1, image2, threshold=0.3):
    siamese_net.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        output1, output2 = siamese_net(image1, image2)
        distance1 = various_distance(output1, output2, 'l2')  # Compute the distance euclidean
        distance = torch.dist(output1, output2, p=2)
        print(f"Distance: {distance}")
        #similarity_score = 1 - distance.item()  # Convert distance to similarity score
        prediction = distance > threshold  # Determine if the pair is similar based on the threshold
    return prediction, output1, output2, distance.item()


### Unmasked raw data

In [13]:
# Prediction
# label 1 is similar, 0 is dissimilar
BATCH_SIZE = 1
default_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    ])
test_data = imageSets("./data/raw/preop/BTC-preop", image_id='T1w.niisaggital_view.jpg', transform=default_transform)
loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)
for i, (img1_set, img2_set, label, filename) in enumerate(loader):
    orig_filename = os.path.splitext(filename[0])[0]
    patient_id = orig_filename.split("_")[0]
    is_similar, output1, output2, distance = predict(siamese_net_unmasked, img1_set, img2_set, 0.014)
    single_layer_similar_heatmap_visual(output1,output2,f"./data/heatmaps/originals/{patient_id}", f'{patient_id}_predic_{label}','l2')
    # Printing the prediction result
    if is_similar:
        print("The pair is similar with a distance of:", distance, " label:", label)
    else:
        print("The pair is dissimilar with a distance of:", distance, " label:", label)

Matching subject (pre and post) not found for sub-PAT31_ses-preop_T1w.niisaggital_view.jpg
Matching subject (pre and post) not found for sub-PAT14_ses-preop_T1w.niisaggital_view.jpg
Matching subject (pre and post) not found for sub-CON01_ses-preop_T1w.niisaggital_view.jpg
Matching subject (pre and post) not found for sub-PAT19_ses-preop_T1w.niisaggital_view.jpg
Matching subject (pre and post) not found for sub-PAT27_ses-preop_T1w.niisaggital_view.jpg
Matching subject (pre and post) not found for sub-PAT29_ses-preop_T1w.niisaggital_view.jpg
Matching subject (pre and post) not found for sub-PAT22_ses-preop_T1w.niisaggital_view.jpg
The pair is dissimilar with a distance of: 0.0004124616098124534  label: tensor([1])
The pair is similar with a distance of: 0.020212870091199875  label: tensor([1])
The pair is dissimilar with a distance of: 0.0  label: tensor([0])
The pair is dissimilar with a distance of: 0.0  label: tensor([0])
The pair is similar with a distance of: 0.014344928786158562  l

### Masked data

In [29]:
# Prediction
# label 1 is similar, 0 is dissimilar
BATCH_SIZE = 1
default_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    ])
train_data = imageSets("./data/raw/preop/BTC-preop", image_id='nii_mask.niisaggital_view.jpg', transform=default_transform)
loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=False)
for i, (img1_set, img2_set, label, filename) in enumerate(loader):
    orig_filename = os.path.splitext(filename[0])[0]
    patient_id = orig_filename.split("_")[0]
    is_similar, output1, output2, distance = predict(siamese_net_masked, img1_set, img2_set, threshold=0.24)
    single_layer_similar_heatmap_visual(output1,output2,f"./data/heatmaps/masked/{patient_id}", f'{patient_id}_predic_{label}','l2')
    # Printing the prediction result
    if is_similar:
        print("The pair is similar with a distance of:", distance, " label:", label)
    else:
        print("The pair is dissimilar with a distance of:", distance, " label:", label)

Matching subject (pre and post) not found for sub-PAT31_ses-preop_T1w.nii_mask.niisaggital_view.jpg
Matching subject (pre and post) not found for sub-PAT14_ses-preop_T1w.nii_mask.niisaggital_view.jpg
Matching subject (pre and post) not found for sub-CON01_ses-preop_T1w.nii_mask.niisaggital_view.jpg
Matching subject (pre and post) not found for sub-PAT19_ses-preop_T1w.nii_mask.niisaggital_view.jpg
Matching subject (pre and post) not found for sub-PAT27_ses-preop_T1w.nii_mask.niisaggital_view.jpg
Matching subject (pre and post) not found for sub-PAT29_ses-preop_T1w.nii_mask.niisaggital_view.jpg
Matching subject (pre and post) not found for sub-PAT22_ses-preop_T1w.nii_mask.niisaggital_view.jpg
The pair is similar with a distance of: 0.5360137820243835  label: tensor([1])
The pair is similar with a distance of: 0.5467800498008728  label: tensor([1])
The pair is dissimilar with a distance of: 0.01990346610546112  label: tensor([0])
The pair is dissimilar with a distance of: 0.06033733114600

## Visualization of latent space (Dim reduc? -> probably doesnt preserve numerical values..)

## Working with raw data

In [6]:
# raw_preop = nib.load("./data/raw/preop/sub-CON02_ses-preop_T1w.nii.gz")
# raw_postop = nib.load("./data/raw/postop/sub-CON02_ses-postop_T1w.nii.gz")

# proc_preop = nib.load("./data/processed/preop/fa.nii.gz")
# proc_postop = nib.load("./data/processed/postop/fa.nii.gz")

# # This apparently returns voxel level of the data
# data_raw_preop= raw_preop.get_fdata()
# data_raw_postop= raw_postop.get_fdata()

# data_proc_preop= proc_preop.get_fdata()
# data_proc_postop= proc_postop.get_fdata()

# print(data_raw_preop.shape)

In [3]:
# def make_pairs():
#     pass
import sys

mock_pairs = [(data_raw_preop, data_raw_postop), (data_raw_preop, data_raw_preop), (data_raw_postop, data_raw_postop)]
mock_labels = [1, 0, 0]

raw_pairs = mock_pairs
raw_labels = mock_labels
# Convert the processed data into PyTorch tensors
raw_pairs_tensor = torch.tensor(raw_pairs, dtype=torch.float32)
raw_labels_tensor = torch.tensor(raw_labels, dtype=torch.float32)
print(sys.getsizeof(raw_pairs_tensor))
# print(raw_pairs_tensor)

# Create DataLoader for training
# raw_dataset = TensorDataset(raw_pairs_tensor, raw_labels_tensor)
# raw_loader = DataLoader(raw_dataset, batch_size=1, shuffle=False)


88


  raw_pairs_tensor = torch.tensor(raw_pairs, dtype=torch.float32)


In [None]:
# class SiameseNetwork(nn.Module):
#     def __init__(self, input_size, hidden_size, output_size):
#         super(SiameseNetwork, self).__init__()
#         self.subnetwork = SubNetwork(input_size, hidden_size, output_size)
    
#     def forward(self, input1, input2):
#         # Pass inputs through the subnetwork
#         output1 = self.subnetwork(input1)
#         output2 = self.subnetwork(input2)
        
#         # Compute the Euclidean distance between the outputs
#         distance = torch.sqrt(torch.sum(torch.pow(output1 - output2, 2), dim=1))
        
#         # Normalize the distance to [0, 1] range
#         distance = torch.sigmoid(distance)
        
#         return distance

In [4]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        # Define the architecture for the Siamese network
        self.fc1 = nn.Linear(256*256, 128)  # Adjust input size based on input dimensions
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)  # Output size 1 for binary classification

    def forward(self, input1, input2):
        # Flatten the input tensors
        input1 = input1.view(input1.size(0), -1)
        input2 = input2.view(input2.size(0), -1)
        # Forward pass through the Siamese network
        output1 = F.relu(self.fc1(input1))
        output1 = F.relu(self.fc2(output1))
        output2 = F.relu(self.fc1(input2))
        output2 = F.relu(self.fc2(output2))
        return output1, output2

# Initialize Siamese network
siamese_net = SiameseNetwork()
siamese_net = siamese_net.cuda()  # Move the network to GPU

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(siamese_net.parameters(), lr=0.001)