In [1]:
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
from torchmetrics import StructuralSimilarityIndexMeasure
from statistics import median, mean
from matplotlib import pyplot as plt
import numpy as np
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
from torchsummary import summary
import json
from tqdm import tqdm
from glob import glob
import sys
path = "/study/mrphys/skunkworks/kk/mriUnet"
sys.path.insert(0,path)
import unet
from torchvision import transforms
from torch.utils.data import Dataset
import h5py
from sklearn.model_selection import KFold as kf
import os

allImages = sorted(glob("/study/mrphys/skunkworks/training_data//mover01/*/", recursive=True))

In [2]:
def slice2d(array, discardZero=False):
    '''
    slice a 4d array of shape (c-channel, n, n, n) where n in the cube length
    into 3d arrays slices of shape (c, n, n) per each 2d plane
    '''
    result = []
    c, w, h, d = array.shape
    assert (w==h)and(h==d)and(d==w), f"Array must be cubic, got: {w}x{h}x{d}"
    for i in range(w):
        result.append(array[:,i,:,:])
        result.append(array[:,:,i,:])
        result.append(array[:,:,:,i])
    return np.array(result)

In [4]:
def getComplexSlices(path, return_scale=False):

    with h5py.File(path,'r') as hf:
        prefix = 'C_000_0'
        imagestackReal = []
        imagestackImag = []
        for i in range(10):
            n = prefix + str(i).zfill(2)
            image = hf['Images'][n]
            imagestackReal.append(np.array(image['real']))
            imagestackImag.append(np.array(image['imag']))
            if i==0:
                normScale = np.abs(np.array(image['real']+image['real']*1j)).max()
        imagestackReal = np.array(imagestackReal)/normScale
        imagestackImag = np.array(imagestackImag)/normScale
        imagesliceReal = slice2d(imagestackReal)
        imagesliceImag = slice2d(imagestackImag)
        
    if return_scale:
        return imagesliceReal+imagesliceImag*1j, normScale
    else:
        return imagesliceReal+imagesliceImag*1j

In [5]:
class mriSliceDataset(Dataset):
    def __init__(self, sample):
        self.originalPathList = []
        self.accelPathList = []
        self.originalFileList = []
        self.accelFileList = []

        allImages = sorted(glob("/study/mrphys/skunkworks/training_data//mover01/*/", recursive=True))
        folderName  = allImages[sample]
        self.originalPathList.append(folderName + 'processed_data/C.h5')
        self.accelPathList.append(folderName +'processed_data/acc_2min/C.h5')
        
        for originalPath, accelPath in zip(self.originalPathList, self.accelPathList):
            self.originalFileList+= list(getComplexSlices(originalPath))
            self.accelFileList+= list(getComplexSlices(accelPath))
            print('Image ' + originalPath + ' loaded')

    def __getitem__(self, index):
        return self.accelFileList[index], self.originalFileList[index]

    def __len__(self):
        return len(self.accelFileList)

In [6]:
l

  0%|                                                                                            | 0/65 [00:00<?, ?it/s]

