In [1]:
import torch
import twixtools
import tkinter as tk
from tkinter import ttk
from tkinter import filedialog as fd
from tkinter.messagebox import showinfo
import numpy as np
import ismrmrd

import ismrmrd
import os
import itertools
import logging
import numpy as np
import numpy.fft as fft
import ctypes
from datetime import datetime
from pathlib import Path
import shutil

from tqdm import tqdm
from mrftools import *
import sys
from matplotlib import pyplot as plt
import scipy

from PIL import Image, ImageDraw, ImageFont
import numpy as np
from matplotlib import pyplot as plt

import os
import pickle
import hashlib

import time
 

# Folder for debug output files
dictionaryFolder = ""

# Configure dictionary simulation parameters
dictionaryName = "5pct"
percentStepSize=5; includeB1=False;  t1Range=(10,4000); t2Range=(1,500); b1Range=(0.5, 1.55); b1Stepsize=0.05; 
phaseRange=(-np.pi, np.pi); numSpins=15; numBatches=100

# Azure logging configuration (temporary for testing, should be a secret in the cluster not plaintext)
connectionString = ""
tableName = "reconstructionLog"

def ApplyXYZShift(svdData, header, acqHeaders, trajectories, matrixSizeOverride=None):
    shape = np.shape(svdData)
    numSVDComponents=shape[0]; numCoils=shape[1]; numPartitions=shape[2]; numReadoutPoints=shape[3]; numSpirals=shape[4]
    shiftedSvdData = torch.zeros_like(svdData)
    # For now, assume all spirals/partitions/etc have same offsets applied
    (x_shift, y_shift, z_shift) = CalculateVoxelOffsetAcquisitionSpace(header, acqHeaders[0,0,0], matrixSizeOverride=matrixSizeOverride)
    trajectories = torch.t(torch.tensor(np.array(trajectories)))
    x = torch.zeros((numPartitions, numReadoutPoints, numSpirals));
    y = torch.zeros((numPartitions, numReadoutPoints, numSpirals));
    partitions = torch.moveaxis(torch.arange(-0.5, 0.5, 1/numPartitions).expand((numReadoutPoints, numSpirals, numPartitions)), -1,0)
    trajectories = trajectories.expand((numPartitions, numReadoutPoints, numSpirals))
    x = torch.cos(-2*torch.pi*(x_shift*trajectories.real + y_shift*trajectories.imag + z_shift*partitions));
    y = torch.sin(-2*torch.pi*(x_shift*trajectories.real + y_shift*trajectories.imag + z_shift*partitions));
    logging.info(f"K-Space x/y/z shift applied: {x_shift}, {y_shift}, {z_shift}")
    return svdData*torch.complex(x,y)

def vertex_of_parabola(points, clamp=False, min=None, max=None):
    x1 = points[:,0,0]
    y1 = points[:,0,1]
    x2 = points[:,1,0]
    y2 = points[:,1,1]
    x3 = points[:,2,0]
    y3 = points[:,2,1]
    denom = (x1-x2) * (x1-x3) * (x2-x3)
    A = (x3 * (y2-y1) + x2 * (y1-y3) + x1 * (y3-y2)) / denom
    B = (x3*x3 * (y1-y2) + x2*x2 * (y3-y1) + x1*x1 * (y2-y3)) / denom
    C = (x2 * x3 * (x2-x3) * y1+x3 * x1 * (x3-x1) * y2+x1 * x2 * (x1-x2) * y3) / denom
    xv = -B / (2*A)
    yv = C - B*B / (4*A)
    if clamp:
        torch.clamp(xv, min, max)
    return (xv, yv)

def GenerateDictionaryLookupTables(dictionaryEntries):
    uniqueT1s = np.unique(dictionaryEntries['T1'])
    uniqueT2s = np.unique(dictionaryEntries['T2'])

    dictionary2DIndexLookupTable = []
    dictionaryEntries2D = np.zeros((len(uniqueT1s), len(uniqueT2s)), dtype=DictionaryEntry)
    dictionary1DIndexLookupTable = np.zeros((len(uniqueT1s), len(uniqueT2s)), dtype=int)
    for dictionaryIndex in tqdm(range(len(dictionaryEntries))):
        entry = dictionaryEntries[dictionaryIndex]
        T1index = np.where(uniqueT1s == entry['T1'])[0]
        T2index = np.where(uniqueT2s == entry['T2'])[0]
        dictionaryEntries2D[T1index, T2index] = entry
        dictionary1DIndexLookupTable[T1index, T2index] = dictionaryIndex
        dictionary2DIndexLookupTable.append([T1index,T2index])
    dictionary2DIndexLookupTable = np.array(dictionary2DIndexLookupTable)
    return uniqueT1s, uniqueT2s, dictionary1DIndexLookupTable, dictionary2DIndexLookupTable


