## Change detection (invariant to angles)

In [1]:
import torch
print(torch.cuda.is_available()) # should be True
# t = torch.rand(10, 10).cuda()
# print(t.device) # should be CUDA

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
from PIL import Image
import os
import nibabel as nib


True


## Start with saggital midpoint images

In [2]:
def convert_to_image_if_not_exist(directory, output_size=(256, 256)):
    for root, dirs, files in os.walk(directory):
        for filename in files:
            if filename.endswith("nii_mask.nii.gz"):
                # Construct input and output paths
                input_path = os.path.join(root, filename)
                output_path = os.path.join(root, os.path.splitext(filename)[0] + "saggital_view" + ".jpg")

                # Check if output image already exists
                if os.path.exists(output_path):
                    print(f"Image {output_path} already exists, skipping...") 
                    continue
                # Saggital - 0, Coronal - 1, Axial - 2
                # Load NIfTI data
                nifti_data = nib.load(input_path)
                image_data = nifti_data.get_fdata()

                slice_index = image_data.shape[0] // 2
                image_slice = image_data[slice_index, :, :]

                # Normalize intensity values
                min_intensity = np.min(image_slice)
                max_intensity = np.max(image_slice)
                image_slice_normalized = (image_slice - min_intensity) / (max_intensity - min_intensity)

                # Resize the slice to the specified output size
                image_slice_resized = np.array(Image.fromarray((image_slice_normalized * 255).astype(np.uint8)).resize(output_size))

                # Convert to image format
                image = Image.fromarray(image_slice_resized)

                # Resize the image to the specified output size
                image = image.resize(output_size)

                # Save the image
                image.save(output_path)
                print(f"Converted masked image {input_path} to {output_path} with size {output_size}")

In [3]:
dir = "./data/raw/preop/BTC-preop"
output_size = (256, 256)  # Specify the desired output size
convert_to_image_if_not_exist(dir, output_size)

Image ./data/raw/preop/BTC-preop/sub-PAT31/ses-preop/anat/sub-PAT31_ses-preop_T1w.nii_mask.niisaggital_view.jpg already exists, skipping...
Image ./data/raw/preop/BTC-preop/sub-CON09/ses-preop/anat/sub-CON09_ses-preop_T1w.nii_mask.niisaggital_view.jpg already exists, skipping...
Image ./data/raw/preop/BTC-preop/sub-CON03/ses-preop/anat/sub-CON03_ses-preop_T1w.nii_mask.niisaggital_view.jpg already exists, skipping...
Image ./data/raw/preop/BTC-preop/sub-PAT25/ses-preop/anat/sub-PAT25_ses-preop_T1w.nii_mask.niisaggital_view.jpg already exists, skipping...
Image ./data/raw/preop/BTC-preop/sub-PAT14/ses-preop/anat/sub-PAT14_ses-preop_T1w.nii_mask.niisaggital_view.jpg already exists, skipping...
Image ./data/raw/preop/BTC-preop/sub-PAT05/ses-preop/anat/sub-PAT05_ses-preop_T1w.nii_mask.niisaggital_view.jpg already exists, skipping...
Image ./data/raw/preop/BTC-preop/sub-CON02/ses-preop/anat/sub-CON02_ses-preop_T1w.nii_mask.niisaggital_view.jpg already exists, skipping...
Image ./data/raw/pre

In [4]:
dir = "./data/raw/postop/BTC-postop"
output_size = (256, 256)  # Specify the desired output size
convert_to_image_if_not_exist(dir, output_size)


Image ./data/raw/postop/BTC-postop/sub-CON09/ses-postop/anat/sub-CON09_ses-postop_T1w.nii_mask.niisaggital_view.jpg already exists, skipping...
Image ./data/raw/postop/BTC-postop/sub-CON03/ses-postop/anat/sub-CON03_ses-postop_T1w.nii_mask.niisaggital_view.jpg already exists, skipping...
Image ./data/raw/postop/BTC-postop/sub-PAT25/ses-postop/anat/sub-PAT25_ses-postop_T1w.nii_mask.niisaggital_view.jpg already exists, skipping...
Image ./data/raw/postop/BTC-postop/sub-PAT05/ses-postop/anat/sub-PAT05_ses-postop_T1w.nii_mask.niisaggital_view.jpg already exists, skipping...
Image ./data/raw/postop/BTC-postop/sub-CON02/ses-postop/anat/sub-CON02_ses-postop_T1w.nii_mask.niisaggital_view.jpg already exists, skipping...
Image ./data/raw/postop/BTC-postop/sub-PAT26/ses-postop/anat/sub-PAT26_ses-postop_T1w.nii_mask.niisaggital_view.jpg already exists, skipping...
Image ./data/raw/postop/BTC-postop/sub-PAT16/ses-postop/anat/sub-PAT16_ses-postop_T1w.nii_mask.niisaggital_view.jpg already exists, skip

