In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from skimage.restoration import (denoise_tv_chambolle, denoise_bilateral,
                                 denoise_wavelet, estimate_sigma)
import warnings
from torch import optim
from torch.optim.lr_scheduler import StepLR



In [2]:
from config import *
from anatomy import *
from renderer import *
from siren import *

In [3]:
# config = Config(np.array([[0.3,0.6]]), TYPE=0, NUM_HEART_BEATS=2.0, NUM_SDFS=2)
# body = Body(config, [Organ(config,[0.6,0.6],0.2,0.2,'pseudo_heart','const2'),
#                      Organ(config,[0.2,0.2],0.1,0.1,'simple_sin','const2')])

config = Config(np.array([[0.5]]), TYPE=0, NUM_HEART_BEATS=2.0, NUM_SDFS=1)
body = Body(config, [Organ(config,[0.6,0.6],0.2,0.2,'pseudo_heart','const2')])

In [4]:
def get_pretraining_sdfs(config, sdf=None):
    
    
    if sdf is None:
        pretraining_sdfs = np.zeros((config.IMAGE_RESOLUTION,config.IMAGE_RESOLUTION,config.TOTAL_CLICKS,config.NUM_SDFS))
        for i in range(config.NUM_SDFS):
            cfg = Config(np.array([[np.random.rand()]]), config.TYPE, config.NUM_HEART_BEATS, 1)
            organ = Organ(cfg, [np.random.rand()*0.7+0.1,
                                np.random.rand()*0.7+0.1], 0.1, 0.1, 'simple_sin', 'simple_sin2')
            body = Body(cfg,[organ])
            sdf = SDFGt(cfg, body)
            all_thetas = np.linspace(0., config.THETA_MAX, config.TOTAL_CLICKS)
            for j in range(config.TOTAL_CLICKS):
                pretraining_sdfs[...,j,i] = denoise_tv_chambolle(sdf(all_thetas[j])[...,0].detach().cpu().numpy())

    else:
        pretraining_sdfs = np.zeros((config.IMAGE_RESOLUTION,config.IMAGE_RESOLUTION,config.TOTAL_CLICKS,config.NUM_SDFS))
        for i in range(config.NUM_SDFS):
            all_thetas = np.linspace(0., config.THETA_MAX, config.TOTAL_CLICKS)
            for j in range(config.TOTAL_CLICKS):
                pretraining_sdfs[...,j,i] = occ_to_sdf(np.round(denoise_tv_chambolle(sdf_to_occ(sdf(all_thetas[j]))[...,0].detach().cpu().numpy()))[...,np.newaxis])[...,0]
            
    return pretraining_sdfs

pretraining_sdfs = get_pretraining_sdfs(config)
print(np.mean(np.sqrt(np.gradient(pretraining_sdfs,axis=0)**2 + np.gradient(pretraining_sdfs,axis=1)**2)))
print(np.mean(np.abs(np.gradient(pretraining_sdfs,axis=2))))

0.999576892943909
0.02488891680498027