def BatchPatternMatchViaMaxInnerProductWithInterpolation(signalTimecourses, dictionaryEntries, dictionaryEntryTimecourses, voxelsPerBatch=500, device=None, radius=1):
    if(device==None):
        if torch.cuda.is_available():
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")

    with torch.no_grad():

        uniqueT1s, uniqueT2s, dictionary1DIndexLookupTable, dictionary2DIndexLookupTable = GenerateDictionaryLookupTables(dictionaryEntries)

        signalsTransposed = torch.t(signalTimecourses)
        signalNorm = torch.linalg.norm(signalsTransposed, axis=1)[:,None]
        normalizedSignals = signalsTransposed / signalNorm

        simulationResults = torch.tensor(dictionaryEntryTimecourses, dtype=torch.complex64)
        simulationNorm = torch.linalg.norm(simulationResults, axis=0)
        normalizedSimulationResults = torch.t((simulationResults / simulationNorm)).to(device)

        numBatches = int(np.shape(normalizedSignals)[0]/voxelsPerBatch)
        patternMatches = np.empty((np.shape(normalizedSignals)[0]), dtype=DictionaryEntry)
        interpolatedMatches = np.empty((np.shape(normalizedSignals)[0]), dtype=DictionaryEntry)

        offsets = np.mgrid[-1*radius:radius+1, -1*radius:radius+1]
        numNeighbors = np.shape(offsets)[1]*np.shape(offsets)[2]
        
        M0 = torch.zeros(np.shape(normalizedSignals)[0], dtype=torch.complex64)
        with tqdm(total=numBatches) as pbar:
            for i in range(numBatches):
                firstVoxel = i*voxelsPerBatch
                if i == (numBatches-1):
                    lastVoxel = np.shape(normalizedSignals)[0]
                else:
                    lastVoxel = firstVoxel+voxelsPerBatch
                batchSignals = normalizedSignals[firstVoxel:lastVoxel,:].to(device)
                innerProducts = torch.inner(batchSignals, normalizedSimulationResults)
                maxInnerProductIndices = torch.argmax(torch.abs(innerProducts), 1, keepdim=True)
                maxInnerProducts = torch.take_along_dim(innerProducts,maxInnerProductIndices,dim=1).squeeze()
                signalNorm_device = signalNorm[firstVoxel:lastVoxel].squeeze().to(device)
                simulationNorm_device = simulationNorm.to(device)[maxInnerProductIndices.squeeze().to(torch.long)]
                M0_device = signalNorm_device/simulationNorm_device
                M0[firstVoxel:lastVoxel] = M0_device.cpu()
                patternValues = dictionaryEntries[maxInnerProductIndices.squeeze().to(torch.long).cpu()].squeeze()
                patternMatches[firstVoxel:lastVoxel] = patternValues
                
                indices = dictionary2DIndexLookupTable[maxInnerProductIndices.squeeze().to(torch.long).cpu()].squeeze()

                numVoxels = len(maxInnerProductIndices)
                neighbor2DIndices = np.reshape(indices.repeat(numNeighbors,axis=1),(np.shape(indices)[0], np.shape(indices)[1],np.shape(offsets)[1], np.shape(offsets)[2])) + offsets
                neighbor2DIndices[:,0,:,:] = np.clip(neighbor2DIndices[:,0,:,:], 0, np.shape(dictionary1DIndexLookupTable)[0]-1)
                neighbor2DIndices[:,1,:,:] = np.clip(neighbor2DIndices[:,1,:,:], 0, np.shape(dictionary1DIndexLookupTable)[1]-1)

                neighborDictionaryIndices = torch.tensor(dictionary1DIndexLookupTable[neighbor2DIndices[:,0,:,:], neighbor2DIndices[:,1,:,:]].reshape(numVoxels, -1)).to(device)
                neighborInnerProducts = torch.take_along_dim(torch.abs(innerProducts),neighborDictionaryIndices,dim=1).squeeze()
                neighborDictionaryEntries = dictionaryEntries[neighborDictionaryIndices.cpu()].squeeze()

                #Sum of inner products through T2 neighbors for each T1 neighbor
                T1InnerProductSums = torch.stack((torch.sum(neighborInnerProducts[:, [0,1,2]], axis=1), torch.sum(neighborInnerProducts[:, [3,4,5]], axis=1), torch.sum(neighborInnerProducts[:,[6,7,8]], axis=1))).t()
                T2InnerProductSums = torch.stack((torch.sum(neighborInnerProducts[:, [0,3,6]], axis=1), torch.sum(neighborInnerProducts[:,[1,4,7]], axis=1), torch.sum(neighborInnerProducts[:,[2,5,8]], axis=1))).t()

                T1s = torch.tensor(neighborDictionaryEntries['T1'][:, [0,3,6]]).to(device)
                stacked_T1 = torch.stack((T1s, T1InnerProductSums))
                stacked_T1 = torch.moveaxis(stacked_T1, 0,1)

                T2s = torch.tensor(neighborDictionaryEntries['T2'][:, [0,1,2]]).to(device)
                stacked_T2 = torch.stack((T2s, T2InnerProductSums))
                stacked_T2 = torch.moveaxis(stacked_T2, 0,1)

                interpolatedValues = np.zeros((numVoxels),dtype=DictionaryEntry)
                interpT1s, _ = vertex_of_parabola(torch.moveaxis(stacked_T1,1,2), clamp=True, min=0, max=np.max(uniqueT1s))
                interpT2s, _ = vertex_of_parabola(torch.moveaxis(stacked_T2,1,2), clamp=True, min=0, max=np.max(uniqueT2s))
                
                interpolatedValues['T1'] = interpT1s.cpu()
                interpolatedValues['T2'] = interpT2s.cpu()
                interpolatedValues['B1'] = 1
                
                # For "edge" voxels, replace the interpolated values with the original pattern matches
                edgeT1s = (indices[:,0] == (len(uniqueT1s)-1)) + (indices[:,0] == (0))
                interpolatedValues[edgeT1s] = patternValues[edgeT1s]
                
                # For "edge" voxels, replace the interpolated values with the original pattern matches
                edgeT2s = (indices[:,1] == (len(uniqueT2s)-1)) + (indices[:,1] == (0))
                interpolatedValues[edgeT2s] = patternValues[edgeT2s]
                
                # For "nan" voxels, replace the interpolated values with the original pattern matches
                nanT1s = np.isnan(interpolatedValues['T1'])
                interpolatedValues[nanT1s] = patternValues[nanT1s]

                # For "nan" voxels, replace the interpolated values with the original pattern matches
                nanT2s = np.isnan(interpolatedValues['T2'])
                interpolatedValues[nanT2s] = patternValues[nanT2s]
                
                interpolatedMatches[firstVoxel:lastVoxel] = interpolatedValues
                pbar.update(1)
                del batchSignals, M0_device, signalNorm_device, simulationNorm_device

        del normalizedSimulationResults, dictionaryEntryTimecourses, dictionaryEntries, signalsTransposed, signalNorm, normalizedSignals, simulationResults
        del simulationNorm
        return patternMatches,interpolatedMatches, M0

