In [6]:
import argparse, os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [7]:
parser = argparse.ArgumentParser(description="IMDN")
parser.add_argument("--block_size", type=tuple, default=(64,64,64),
                    help="Block Size")
parser.add_argument("--crop_depth", type=int, default=30,
                    help="crop across z-axis")
parser.add_argument("--dir", type=str,
                    help="dataset_directory")
parser.add_argument("--batch_size", type=int,
                    help="dataset_directory")
parser.add_argument("--sort", type=bool,
                    help="dataset_directory")
parser.add_argument("--debug", type=bool,
                    help="dataset_directory")
parser.add_argument("--preload", type=bool,
                    help="dataset_directory")
args = list(parser.parse_known_args())[0]
args.preload = True
args.debug = False
args.dir = "/storage"
args.batch_size = 4
args.sort = True
args.typ = 'upsampled'
print(args)

Namespace(batch_size=4, block_size=(64, 64, 64), crop_depth=30, debug=False, dir='/storage', preload=True, sort=True, typ='upsampled')


In [8]:
import data.HCP_dataset_h5
import utils
import data.HCP_dataset_h5_test

ids = utils.get_ids()
ids.sort()
ids = ids[:2]
training_dataset = data.HCP_dataset_h5.hcp_data(args,ids)
testing_dataset = data.HCP_dataset_h5_test.hcp_data(args,ids)

number of common Subjects  171
raw (173, 207, 173, 7)
ADC (173, 207, 173)
FA (173, 207, 173)
RGB (173, 207, 173, 3)
raw (173, 207, 173, 7)
ADC (173, 207, 173)
FA (173, 207, 173)
RGB (173, 207, 173, 3)


In [9]:
len(testing_dataset)

2

In [10]:
x = testing_dataset[0]

In [11]:
x[0].shape,x[1].shape,x[2].shape,x[3].shape

((173, 207, 173, 8), (173, 207, 173), (173, 207, 173), (173, 207, 173, 3))

In [12]:
len(training_dataset)

45

In [13]:
import math
import torch.nn.functional as F
def pad(x):
        _, _, h, w = x.shape
        w_mult = ((w - 1) | 15) + 1
        h_mult = ((h - 1) | 15) + 1
        w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)]
        h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)]
        #print(w_pad,h_pad)
        # # TODO: fix this type when PyTorch fixes theirs
        # # the documentation lies - this actually takes a list
        # # https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457
        # # https://github.com/pytorch/pytorch/pull/16949
        x = F.pad(x, w_pad + h_pad)
        return x, (h_pad, w_pad, h_mult, w_mult)

def unpad(x,h_pad,w_pad,h_mult,w_mult):
    return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]]

def resize(data):
    x,y = [],[]
    for i in range(len(data)):
        x.append(data[i][0].reshape((data[i][0].shape[0]*data[i][0].shape[1],data[i][0].shape[2],data[i][0].shape[3])))
        y.append(np.concatenate([np.expand_dims(data[i][1],axis = 3),np.expand_dims(data[i][2],axis = 3),data[i][3]], axis=3))
#         torch.cat([data[i][3],data[i][2].unsqueeze(3),data[i][1].unsqueeze(3)],axis =3))
#         print(data[i][1].shape,data[i][2].shape,data[i][3].shape)
#         y.append(np.stack([data[i][1],data[i][2],data[i][3]]))
    return torch.from_numpy(np.stack(x)),torch.from_numpy(np.stack(y))

In [14]:
training_dataset.blk_indx

array([27, 45])

In [15]:
x = testing_dataset[0]

In [16]:
x[0].shape,x[1].shape,x[2].shape,x[3].shape

((173, 207, 173, 8), (173, 207, 173), (173, 207, 173), (173, 207, 173, 3))

In [17]:
training_data_loader = DataLoader(dataset=training_dataset, batch_size=12, shuffle=True, pin_memory=True, drop_last=True,collate_fn=resize)
testing_data_loader = DataLoader(dataset=testing_dataset, batch_size=1,pin_memory=True,collate_fn=resize)

In [18]:
temp = next(iter(testing_data_loader))

In [19]:
temp[0].shape

torch.Size([1, 35811, 173, 8])

In [20]:
# torch.permute(temp[0],(0,3,1,2)).shape

