In [22]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import nibabel as nib
import numpy as np
import SimpleITK as sitk

In [23]:
TRAIN_DATASET_PATH = '/home/islam/Downloads/archive_001/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData'
VALIDATION_DATASET_PATH = '/home/islam/Downloads/archive_001/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData'

In [24]:

def correct_bias_field(input_image,path):
    """
    Perform bias field correction on a 3D MRI image.
    
    Args:
        input_image_path (str): Path to the 3D MRI image (.nii or .nii.gz).
    
    Returns:
        sitk.Image: Bias field corrected 3D MRI image.
    """
    # Load the input image
    #raw_img_sitk = sitk.ReadImage(input_image_path, sitk.sitkFloat32)
    raw_img_sitk = sitk.GetImageFromArray(input_image)
    raw_img_sitk = sitk.DICOMOrient(raw_img_sitk, 'RPS')  # Reorient to standard
    
    # Create a head mask
    transformed = sitk.RescaleIntensity(raw_img_sitk, 0, 255)
    head_mask = sitk.LiThreshold(transformed, 0, 1)
    
    # Bias Field Correction
    shrink_factor = 4
    input_image_shrunk = sitk.Shrink(raw_img_sitk, [shrink_factor] * raw_img_sitk.GetDimension())
    mask_image_shrunk = sitk.Shrink(head_mask, [shrink_factor] * raw_img_sitk.GetDimension())
    
    bias_corrector = sitk.N4BiasFieldCorrectionImageFilter()
    corrected_shrunk = bias_corrector.Execute(input_image_shrunk, mask_image_shrunk)
    
    # Full-resolution correction
    log_bias_field = bias_corrector.GetLogBiasFieldAsImage(raw_img_sitk)
    corrected_full_resolution = raw_img_sitk / sitk.Exp(log_bias_field)
    #corrected_full_resolution.CopyInformation(input_image)
    
    # Convert the SimpleITK image to a Numpy array
    corrected_np = sitk.GetArrayFromImage(corrected_full_resolution)
    
    # Create a NIfTI image using nibabel
    affine = np.array(corrected_full_resolution.GetDirection()).reshape(3, 3) @ np.diag(corrected_full_resolution.GetSpacing())
    origin = np.array(corrected_full_resolution.GetOrigin())
    affine = np.hstack([affine, origin[:, None]])
    affine = np.vstack([affine, [0, 0, 0, 1]])  # Add last row for affine
    
    nifti_image = nib.Nifti1Image(corrected_np, affine)
    nib.save(nifti_image, path)
    return nifti_image

# Example usage:
# corrected_image = correct_bias_field("path_to_your_mri_image.nii.gz")
# sitk.WriteImage(corrected_image, "corrected_image.nii.gz")

In [25]:
from gibbs_removal import gibbs_removal

def gibbs_removal_nii(input_image_nii,output_filepath, slice_axis=2, n_points=3):
    """
    Applies Gibbs ringing removal to a .nii MRI file.

    Parameters
    ----------
    input_filepath : str
        Path to the input .nii file.
    output_filepath : str
        Path to save the corrected .nii file.
    slice_axis : int, optional
        Data axis corresponding to the number of acquired slices. Default is 2.
    n_points : int, optional
        Number of neighbor points to access local TV. Default is 3.
    """
    # Load the .nii file
    #nii = nib.load(input_image_nii)
    vol = input_image_nii.get_fdata()  # Extract data as a NumPy array
    affine = input_image_nii.affine  # Get the affine transformation matrix
    header = input_image_nii.header  # Get the header information

    # Apply Gibbs ringing removal
    corrected_vol = gibbs_removal(vol, slice_axis=slice_axis, n_points=n_points)

    # Create a new NIfTI image with the corrected data
    corrected_nii = nib.Nifti1Image(corrected_vol, affine, header)

    # Save the corrected image to the specified output file
    nib.save(corrected_nii, output_filepath)

    #print(f"Gibbs ringing removal applied. Corrected file saved at: {output_filepath}")
    return corrected_vol


In [26]:
def load_study(study_id):
    # Define file paths for each modality
    var='Training'
    image_paths = [
        f'{TRAIN_DATASET_PATH}/BraTS20_{var}_{study_id}/BraTS20_{var}_{study_id}_t1.nii',      # T1 modality
        f'{TRAIN_DATASET_PATH}/BraTS20_{var}_{study_id}/BraTS20_{var}_{study_id}_t1ce.nii',    # T1ce modality
        f'{TRAIN_DATASET_PATH}/BraTS20_{var}_{study_id}/BraTS20_{var}_{study_id}_t2.nii',      # T2 modality
        f'{TRAIN_DATASET_PATH}/BraTS20_{var}_{study_id}/BraTS20_{var}_{study_id}_flair.nii'
    ]
    image_paths1 = [
        f'{TRAIN_DATASET_PATH}/BraTS20_{var}_{study_id}/BraTS20_{var}_{study_id}_t1_bias.nii',      # T1 modality
        f'{TRAIN_DATASET_PATH}/BraTS20_{var}_{study_id}/BraTS20_{var}_{study_id}_t1ce_bias.nii',    # T1ce modality
        f'{TRAIN_DATASET_PATH}/BraTS20_{var}_{study_id}/BraTS20_{var}_{study_id}_t2_bias.nii',      # T2 modality
        f'{TRAIN_DATASET_PATH}/BraTS20_{var}_{study_id}/BraTS20_{var}_{study_id}_flair_bias.nii',
        f'{TRAIN_DATASET_PATH}/BraTS20_{var}_{study_id}/BraTS20_{var}_{study_id}_t1_ring.nii',      # T1 modality
        f'{TRAIN_DATASET_PATH}/BraTS20_{var}_{study_id}/BraTS20_{var}_{study_id}_t1ce_ring.nii',    # T1ce modality
        f'{TRAIN_DATASET_PATH}/BraTS20_{var}_{study_id}/BraTS20_{var}_{study_id}_t2_ring.nii',      # T2 modality
        f'{TRAIN_DATASET_PATH}/BraTS20_{var}_{study_id}/BraTS20_{var}_{study_id}_flair_ring.nii'     # FLAIR modality
    ]
    #label_path = f'{TRAIN_DATASET_PATH}/BraTS20_{var}_{study_id}/BraTS20_{var}_{study_id}_seg.nii'  # Segmentation label

    # Load each modality and stack into a 4-channel tensor
    modalities = [nib.load(img_path).get_fdata().astype(np.float32) for img_path in image_paths]
    modalities_cropped = []
    i = 0
    for modality in modalities:
        #numpy_array = modality.get_fdata()
        modality = correct_bias_field(modality,image_paths1[i])
        modality = gibbs_removal_nii(modality,image_paths1[i+4], slice_axis=2, n_points=3)
        i = i+1