def AddText(image, text="NOT FOR DIAGNOSTIC USE", fontSize=12):
    matrixsize = np.shape(image)
    img = Image.fromarray(np.uint8(np.zeros((matrixsize[0:2]))))
    draw = ImageDraw.Draw(img)
    font = ImageFont.truetype("terminess.ttf", fontSize)
    _, _, w, h = draw.textbbox((0, 0), text, font=font)
    draw.text(((matrixsize[0]-w)/2, h/2), text, (255), font=font)
    overlay = np.array(img) > 0
    repeated = np.repeat(overlay[:,:,np.newaxis], matrixsize[2],axis=2)
    repeated = repeated + np.rot90(repeated)
    repeated = repeated + np.rot90(repeated)    
    repeated = repeated + np.rot90(repeated)    
    image[repeated] = np.max(image)
    return image

def LoadB1Map(matrixSize, b1Filename, resampleToMRFMatrixSize=True, deinterleave=True, deleteB1File=True):
    # Using header, generate a unique b1 filename. This is temporary
    try:
        b1Data = np.load(b1Folder + "/" + b1Filename +".npy")
    except:
        logging.info("No B1 map found with requested filename. Trying fallback. ")
        try:
            b1Filename = f"B1Map_fallback"
            b1Data = np.load(b1Folder + "/" + b1Filename +".npy")
        except:
            logging.info("No B1 map found with fallback filename. Skipping B1 correction.")
            return np.array([])

    b1MapSize = np.array(np.shape(b1Data))
    logging.info(f"B1 Input Size: {b1MapSize}")
    if deinterleave:
        numSlices = b1MapSize[2]
        deinterleaved = np.zeros_like(b1Data)
        deinterleaved[:,:,np.arange(1,numSlices,2)] = b1Data[:,:,0:int(np.floor(numSlices/2))]
        deinterleaved[:,:,np.arange(0,numSlices-1,2)] = b1Data[:,:,int(np.floor(numSlices/2)):numSlices]
        b1Data = deinterleaved
    if resampleToMRFMatrixSize:
        b1Data = scipy.ndimage.zoom(b1Data, matrixSize/b1MapSize, order=5)
        b1Data = np.flip(b1Data, axis=2)
        b1Data = np.rot90(b1Data, axes=(0,1))
        b1Data = np.flip(b1Data, axis=0)
    logging.info(f"B1 Output Size: {np.shape(b1Data)}")
    if(deleteB1File):
        os.remove(b1Folder + "/" + b1Filename +".npy")     
        logging.info(f"Deleted B1 File: {b1Filename}")
    return b1Data
        
def performB1Binning(b1Data, b1Range, b1Stepsize, b1IdentityValue=800):
    b1Bins = np.arange(b1Range[0], b1Range[1], b1Stepsize)
    b1Clipped = np.clip(b1Data, np.min(b1Bins)*b1IdentityValue, np.max(b1Bins)*b1IdentityValue)
    b1Binned = b1Bins[np.digitize(b1Clipped, b1Bins*b1IdentityValue, right=True)]
    logging.info(f"Binned B1 Shape: {np.shape(b1Binned)}")
    return b1Binned

