In [None]:
import numpy as np
import cmath
import matplotlib.pyplot as plt
from plotly.offline import download_plotlyjs,init_notebook_mode,plot,iplot
init_notebook_mode(connected=True)
import plotly.graph_objs as go

In [None]:
def run(angle, n1, n2, frames, interferance=True, graph=True):
    """
    Creates a list of frames that progress in time
    """
    theta_i_deg = angle

    # calculate critical angle if appropriate
    theta_crit = 0.5*np.pi + 0.0001
    if n1>n2:
        theta_crit - np.arcsin(n2/n1)

    theta_i = np.deg2rad(theta_i_deg)

    #Time normalised in wave periods
    omega = 2.0*np.pi

    #Distances normalised by vacuum wavelength of wave
    k1 = 2.0*np.pi*n1
    k2 = 2.0*np.pi*n2

    #Use Snells Law
    cos_theta_t = cmath.sqrt(1.0-(n1*np.sin(theta_i)/n2)**2)

    #Calculate normalised components of k vectors
    #x direction is parallel to interfeace
    #y direction is normal to interface
    k1x = k1*np.sin(theta_i)
    k1y = k1*np.cos(theta_i)
    k2x = k1x
    k2y = k2*cos_theta_t

    #Calculate reflection anf transmission coefficients
    if theta_i < theta_crit:
        denom = n1*cos_theta_t + n2*np.cos(theta_i)
        rp = (n1*cos_theta_t - n2*np.cos(theta_i))/denom
        tp = 2.0*n1*np.cos(theta_i)/denom
        energyCheck = rp**2 + tp**2*n2*cos_theta_t/(n1*np.cos(theta_i))
    else:
        rp = 1.
        tp = 2.0*n1/n2
        energyCheck=rp**2

    #Maximum Amplitude
    amp = [1.0, abs(rp) + 1, abs(tp)]
    maxAmp = max(amp)

    t_0 = 0
    tFinal = 3
    numFrames = frames
    interval = (tFinal - t_0)/numFrames

    xMax = 2.0
    yMin = -2.0
    yMax = 2.0
    x = np.arange(0., xMax, 0.02)
    y = np.arange(yMin, yMax, 0.02)
    X,Y = np.meshgrid(x, y)

    logindpos = Y>=0
    logindneg = Y<0
    
    data =[]
    raw=[]
    for k in range(0,numFrames):
        t = t_0 + k*interval

        E1inc = (np.exp(1j*(k1x*X + k1y*Y - omega*t))).real
        E1tot = (np.exp(1j*(k1x*X + k1y*Y - omega*t))).real + (rp*np.exp(1j*(k1x*X - k1y*Y - omega*t))).real
        Etrans = (tp*np.exp(1j*(k2x*X + k2y*Y - omega*t))).real
        
        if interferance==True:
            E1 = np.multiply(logindneg, E1tot)+ np.multiply(logindpos, Etrans)
        else:
            E1 = np.multiply(logindneg, E1inc)+ np.multiply(logindpos, Etrans)

        #trace = go.Heatmap(z=E1, showscale=False)
        trace = go.Contour(z=E1, colorscale='Portland', showscale=False, contours = dict(coloring = 'heatmap', showlines=False))
        trace['z'] = trace['z'].tolist()
        data.append(trace)
        E1 = E1.tolist()
        raw.append(E1)
    
    if graph:
         return data
    else:
        return raw
   

In [None]:
# data = run(angle=45, n1=1.5, n2=1., interferance=False, graph=False)
# layout = dict(width=500, height=500)

In [None]:
# fig=dict(data=[data[0]], layout=layout)
# iplot(fig)

In [None]:
# import json

# with open('heatmaps.JSON', 'w') as outfile: 
#     json.dump(data, outfile)

In [None]:
raw = run(angle=40, n1=1.5, n2=1., frames=10, interferance=False, graph=False)

In [None]:
import json

with open('rawHeatData.JSON', 'w') as outfile: 
    json.dump(raw, outfile)