Image /study/mrphys/skunkworks/training_data//mover01/M001/processed_data/C.h5 loaded


  2%|█▎                                                                                  | 1/65 [00:25<27:07, 25.43s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M002/processed_data/C.h5 loaded


  3%|██▌                                                                                 | 2/65 [00:51<26:52, 25.60s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M004/processed_data/C.h5 loaded


  5%|███▉                                                                                | 3/65 [01:16<26:14, 25.40s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M005/processed_data/C.h5 loaded


  6%|█████▏                                                                              | 4/65 [01:41<25:54, 25.49s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M006/processed_data/C.h5 loaded


  8%|██████▍                                                                             | 5/65 [02:06<25:19, 25.32s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M007/processed_data/C.h5 loaded


  9%|███████▊                                                                            | 6/65 [02:31<24:43, 25.15s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M008/processed_data/C.h5 loaded


 11%|█████████                                                                           | 7/65 [02:57<24:29, 25.34s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M009/processed_data/C.h5 loaded


 12%|██████████▎                                                                         | 8/65 [03:25<24:53, 26.19s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M010/processed_data/C.h5 loaded


 14%|███████████▋                                                                        | 9/65 [03:51<24:28, 26.23s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M011/processed_data/C.h5 loaded


 15%|████████████▊                                                                      | 10/65 [04:18<24:05, 26.29s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M012/processed_data/C.h5 loaded


 17%|██████████████                                                                     | 11/65 [04:44<23:39, 26.28s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M013/processed_data/C.h5 loaded


 18%|███████████████▎                                                                   | 12/65 [05:10<23:11, 26.25s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M014/processed_data/C.h5 loaded


 20%|████████████████▌                                                                  | 13/65 [05:37<22:48, 26.31s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M015/processed_data/C.h5 loaded


 22%|█████████████████▉                                                                 | 14/65 [06:03<22:20, 26.28s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M016/processed_data/C.h5 loaded


 23%|███████████████████▏                                                               | 15/65 [06:29<21:49, 26.19s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M017/processed_data/C.h5 loaded


 25%|████████████████████▍                                                              | 16/65 [06:55<21:20, 26.13s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M018/processed_data/C.h5 loaded


 26%|█████████████████████▋                                                             | 17/65 [07:22<21:05, 26.35s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M019/processed_data/C.h5 loaded


 28%|██████████████████████▉                                                            | 18/65 [07:50<21:08, 26.99s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M020/processed_data/C.h5 loaded


 29%|████████████████████████▎                                                          | 19/65 [08:17<20:45, 27.07s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M022/processed_data/C.h5 loaded


 31%|█████████████████████████▌                                                         | 20/65 [08:43<20:04, 26.76s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M023/processed_data/C.h5 loaded


 32%|██████████████████████████▊                                                        | 21/65 [09:11<19:50, 27.05s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M024/processed_data/C.h5 loaded


 34%|████████████████████████████                                                       | 22/65 [09:38<19:23, 27.06s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M025/processed_data/C.h5 loaded


 35%|█████████████████████████████▎                                                     | 23/65 [10:05<18:53, 26.99s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M026/processed_data/C.h5 loaded


 37%|██████████████████████████████▋                                                    | 24/65 [10:32<18:26, 26.99s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M027/processed_data/C.h5 loaded


 38%|███████████████████████████████▉                                                   | 25/65 [11:00<18:13, 27.33s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M028/processed_data/C.h5 loaded


 40%|█████████████████████████████████▏                                                 | 26/65 [11:27<17:37, 27.10s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M029/processed_data/C.h5 loaded


 42%|██████████████████████████████████▍                                                | 27/65 [11:54<17:11, 27.15s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M030/processed_data/C.h5 loaded


 43%|███████████████████████████████████▊                                               | 28/65 [12:20<16:35, 26.90s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M031/processed_data/C.h5 loaded


 45%|█████████████████████████████████████                                              | 29/65 [12:49<16:24, 27.36s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M032/processed_data/C.h5 loaded


 46%|██████████████████████████████████████▎                                            | 30/65 [13:16<15:50, 27.16s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M033/processed_data/C.h5 loaded


 48%|███████████████████████████████████████▌                                           | 31/65 [13:45<15:42, 27.73s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M034/processed_data/C.h5 loaded


 49%|████████████████████████████████████████▊                                          | 32/65 [14:13<15:25, 28.06s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M035/processed_data/C.h5 loaded


 51%|██████████████████████████████████████████▏                                        | 33/65 [14:45<15:29, 29.06s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M036/processed_data/C.h5 loaded


 52%|███████████████████████████████████████████▍                                       | 34/65 [15:17<15:30, 30.02s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M037/processed_data/C.h5 loaded


 54%|████████████████████████████████████████████▋                                      | 35/65 [15:51<15:38, 31.27s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M040/processed_data/C.h5 loaded


 55%|█████████████████████████████████████████████▉                                     | 36/65 [16:27<15:44, 32.56s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M041/processed_data/C.h5 loaded


 57%|███████████████████████████████████████████████▏                                   | 37/65 [16:56<14:43, 31.54s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M042/processed_data/C.h5 loaded


 58%|████████████████████████████████████████████████▌                                  | 38/65 [17:27<14:08, 31.41s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M043/processed_data/C.h5 loaded


 60%|█████████████████████████████████████████████████▊                                 | 39/65 [17:55<13:12, 30.50s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M044/processed_data/C.h5 loaded


 62%|███████████████████████████████████████████████████                                | 40/65 [18:22<12:16, 29.46s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M045/processed_data/C.h5 loaded


 63%|████████████████████████████████████████████████████▎                              | 41/65 [18:50<11:33, 28.88s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M046/processed_data/C.h5 loaded


 65%|█████████████████████████████████████████████████████▋                             | 42/65 [19:20<11:11, 29.20s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M047/processed_data/C.h5 loaded


 66%|██████████████████████████████████████████████████████▉                            | 43/65 [19:46<10:24, 28.40s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M048/processed_data/C.h5 loaded


 68%|████████████████████████████████████████████████████████▏                          | 44/65 [20:14<09:48, 28.01s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M049/processed_data/C.h5 loaded


 69%|█████████████████████████████████████████████████████████▍                         | 45/65 [20:40<09:11, 27.56s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M050/processed_data/C.h5 loaded


 71%|██████████████████████████████████████████████████████████▋                        | 46/65 [21:07<08:41, 27.45s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M051/processed_data/C.h5 loaded


 72%|████████████████████████████████████████████████████████████                       | 47/65 [21:34<08:12, 27.33s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M052/processed_data/C.h5 loaded


 74%|█████████████████████████████████████████████████████████████▎                     | 48/65 [22:01<07:43, 27.28s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M053/processed_data/C.h5 loaded


 75%|██████████████████████████████████████████████████████████████▌                    | 49/65 [22:29<07:18, 27.39s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M054/processed_data/C.h5 loaded


 77%|███████████████████████████████████████████████████████████████▊                   | 50/65 [22:56<06:46, 27.13s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M055/processed_data/C.h5 loaded


 78%|█████████████████████████████████████████████████████████████████                  | 51/65 [23:26<06:31, 27.99s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M056/processed_data/C.h5 loaded


 80%|██████████████████████████████████████████████████████████████████▍                | 52/65 [23:54<06:04, 28.02s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M057/processed_data/C.h5 loaded


 82%|███████████████████████████████████████████████████████████████████▋               | 53/65 [24:23<05:42, 28.51s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M058/processed_data/C.h5 loaded


 83%|████████████████████████████████████████████████████████████████████▉              | 54/65 [24:53<05:16, 28.76s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M059/processed_data/C.h5 loaded


 85%|██████████████████████████████████████████████████████████████████████▏            | 55/65 [25:23<04:51, 29.15s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M060/processed_data/C.h5 loaded


 86%|███████████████████████████████████████████████████████████████████████▌           | 56/65 [25:52<04:22, 29.18s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M061/processed_data/C.h5 loaded


 88%|████████████████████████████████████████████████████████████████████████▊          | 57/65 [26:22<03:56, 29.53s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M062/processed_data/C.h5 loaded


 89%|██████████████████████████████████████████████████████████████████████████         | 58/65 [26:54<03:30, 30.12s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M063/processed_data/C.h5 loaded


 91%|███████████████████████████████████████████████████████████████████████████▎       | 59/65 [27:25<03:03, 30.53s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M064/processed_data/C.h5 loaded


 92%|████████████████████████████████████████████████████████████████████████████▌      | 60/65 [27:55<02:30, 30.13s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M066/processed_data/C.h5 loaded


 94%|█████████████████████████████████████████████████████████████████████████████▉     | 61/65 [28:24<01:59, 29.85s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M067/processed_data/C.h5 loaded


 95%|███████████████████████████████████████████████████████████████████████████████▏   | 62/65 [28:52<01:28, 29.39s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M068/processed_data/C.h5 loaded


 97%|████████████████████████████████████████████████████████████████████████████████▍  | 63/65 [29:22<00:59, 29.52s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M069/processed_data/C.h5 loaded


 98%|█████████████████████████████████████████████████████████████████████████████████▋ | 64/65 [29:51<00:29, 29.26s/it]

Image /study/mrphys/skunkworks/training_data//mover01/M070/processed_data/C.h5 loaded


100%|███████████████████████████████████████████████████████████████████████████████████| 65/65 [30:20<00:00, 28.01s/it]


In [None]:
class noisyDataset(Dataset):
    def __init__(self, sample):
        self.accelPathList = []
        self.accelFileList = []

        allImages = sorted(glob("/study/mrphys/skunkworks/training_data//mover01/*/", recursive=True))
        folderName  = allImages[sample]
        self.accelPathList.append(folderName +'processed_data/acc_2min/C.h5')
        
        for accelPath in self.accelPathList:
            slices, scale = getComplexSlices(accelPath, return_scale=True)
            self.accelFileList+= list(slices)
            self.scale = scale
            print('Image ' + accelPath + ' loaded')

    def __getitem__(self, index):
        return self.accelFileList[index]

    def __len__(self):
        return len(self.accelFileList)
    
def predict(model, dataset, device = torch.device("cuda" if torch.cuda.is_available() else "cpu")):
    model.eval()
    model.to(device)
    X = []
    Y = []
    Z = []
    for i, noisy in tqdm(enumerate(dataset)):
        noisy = torch.tensor(noisy).to(device).unsqueeze(0)
        with torch.no_grad():
            p = model(noisy).cpu().numpy() * dataset.scale
            if i%3==0:
                X.append(p)
            elif i%3==1:
                Y.append(p)
            else:
                Z.append(p)
                
    return np.vstack(X).transpose(1,0,2,3), np.vstack(Y).transpose(1,2,0,3), np.vstack(Z).transpose(1,2,3,0)

def get_prediction(idx, fold):
    
    model = unet.UNet(6,
    6,
    f_maps=32,
    layer_order=['separable convolution', 'relu'],
    depth=3,
    layer_growth=2.0,
    residual=True,
    complex_input=True,
    complex_kernel=True,
    ndims=2,
    padding=1)

    name = f'slice_kfold_{fold}'

    model.load_state_dict(torch.load(f'/study/mrphys/skunkworks/kk/outputs/{name}/weights/{name}_LATEST.pth'))

    dataset = noisyDataset(idx)
    X, Y, Z = predict(model, dataset)
    
    return (X+Y+Z)/3

In [None]:
pred = get_prediction(4, 1)

In [None]:
plt.gray()
n = 130
plt.title("Top-down")
plt.imshow(pred[0][n,:,:].imag)
plt.show()
plt.title("Left-right")
plt.imshow(pred[0][:,n,:].imag)
plt.show()
plt.title("Front-back")
plt.imshow(pred[0][:,:,n].imag)
plt.show()

In [None]:
kfsplitter = kf(n_splits=5, shuffle=True, random_state=69420)
for i, (train_index, test_index) in enumerate(kfsplitter.split(np.arange(65))):
    for idx in test_index:
        print(f'Fold = {i+1}')
        pred = get_prediction(idx, i+1)
        np.save(f'pred/denoised_{idx}.npy', pred)

In [None]:
path = '/study/mrphys/skunkworks/training_data//mover01/M001/processed_data/acc_2min/C.h5'

with h5py.File(path,'r') as hf:
    prefix = 'C_000_0'
    imagestackReal = []
    imagestackImag = []
    for i in range(6):
        n = prefix + str(i).zfill(2)
        image = hf['Images'][n]
        imagestackReal.append(np.array(image['real']))
        imagestackImag.append(np.array(image['imag']))
        if i==0:
            normScale = np.max([np.abs(np.array(image['real'])).max(), np.abs(np.array(image['real'])).max()])
    imagestackReal = np.array(imagestackReal)/normScale
    imagestackImag = np.array(imagestackImag)/normScale
    x = imagestackReal+imagestackImag*1j
    
plt.imshow(x[0,130].real)

# Save to H5

In [None]:
import h5py
import numpy as np

imgIndex = 0
name = allImages[imgIndex].split('/')[-2]
with h5py.File(f'/scratch/mrphys/denoised/comparison_{name}.h5','w') as f:
    grp = f.create_group('Original')
    with h5py.File(allImages[imgIndex]+'processed_data/C.h5','r') as hfOriginal:
        for n in range(6):
            n = 'C_000_0'+ str(n).zfill(2)
            grp.create_dataset(n, data=np.array(hfOriginal['Images'][n]))
        
    grp = f.create_group('Noisy')
    with h5py.File(allImages[imgIndex]+'processed_data/acc_2min/C.h5','r') as hfNoisy:
        for n in range(6):
            n = 'C_000_0'+ str(n).zfill(2)
            grp.create_dataset(n, data=np.array(hfNoisy['Images'][n]))

    pred = np.load(f'pred/denoised_{0}.npy')
    temp = pred.astype(np.dtype([('real','f'),('imag','f')]))
    temp['imag'] = pred.imag
    pred = temp
    grp = f.create_group('Denoised')
    for n in range(6):
        grp.create_dataset('C_000_0'+ str(n).zfill(2), data=pred[n])

In [None]:
with h5py.File(f'/scratch/mrphys/denoised/comparison_{name}.h5','r') as f:
    n = 0
    n = 'C_000_0'+ str(n).zfill(2)
    plt.imshow(f['Denoised'][n]['real'][130])
    plt.show()
    plt.imshow(f['Original'][n]['real'][130])
    plt.show()
    plt.imshow(f['Noisy'][n]['real'][130])
    plt.show()

# Save using the same format

In [7]:
import h5py
import numpy as np
import os
from glob import glob
from tqdm import tqdm

In [8]:
allImages = sorted(glob("/study/mrphys/skunkworks/training_data//mover01/*/", recursive=True))
try:
    os.mkdir('/scratch/mrphys/denoised')
except:
    pass

for i, imgIndex in tqdm(enumerate(range(len(allImages)))):
    name = allImages[imgIndex].split('/')[-2]
    with h5py.File(f'/scratch/mrphys/denoised/denoised_{name}.h5','w') as f:
        pred = np.load(f'pred/denoised_{i}.npy')
        temp = pred.astype(np.dtype([('real','f'),('imag','f')]))
        temp['imag'] = pred.imag
        pred = temp
        grp = f.create_group('Images')
        for n in range(6):
            grp.create_dataset('C_000_0'+ str(n).zfill(2), data=pred[n])

  temp = pred.astype(np.dtype([('real','f'),('imag','f')]))
65it [02:22,  2.19s/it]


In [None]:
with h5py.File(path,'r') as f:
    print(f.keys())
    print(f['Images']['C_000_001'])

In [None]:
with h5py.File(f'/scratch/mrphys/skunkworks/denoised/denoised_{name}.h5','r') as f:
    print(f.keys())
    print(f['Images']['C_000_001'])

In [None]:
import torch.distributed as dist
import torch
import os

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '6969'
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()
    
n_gpus = torch.cuda.device_count()
rank = 4
print(rank)
setup(rank, n_gpus)
rank = dist.get_rank()
print(rank)

In [5]:
import pickle
from torch.utils.data import Dataset
class mriSliceDataset(Dataset):
    def __init__(self, sample):
        self.originalPathList = []
        self.accelPathList = []
        self.originalFileList = []
        self.accelFileList = []

        allImages = sorted(glob("/study/mrphys/skunkworks/training_data//mover01/*/", recursive=True))
        folderName  = allImages[sample]
        self.originalPathList.append(folderName + 'processed_data/C.h5')
        self.accelPathList.append(folderName +'processed_data/acc_2min/C.h5')
        
        for originalPath, accelPath in zip(self.originalPathList, self.accelPathList):
            self.originalFileList+= list(getComplexSlices(originalPath))
            self.accelFileList+= list(getComplexSlices(accelPath))
            print('Image ' + originalPath + ' loaded')

    def __getitem__(self, index):
        return self.accelFileList[index], self.originalFileList[index]

    def __len__(self):
        return len(self.accelFileList)
with open(f'/scratch/mrphys/pickled/dataset_{0}.pickle', 'rb') as f:
    data = pickle.load(f)

In [6]:
for X, Y in data:
    print(X.shape)

(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 256)
(16, 256, 

In [3]:
import sys
sys.path.insert(0,"/study3/mrphys/skunkworks/kk/mriUnet")
import unet
import torch
from torch import nn
model = unet.PatchGAN(
    14,
    f_maps=32,
    layer_order=['separable convolution', 'relu', 'batch norm'],
    depth=4,
    layer_growth=2.0,
    residual=True,
    complex_input=False,
    complex_kernel=False,
    ndims=2,
    padding=1
)

In [4]:
device = "cpu" #torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
inp = torch.randn(1,7,256,256)
inp = inp.to(device)
pred = model(inp, inp)

In [4]:
pred

tensor([[[[-2.7383e-01, -3.1528e-01, -1.0539e+00,  ..., -2.9841e-01,
           -4.1028e-02, -8.5437e-02],
          [ 1.3807e-01,  3.1959e-01, -1.9353e-02,  ...,  6.0839e-02,
           -5.8703e-01,  1.2111e-04],
          [-8.2378e-01,  5.7863e-02,  1.4145e+00,  ..., -2.4023e-02,
           -1.4754e-01, -6.0467e-01],
          ...,
          [-1.1502e-01, -2.8118e-01,  6.2925e-02,  ...,  2.1677e-01,
            6.0013e-01, -7.3894e-01],
          [ 2.2050e-02,  7.1407e-01,  5.9331e-02,  ...,  1.6346e-01,
           -6.3900e-01, -1.0993e+00],
          [-2.7264e-01,  2.7381e-02, -8.0088e-01,  ...,  3.9469e-01,
           -4.2766e-01, -4.0391e-01]]]], grad_fn=<ConvolutionBackward0>)

In [6]:
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
y = inp
ssim_loss = (1-ms_ssim(pred[:,:,:,:,0], y.real, data_range=1, size_average=False)).mean() + (1-ms_ssim(pred[:,:,:,:,1], y.imag, data_range=1, size_average=False)).mean()

In [7]:
!nvidia-smi

Tue Mar 21 13:12:57 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Quadro RTX 8000                 Off| 00000000:01:00.0 Off |                  Off |
| 43%   69C    P2              256W / 260W|  42793MiB / 49152MiB |     98%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Quadro RTX 8000                 Off| 00000000:24:00.0 Off |  