In [1]:
import sys
sys.path.append('C:\\Users\\DELL\\Graduation Project\\Codes and Scripts\\type-identification-mri-sequences\\')

import os
import random
import cv2
import time
import glob
import pandas as pd
import nibabel as nib
import numpy as np
import imageio
from scipy import ndimage
import shutil

import torch
import torch.utils.data as data
import tensorflow as tf

from models import select_net
from time_util import time_format
from MedicalDataset import MedicalDataset

from HelperFunctions import preprocess_image_train, generate_images_GIF, predict_image, black_seq_generator
from SynthModels import ResNet, unet_model, Discriminator, old_squeeze, SqueezeAttention, gans_SqueezeAttention

 The versions of TensorFlow you are currently using is 2.6.0 and is not supported. 
Some things might work, some things might not.
If you were to encounter a bug, do not file an issue.
If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. 
You can find the compatibility matrix in TensorFlow Addon's readme:
https://github.com/tensorflow/addons


In [2]:
def fix_random_seeds():
    torch.backends.cudnn.deterministic = True
    random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)
    np.random.seed(1)

def csv_maker(path):
    
    #1- making the CSV
    test_path = path
    test_data_list = sorted(glob.glob(test_path + '/*'))

    #Delete segmentation(labels)
    for test_data in test_data_list:
        if test_data.endswith('seg.nii.gz'):
            os.remove(test_data)        
    test_data_list = sorted(glob.glob(test_path + '/*'))

    num_rows = len(test_data_list)
    test_data = pd.DataFrame({'path': test_data_list, 'label': np.nan})
    print(test_data)
    return test_path, test_data_list, num_rows, test_data
    
def seq_identification(path, model_file):
    test_path, test_data_list, num_rows, test_data = csv_maker(path)
    n_slices = 4
    tridim = False
    consider_other_class = not '--no-other'
    architecture = 'resnet18'

    assert(architecture in ['resnet18', 'alexnet', 'vgg', 'squeezenet', 'mobilenet'])

    fix_random_seeds()
    test_set = MedicalDataset(test_data, min_slices = n_slices, consider_other_class = consider_other_class, test = True)
    test_loader = data.DataLoader(test_set, num_workers = 8, pin_memory = True)

    n_test_files = test_set.__len__()
    classes = ['FLAIR', 'T1', 'T1c', 'T2']
    paths = []
    net = select_net(architecture, n_slices, tridim, consider_other_class)

    if torch.cuda.is_available():
        net = net.cuda()

    start_time = time.time()

    #test
    net.load_state_dict(torch.load(os.path.join('models', model_file), map_location=torch.device('cpu')))
    net.eval()
    correct = 0
    total = 0
    correct_per_class = [0] * len(classes)
    total_per_class = [0] * len(classes)
    actual_classes = []
    predicted_classes = []
    wrong_predictions = []
    flag = 0
    with torch.no_grad():
        for i, (pixel_data, label, path) in enumerate(test_loader):
            print('Tested', i + 1, 'of', n_test_files, 'files.')
            label_as_num = label.numpy()[0]
            if tridim:
                pixel_data = pixel_data.view(-1, 1, 10, 200, 200)

            outputs = net(pixel_data) #.cuda()
            _, predicted = torch.max(outputs.data, 1)

            predicted_classes.append(classes[predicted.cpu().numpy()[0]])
            paths.append(path[0])
            flag += 1
            if flag == 4:
    #             predicted_classes.clear()
    #             paths.clear()
                flag = 0


        if set(predicted_classes) == set(classes):
            print("ALL GOOD.\nOUT FOR SEGMENTATION\n")
        else:
            print("NOPE MISSING MODES.\nOUT FOR DATA CONVERSION\n")
            inter = list(set(predicted_classes).intersection(set(classes)))
            print("AVAILABLE MODES: {}".format(inter))

        print("\nPrediction:  ", predicted_classes)
        print("Paths: ", paths)
    return classes, predicted_classes, test_path, test_data_list, num_rows, test_data, paths

def extract_images(test_path, predicted_classes, paths):
    avail_0 = os.path.join(test_path, "images", "imgs " + str(predicted_classes[0]), "imgs 0")
    avail_1 = os.path.join(test_path, "images", "imgs " + str(predicted_classes[1]), "imgs 1")
    avail_2 = os.path.join(test_path, "images", "imgs " + str(predicted_classes[2]), "imgs 2")
    path_until_images = os.path.dirname(os.path.dirname(avail_0))
    
    if not os.path.exists(avail_0):
        os.makedirs(avail_0)

    if not os.path.exists(avail_1):
        os.makedirs(avail_1)

    if not os.path.exists(avail_2):
        os.makedirs(avail_2)

    nifti_0 = nib.load(paths[0]).get_fdata()                    
    nifti_1 = nib.load(paths[1]).get_fdata()                          
    nifti_2 = nib.load(paths[2]).get_fdata()                          


    for i in range(nifti_0.shape[2]):
        slice_0 = nifti_0[:, :, i]
        slice_1 = nifti_1[:, :, i]
        slice_2 = nifti_2[:, :, i]

        path_0 = f"{avail_0}/{i:03d}.png"
        path_1 = f"{avail_1}/{i:03d}.png"
        path_2 = f"{avail_2}/{i:03d}.png"

        imageio.imwrite(path_0, slice_0)
        imageio.imwrite(path_1, slice_1)
        imageio.imwrite(path_2, slice_2)        
    return path_until_images
        