def PatternMatchingViaMaxInnerProductWithInterpolation(combined, dictionary, simulation, voxelsPerBatch=500, b1Binned=None, device=None,):
    if(device==None):
        if torch.cuda.is_available():
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")
    sizes = np.shape(combined)
    numSVDComponents=sizes[0]; matrixSize=sizes[1:4]
    patternMatches = np.empty((matrixSize), dtype=DictionaryEntry)
    interpolatedMatches = np.empty((matrixSize), dtype=DictionaryEntry)
    M0 = torch.zeros((matrixSize), dtype=torch.complex64)
    if b1Binned is not None:
        for uniqueB1 in np.unique(b1Binned):
            logging.info(f"Pattern Matching B1 Value: {uniqueB1}")
            if uniqueB1 == 0:
                patternMatches[b1Binned==uniqueB1] = 0
            else:
                signalTimecourses = combined[:,b1Binned == uniqueB1]
                simulationTimecourses = torch.t(torch.t(torch.tensor(simulation.truncatedResults))[(np.argwhere(dictionary.entries['B1'] == uniqueB1))].squeeze())
                dictionaryEntries = dictionary.entries[(np.argwhere(dictionary.entries['B1'] == uniqueB1))]
                signalTimecourses = combined[:,b1Binned == uniqueB1]
                patternMatches[b1Binned == uniqueB1], interpolatedMatches[b1Binned == uniqueB1], M0[b1Binned == uniqueB1] = BatchPatternMatchViaMaxInnerProductWithInterpolation(signalTimecourses,dictionaryEntries,simulationTimecourses, voxelsPerBatch=voxelsPerBatch, device=device)
    else:
        signalTimecourses = torch.reshape(combined, (numSVDComponents,-1))
        if(dictionary.entries['B1'][0]):
            simulationTimecourses = torch.t(torch.t(torch.tensor(simulation.truncatedResults))[(np.argwhere(dictionary.entries['B1'] == 1))].squeeze())
            dictionaryEntries = dictionary.entries[(np.argwhere(dictionary.entries['B1'] == 1))]
        else:   
            simulationTimecourses = torch.tensor(simulation.truncatedResults)
            dictionaryEntries = dictionary.entries
        patternMatches, interpolatedMatches, M0 = BatchPatternMatchViaMaxInnerProductWithInterpolation(signalTimecourses, dictionaryEntries, simulationTimecourses, voxelsPerBatch=voxelsPerBatch, device=device)
    patternMatches = np.reshape(patternMatches, (matrixSize))
    interpolatedMatches = np.reshape(interpolatedMatches, (matrixSize))
    M0 = np.reshape(M0, (matrixSize)).numpy()
    M0 = np.nan_to_num(M0)
    return patternMatches, interpolatedMatches, M0

# Takes data input as: [cha z y x], [z y x], or [y x]
def PopulateISMRMRDImage(header, data, acquisition, image_index, colormap=None, window=None, level=None, comment=""):
    image = ismrmrd.Image.from_array(data.transpose(), acquisition=acquisition, transpose=False)
    image.image_index = image_index

    # Set field of view
    image.field_of_view = (ctypes.c_float(header.encoding[0].reconSpace.fieldOfView_mm.x), 
                            ctypes.c_float(header.encoding[0].reconSpace.fieldOfView_mm.y), 
                            ctypes.c_float(header.encoding[0].reconSpace.fieldOfView_mm.z))

    if colormap is None:
        colormap = ""
    if window is None:
        window = np.max(data)
    if level is None:
        level = np.max(data)/2

    # Set ISMRMRD Meta Attributes
    meta = ismrmrd.Meta({'DataRole':               'Image',
                         'ImageProcessingHistory': ['FIRE', 'PYTHON'],
                         'WindowCenter':           str(level),
                         'WindowWidth':            str(window), 
                         'GADGETRON_ColorMap':     colormap,
                         'GADGETRON_ImageComment': comment})

    # Add image orientation directions to MetaAttributes if not already present
    if meta.get('ImageRowDir') is None:
        meta['ImageRowDir'] = ["{:.18f}".format(image.getHead().read_dir[0]), "{:.18f}".format(image.getHead().read_dir[1]), "{:.18f}".format(image.getHead().read_dir[2])]

    if meta.get('ImageColumnDir') is None:
        meta['ImageColumnDir'] = ["{:.18f}".format(image.getHead().phase_dir[0]), "{:.18f}".format(image.getHead().phase_dir[1]), "{:.18f}".format(image.getHead().phase_dir[2])]

    xml = meta.serialize()
    logging.debug("Image MetaAttributes: %s", xml)
    logging.debug("Image data has %d elements", image.data.size)

    image.attribute_string = xml
    return image

def GenerateRadialMask(coilImageData, svdNum = 0, angularResolution = 0.01, stepSize = 3, fillSize = 3, maxDecay = 15, featheringKernelSize=4, coilCountCutoff = 20):
    coilMax = np.max(np.abs(coilImageData[svdNum,:,:,:,:].cpu().numpy()), axis=0)
    if(np.shape(coilImageData)[1]>coilCountCutoff):
        maskIm = np.ones(np.shape(coilMax))
        outputMask = np.moveaxis(maskIm, 0,-1)
        return outputMask
    threshold = np.mean(coilMax)
    maskIm = np.zeros(np.shape(coilMax))
    center = np.array(np.shape(coilMax)[1:3])/2
    Y, X = np.ogrid[:np.shape(coilMax)[2], :np.shape(coilMax)[1]]
    dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)
    cylindricalMask = dist_from_center <= np.shape(coilMax)[1]/2
    coilMax = cylindricalMask*coilMax
    
    for partition in np.arange(0,np.shape(coilMax)[0]):
        for polarAngle in np.arange(0,2*np.pi, angularResolution):
            decayCounter = 0
            radius = 0
            historicalPos = []
            while decayCounter < maxDecay:
                radius += stepSize
                pos = (center + [radius*np.cos(polarAngle), radius*np.sin(polarAngle)]).astype(int)
                if(pos[0] > 0 and pos[0] < np.shape(coilMax)[1]-1 and pos[1] > 0 and pos[1] < np.shape(coilMax)[2]-1):
                    if coilMax[partition,pos[0],pos[1]] > threshold:
                        for histPos in historicalPos:
                            maskIm[partition, histPos[0]-fillSize:histPos[0]+fillSize, histPos[1]-fillSize:histPos[1]+fillSize] = 1
                        historicalPos.clear()
                        maskIm[partition, pos[0]-fillSize:pos[0]+fillSize, pos[1]-fillSize:pos[1]+fillSize] = 1
                        decayCounter = 0
                    else:
                        decayCounter += 1
                        #maskIm[partition, pos[0]-fillSize:pos[0]+fillSize, pos[1]-fillSize:pos[1]+fillSize] = 1 - (decayCounter/maxDecay)
                        historicalPos.append(pos)
                else:
                     break
    device = torch.device("cpu")
    maskIm = torch.tensor(maskIm).to(torch.float32)  
    meanFilter = torch.nn.Conv3d(in_channels=1, out_channels=1, kernel_size=featheringKernelSize, bias=False, padding='same')
    featheringKernelWeights = (torch.ones((featheringKernelSize, featheringKernelSize, featheringKernelSize), 
                                          dtype=torch.float32)/(featheringKernelSize*featheringKernelSize*featheringKernelSize)).to(device)
    meanFilter.weight.data = featheringKernelWeights.unsqueeze(0).unsqueeze(0)
    maskIm = meanFilter(maskIm.unsqueeze(0)).squeeze().detach().numpy()
    del featheringKernelWeights, meanFilter
    outputMask = np.moveaxis(maskIm, 0,-1)
    return outputMask

