In [None]:
import os
import time 
import numpy as np 
import torch
import torch.optim as optim
import random
from torch.utils import data
from loader.COSMOS_data_loader import COSMOS_data_loader
from models.unet import Unet
from utils.train import BayesianQSM_train
from utils.medi import *
from utils.data import *
from utils.files import *

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
rootDir = '/data/Jinwei/Bayesian_QSM/weight'

In [None]:
patchSize = (64, 64, 64)
extraction_step = (21, 21, 21)

In [None]:
dataLoader = COSMOS_data_loader(
    split='Test',
    case_validation=7,
    case_test = 6,
    test_dir=2,
    patchSize=patchSize, 
    extraction_step=extraction_step
)
testLoader = data.DataLoader(dataLoader, batch_size=1, shuffle=False)

In [None]:
unet3d = Unet(input_channels=1, output_channels=2, num_filters=[2**i for i in range(5, 10)])
unet3d.to(device)
unet3d.load_state_dict(torch.load(rootDir+'/weights.pt'))

In [None]:
patches_means, patches_stds = [], []
for idx, (rdfs, masks, weights, qsms) in enumerate(testLoader):
   
    rdfs = rdfs.to(device)
    means = unet3d(rdfs)[:, 0, ...]
    stds = unet3d(rdfs)[:, 1, ...]
    
    means = np.asarray(means.cpu().detach())
    stds = np.asarray(stds.cpu().detach())
    
    patches_means.append(means)
    patches_stds.append(stds)
    
patches_means = np.concatenate(patches_means, axis=0)
patches_stds = np.concatenate(patches_stds, axis=0)

In [None]:
QSM = reconstruct_patches(patches_means, dataLoader.volSize, extraction_step)
STD = reconstruct_patches(patches_stds, dataLoader.volSize, extraction_step)

In [None]:
adict = {}
adict['QSM'] = QSM
sio.savemat('QSM.mat', adict)

adict = {}
adict['STD'] = STD
sio.savemat('STD.mat', adict)