In [1]:
from tumor_mass_effect import semi_implicit_solver
from medpy.io import load
import numpy as np
import math
import torch
import copy
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
%load_ext autoreload
%autoreload 2

## All parameters

In [2]:
epsilon = 0.9
anatomy_folder = '/media/Drives/Data/leon_tumor/Atlas/anatomy/'
Dg = 1.52e-08
Dw = 1.52e-07
dt = 4.0
rho = 0.1
dx = 0.002
E = np.float32(np.asarray([2100, 2100, 100, 8000]))
nu = np.float32(np.asarray([0.4, 0.4, 0.1, 0.45]))
gamma = 16000.
MaxIter = 20

In [3]:
wm, _ = load(anatomy_folder+'WM.nii.gz')
gm, _ = load(anatomy_folder+'GM.nii.gz')
csf, _ = load(anatomy_folder+'CSF.nii.gz')

## Simulator

#
\begin{aligned}
\partial_{t} c+\operatorname{div}(c \boldsymbol{v})-\kappa \mathcal{D} c-\rho \mathcal{R} c &=0, \quad c(0)=c_{0}  & \text { in } \Omega \times[0,1] \\
\partial_{t} m_i+\nabla \cdot\left(m_i \boldsymbol{v}\right) &=0, \quad m_i(0)=m_{i0} & \text { in } \Omega \times[0,1] \\
\left(\lambda +\mu\right)\Delta \boldsymbol{u}+\mu\nabla \nabla\cdot \boldsymbol{u}  &=\gamma \nabla c & \text { in } \Omega \times[0,1] \\
\partial_{t} \boldsymbol{u} &=\boldsymbol{v},  \quad \boldsymbol{u}(0)=\mathbf{0} & \text { in } \Omega \times[0,1]
\end{aligned}

In [4]:
tomor_solver = semi_implicit_solver(Dw, rho, dx, dt, E, nu, gamma, MaxIter, epsilon)

## Variable intialization

In [5]:
m = Variable(torch.tensor(np.concatenate([wm[np.newaxis, :], gm[np.newaxis, :], csf[np.newaxis, :]], 0)).cuda().unsqueeze(0), requires_grad=False)
m = F.interpolate(F.pad(m,(32,31,14,13,32,31)), scale_factor=0.5, mode='trilinear', align_corners=True)
m_init = copy.deepcopy(m)
u = Variable(torch.zeros(m.shape).cuda(), requires_grad=False)
v = Variable(torch.zeros(m.shape).cuda(), requires_grad=False)
c = Variable(torch.zeros(m.shape).cuda().sum(1), requires_grad=False).unsqueeze(1)
c[..., 50, 75, 64] = 1.0



## Run simulation

In [6]:
#can be used to single out specific data
stopit = 4
c_list = [c.detach().data.cpu().numpy()]
m_list = [m.detach().data.cpu().numpy()]
u_list = []
Tmax = 300
phi_brain = torch.sum(m, 1, keepdim=True)
for t in range(Tmax):
    c, m, u, v = tomor_solver.solver_step(c, m, u, v, phi_brain, m_init)
    c_list.append(c.detach().data.cpu().numpy())
    m_list.append(m.detach().data.cpu().numpy())
#     u_list.append(u.detach().data.cpu().numpy())
    print("Finished time step", t+1)
    if math.isnan(u.detach().data.cpu().numpy().sum()):
        break

Finished time step 1
Finished time step 2
Finished time step 3
Finished time step 4
Finished time step 5
Finished time step 6
Finished time step 7
Finished time step 8
Finished time step 9
Finished time step 10
Finished time step 11
Finished time step 12
Finished time step 13
Finished time step 14
Finished time step 15
Finished time step 16
Finished time step 17
Finished time step 18
Finished time step 19
Finished time step 20
Finished time step 21
Finished time step 22
Finished time step 23
Finished time step 24
Finished time step 25
Finished time step 26
Finished time step 27
Finished time step 28
Finished time step 29
Finished time step 30
Finished time step 31
Finished time step 32
Finished time step 33
Finished time step 34
Finished time step 35
Finished time step 36
Finished time step 37
Finished time step 38
Finished time step 39
Finished time step 40
Finished time step 41
Finished time step 42
Finished time step 43
Finished time step 44
Finished time step 45
Finished time step 

KeyboardInterrupt: 

## Plot tumor

In [7]:
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons
%matplotlib tk
from ipywidgets import *
plt_list = c_list
fig, ax = plt.subplots(1,4,dpi=150)
plt.subplots_adjust(left=0.25, bottom=0.25)
l1 = ax[0].imshow(c_list[0][0,0,:,:,64])
l2 = ax[1].imshow(m_list[0][0,0,:,:,64])
l3 = ax[2].imshow(m_list[0][0,1,:,:,64])
l4 = ax[3].imshow(m_list[0][0,2,:,:,64])
# ax[0].margins(x=0)

axcolor = 'lightgoldenrodyellow'
axfreq = plt.axes([0.25, 0.1, 0.65, 0.03], facecolor=axcolor)
axamp = plt.axes([0.25, 0.15, 0.65, 0.03], facecolor=axcolor)

stime_ = Slider(axfreq, 'Time', valmin=1, valmax=len(plt_list), valinit=5, valstep=1)
sslice_ = Slider(axamp, 'Slice', valmin=1, valmax=plt_list[0].shape[4], valinit=5, valstep=1)

def update(val):
    time_ = stime_.val
    slice_ = sslice_.val
    l1.set_data(c_list[int(time_)-1][0,0,:,:,int(slice_)-1])
    l2.set_data(m_list[int(time_)-1][0,0,:,:,int(slice_)-1])
    l3.set_data(m_list[int(time_)-1][0,1,:,:,int(slice_)-1])
    l4.set_data(m_list[int(time_)-1][0,2,:,:,int(slice_)-1])
    fig.canvas.draw_idle()
#     print('Volume Sum:', plt_list[int(time_)-1][:,2,...].sum())

stime_.on_changed(update)
sslice_.on_changed(update)

plt.show()