In [21]:
temp[1].shape

torch.Size([1, 173, 207, 173, 5])

In [22]:
training_data_loader = DataLoader(dataset=training_dataset, batch_size=40, shuffle=True, pin_memory=True, drop_last=True,collate_fn=resize)

In [23]:
for iteration, (lr_tensor, hr_tensor) in enumerate(training_data_loader, 1):
    print(lr_tensor.shape,hr_tensor.shape)
    break

torch.Size([40, 4096, 64, 8]) torch.Size([40, 64, 64, 64, 5])


In [24]:
from Meta_SR_Pytorch.model import metarcan 

parser = argparse.ArgumentParser(description="metarcan")
parser.add_argument("--reduction", type=int,
                    help="dataset_directory")
parser.add_argument("--n_resblocks", type=int,
                    help="dataset_directory")
parser.add_argument("--n_resgroups", type=int,
                    help="dataset_directory")
parser.add_argument("--n_feats", type=int,
                    help="dataset_directory")
parser.add_argument("--scale", type=int,
                    help="dataset_directory")
parser.add_argument("--rgb_range", type=int,
                    help="dataset_directory")
args = list(parser.parse_known_args())[0]
args.reduction = 16
args.n_resgroups = 8
args.n_resblocks = 8
args.n_feats = 4
args.scale = [1.2]
args.n_colors = 8
args.rgb_range = 1
args.res_scale = 1.2
print(args)

Namespace(n_colors=8, n_feats=4, n_resblocks=8, n_resgroups=8, reduction=16, res_scale=1.2, rgb_range=1, scale=[1.2])


In [25]:
# temp_model = metarcan.RCAN(args)

In [26]:
temp = next(iter(training_data_loader))
temp[0] = torch.permute(temp[0],(0,3,1,2))
temp[0].shape

torch.Size([40, 8, 4096, 64])

In [27]:
temp[1].shape

torch.Size([40, 64, 64, 64, 5])

In [28]:
from deep_cascade_caunet.models import CSEUnetModel

In [29]:
model = CSEUnetModel(in_chans = 8,out_chans = 5,chans = 4,num_pool_layers = 2,drop_prob=0.2,reduction=4)

In [30]:
out = model(temp[0])

In [31]:
out.shape

torch.Size([40, 5, 4096, 64])

In [32]:
out_temp = torch.permute(out,(0,2,3,1))

In [33]:
out_temp.shape

torch.Size([40, 4096, 64, 5])

In [34]:
out_out = out_temp.reshape(40,64,64,64,5)

In [35]:
out_out.shape

torch.Size([40, 64, 64, 64, 5])

In [36]:
x[0].shape,x[1].shape,x[2].shape,x[3].shape

((173, 207, 173, 8), (173, 207, 173), (173, 207, 173), (173, 207, 173, 3))

In [44]:
temp = next(iter(testing_data_loader))
temp[0] = torch.permute(temp[0],(0,3,1,2))
temp[0].shape

torch.Size([1, 8, 35811, 173])

In [46]:
t= pad(temp[0])

In [47]:
t[0].shape

torch.Size([1, 8, 35824, 176])

In [48]:
y = model(t[0])

In [49]:
y.shape

torch.Size([1, 5, 35824, 176])

In [50]:
t[1]

([6, 7], [1, 2], 35824, 176)

In [51]:
u = unpad(y,t[1][0],t[1][1],t[1][2],t[1][3])

In [52]:
u.shape

torch.Size([1, 5, 35811, 173])

In [68]:
import monai

In [69]:
ssim = monai.losses.ssim_loss.SSIMLoss(spatial_dims =3)

In [53]:
out = u.reshape(1, 173, 207, 173, 5)
out = torch.permute(out,(0,4,1,2,3))

In [54]:
out.shape

torch.Size([1, 5, 173, 207, 173])

In [55]:
temp[1].shape
hr = torch.permute(temp[1],(0,4,1,2,3))

In [56]:
hr.shape,out.shape

(torch.Size([1, 5, 173, 207, 173]), torch.Size([1, 5, 173, 207, 173]))

In [None]:
ssim(hr,out)

In [64]:
from torchmetrics import StructuralSimilarityIndexMeasure
import torch
ssim = StructuralSimilarityIndexMeasure(data_range=1.0)