In [1]:
import cv2
from segmentation_models_pytorch import FPN
import torch
import torch.nn as nn
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pandas as pd
from sklearn.utils import shuffle
from pathlib import PosixPath,PurePosixPath,Path,PurePath
import batch_norm_methods as bn_adapt
from torch.utils.data import DataLoader,Dataset
from torch.nn.parallel import DataParallel
import gc
import matplotlib.pyplot as plt
from tqdm import tqdm
import nrrd  
import numpy as np
import pydicom
import os
import SimpleITK as sitk


In [None]:
# root_dir=Path('../data/UCSF_ComeBACK/xnat.radiology.ucsf.edu/San_Francisco_cohort/')  # SF CMBK-0101
root_dir=Path('../data/UCSF_ComeBACK/xnat.radiology.ucsf.edu/Other_sites') # Other site CMBK-0102, CMBK-0103, CMBK-0104
# root_dir=Path('../data/UCSF_ComeBACK/COMEBACK_Controls/data/mri/')  # SF controls
bsize=24
alpha=0.2 # Lower alpha, more weight to Target Dataset statistics
device='cpu'

### For the patient cohorts:

In [None]:
images = [str(fname) for fname in list(PurePath.joinpath(Path(root_dir)).glob("**/*.dcm"))]

slice_ids = [Path(x).stem for x in images]  # .stem gives the filename without extension
patient_ids = [x.split('/')[8] for x in images]
df = pd.DataFrame(data=list(zip(patient_ids, images, slice_ids)), columns=['patient_id', 'image', 'slice_id'])
test_df = shuffle(df)
test_df.head(3)

metadata = pd.read_csv('../data/UCSF_ComeBACK/xnat.radiology.ucsf.edu/dicom_metadata_mid_third_slices.csv')
metadata = metadata[metadata['PathToFolder'].str.contains('CMBK-0104')]
metadata['FileName'] = metadata['FileName'].str.replace('.dcm', '')

test_df = test_df[test_df['image'].str.contains('|'.join(metadata['FileName']))]

  metadata = pd.read_csv('/Users/tmcsween21/Documents/data/UCSF_ComeBACK/xnat.radiology.ucsf.edu/dicom_metadata_mid_third_slices.csv')


### For the Control cohort:

In [4]:
# # Get full file paths instead of just filenames
# images = [str(fname) for fname in list(PurePath.joinpath(Path(root_dir)).glob("**/*.DCM"))]

# # subset images for entris that contain the string _T2w_
# images = [x for x in images if '_T2w_' in x]

# # make this into a dataframe
# control_files = pd.DataFrame(images, columns=['image'])

# # split by / and put the last part in a new column called lsice id
# control_files['slice_id'] = control_files['image'].apply(lambda x: x.split('/')[-1])

# # split the slice id by _ and take the last part as the slice number and convert to int
# control_files['slice_num'] = control_files['slice_id'].apply(lambda x: int(x.split('_')[-1].split('.')[0]))

# # plit image by / and take the 8th part as the patient id
# control_files['patient_id'] = control_files['image'].apply(lambda x: x.split('/')[9])

# # group by patient id and order the slice numbers into a list
# control_files_grouped = control_files.groupby('patient_id')['slice_num'].apply(list).reset_index()

# # for each row divide the length of the slice number by 3 and round down to the nearest whole number
# control_files_grouped['num_slices'] = control_files_grouped['slice_num'].apply(lambda x: np.floor(len(x)/3))

# # reorder the numbers in each list to be in sequential order
# control_files_grouped['slice_num'] = control_files_grouped['slice_num'].apply(lambda x: sorted(x))

# # remove the first x number of slices from the list where x is the number of slices divided by 3
# control_files_grouped['middle_third'] = control_files_grouped.apply(lambda x: x['slice_num'][int(x['num_slices']):], axis=1)

# # remove the last x number of slices from the list where x is the number of slices divided by 3
# control_files_grouped['middle_third'] = control_files_grouped.apply(lambda x: x['middle_third'][:-int(x['num_slices'])], axis=1)

# # drop the slice_num and num_slices columns
# control_files_grouped.drop(['slice_num', 'num_slices'], axis=1, inplace=True)