## Datasets and dataloaders

In [5]:
class imageSets(Dataset):
    """
    Image dataset for each subject in the dataset
    creating only 'correct' pairs for now
    TODO: create 'incorrect' pairs

    Works by passing preop or postop directory to the class
    and finds the corresponding image in the other dir and labels
    """
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.data = []
        for root, dirs, files in os.walk(self.root):
            for filename in files:
                if filename.endswith("saggital_view.jpg"):
                    img_1 = Image.open(os.path.join(root, filename))
                    ## finds the corresponding image in the other dir
                    try:
                        if "preop" in root:
                            img_2 = Image.open(os.path.join(root.replace("preop", "postop"), filename.replace("preop", "postop")))
                        else:
                            img_2 = Image.open(os.path.join(root.replace("postop", "preop"), filename.replace("postop", "preop")))
                        self.data.append((img_1, img_2, 1))
                    except FileNotFoundError:
                        print(f"Matching subject (pre and post) not found for {filename}")
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        if self.transform:
            img1_file = self.transform(self.data[idx][0])
            img2_file = self.transform(self.data[idx][1])
        return (img1_file, img2_file, self.data[idx][2])

## Network

In [8]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        # Define the architecture for the Siamese network
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.dropout = nn.Dropout(p=0.5)
        self.fc1 = nn.Linear(131072, 128)  # Adjust input size based on input dimensions

    def forward(self, input1, input2):
        # Forward pass through the Siamese network
        output1 = F.relu(self.bn1(self.conv1(input1)))
        output1 = F.max_pool2d(output1, kernel_size=2, stride=2)
        output1 = F.relu(self.bn2(self.conv2(output1)))
        output1 = F.max_pool2d(output1, kernel_size=2, stride=2)
        output1 = F.relu(self.bn3(self.conv3(output1)))
        output1 = F.max_pool2d(output1, kernel_size=2, stride=2)
        output1 = output1.view(output1.size(0), -1)
        output1 = self.dropout(output1)
        output1 = F.relu(self.fc1(output1))

        output2 = F.relu(self.bn1(self.conv1(input2)))
        output2 = F.max_pool2d(output2, kernel_size=2, stride=2)
        output2 = F.relu(self.bn2(self.conv2(output2)))
        output2 = F.max_pool2d(output2, kernel_size=2, stride=2)
        output2 = F.relu(self.bn3(self.conv3(output2)))
        output2 = F.max_pool2d(output2, kernel_size=2, stride=2)
        output2 = output2.view(output2.size(0), -1)
        output2 = self.dropout(output2)
        output2 = F.relu(self.fc1(output2))

        return output1, output2


In [9]:
class ContrastiveLoss(torch.nn.Module):

    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, input1, input2, y):
        diff = input1 - input2
        dist_sq = torch.sum(torch.pow(diff, 2), 1)
        dist = torch.sqrt(dist_sq)
        mdist = self.margin - dist
        dist = torch.clamp(mdist, min=0.0)
        loss = y * dist_sq + (1 - y) * torch.pow(dist, 2)
        loss = torch.sum(loss) / 2.0 / input1.size()[0]
        return loss

In [38]:
# Define the Siamese network architecture


# Train the Siamese network
def train(siamese_net,  optimizer, criterion, epochs=10):
    default_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    ])
    BATCH_SIZE = 4
    train_data = imageSets("./data/raw/preop/BTC-preop", transform=default_transform)
    loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=False)
    siamese_net.train()
    
    print("starting training...")
    total_loss = 0
    for epoch in range(epochs):
        epoch_loss = 0.0000
        for i, (img1_set, img2_set, label) in enumerate(loader):
            # img1_set = img1_set.cuda()
            # img2_set = img2_set.cuda()
            # label = label.cuda()
            print(label)

            output1, output2 = siamese_net(img1_set, img2_set)
            loss = criterion(output1, output2, label)  # Add unsqueeze to match output shape
            loss.backward()
            optimizer.step()
            print(loss.item())
            epoch_loss += loss.item()
           #test
        print('Epoch [%d/%d], Loss: %.4f' % (epoch+1, epochs, epoch_loss))


In [40]:
# Initialize Siamese network
siamese_net = SiameseNetwork()
# siamese_net = siamese_net.cuda()  # Move the network to GPU

# Define loss function and optimizer
# criterion = nn.BCEWithLogitsLoss()
# optimizer = optim.Adam(siamese_net.parameters(), lr=0.001)

criterion = ContrastiveLoss()
optimizer = optim.Adam(siamese_net.parameters(), lr=0.001)
# Train the Siamese network
train(siamese_net,  optimizer, criterion, epochs=10)

