In [None]:
# NB to create a FITS file containing the disc density on a grid

In [None]:
%matplotlib inline
# import pymcfost as mcfost
import matplotlib.pyplot as plt
import numpy as np
from astropy.io import fits
from mpl_toolkits.mplot3d import Axes3D

In [None]:
## Pull in the grid strucutre (run mcfost first with the parameter file and "-disk_struct" to get this)
grid_file = 'data_disk/grid.fits'
hdu_list = fits.open(grid_file);
grid_data = hdu_list[0].data
hdu_list[0].header

n_az = len(grid_data[0,:,0,0])
n_z = len(grid_data[0,0,:,0])
n_rad = len(grid_data[0,0,0,:])


In [None]:
## Define the density
def sigma(r,r0):
    sigma0 = 3.46E-4 ## allegedly accounted for by mcfost
    return sigma0*(r/r0)**(-1)

def scale_height(r):
    return 0.2*r*((r/100)**0.25)

def density_function(z,r,azimuth):    
    if ((r>1.0) and (r < 100.)):
        ztest = scale_height(r)
            
        if (abs(z) < abs(ztest)):
            rho0 = 1.E-15 #sigma(r,1)/(2*np.pi*scale_height(r))
            return rho0*np.exp(-z**2/(2*(scale_height(r)**2)))
        else:
            return 1.E-20
    else:
        return 1.E-20

x_store = np.zeros((n_az*n_z*n_rad))
y_store = np.zeros((n_az*n_z*n_rad))
z_store = np.zeros((n_az*n_z*n_rad))
density_store = np.zeros((n_az*n_z*n_rad))
mm = 0

density = np.zeros((1,n_az,n_z,n_rad))
for ii in range(n_rad):
    for jj in range(n_z):
        for kk in range(n_az):
            rad = grid_data[0,kk,jj,ii]
            z = grid_data[1,kk,jj,ii]
            azimuth = grid_data[2,kk,jj,ii]
            density[0,kk,jj,ii] = density_function(z,rad,azimuth)
            if (density[0,kk,jj,ii] > 1.E-19):
                x_store[mm] = rad*np.cos(azimuth)
                y_store[mm] = rad*np.sin(azimuth)
                z_store[mm] = z
                density_store[mm] = density[0,kk,jj,ii]
                mm += 1

## Now save the fits
hdu = fits.PrimaryHDU(density)
hdu.writeto('density_file.fits')
hdu.header

In [None]:
## Check what's been made by plotting
## NOTE FOR SAHL: there's a lot going on in this cell and I'm not sure
## why there's all these different ways to plot. Suffice to say I was
## messing around with this and all I really wanted to do was to plot
## the thing to check it's what I wanted. I'm sure you can do this far
## more efficiently and accurately!

plot_type = 'ysec'
fig = plt.figure()

if plot_type=='3d':
    ## Plot the density model
    ax = fig.add_subplot(111, projection='3d')
    # Limit everything to only plot *where* the model is
    img = ax.scatter(x_store[1:mm:2],y_store[1:mm:2],z_store[1:mm:2],c=density_store[1:mm:2])
    fig.colorbar(img)
elif plot_type=='xsec':
    ax = fig.add_subplot(111)
    x_cross = x_store[0:mm][abs(y_store[0:mm]) < 1.0]
    z_cross = z_store[0:mm][abs(y_store[0:mm]) < 1.0]
    density_cross = density_store[0:mm][abs(y_store[0:mm]) < 1.0]
    x = np.linspace(-np.max(x_cross),np.max(x_cross),100)
    height = 0.2*abs(x)*((abs(x)/100)**0.25)
    img = ax.scatter(x_cross,z_cross,c=density_cross)
    ax.plot(x,height,color='red')
    ax.plot(x,-height,color='red')
elif plot_type=='ysec':
    ax = fig.add_subplot(111)
    y_cross = y_store[0:mm][abs(x_store[0:mm]) < 1.0]
    z_cross = z_store[0:mm][abs(x_store[0:mm]) < 1.0]
    density_cross = density_store[0:mm][abs(x_store[0:mm]) < 1.0]
    x = np.linspace(-np.max(y_cross),np.max(y_cross),100)
    height = 0.2*abs(x)*((abs(x)/100)**0.25)
    img = ax.scatter(y_cross,z_cross,c=density_cross)
    #ax.plot(x,height,color='red')
    #ax.plot(x,-height,color='red')
    ax.set_xlim([-200,200])
    ax.set_ylim([-50,50])
    ax.set_xlabel('y')
    ax.set_ylabel('z')
elif plot_type=='surface':
    ax = fig.add_subplot(111,projection='3d')
    r = np.linspace(10,100,100)
    azimuth = np.linspace(0,2*np.pi,100)
    wall = np.linspace(-20,20,100)
    r, azimuth = np.meshgrid(r, azimuth)
    x = r*np.cos(azimuth)
    y = r*np.sin(azimuth)
    z = 0.2*r*((r/100)**0.25)
    surf = ax.plot_surface(x,y,z,color='olive')
    surf = ax.plot_surface(x,y,-z,color='olive')
    
    azimuth = np.linspace(0,2*np.pi,100)
    wall = np.linspace(-20,20,100)
    azimuth, wall = np.meshgrid(azimuth, wall)
    x_wall = 100.*np.cos(azimuth)
    y_wall = 100.*np.sin(azimuth)
    surf = ax.plot_surface(x_wall,y_wall,wall,color='olive',alpha='0.5')
    
    ax.set_xlim(-110,110)
    ax.set_ylim(-110,110)
    ax.set_zlim(-110,100)
plt.show()