In [5]:
class FourierFeatures(nn.Module):
    '''
    Learning a function as a fourier series
    Refer: https://colab.research.google.com/github/ndahlquist/pytorch-fourier-feature-networks/blob/master/demo.ipynb#scrollTo=QDs4Im9WTQoy
    '''
    
    def __init__(self, input_channels, output_channels, mapping_size = 128, scale=2.5, testing=False):
        super(FourierFeatures, self).__init__()
        
        assert isinstance(input_channels, int), 'input_channels must be an integer'
        assert isinstance(output_channels, int), 'output_channels must be an integer'
        assert isinstance(mapping_size, int), 'maping_size must be an integer'
        assert isinstance(scale, float), 'scale must be an float'
        assert isinstance(testing, bool), 'testing should be a bool'
        
        self.mapping_size = mapping_size
        self.output_channels = output_channels
        self.testing = testing
        
        if self.testing:
            self.B = torch.ones((1, self.mapping_size, self.output_channels))
        else:
            self.B = torch.randn((1, self.mapping_size, self.output_channels))*scale
            
        self.B = self.B.cuda()
        self.net = Siren(input_channels,128,3,(2*self.mapping_size+1)*self.output_channels)
        
    def forward(self, x, t):
        
        assert isinstance(x, torch.Tensor) and len(x.shape) == 2, 'x must be a 2D tensor'
        assert isinstance(t, torch.Tensor) or isinstance(t, float) and t>=-1 and t <=1, 't must be a float between -1 and 1'

        if self.testing:
            fourier_coeffs = torch.ones((x.shape[0],self.mapping_size*2+1, self.output_channels)).type_as(x)
        else:
            fourier_coeffs = self.net(x).view(-1, self.mapping_size*2+1, self.output_channels)
            
        fourier_coeffs_dc = fourier_coeffs[:,-1:,:]
        fourier_coeffs_ac = fourier_coeffs[:,:-1,:]
        
        assert fourier_coeffs_dc.shape == (x.shape[0], 1, self.output_channels), 'Inavild size for fourier_coeffs_dc : {}'.format(fourier_coeffs_dc.shape)
        assert fourier_coeffs_ac.shape == (x.shape[0], self.mapping_size*2, self.output_channels),  'Inavild size for fourier_coeffs_ac : {}'.format(fourier_coeffs_ac.shape)

        t = (2*np.pi*t*self.B).repeat(x.shape[0],1,1)
        
        tsins = torch.cat([torch.sin(t), torch.cos(t)], dim=1).type_as(x)

        assert tsins.shape == (x.shape[0],2*self.mapping_size,self.output_channels)
        series = torch.mul(fourier_coeffs_ac, tsins)
        assert series.shape ==  (x.shape[0],2*self.mapping_size,self.output_channels)
        val_t = torch.mean(series, dim=1, keepdim=True)
        assert val_t.shape == (x.shape[0],1,self.output_channels)
        val_t = val_t + fourier_coeffs_dc
        assert val_t.shape == (x.shape[0],1,self.output_channels)
        
        return val_t.squeeze(1)
    
ff =  FourierFeatures(2, 2, testing=True).cuda()
x = torch.Tensor([[0.0,0.0],[0.0,1.0],[1.0,0.0],[1.0,1.0]]).cuda()
t = 1.0
val_t = ff(x,t)
print(torch.norm(val_t - 1.5*torch.ones(x.shape).type_as(x)))

tensor(0., device='cuda:0')


In [6]:
class SDFNCT(SDF):
    def __init__(self, config):
        super(SDFNCT, self).__init__()
        
        assert isinstance(config, Config), 'config must be an instance of class Config'
        
        self.config = config
        x,y = np.meshgrid(np.linspace(0,1,self.config.IMAGE_RESOLUTION),np.linspace(0,1,self.config.IMAGE_RESOLUTION))
        self.pts = torch.autograd.Variable(2*(torch.from_numpy(np.hstack((x.reshape(-1,1),y.reshape(-1,1)))).cuda().float()-0.5),requires_grad=True)
        
        self.encoder = Siren(2,256,3,config.NUM_SDFS).cuda()
        self.velocity = FourierFeatures(2,config.NUM_SDFS).cuda()
        
    def compute_sdf_t(self, t):
        assert isinstance(t, torch.Tensor) or isinstance(t, float), 't = {} must be a float or a tensor here'.format(t)
        assert t >= -1 and t <= 1, 't = {} is out of range'.format(t)
        
        displacement = self.velocity(self.pts, t)
        init_sdf = self.encoder(self.pts)
        assert init_sdf.shape == displacement.shape
        
        canvas = (init_sdf + displacement)*self.config.SDF_SCALING
        if not (torch.min(canvas) < -1 and torch.max(canvas) > 1):
            warnings.warn('SDF values are in a narrow range between (-1,1)')
            
        canvas = canvas.view(self.config.IMAGE_RESOLUTION,self.config.IMAGE_RESOLUTION,self.config.NUM_SDFS)
        
        return canvas
            
    def forward(self, t):
        
        assert isinstance(t, float), 't = {} must be a float here'.format(t)
        assert t >= -self.config.THETA_MAX and t <= self.config.THETA_MAX, 't = {} is out of range'.format(t)
        
        t = 2*get_phase(self.config,t) - 1
        
        canvas = self.compute_sdf_t(t)        
        assert len(canvas.shape) == 3, 'Canvas must be a 3D tensor, instead is of shape: {}'.format(canvas.shape)
        
        return canvas
    
    def grad(self, t):
        
        assert isinstance(t, float), 't = {} must be a float here'.format(t)
        assert t >= -self.config.THETA_MAX and t <= self.config.THETA_MAX, 't = {} is out of range'.format(t)
        
        t = torch.autograd.Variable(torch.Tensor([2*get_phase(self.config,t) - 1]).cuda().float(),requires_grad=True)
        
        canvas = self.compute_sdf_t(t)/self.config.SDF_SCALING
        
        dc_dxy = gradient(canvas, self.pts)
        assert len(dc_dxy.shape) == 2, 'Must be a 2D tensor, instead is {}'.format(dc_dxy.shape)

        occupancy = sdf_to_occ(canvas)
        assert len(occupancy.shape) == 3
        
        do_dxy = gradient(occupancy, self.pts)
        assert len(do_dxy.shape) == 2, 'Must be a 2D tensor, instead is {}'.format(do_dxy.shape)
        
        dc_dt = gradient(canvas, t)/(np.prod(canvas.shape))
        assert len(dc_dt.shape) == 1, 'Must be a 1D tensor, instead is {}'.format(dc_dt.shape)
        
        eikonal = torch.abs(torch.norm(dc_dxy, dim=1) - 1).mean()
        total_variation_space = torch.norm(do_dxy, dim=1).mean()
        total_variation_time = torch.abs(dc_dt)
        
        return eikonal, total_variation_space, total_variation_time