# # create a new row for each slice in the middle third list
# control_files_grouped = control_files_grouped.explode('middle_third')

# # merge the control_files_grouped dataframe with the control_files dataframe on the middle_third column and the patient_id column
# control_files = pd.merge(control_files, control_files_grouped, how='inner', left_on=['patient_id', 'slice_num'], right_on=['patient_id', 'middle_third'])

In [5]:
# test_df = control_files[['patient_id', 'image', 'slice_id']]
# # remove .DCM from the end of the slice_id column
# test_df['slice_id'] = test_df['slice_id'].str.replace('.DCM', '')
# test_df.head(3)


In [6]:
# take a sample of 150 images from test_df
# test_df = test_df.sample(150, random_state=42).reset_index(drop=True)

In [7]:
class SpineSeg(nn.Module):
    def __init__(self):
        super(SpineSeg, self).__init__()
        self.model=FPN(encoder_name='resnet34',in_channels=1,classes=14)
    def forward(self,x):
        return self.model(x)

In [8]:
class DatasetLoader(Dataset):
    def __init__(self,dataset,data_slice, aug_type='test'):
        #Assign dataset
        self.dataset=dataset
        # Data slice of overall dataset
        self.data_slice = data_slice
        # Size of the Data Slice
        self.dataset_size = data_slice.shape[0]
        # Apply transformations based on train or val set
        self.aug_type = aug_type


    def __len__(self):
        return self.dataset_size

    def __getitem__(self, index):
        # Get the filename of the image
        slice_id = self.data_slice.iloc[index]['slice_id']
        patient_id = self.data_slice.iloc[index]['patient_id']
        img_file = self.data_slice.iloc[index]['slice_id'] + '.dcm' 
    
        image_path = self.data_slice.iloc[index]['image']
        # Load DICOM image
        image_raw = self.load_dicom(image_path)
        
        # Apply transformations
        transform_raw = self.apply_img_transforms(image_raw)
        return {"img_file": img_file, "transformed_raw": transform_raw, "slice_id": slice_id, "patient_id": patient_id, "image_path": image_path}

    def load_dicom(self, dicom_path):
        """ Load a DICOM file, rescale it to [0-255], and return as uint8. """
        dicom = pydicom.dcmread(str(PurePosixPath(dicom_path)))

        # Convert pixel data to numpy array
        int16_image = dicom.pixel_array.astype(np.float32)

        # Find min/max values
        min_val = np.min(int16_image)
        max_val = np.max(int16_image)

        # Rescale to 0-255
        if min_val != max_val:
            scale = 255.0 / (max_val - min_val)
            uint8_image = ((int16_image - min_val) * scale).round().astype(np.uint8)
        else:
            # Handle case where all pixel values are the same
            uint8_image = np.ones_like(int16_image, dtype=np.uint8) * min_val

        return uint8_image
    
    # Apply Image Transformations using Albumentations
    # When you normalize the class labels of masks change.
    def apply_img_transforms(self, img_raw):
        if self.aug_type=="test":
            test_transform = A.Compose(
                 [A.Normalize(mean=0.181, std=0.184, always_apply=True, p=1.0),
                 A.Resize(512, 512),
                 ToTensorV2()])
            transformed = test_transform(image=img_raw)
            transformed_img = transformed['image']
            # Tensors are to be converted to float tensors as cv2 is giving byte tensor
            transformed_img = transformed_img.type(torch.FloatTensor)
            return transformed_img

In [9]:
def lumbar_mask(f):
    lumbar_tensor=f
    lumbar_tensor[lumbar_tensor==12] = 0
    lumbar_tensor[lumbar_tensor==13] = 0

    return lumbar_tensor 

### Generate the data loader

In [10]:
dataset={"root_dir":root_dir,"img_dir":"images"}
test_ds = DatasetLoader(dataset,test_df, 'test')
test_loader = DataLoader(dataset=test_ds, batch_size=bsize,num_workers=0)

### Load the original trained models

In [None]:
# load the original trained models:
checkpoint_path="../data/segmentation_model_weights_14032024"  

