In [26]:
from tumor_mass_effect import semi_implicit_solver
from medpy.io import load
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## All parameters

In [27]:
epsilon = 0.8
anatomy_folder = '/media/Drives/Data/leon_tumor/Atlas/anatomy/'
Dg = 1.52e-07
Dw = 1.52e-06
dt = 6.60154
rho = 0.1553
dx = 0.002
lamda = 0.0 #[3.0e+03, 3.0e+03, 1.136e+01] #, 2.48e+04
mu = 0.0 #[750., 750., 45.45] #, 2758.62
gamma = 0.0 #0.8
MaxIter = 50

In [28]:
wm, _ = load(anatomy_folder+'WM.nii.gz')
gm, _ = load(anatomy_folder+'GM.nii.gz')
csf, _ = load(anatomy_folder+'CSF.nii.gz')

## Variable intialization

In [35]:
m = Variable(torch.tensor(np.concatenate([wm[np.newaxis, :], gm[np.newaxis, :], csf[np.newaxis, :]], 0)).cuda().unsqueeze(0), requires_grad=False)
m = F.interpolate(F.pad(m,(32,31,14,13,32,31)), scale_factor=0.5, mode='trilinear', align_corners=True)
u = Variable(torch.zeros(m.shape).cuda(), requires_grad=False)
v = Variable(torch.zeros(m.shape).cuda(), requires_grad=False)
c = Variable(torch.zeros(m.shape).cuda().sum(1), requires_grad=False).unsqueeze(1)
c[..., 54, 64, 64] = 1.0



## Simulator

In [36]:
tomor_solver = semi_implicit_solver(Dw, rho, dx, dt, lamda, mu, gamma, MaxIter, epsilon)

## Run simulation

In [37]:
#can be used to single out specific data
stopit = 4
c_list = []
m_list = []
u_list = []
phi_brain = torch.sum(m, 1, keepdim=True)
Tmax = 10
for t in range(Tmax):
    c, m, u, v = tomor_solver.solver_step(c, m, u, v, phi_brain)
    c_list.append(c.detach().data.cpu().numpy())
    m_list.append(m.detach().data.cpu().numpy())
    u_list.append(u.detach().data.cpu().numpy())

## Plot tumor

In [38]:
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons
%matplotlib tk
from ipywidgets import *
fig, ax = plt.subplots(dpi=150)
plt.subplots_adjust(left=0.25, bottom=0.25)
l = plt.imshow(c_list[0][0,0,:,:,64])
ax.margins(x=0)

axcolor = 'lightgoldenrodyellow'
axfreq = plt.axes([0.25, 0.1, 0.65, 0.03], facecolor=axcolor)
axamp = plt.axes([0.25, 0.15, 0.65, 0.03], facecolor=axcolor)

stime_ = Slider(axfreq, 'Time', valmin=1, valmax=len(c_list), valinit=5, valstep=1)
sslice_ = Slider(axamp, 'Slice', valmin=1, valmax=c_list[0].shape[4], valinit=5, valstep=1)

def update(val):
    time_ = stime_.val
    slice_ = sslice_.val
    l.set_data(c_list[time_-1][0,0,:,:,slice_-1])
    fig.canvas.draw_idle()


stime_.on_changed(update)
sslice_.on_changed(update)

plt.show()