def model_selector(missing, path_until_images):
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    
    if missing == 'FLAIR':
        generator_g = gans_SqueezeAttention().model
        print(generator_g.summary())
        checkpoint_path = r'E:\Graduation Project\55-EP-MODELS-FLAIR-MIX\T2-FLAIR'
        ckpt = tf.train.Checkpoint(generator_g=generator_g)
        ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
        ckpt.restore(ckpt_manager.latest_checkpoint)
        print(f'Last Check Point: {ckpt_manager.latest_checkpoint} is restored')
        
        seq_directories = os.listdir(path_until_images)

        # Find the folder that ends with "T2"
        for directory in seq_directories:
            if directory.endswith("T2"):
                desired_folder = os.path.join(path_until_images, directory)
                return desired_folder, generator_g, AUTOTUNE        
    elif missing == 'T1':
        pass
    elif missing == 'T2':
        pass
    
def load_imgs(desired_folder, AUTOTUNE):
    img_data = tf.keras.preprocessing.image_dataset_from_directory(
                          desired_folder,
                          seed=123,
                          image_size=(256, 256),
                          batch_size=1,
                          shuffle = False)
    img_data = img_data.cache().prefetch(buffer_size=AUTOTUNE)
    img_data = img_data.map(lambda x, _: (preprocess_image_train(x)))
    return img_data

def main_seq_ident(path, model_file):
    classes, predicted_classes, test_path, test_data_list, num_rows, test_data, paths = seq_identification(path, model_file)
    print(f"Classes: {classes}")
    print(f"Predicted Classes: {predicted_classes}")
    missing = set(classes) - set(predicted_classes)
    print("Missing:", list(missing))
    
    if len(missing) > 1:
        raise Exception("wth is wrong bro?!")    

    if len(missing) < 1:
        raise Exception("wtf is wrong bro?!")     
        
    path_until_images = extract_images(test_path, predicted_classes, paths)
    missing = next(iter(missing))
    desired_folder, generator_g, AUTOTUNE = model_selector(missing, path_until_images)
    img_data = load_imgs(desired_folder, AUTOTUNE)
    
    vol = []
    dep = len(glob.glob(desired_folder + '/*'))
    for image in img_data:
        img = predict_image(image, generator_g)
        vol.append(img)
    
    # Delete the images directory used for prediction and its contents
    shutil.rmtree(path_until_images)
    
    original_vol_path = sorted(glob.glob(test_path + '/*'))[1]
    original_vol = nib.load(original_vol_path)
    original_shape = original_vol.shape

    vol = np.array(vol).transpose(1, 2, 0)
    vol = ndimage.zoom(vol, (original_shape[0]/vol.shape[0],
                                               original_shape[1]/vol.shape[1],
                                               original_shape[2]/vol.shape[2]), order=0)

    v = nib.Nifti1Image(np.array(vol), original_vol.affine)            # to save this 3D (ndarry) numpy
    brats_num = int(os.path.basename(test_path)[-5:])

    if missing == 'T1':
        nib.save(v, test_path + '/' + f'BraTS2021_0{brats_num:04d}_t1.nii.gz')

    elif missing == 'T2':
        nib.save(v, test_path + '/' + f'BraTS2021_0{brats_num:04d}_t2.nii.gz')

    elif missing == 'FLAIR':
        nib.save(v, test_path + '/' + f'BraTS2021_0{brats_num:04d}_flair.nii.gz')

In [3]:
main_seq_ident(path=r"C:\Users\DELL\Desktop\BraTS2021_00000", model_file=r"C:\Users\DELL\Graduation Project\Codes and Scripts\type-identification-mri-sequences\models\b4_sl4.pth")

                                                path  label
0  C:\Users\DELL\Desktop\BraTS2021_00000\BraTS202...    NaN
1  C:\Users\DELL\Desktop\BraTS2021_00000\BraTS202...    NaN
2  C:\Users\DELL\Desktop\BraTS2021_00000\BraTS202...    NaN
Starting loading data...
Loading C:\Users\DELL\Desktop\BraTS2021_00000\BraTS2021_00000_t1.nii.gz
Loaded 1 / 3 (counting discarded).
Loading C:\Users\DELL\Desktop\BraTS2021_00000\BraTS2021_00000_t1ce.nii.gz
Loaded 2 / 3 (counting discarded).
Loading C:\Users\DELL\Desktop\BraTS2021_00000\BraTS2021_00000_t2.nii.gz
Loaded 3 / 3 (counting discarded).

Loading time: 0h0min0s
Actually loaded: 3 ("Other" class discarded)
Tested 1 of 3 files.
Tested 2 of 3 files.
Tested 3 of 3 files.
NOPE MISSING MODES.
OUT FOR DATA CONVERSION

AVAILABLE MODES: ['T2', 'T1c', 'T1']

Prediction:   ['T1', 'T1c', 'T2']
Paths:  ['C:\\Users\\DELL\\Desktop\\BraTS2021_00000\\BraTS2021_00000_t1.nii.gz', 'C:\\Users\\DELL\\Desktop\\BraTS2021_00000\\BraTS2021_00000_t1ce.nii.gz', 'C:\\Use















Model: "gams_Squeeze-Attention-UNET"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 256, 256, 64) 1792        input_1[0][0]                    
__________________________________________________________________________________________________
instance_normalization (Instanc (None, 256, 256, 64) 128         conv2d[0][0]                     
__________________________________________________________________________________________________
activation (Activation)         (None, 256, 256, 64) 0           instance_normalization[0][0]     
________________________________________________________________________

Last Check Point: E:\Graduation Project\55-EP-MODELS-FLAIR-MIX\T2-FLAIR\ckpt-54 is restored








Found 146 files belonging to 1 classes.