In [7]:
def pretrain_sdf(config, sdf = None, lr = 1e-4):
    

    pretraining_sdfs = get_pretraining_sdfs(config, sdf)
        
    assert len(pretraining_sdfs.shape) == 4, 'Invalid shape : {}'.format(pretraining_sdfs.shape)
    
    sdf = SDFNCT(config)
    gt = torch.from_numpy(pretraining_sdfs).cuda()
    
    optimizer = optim.Adam(list(sdf.parameters()), lr = lr)
    scheduler = StepLR(optimizer, step_size=1, gamma=0.95)
    
    for itr in range(1000):
        optimizer.zero_grad()
        t = np.random.randint(0,config.TOTAL_CLICKS,1)[0]
        theta = t*(config.THETA_MAX/config.TOTAL_CLICKS)
        
        pred = sdf(theta)
        target = gt[...,t,:]
        assert target.shape == pred.shape, 'target has shape : {} while prediction has shape :{}'.format(target.shape, pred.shape)
        eikonal, _, _ = sdf.grad(theta)

        loss1 = torch.abs(pred - target).mean()
        loss = loss1 + 0.1*eikonal
        loss.backward()
        optimizer.step()
        
        if itr %200 == 0:
            print('itr: {}, loss: {:.4f}, lossP: {:.4f}, lossE: {:.4f}, lr: {:.4f}'.format(itr, loss.item(), loss1.item(), 
                                                                                           eikonal.item(),scheduler.get_last_lr()[0]*10**4))
            scheduler.step()
            
    return sdf

In [8]:
sdf = pretrain_sdf(config)



itr: 0, loss: 27.6447, lossP: 27.5913, lossE: 0.5346, lr: 1.0000
itr: 200, loss: 0.8892, lossP: 0.8589, lossE: 0.3036, lr: 0.9500
itr: 400, loss: 0.4276, lossP: 0.3988, lossE: 0.2876, lr: 0.9025
itr: 600, loss: 0.7906, lossP: 0.7633, lossE: 0.2726, lr: 0.8574
itr: 800, loss: 0.4164, lossP: 0.3925, lossE: 0.2394, lr: 0.8145


In [9]:
def get_gt_sinogram(config, body):
    sdf = SDFGt(config, body)
    renderer = Renderer(config, sdf)
    all_thetas = np.linspace(0,config.THETA_MAX, config.TOTAL_CLICKS)
    sinogram = renderer.forward(all_thetas).detach().cpu().numpy()
    
    return sinogram

gt_sinogram = torch.from_numpy(get_gt_sinogram(config, body)).cuda()

In [10]:
def train(config, sdf, gt_sinogram, lr=1e-4):
        
    renderer = Renderer(config, sdf)
    
    optimizer = optim.Adam(list(sdf.parameters()), lr = lr)
    scheduler = StepLR(optimizer, step_size=1, gamma=0.95)
    
    for itr in range(1000):
        optimizer.zero_grad()
        t = np.random.randint(0,config.TOTAL_CLICKS,config.BATCH_SIZE)
        theta = t*(config.THETA_MAX/config.TOTAL_CLICKS)
        pred = renderer(theta)
        target = gt_sinogram[:,t]
        loss1 = torch.abs(pred - target).mean()*100
        
        eikonal, total_variation_space, total_variation_time = sdf.grad(theta[0])
        assert target.shape == pred.shape, 'target has shape : {} while prediction has shape :{}'.format(target.shape, pred.shape)
        
        loss = loss1 + 0.1*eikonal + 0.01*total_variation_space + 0.01*total_variation_time
        
        loss.backward()
        optimizer.step()
        
        if itr %200 == 0:
            print('itr: {}, loss: {:.4f}, lossP: {:.4f}, lossE: {:.4f}, lossTVs: {:.4f}, lossTVt: {:.4f}, lr: {:.4f}'.format(itr, loss.item(), loss1.item(), eikonal.item(), 
                                     total_variation_space.item(), total_variation_time.item(),scheduler.get_last_lr()[0]*10**4))
            scheduler.step()
            
    return sdf