checkpoints = list(Path(checkpoint_path).glob("SpineSeg*"))
models=[]
for checkpoint in checkpoints:
    model_checkpoint = checkpoint
    model= SpineSeg()
    state = torch.load(model_checkpoint, map_location=torch.device('cpu') )['model_state']
    model.load_state_dict(state)
    bn_adapt.adapt_weightedBN(model,alpha=alpha)
    model.to(device)
    if device == 'cuda':
        model = DataParallel(model)
    model.train()
    models.append(model)

| Found 36 modules to be replaced.
| Found 36 modules to be replaced.
| Found 36 modules to be replaced.
| Found 36 modules to be replaced.
| Found 36 modules to be replaced.


### Run the domain adaptation

In [22]:
print(len(test_loader))
#BN Adapt
for i,test_data in tqdm(enumerate(test_loader)):
    print(i)

    if i == 29:
        break
    test_images=test_data['transformed_raw'].to(device)
    for model in models:
        _ = model(test_images)
print('Domain Adaptation Complete')

35


0it [00:00, ?it/s]

0


1it [00:52, 52.65s/it]

1


2it [01:44, 51.99s/it]

2


3it [02:40, 53.89s/it]

3


4it [03:37, 55.08s/it]

4


5it [04:35, 56.18s/it]

5


6it [05:26, 54.43s/it]

6


7it [06:18, 53.60s/it]

7


8it [07:07, 52.24s/it]

8


9it [07:56, 51.19s/it]

9


10it [08:44, 50.09s/it]

10


11it [09:32, 49.71s/it]

11


12it [10:20, 49.06s/it]

12


13it [11:13, 50.35s/it]

13


14it [12:03, 50.06s/it]

14


15it [12:54, 50.27s/it]

15


16it [13:41, 49.38s/it]

16


17it [14:26, 48.26s/it]

17


18it [15:15, 48.28s/it]

18


19it [16:01, 47.66s/it]

19


20it [16:48, 47.55s/it]

20


21it [17:34, 47.08s/it]

21


22it [18:19, 46.43s/it]

22


23it [19:11, 47.92s/it]

23


24it [20:02, 49.09s/it]

24


25it [20:55, 50.26s/it]

25


26it [21:43, 49.44s/it]

26


27it [22:28, 48.19s/it]

27


28it [23:15, 47.84s/it]

28


29it [23:59, 49.63s/it]

29
Domain Adaptation Complete





In [None]:
for i, test_data in tqdm(enumerate(test_loader)):
    if i == 30:
        continue
    try:
        dicom_path = test_data['img_file']  # Modify based on dataset structure
        print(dicom_path)
        test_images = test_data['transformed_raw'].to(device)

        for model in models:
            _ = model(test_images)

    except ValueError as e:
        print(f"Skipping corrupted file: {dicom_path} | Error: {e}")
        continue  # Skip to the next item

0it [00:00, ?it/s]