File not found for sub-PAT31_ses-preop_T1w.nii_mask.niisaggital_view.jpg
File not found for sub-PAT14_ses-preop_T1w.nii_mask.niisaggital_view.jpg
File not found for sub-CON01_ses-preop_T1w.nii_mask.niisaggital_view.jpg
File not found for sub-PAT19_ses-preop_T1w.nii_mask.niisaggital_view.jpg
File not found for sub-PAT27_ses-preop_T1w.nii_mask.niisaggital_view.jpg
File not found for sub-PAT29_ses-preop_T1w.nii_mask.niisaggital_view.jpg
File not found for sub-PAT22_ses-preop_T1w.nii_mask.niisaggital_view.jpg
starting training...
tensor([1, 1, 1, 1])
20.54092025756836
tensor([1, 1, 1, 1])
16.242149353027344
tensor([1, 1, 1, 1])
5.703119277954102
tensor([1, 1, 1, 1])
2.217747926712036
tensor([1, 1, 1, 1])
2.8933587074279785
tensor([1, 1, 1, 1])
0.0
tensor([1, 1, 1, 1])
0.0
tensor([1])
0.0
Epoch [1/10], Loss: 47.5973
tensor([1, 1, 1, 1])
0.0
tensor([1, 1, 1, 1])
0.0
tensor([1, 1, 1, 1])
0.0
tensor([1, 1, 1, 1])
0.0
tensor([1, 1, 1, 1])
0.0
tensor([1, 1, 1, 1])
0.0
tensor([1, 1, 1, 1])
0.0
te

KeyboardInterrupt: 

In [6]:
# raw_preop = nib.load("./data/raw/preop/sub-CON02_ses-preop_T1w.nii.gz")
# raw_postop = nib.load("./data/raw/postop/sub-CON02_ses-postop_T1w.nii.gz")

# proc_preop = nib.load("./data/processed/preop/fa.nii.gz")
# proc_postop = nib.load("./data/processed/postop/fa.nii.gz")

# # This apparently returns voxel level of the data
# data_raw_preop= raw_preop.get_fdata()
# data_raw_postop= raw_postop.get_fdata()

# data_proc_preop= proc_preop.get_fdata()
# data_proc_postop= proc_postop.get_fdata()

# print(data_raw_preop.shape)

In [3]:
# def make_pairs():
#     pass
import sys

mock_pairs = [(data_raw_preop, data_raw_postop), (data_raw_preop, data_raw_preop), (data_raw_postop, data_raw_postop)]
mock_labels = [1, 0, 0]

raw_pairs = mock_pairs
raw_labels = mock_labels
# Convert the processed data into PyTorch tensors
raw_pairs_tensor = torch.tensor(raw_pairs, dtype=torch.float32)
raw_labels_tensor = torch.tensor(raw_labels, dtype=torch.float32)
print(sys.getsizeof(raw_pairs_tensor))
# print(raw_pairs_tensor)

# Create DataLoader for training
# raw_dataset = TensorDataset(raw_pairs_tensor, raw_labels_tensor)
# raw_loader = DataLoader(raw_dataset, batch_size=1, shuffle=False)


88


  raw_pairs_tensor = torch.tensor(raw_pairs, dtype=torch.float32)


In [None]:
# class SiameseNetwork(nn.Module):
#     def __init__(self, input_size, hidden_size, output_size):
#         super(SiameseNetwork, self).__init__()
#         self.subnetwork = SubNetwork(input_size, hidden_size, output_size)
    
#     def forward(self, input1, input2):
#         # Pass inputs through the subnetwork
#         output1 = self.subnetwork(input1)
#         output2 = self.subnetwork(input2)
        
#         # Compute the Euclidean distance between the outputs
#         distance = torch.sqrt(torch.sum(torch.pow(output1 - output2, 2), dim=1))
        
#         # Normalize the distance to [0, 1] range
#         distance = torch.sigmoid(distance)
        
#         return distance

In [4]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        # Define the architecture for the Siamese network
        self.fc1 = nn.Linear(256*256, 128)  # Adjust input size based on input dimensions
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)  # Output size 1 for binary classification

    def forward(self, input1, input2):
        # Flatten the input tensors
        input1 = input1.view(input1.size(0), -1)
        input2 = input2.view(input2.size(0), -1)
        # Forward pass through the Siamese network
        output1 = F.relu(self.fc1(input1))
        output1 = F.relu(self.fc2(output1))
        output2 = F.relu(self.fc1(input2))
        output2 = F.relu(self.fc2(output2))
        return output1, output2

# Initialize Siamese network
siamese_net = SiameseNetwork()
siamese_net = siamese_net.cuda()  # Move the network to GPU

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(siamese_net.parameters(), lr=0.001)