# Generate Classification Maps from Timecourses and Known Tissue Timecourses
def GenerateClassificationMaps(imageData, dictionary, simulation, matrixSize):
    ## Run for all pixels
    shape = np.shape(imageData)
    timecourses = imageData.reshape(shape[0], -1)

    ## Set up coefficient dictionary
    coefficientDictionaryEntries = []
    stepSize = 1/(10**2)
    roundingFactor = 1/stepSize
    maxSum = 1

    for aValue in np.arange(0, maxSum, stepSize):
        remainingForB = maxSum - aValue
        if(remainingForB < stepSize):
            coefficientDictionaryEntries.append([aValue, remainingForB, 0])
        else:   
            for bValue in np.arange(0,remainingForB, stepSize):
                remainingForC = remainingForB - bValue
                coefficientDictionaryEntries.append([aValue, bValue, remainingForC])
    coefficientDictionaryEntries = np.array(coefficientDictionaryEntries)
    coefficientDictionaryEntries = np.round(coefficientDictionaryEntries*roundingFactor)/roundingFactor
    sums = np.sum(coefficientDictionaryEntries, axis=1)
    coefficientDictionaryEntries = np.array([tuple(i) for i in coefficientDictionaryEntries], dtype=DictionaryEntry)

    ## Timecourse Equation for a voxel
    ## Gm(t) = sum across dictionary entries of e^-1*(((T1_gm-T1)/sigmaT1_gm)**2 + ((T2_gm-T2)/sigmaT2_gm)**2)
    T1_wm = WHITE_MATTER_3T[0]['T1']; sigmaT1_wm = 0.01
    T2_wm = WHITE_MATTER_3T[0]['T2']; sigmaT2_wm = 0.01
    T1_gm = GREY_MATTER_3T[0]['T1']; sigmaT1_gm = 0.01
    T2_gm = GREY_MATTER_3T[0]['T2']; sigmaT2_gm = 0.01
    T1_csf = CSF_3T[0]['T1']; sigmaT1_csf = 0.01
    T2_csf = CSF_3T[0]['T2']; sigmaT2_csf = 0.01
    T1 = dictionary.entries['T1'][simulation.dictionaryParameters.entries['B1']==1] # Revise to not pass in Dictionary - use the subclass dictionary instead so it matches for sure
    T2 = dictionary.entries['T2'][simulation.dictionaryParameters.entries['B1']==1]
    WmWeights = np.exp( -1 * ( ((T1_wm - T1)/sigmaT1_wm )**2 + ( (T2_wm-T2)/sigmaT2_wm )**2 ))
    GmWeights = np.exp( -1 * ( ((T1_gm - T1)/sigmaT1_gm )**2 + ( (T2_gm-T2)/sigmaT2_gm )**2 )) 
    CsfWeights = np.exp( -1 * ( ((T1_csf - T1)/sigmaT1_csf )**2 + ( (T2_csf-T2)/sigmaT2_csf )**2 )) 

    ## Create timecourses for WM/GM/CSF based on the above 
    truncatedResultsIdentityB1 = simulation.truncatedResults[:,simulation.dictionaryParameters.entries['B1']==1]
    WM = np.sum(truncatedResultsIdentityB1* WmWeights, axis=1); GM = np.sum(truncatedResultsIdentityB1 * GmWeights, axis=1); CSF = np.sum(truncatedResultsIdentityB1 * CsfWeights, axis=1)
    coefficientDictionaryTimecourses = []
    for coefficients in coefficientDictionaryEntries:
        coefficientTimecourse = coefficients['T1'] * WM + coefficients['T2'] * GM + coefficients['B1'] * CSF
        coefficientDictionaryTimecourses.append(coefficientTimecourse)
    coefficientDictionaryTimecourses = np.array(coefficientDictionaryTimecourses).transpose()

    # Perform Coefficient-Space Pattern Matching
    coefficientPatternMatches,coefficientM0 = BatchedPatternMatchViaMaxInnerProduct(timecourses.to(torch.cfloat), coefficientDictionaryEntries, torch.tensor(coefficientDictionaryTimecourses).to(torch.cfloat))
    normalizedCoefficients = coefficientPatternMatches.view((np.dtype('<f4'), len(coefficientPatternMatches.dtype.names)))

    wmFractionMap = np.reshape(normalizedCoefficients[:,0], (matrixSize)) * 10000
    gmFractionMap = np.reshape(normalizedCoefficients[:,1], (matrixSize)) * 10000
    csfFractionMap = np.reshape(normalizedCoefficients[:,2], (matrixSize)) * 10000
    
    return (wmFractionMap, gmFractionMap, csfFractionMap)

