In [1]:
import argparse

parser = argparse.ArgumentParser(description="DTI_ARB")
parser.add_argument("--block_size", type=tuple, default=(16,16,16),
                    help="Block Size")
parser.add_argument("--test_block_size", type=tuple, default=(16,16,16),
                    help="Block Size")
parser.add_argument("--crop_depth", type=int, default=15,
                    help="crop across z-axis")
parser.add_argument("--dir", type=str,
                    help="dataset_directory")
parser.add_argument("--batch_size", type=int,
                    help="Batch_size")
parser.add_argument("--sort", type=bool,
                    help="Sort Subject Ids")
parser.add_argument("--debug", type=bool,
                    help="Print additional input")
parser.add_argument("--preload", type=bool,
                    help="Preload data into memory")
parser.add_argument("--ret_points", type=bool, default=False,
                    help="return box point of crops")
parser.add_argument("--thres", type=float, default=0.6,
                    help="threshold for blk emptiness")
parser.add_argument("--offset", type=int, default=20,
                    help="epoch with scale (1,1,1)")
parser.add_argument("--gaps", type=int, default=20,
                    help="number of epochs of gap between each scale change")


# Optimization specifications
parser.add_argument('--lr', type=float, default=1e-4,
                    help='learning rate')
parser.add_argument('--lr_decay', type=int, default=40,
                    help='learning rate decay per N epochs')
parser.add_argument('--decay_type', type=str, default='step',
                    help='learning rate decay type')
parser.add_argument('--gamma', type=float, default=0.5,
                    help='learning rate decay factor for step decay')
parser.add_argument('--optimizer', default='ADAM',
                    choices=('SGD', 'ADAM', 'RMSprop'),
                    help='optimizer to use (SGD | ADAM | RMSprop)')
parser.add_argument('--momentum', type=float, default=0.9,
                    help='SGD momentum')
parser.add_argument('--beta1', type=float, default=0.9,
                    help='ADAM beta1')
parser.add_argument('--beta2', type=float, default=0.999,
                    help='ADAM beta2')
parser.add_argument('--epsilon', type=float, default=1e-8,
                    help='ADAM epsilon for numerical stability')
parser.add_argument('--weight_decay', type=float, default=0,
                    help='weight decay')
parser.add_argument('--start_epoch', type=int, default=0,
                    help='resume from the snapshot, and the start_epoch')

# Loss specifications
parser.add_argument('--loss', type=str, default='1*MSE',
                    help='loss function configuration')
parser.add_argument('--skip_threshold', type=float, default='1e6',
                    help='skipping batch that has large error')


# Log specifications
parser.add_argument('--save', type=str, default='DTIArbNet',
                    help='file name to save')
parser.add_argument('--load', type=str, default='.',
                    help='file name to load')
parser.add_argument('--save_models', action='store_true',
                    help='save all intermediate models')
parser.add_argument('--resume', type=int, default=0,
                    help='resume from specific checkpoint')

parser.add_argument('--print_every', type=int, default=200,
                    help='how many batches to wait before logging training status')
parser.add_argument('--save_every', type=int, default=30,
                    help='how many batches to wait before logging training status')



# Hardware specifications
# parser.add_argument('--n_threads', type=int, default=2,
#                     help='number of threads for data loading')
parser.add_argument('--cpu', type=bool, default=False,
                    help='use cpu only')
# parser.add_argument('--n_GPUs', type=int, default=2,
#                     help='number of GPUs')
parser.add_argument('--seed', type=int, default=1,
                    help='random seed')


# Training specifications
parser.add_argument('--reset', action='store_true',
                    help='reset the training')
parser.add_argument('--pin_mem', action='store_true',
                    help='pin memory for dataloader')
parser.add_argument("--train_set", type=float, default=0.7,
                    help="percentage of data to be used for training")


# Model specifications
parser.add_argument('--model', default='dmri_arb',
                    help='model name')
parser.add_argument('--act', type=str, default='relu',
                    help='activation function')
parser.add_argument('--pre_train', type=str, default= 'None',
                    help='pre-trained model directory')
# parser.add_argument('--extend', type=str, default='.',
#                     help='pre-trained model directory')
# parser.add_argument('--res_scale', type=float, default=1,
#                     help='residual scaling')
# parser.add_argument('--shift_mean', default=True,
#                     help='subtract pixel mean from the input')
# parser.add_argument('--dilation', action='store_true',
#                     help='use dilated convolution')
parser.add_argument('--precision', type=str, default='single',
                    choices=('single', 'half'),
                    help='FP precision for test (single | half)')
args = list(parser.parse_known_args())[0]

args.preload = True
args.debug = False
args.dir = "/storage"
args.batch_size = 16
args.sort = True
args.cuda = True
args.scale = (1,1,1)
args.epochs = 400
args.gaps = 20
args.offset = 10
print(args)

Namespace(block_size=(16, 16, 16), test_block_size=(16, 16, 16), crop_depth=15, dir='/storage', batch_size=16, sort=True, debug=False, preload=True, ret_points=False, thres=0.6, offset=10, gaps=20, lr=0.0001, lr_decay=40, decay_type='step', gamma=0.5, optimizer='ADAM', momentum=0.9, beta1=0.9, beta2=0.999, epsilon=1e-08, weight_decay=0, start_epoch=0, loss='1*MSE', skip_threshold=1000000.0, save='DTIArbNet', load='.', save_models=False, resume=0, print_every=200, save_every=30, cpu=False, seed=1, reset=False, pin_mem=False, train_set=0.7, model='dmri_arb', act='relu', pre_train='None', precision='single', cuda=True, scale=(1, 1, 1), epochs=400)


In [2]:
import torch
import utility
import data
import utils
import model
import loss
from trainer import Trainer
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

ids = utils.get_ids()
ids.sort()
total_vols = 20
ids = ids[:total_vols]

if __name__ == '__main__':
    torch.manual_seed(args.seed)
    checkpoint = utility.checkpoint(args)       ## setting the log and the train information
    if checkpoint.ok:
        loader = data.Data(args,ids= ids)                ## data loader
        model = model.Model(args, checkpoint)
        loss = loss.Loss(args, checkpoint)
#         t = Trainer(args, loader, model, loss, checkpoint)
#         while not t.terminate():
#             t.train()
#             t.test()

        # checkpoint.done()



number of common Subjects  171
Loading Done
Making model...
Preparing loss function:
1.000 * MSE


In [3]:
len(loader.testing_data)

6

In [4]:
x = next(iter(loader.testing_data))

In [5]:
print(x[0].shape,x[1].shape)

torch.Size([197, 16, 16, 16, 7]) torch.Size([173, 207, 173, 5])


In [6]:
device = torch.device('cuda')

In [7]:
sca = loader.get_scale_test()
print(sca)

(1, 1, 1)


In [8]:
model.eval()
for iteration, (lr_tensor, hr_tensor,pnts,mask) in enumerate(loader.testing_data, 1):
#         pbar.update(1)
        if args.cuda:
            lr_tensor = lr_tensor.to(device)
            hr_tensor = hr_tensor.to(device)
            lr_tensor = torch.permute(lr_tensor, (0,4,1,2,3))
            with torch.no_grad():
                pred_tensor = model(lr_tensor,sca)
            
        pred_tensor = torch.permute(pred_tensor, (0,2,3,4,1))