#         break
        

In [None]:
sdf = train(config, sdf, gt_sinogram)
sdf = pretrain_sdf(config, sdf, lr = 5e-5)
sdf = train(config, sdf, gt_sinogram, lr=5e-5)
# sdf = pretrain_sdf(config, sdf, lr = 5e-5)
# sdf = train(config, sdf, gt_sinogram, lr=1e-5)

itr: 0, loss: 0.2355, lossP: 0.1948, lossE: 0.3163, lossTVs: 0.6157, lossTVt: 0.2867, lr: 1.0000
itr: 200, loss: 0.2371, lossP: 0.1754, lossE: 0.4858, lossTVs: 1.1495, lossTVt: 0.1601, lr: 0.9500
itr: 400, loss: 0.1922, lossP: 0.1562, lossE: 0.3166, lossTVs: 0.4014, lossTVt: 0.0389, lr: 0.9025
itr: 600, loss: 0.1831, lossP: 0.1358, lossE: 0.3293, lossTVs: 0.7948, lossTVt: 0.6506, lr: 0.8574
itr: 800, loss: 0.1642, lossP: 0.1308, lossE: 0.2889, lossTVs: 0.4402, lossTVt: 0.0163, lr: 0.8145




itr: 0, loss: 21.9571, lossP: 21.9018, lossE: 0.5535, lr: 0.5000
itr: 200, loss: 0.5452, lossP: 0.5168, lossE: 0.2842, lr: 0.4750


In [12]:
def fetch_movie(config, sdf):
    frames = np.zeros((config.IMAGE_RESOLUTION,config.IMAGE_RESOLUTION,config.TOTAL_CLICKS,config.NUM_SDFS))
    for i in range(config.NUM_SDFS):
        all_thetas = np.linspace(0., config.THETA_MAX, config.TOTAL_CLICKS)
        for j in range(config.TOTAL_CLICKS):
            frames[...,j,i] = sdf_to_occ(sdf(all_thetas[j]))[...,0].detach().cpu().numpy()
            
    intensities = config.INTENSITIES.reshape(1,1,1,-1)
    movie = np.sum(frames*intensities, axis=3)
    
    return movie       

In [13]:
movie = fetch_movie(config, sdf)

In [14]:
for i in range(movie.shape[2]):
    plt.imsave('movie/{}.png'.format(i), movie[...,i], cmap='gray')

In [15]:
import os
os.system('zip -r movie.zip movie/')

0

In [89]:
img = sdf(10.0).detach().cpu().numpy()

In [17]:
%matplotlib notebook

plt.imshow(movie[...,0])
plt.show()

<IPython.core.display.Javascript object>

In [None]:
sdfnct = SDFNCT(config).cuda()

In [None]:
# sdfnct(0.0).shape
sdfnct.grad(0.0)

In [22]:
# val_t
# np.pi
img = np.load('test_outputs/renderer_forward.npy')
img.shape

(64, 720)

In [23]:
%matplotlib notebook

plt.imshow(img)
plt.show()
# plt.figure()
# plt.imshow(img[...,1])
# plt.show()

<IPython.core.display.Javascript object>

In [175]:
class test(nn.Module):
    def __init__(self):
        super(test, self).__init__()
#          self.l = torch.nn.Linear(3,10)
        
    def forward(self, x):
        
        return torch.sqrt(x[:,0]**2 + x[:,1]**2 + x[:,2]**2)

In [186]:
x = torch.autograd.Variable(torch.Tensor([[1,1,1],[0,1,0]]),requires_grad=True)

In [187]:
model = test()

In [188]:
y = model(x)

In [190]:
torch.norm(gradient(y,x), dim=1) - 1

tensor([-5.9605e-08,  0.0000e+00], grad_fn=<SubBackward0>)

In [174]:
y.shape

torch.Size([2])