def ThroughplaneFFT(nufftResults, device=None):
    if(device==None):
        if torch.cuda.is_available():
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")
    sizes = np.shape(nufftResults)
    numSVDComponents=sizes[0]; numCoils=sizes[1]; numPartitions=sizes[2]; matrixSize=sizes[3:5]
    images = torch.zeros((numSVDComponents, numCoils, numPartitions, matrixSize[0], matrixSize[1]), dtype=torch.complex64)
    with tqdm(total=numSVDComponents) as pbar:
        for svdComponent in np.arange(0, numSVDComponents):
            nufft_device = nufftResults[svdComponent,:,:,:,:].to(device)
            images[svdComponent,:,:,:,:] = torch.fft.ifftshift(torch.fft.ifft(nufft_device, dim=1), dim=1)
            del nufft_device
            pbar.update(1)
    torch.cuda.empty_cache()
    print("Images Shape:", np.shape(images))   
    return images
            

def runReconstruction(filename, b1Filename=""):
    multi_twix = twixtools.read_twix(filename)
    numSpirals = int(multi_twix[-1]['hdr']['Meas']['iNoOfFourierLines']); print(f'Spirals: {numSpirals}')
    numMeasuredPartitions = int(multi_twix[-1]['hdr']['Meas']['iNoOfFourierPartitions']); print(f'Measured Partitions: {numMeasuredPartitions}')
    numUndersampledPartitions = int(multi_twix[-1]['hdr']['MeasYaps']['sKSpace']['lPartitions']); print(f'Undersampled Partitions: {numUndersampledPartitions}')
    centerMeasuredPartition =  int(numMeasuredPartitions/2); print(f'Center Measured Partition: {centerMeasuredPartition}') # Fix this to work with partial fourier
    numSets = int(multi_twix[-1]['hdr']['Meas']['iNSet']); print(f'Sets: {numSets}')
    numCoils = int(multi_twix[-1]['hdr']['Meas']['iMaxNoOfRxChannels']); print(f'Coils: {numCoils}')
    xMatSize = multi_twix[-1]['hdr']['MeasYaps']['sKSpace']['lBaseResolution']
    yMatSize = multi_twix[-1]['hdr']['MeasYaps']['sKSpace']['lPhaseEncodingLines']
    zMatSize = multi_twix[-1]['hdr']['MeasYaps']['sKSpace']['lImagesPerSlab']
    matrixSize = np.array([xMatSize, yMatSize, zMatSize]); print(f'Matrix Size: {matrixSize}')
    xFOV = multi_twix[-1]['hdr']['MeasYaps']['sSliceArray']['asSlice'][0]['dReadoutFOV']
    yFOV = multi_twix[-1]['hdr']['MeasYaps']['sSliceArray']['asSlice'][0]['dPhaseFOV']
    zFOV = multi_twix[-1]['hdr']['MeasYaps']['sSliceArray']['asSlice'][0]['dThickness']

    undersamplingRatio = 1
    if(numUndersampledPartitions > 1): # Hack, may not work for multislice 2d
        undersamplingRatio = int(numUndersampledPartitions / (centerMeasuredPartition * 2)); 
        print(f'Undersampling Ratio: {undersamplingRatio}')
    usePartialFourier = False
    if(numMeasuredPartitions*undersamplingRatio < numUndersampledPartitions):
        usePartialFourier = True
        partialFourierRatio = numMeasuredPartitions / (numUndersampledPartitions/undersamplingRatio)
        print(f'Measured partitions is less than expected for undersampling ratio - assuming partial fourier acquisition with ratio: {partialFourierRatio}')

    # Set up sequence parameter arrays
    numTimepoints = numSets*numSpirals
    TRs = np.zeros((numTimepoints, numMeasuredPartitions))
    TEs = np.zeros((numTimepoints, numMeasuredPartitions))
    FAs = np.zeros((numTimepoints, numMeasuredPartitions))
    PHs = np.zeros((numTimepoints, numMeasuredPartitions))
    IDs = np.zeros((numTimepoints, numMeasuredPartitions))

    # Set up raw data and header arrays
    rawdata = None
    header = ismrmrd.xsd.ismrmrdHeader()
    matrixSizeHeader=ismrmrd.xsd.matrixSizeType(xMatSize, yMatSize, zMatSize)
    fovHeader=ismrmrd.xsd.fieldOfViewMm(xFOV, yFOV, zFOV)
    encoding = ismrmrd.xsd.encodingType(reconSpace=ismrmrd.xsd.encodingSpaceType(matrixSize=matrixSizeHeader, fieldOfView_mm=fovHeader))
    header.encoding.append(encoding)
    acqHeaders = np.empty((numUndersampledPartitions, numSpirals, numSets), dtype=ismrmrd.Acquisition)
    discardPre=0;discardPost=0

    ## If dictionary simulation is new, upload to Azure so it will exist forever?
    if matrixSize[0]==256:
        trajectoryFilepath='mrf_dependencies/trajectories/SpiralTraj_FOV250_256_uplimit1916_norm.bin'
        densityFilepath='mrf_dependencies/trajectories/DCW_FOV250_256_uplimit1916.bin'
        numToDiscard = 1916
    elif matrixSize[0]==400:
        trajectoryFilepath='mrf_dependencies/trajectories/SpiralTraj_FOV400_400_uplimit2890_norm.bin'
        densityFilepath='mrf_dependencies/trajectories/DCW_FOV400_400_uplimit2890.bin'
        numToDiscard = 2890
    else:
        print('Trajectory unknown, using default')
        trajectoryFilepath='mrf_dependencies/trajectories/SpiralTraj_FOV250_256_uplimit1916_norm.bin'
        densityFilepath='mrf_dependencies/trajectories/DCW_FOV250_256_uplimit1916.bin'
        numToDiscard = 1916

    # Process data as it comes in
    for mdb in multi_twix[-1]['mdb']:
        if mdb is None:
            break
        if mdb.is_flag_set('NOISEADJSCAN') or mdb.is_flag_set('PHASCOR'):
            print('Noise')
            continue
        else:
            acqHeader = ismrmrd.Acquisition()
            acqHeader.position[0] = mdb.mdh.SliceData.SlicePos.Sag
            acqHeader.position[1] = mdb.mdh.SliceData.SlicePos.Cor
            acqHeader.position[2] = mdb.mdh.SliceData.SlicePos.Tra
            quat = mdb.mdh.SliceData.Quaternion
            a = quat[0]; b = quat[1]; c = quat[2]; d = quat[3]

            acqHeader.read_dir[0] = 1.0 - 2.0 * (b * b + c * c)
            acqHeader.phase_dir[0] = 2.0 * (a * b - c * d)
            acqHeader.slice_dir[0] = 2.0 * (a * c + b * d)
            
            acqHeader.read_dir[1] = 2.0 * (a * b + c * d)
            acqHeader.phase_dir[1] = 1.0 - 2.0 * (a * a + c * c)
            acqHeader.slice_dir[1] = 2.0 * (b * c - a * d)
            
            acqHeader.read_dir[2] = 2.0 * (a * c - b * d)
            acqHeader.phase_dir[2] = 2.0 * (b * c + a * d)
            acqHeader.slice_dir[2] = 1.0 - 2.0 * (a * a + b * b)
            measuredPartition = mdb.mdh.Counter.Par
            undersampledPartition = mdb.mdh.IceProgramPara[1]
            spiral = mdb.mdh.Counter.Lin
            set = mdb.mdh.Counter.Set
            timepoint = mdb.mdh.IceProgramPara[0]
            TRs[timepoint, measuredPartition] = mdb.mdh.IceProgramPara[2]          
            TEs[timepoint, measuredPartition] = mdb.mdh.IceProgramPara[3]
            FAs[timepoint, measuredPartition] = mdb.mdh.IceProgramPara[4]  # Use requested FA
            #FAs[timepoint, measuredPartition] = acq.user_int[5] # Use actual FA not requested
            PHs[timepoint, measuredPartition] = mdb.mdh.IceProgramPara[6] 
            IDs[timepoint, measuredPartition] = spiral
            #print(undersampledPartition, spiral, set)
            acqHeaders[undersampledPartition, spiral, set] = acqHeader
            if rawdata is None:
                discardPre = int(mdb.mdh.CutOff.Pre / 2); print(f'Discard Pre: {discardPre}') # Fix doubling in sequence - weird;
                discardPost = discardPre + numToDiscard; print(f'Discard Post: {discardPost}') # Fix in sequence
                numReadoutPoints = discardPost-discardPre; print(f'Readout Points: {numReadoutPoints}')
                rawdata = np.zeros([numCoils, numUndersampledPartitions, numReadoutPoints, numSpirals, numSets], dtype=np.complex64)
            readout = mdb.data[:, discardPre:discardPost]
            rawdata[:, undersampledPartition, :, spiral, set] = readout
    
    B1map = LoadB1Map(matrixSize, b1Filename)
    if(np.size(B1map)!=0):
        B1map_binned = performB1Binning(B1map, b1Range, b1Stepsize, b1IdentityValue=800)
        dictionary = DictionaryParameters.GenerateFixedPercent(dictionaryName, percentStepSize=percentStepSize, t1Range=t1Range, t2Range=t2Range, includeB1=True, b1Range=b1Range, b1Stepsize=b1Stepsize)
    else:
        B1map_binned = None
        dictionary = DictionaryParameters.GenerateFixedPercent(dictionaryName, percentStepSize=percentStepSize, t1Range=t1Range, t2Range=t2Range, includeB1=False, b1Range=None, b1Stepsize=None)

    ## Initialize the Sequence
    sequence = SequenceParameters("largescale", SequenceType.FISP)
    print("TRs:", np.min(TRs), np.max(TRs))
    print("TEs:", np.min(TEs), np.max(TEs))
    print("FAs:", np.min(FAs), np.max(FAs))
    print("IDs:", np.min(IDs), np.max(IDs))
    sequence.Initialize(TRs[:,0]/(1000*1000), TEs[:,0]/(1000*1000), FAs[:,0]/(100), PHs[:,0]/(100), IDs[:,0])
    simulation = Simulation(sequence, dictionary, phaseRange=phaseRange, numSpins=numSpins)
    simulationHash = hashlib.sha256(pickle.dumps(simulation)).hexdigest()
    if(dictionaryFolder == ""):
        dictionaryPath = simulationHash+".simulation"
    else:
        dictionaryPath = dictionaryFolder+"/"+simulationHash+".simulation"
    logging.info(f"Dictionary Path: {dictionaryPath}")
    Path(dictionaryFolder).mkdir(parents=True, exist_ok=True)

    ## Check if dictionary already exists
    if (os.path.isfile(dictionaryPath)):
        logging.info("Dictionary already exists. Using local copy.")
        filehandler = open(dictionaryPath,'rb')
        simulation = pickle.load(filehandler) 
        filehandler.close()

    else:        
        ## Simulate the Dictionary
        logging.info("Dictionary not found. Simulating. ")
        simulation.Execute(numBatches=numBatches)
        simulation.CalculateSVD(truncationNumberOverride=10)
        logging.info(f"Simulated {numSpirals*numSets} timepoints")
        del simulation.results
        filehandler = open(dictionaryPath, 'wb')
        pickle.dump(simulation, filehandler)
        filehandler.close()

    ## Run the Reconstruction
    svdData = ApplySVDCompression(rawdata, simulation, device=torch.device("cpu"))
    (trajectoryBuffer,trajectories,densityBuffer,_) = LoadSpirals(trajectoryFilepath, densityFilepath, numSpirals)
    svdData = ApplyXYZShift(svdData, header, acqHeaders, trajectories)
    nufftResults = PerformNUFFTs(svdData, trajectoryBuffer, densityBuffer, matrixSize, matrixSize*2)
    del svdData
    coilImageData = ThroughplaneFFT(nufftResults)
    del nufftResults
    imageData, coilmaps = PerformWalshCoilCombination(coilImageData)
    imageMask = GenerateRadialMask(coilImageData, coilCountCutoff=0)
    patternMatchResults, interpolatedResults, M0 = PatternMatchingViaMaxInnerProductWithInterpolation(imageData, dictionary, simulation, b1Binned = B1map_binned, voxelsPerBatch=2000)
    (wmFractionMap, gmFractionMap, csfFractionMap) = GenerateClassificationMaps(imageData, dictionary, simulation, matrixSize)
    reconstructionFinishTime = time.time()

    import imageio
    saveDir = "exports"
    os.makedirs(saveDir,exist_ok=True)
    T1map_interp = AddText((imageMask>0.1) * interpolatedResults['T1'] * 1000) # to milliseconds
    T2map_interp = AddText((imageMask>0.1) * interpolatedResults['T2'] * 1000) # to milliseconds
    M0map = AddText((imageMask>0.1) * (np.abs(M0) / np.max(np.abs(M0))) * 2**12)
    WMmap = AddText((imageMask>0.1) * wmFractionMap)
    GMmap = AddText((imageMask>0.1) * gmFractionMap)
    CSFmap = AddText((imageMask>0.1) * csfFractionMap)
    for slice in np.arange(0,np.shape(patternMatchResults)[2]):
        imageio.imwrite(saveDir + "t1_" + str(slice)  + '.tif', T1map_interp[:,:,slice])
        imageio.imwrite(saveDir + "t2_" + str(slice)  + '.tif', T2map_interp[:,:,slice])
        imageio.imwrite(saveDir + "m0_" + str(slice)  + '.tif', M0map[:,:,slice])
        imageio.imwrite(saveDir + "wm_" + str(slice)  + '.tif', WMmap[:,:,slice])
        imageio.imwrite(saveDir + "gm_" + str(slice)  + '.tif', GMmap[:,:,slice])
        imageio.imwrite(saveDir + "csf_" + str(slice)  + '.tif', CSFmap[:,:,slice])