['9999.92684734282789915878910181250787153122-5-13-7ulnc1.dcm', '9999.31699787059162127861485996968769613189-5-14-19rvrpa.dcm', '9999.148422095646504608214593584451361252229-5-12-4fukxd.dcm', '9999.53615604897189716474771048794532318298-5-9-1d6zsqq.dcm', '9999.200329822020012919450143079036443272282-5-15-1a3i0yy.dcm', '9999.158883181974054731955445435652959704831-4-12-1q3k1tf.dcm', '9999.109740486636312245581545813160398581463-5-16-dlaxd4.dcm', '9999.199228736767843594664666847966965742768-5-17-92h4yp.dcm', 'Anon-6510483331405995986.dcm', '9999.275724190541001283042949349443038646921-5-16-sql1gg.dcm', '9999.148422095646504608214593584451361252229-5-16-21kqs9.dcm', '9999.217169010983060717519489818559071473823-5-16-2qzyen.dcm', '9999.112717985666400007954654845309037232460-5-19-1axaroi.dcm', '9999.200329822020012919450143079036443272282-5-17-56lwow.dcm', '9999.145046372870071345392296811313495476221-5-9-10l9mgk.dcm', '9999.148561699873986463543854905843721193269-5-17-1sb25i3.dcm', '9999

0it [00:23, ?it/s]


KeyboardInterrupt: 

### Save the domain adapted model weights

In [None]:
folder = '../data/segmentation_model_weights_domain_adapted_UCSF_CMBK-0104'

# save the models
for i in range(len(models)):
    torch.save({'model_state': model.state_dict()},  f'{folder}/SpineSeg_adapted_{i}.pth')

### Load the domain adapted model weights

In [None]:
checkpoint_path=folder
checkpoints = list(Path(checkpoint_path).glob("SpineSeg*"))

models = []
device = torch.device("cpu")

for checkpoint in checkpoints:
    model = SpineSeg()  
    bn_adapt.adapt_weightedBN(model, alpha=alpha)  
    state = torch.load(checkpoint, map_location=device)['model_state']
    model.load_state_dict(state) 
    
    model.to(device)
    model.eval() 
    models.append(model)

| Found 36 modules to be replaced.
| Found 36 modules to be replaced.
| Found 36 modules to be replaced.
| Found 36 modules to be replaced.
| Found 36 modules to be replaced.


In [25]:
def convert_dicom_to_nrrd(dicom_path, nrrd_path):
    # Load the DICOM file
    dicom_data = pydicom.dcmread(dicom_path)

    # Get pixel array and convert it to SimpleITK format
    pixel_array = dicom_data.pixel_array.astype(np.float32)  # Ensure correct type
    sitk_image = sitk.GetImageFromArray(pixel_array)

    # Save the image as an NRRD file
    sitk.WriteImage(sitk_image, nrrd_path)

In [None]:
save_dir = "../data/UCSF_ComeBACK/mask_evaluation/CMBK-0104_domain_adapted_alpha_02"

batch_count = 0  # Counter to track number of batches processed
saved_count = 0  # Counter to track saved images

for batch in test_loader:
    
    if batch_count <=32:
        batch_count += 1
        continue


    print(batch_count)

    images = batch['transformed_raw'].to(device)
    batch_patient_ids = batch['patient_id']  # Extract patient IDs from the batch
    batch_slice_ids = batch['slice_id']  # Extract slice IDs from the batch
    batch_dicom_paths = batch['image_path']  # Extract DICOM paths from the batch

    # Compute predictions
    predictions = 0
    with torch.no_grad():
        for model in models:
            model.eval()
            predictions += model(images)

    outputs = predictions / len(models)
    outputs = outputs.argmax(axis=1)
    outputs = lumbar_mask(outputs)
    outputs = outputs.detach().cpu().numpy()

    batch_count += 1

    # Loop through images in batch
    for i in range(len(outputs)):
        # if saved_count >= 40:  # Stop once 40 images are saved
        #     break

        # Get correct identifiers from the batch
        patient_id = batch_patient_ids[i]
        slice_id = batch_slice_ids[i]
        dicom_image = batch_dicom_paths[i]

        if patient_id == 'CMBK-0104-00102':
            continue
        if patient_id == 'CMBK-0104-00003':
            continue
        if patient_id == 'CMBK-0104-00027':
            continue

        mask_save_path = os.path.join(save_dir, f"{patient_id}_{slice_id}_segmentation.nrrd")
        print(saved_count, patient_id, slice_id)
        print('##################')

        # Save the mask NRRD file
        corrected_mask = outputs[i].T  # Flip vertically & transpose
        nrrd.write(mask_save_path, corrected_mask.astype(np.uint8))

        # Save the DICOM as NRRD
        nrrd_save_path = os.path.join(save_dir, f"{patient_id}_{slice_id}.nrrd")
        convert_dicom_to_nrrd(dicom_image, nrrd_save_path)

        # Save visualization
        plt.figure(figsize=(8, 8))
        plt.axis('off')
        plt.subplot(1, 2, 1)
        plt.imshow(images[i].squeeze().cpu().numpy(), cmap='gray')
        plt.axis('off')
        plt.subplot(1, 2, 2)
        plt.imshow(outputs[i], cmap='Paired')

        png_save_path = os.path.join(save_dir, f"{patient_id}_{slice_id}_eval.png")
        plt.savefig(png_save_path)
        plt.close()

        saved_count += 1

    # if saved_count >= 40:  # Stop iterating over batches if we reach 40 images
    #     break

# Clean up memory
gc.collect()
torch.cuda.empty_cache()



ValueError: The length of the pixel data in the dataset (414368 bytes) doesn't match the expected length (524288 bytes). The dataset may be corrupted or there may be an issue with the pixel data handler.