In [27]:
train_ids=['192', '136', '014', '191', '294', '235', '026', '150', '131', '144', '022', '018', '115', '210', '053', '273', '176', '352', '020', '359', '309', 
           '123', '264', '160', '149', '348', '260', '369', '182', '122', '366', '238', '230', '248', '322', '029', '221', '339', '145', '031', '334', '263', 
           '279', '007', '038', '074', '025', '347', '002', '127', '280', '082', '064', '172', '362', '114', '231', '043', '165', '278', '016', '197', '067', 
           '307', '331', '200', '203', '344', '346', '327', '089', '275', '178', '059', '300', '008', '017', '284', '224', '297', '218', '351', '019', '096', 
           '283', '036', '336', '075', '311', '360', '308', '081', '320', '364', '277', '342', '292', '316', '057', '243', '032', '179', '356', '217', '326', 
           '125', '232', '048', '242', '285', '170', '132', '139', '291', '164', '233', '054', '035', '215', '196', '357', '202', '324', '254', '270', '133', 
           '183', '350', '046', '162', '343', '073', '288', '173', '228', '214', '166', '213', '009', '353', '298', '319', '141', '155', '129', '241', '257', 
           '072', '256', '239', '249', '027', '229', '159', '094', '110', '174', '108', '318', '253', '281', '119', '024', '113', '194', '068', '041', '177', 
           '306', '310', '290', '272', '120', '042', '367', '187', '354', '154', '104', '087', '330', '091', '258', '338', '188', '088', '321', '341', '340', 
           '045', '169', '227', '148', '325', '086', '286', '153', '247', '050', '167', '199', '212', '021', '085', '315', '274', '333', '006', '109', '299', 
           '100', '034', '077', '083', '246', '101', '328', '126', '049', '295', '195', '220', '137', '152', '244', '175', '223', '013', '063', '102', '304', 
           '361', '205', '066', '023', '158', '234', '084', '180', '240', '001', '058', '118', '337', '030', '282', '103', '156', '116', '143', '117', '216', '313', '171', '190', '293', '365', '012', '266', '252', '040', '236', '004', '128', '301', '015', '245', '011', '204', '314', '161', 
          '047', '051', '157', '358', '186', '289', '044', '181', '261', '037', '265', '105', '302', '317', '209', '039', '106', '355', '093', '198', '095', 
          '134', '076', '312', '255', '349', '135', '193', '323', '168', '219', '185', '010', '250', '335', '251', '065', '080', '267', '206', '124', '363', '303', '005', '146', 
          '056', '142', '140', '329', '268', '092', '201', '097', '269', '147', '071', '090', '211', '151', '070', '078', '112', '207', '296', '062', '271', 
          '111', '259', '060', '189', '345', '130', '226', '225', '287', '138', '028', '052', '237', '003', '098', '033', '069', '262', '184', '099', '208', 
          '276', '121', '055', '368', '305', '163', '061', '107', '222', '079', '332']

test_ids=[]
val_ids= []

In [28]:
for tid in train_ids:
    load_study(tid)
    print(tid)

192
136
014
191
294
235
026
150
131
144
022
018
115
210
053
273
176
352
020
359
309
123
264
160
149
348
260
369
182
122
366
238
230
248
322
029
221
339
145
031
334
263
279
007
038
074
025
347
002
127
280
082
064
172
362
114
231
043
165
278
016
197
067
307
331
200
203
344
346
327
089
275
178
059
300
008
017
284
224
297
218
351
019
096
283
036
336
075
311
360
308
081
320
364
277
342
292
316
057
243
032
179
356
217
326
125
232
048
242
285
170
132
139
291
164
233
054
035
215
196
357
202
324
254
270
133
183
350
046
162
343
073
288
173
228
214
166
213
009
353
298
319
141
155
129
241
257
072
256
239
249
027
229
159
094
110
174
108
318
253
281
119
024
113
194
068
041
177
306
310
290
272
120
042
367
187
354
154
104
087
330
091
258
338
188
088
321
341
340
045
169
227
148
325
086
286
153
247
050
167
199
212
021
085
315
274
333
006
109
299
100
034
077
083
246
101
328
126
049
295
195
220
137
152
244
175
223
013
063
102
304
361
205
066
023
158
234
084
180
240
001
058
118
337
030
282
103
156
116
143