root = tk.Tk()
def select_file():
    filetypes = (
        ('Siemens Raw Data', '*.dat'),
        ('All files', '*.*')
    )

    filename = fd.askopenfilename(
        title='Open a file',
        initialdir='/',
        filetypes=filetypes)

    showinfo(
        title='',
        message="Beginning Reconstruction: " + filename
    )

    runReconstruction(filename)
    root.destroy()


# create the root window
root.title('Open File Dialog')
root.resizable(False, False)
root.geometry('300x150')

# open button
open_button = ttk.Button(
    root,
    text='Select .dat File',
    command=select_file
)
open_button.pack(expand=True)

# run the application
root.mainloop()


100%|██████████| 18.2M/18.2M [00:00<00:00, 373MB/s]

Software version: VD/VE (!?)

Scan  0



  0%|          | 37.1M/8.92G [00:00<00:24, 387MB/s]

Scan  1


100%|██████████| 8.92G/8.92G [00:22<00:00, 425MB/s]


Spirals: 48
Measured Partitions: 30
Undersampled Partitions: 60
Center Measured Partition: 15
Sets: 20
Coils: 18
Matrix Size: [256 256  60]
Undersampling Ratio: 2
Discard Pre: 20
Discard Post: 1936
Readout Points: 1916
Dictionary Parameter set '5pct' initialized with 12504 entries
TRs: 0.0 10500.0
TEs: 1680.0 1680.0
FAs: 0.0 5700.0
IDs: 0.0 47.0
Sequence Parameter set 'largescale' initialized with 960 timepoint definitions


100%|██████████| 48/48 [00:10<00:00,  4.51it/s]


SVD Compressed Raw Data Shape: torch.Size([10, 18, 60, 1916, 48])
Found (48, 1916) spirals


: 

: 

In [2]:
import imageio