In [None]:
import numpy as np
import matplotlib.pyplot as plt
from plotly.offline import download_plotlyjs,init_notebook_mode,plot,iplot
init_notebook_mode(connected=True)
import plotly.graph_objs as go

In [None]:
class Boundary:
    def __init__(self, angle, n1, n2, interference=True):
        """
        Creates a list of frames that progress in time

        Args:
            angle (float) - [degs] {0 to 90}
            # out (bool) - True if outgoing, False if incoming (to the origin)

            # E_0 (float) - Magnitude of the E field
            # polarisation (str) - {'s' or 'p'} Whether E or B is parallel to the boundary
            # w (float) - [rad s^-1] Angular frequency (default: green light)
            n1 (float) - The incident material's refractive index
            n2 (float) - The second material's refractive index

            interference
        """
        self.theta = np.deg2rad(angle)
#         self.theta_crit = self.find_critical_angle()
        self.n1 = n1
        self.n2 = n2
        self.polarisation = "s"
        
        self.incident = self.create_wave(theta=self.theta, amplitude=1, material=1)
        self.transmitted = self.transmit()
        if interference and self.n1 != self.n2:
            self.reflected = self.reflect()
            self.incident[2] += self.reflected[2]
        # TODO deal with TIR
        

    def create_wave(self, theta, amplitude, material):
        """
        Args:
            theta (float) - [rads] {0 to π/2}
            material (int) - {1 or 2} Whether the wave is in material 1 or 2
        """
        graph_dim = 200
        
        
        if material == 1:
            x = np.linspace(-1, 0, graph_dim/2)
            n = self.n1
        elif material == 2:
            x = np.linspace(0, 1, graph_dim/2)
            n = self.n2
        else:
            raise ValueError("arg 'material' must be 1 or 2")
        y = np.linspace(-1, 1, graph_dim)
            
        xx, yy = np.meshgrid(x, y)
        
        k_x = np.cos(theta) * n
        k_y = -np.sin(theta) * n
        
        return [x, y, np.array([amplitude * np.cos(8*np.pi * (k_x*xx + k_y*yy - phase)) for phase in np.linspace(0, 0.25, 20)])]
        
    
    def transmit(self):
        theta_i = self.theta
        theta_t = self.snell(self.n1, self.n2, theta_i)
        if np.isnan(theta_t):
            print('Total internal reflection')
            return None
        
        plot_theta_t = np.pi + theta_t
        
        if self.polarisation == "s":
            t = (2. * self.n1 * np.cos(theta_i)) / (self.n1 * np.cos(theta_i) + self.n2 * np.cos(theta_t))
        else:
            t = (2. * self.n1 * np.cos(theta_i)) / (self.n1 * np.cos(theta_t) + self.n2 * np.cos(theta_i))
        print(t)
                
        return self.create_wave(theta=plot_theta_t, amplitude=t, material=2)
    
    
    def reflect(self):
        if self.n1 == self.n2:
            print('Refractive indices equal - no reflection')
            return None
        
        theta_i = self.theta
        theta_r = theta_i
        theta_t = self.snell(self.n1, self.n2, theta_i)
        if np.isnan(theta_t):
            theta_t = 0.5 * np.pi
        
        plot_theta_r = -theta_r
        
        if self.polarisation == "s":
            r = (self.n1 * np.cos(theta_i) - self.n2 * np.cos(theta_t)) / (self.n1 * np.cos(theta_i) + self.n2 * np.cos(theta_t))
        else:
            r = (self.n1 * np.cos(theta_t) - self.n2 * np.cos(theta_i)) / (self.n1 * np.cos(theta_t) + self.n2 * np.cos(theta_i))
        
        return self.create_wave(theta=plot_theta_r, amplitude=r, material=1)
        
        
    def find_critical_angle(self):
        if self.n1 > self.n2:
            return 0.5*np.pi - np.arcsin(n2/n1)
        else:
            return 0.5*np.pi + 0.0001
    
    
    @staticmethod
    def snell(n1, n2, theta_i):
        """
        Finds angle of transmission using Snell's law
        
        Args:
            n1 (float) - incident medium's refractive index
            n2 (float) - transmissive medium's refractive index
            theta_i (float) - angle of incidence
        """
        return np.arcsin((n1 / n2) * np.sin(theta_i))
    
    
    def plot(self):
        plt.figure(figsize=(10, 10))

        plt.pcolormesh(self.incident[0], self.incident[1], self.incident[2][0], vmax=1)
        plt.pcolormesh(self.transmitted[0], self.transmitted[1], self.transmitted[2][0], vmax=1)
        
        plt.show()
    

In [None]:
boundary = Boundary(angle=60, n1=1., n2=1.2, interference=True)

boundary.plot()

In [None]:
def run(angle, n1, n2, frames, interferance=True, graph=True):
    theta_i_deg = angle

    # calculate critical angle if appropriate
    theta_crit = 0.5*np.pi + 0.0001
    if n1>n2:
        theta_crit - np.arcsin(n2/n1)

    theta_i = np.deg2rad(theta_i_deg)

    #Time normalised in wave periods
    omega = 2.0*np.pi

    #Distances normalised by vacuum wavelength of wave
    k1 = 2.0*np.pi*n1
    k2 = 2.0*np.pi*n2

    #Use Snells Law
    cos_theta_t = np.sqrt(1.0-(n1*np.sin(theta_i)/n2)**2)

    #Calculate normalised components of k vectors
    #x direction is parallel to interfeace
    #y direction is normal to interface
    k1x = k1*np.sin(theta_i)
    k1y = k1*np.cos(theta_i)
    k2x = k1x
    k2y = k2*cos_theta_t

    #Calculate reflection anf transmission coefficients
    if theta_i < theta_crit:
        denom = n1*cos_theta_t + n2*np.cos(theta_i)
        rp = (n1*cos_theta_t - n2*np.cos(theta_i))/denom
        tp = 2.0*n1*np.cos(theta_i)/denom
        energyCheck = rp**2 + tp**2*n2*cos_theta_t/(n1*np.cos(theta_i))
    else:
        rp = 1.
        tp = 2.0*n1/n2
        energyCheck=rp**2

    #Maximum Amplitude
    amp = [1.0, abs(rp) + 1, abs(tp)]
    maxAmp = max(amp)

    t_0 = 0
    tFinal = 3
    numFrames = frames
    interval = (tFinal - t_0)/numFrames

    xMax = 2.0
    yMin = -2.0
    yMax = 2.0
    x = np.arange(0., xMax, 0.02)
    y = np.arange(yMin, yMax, 0.02)
    X,Y = np.meshgrid(x, y)

    logindpos = Y>=0
    logindneg = Y<0
    
    data =[]
    raw=[]
    for k in range(0,numFrames):
        t = t_0 + k*interval

        E1inc = np.cos(k1x*X + k1y*Y - omega*t)
        E1tot = np.cos(k1x*X + k1y*Y - omega*t) + rp*np.cos(k1x*X - k1y*Y - omega*t)
        Etrans = tp*np.cos(k2x*X + k2y*Y - omega*t)
        
        if interferance==True:
            E1 = np.multiply(logindneg, E1tot)+ np.multiply(logindpos, Etrans)
        else:
            E1 = np.multiply(logindneg, E1inc)+ np.multiply(logindpos, Etrans)

        #trace = go.Heatmap(z=E1, showscale=False)
        trace = go.Contour(z=E1, colorscale='Viridis', showscale=False, contours = dict(coloring = 'heatmap', showlines=False))
        trace['z'] = trace['z'].tolist()
        data.append(trace)
        E1 = E1.tolist()
        raw.append(E1)
    
    if graph:
         return data
    else:
        return raw
   

In [None]:
data = run(angle=45, n1=1., n2=1.5, frames=10, interferance=True, graph=True)
layout = dict(width=500, height=500)

In [None]:
fig=dict(data=[data[0]], layout=layout)
iplot(fig)

In [None]:
# import json

# with open('heatmaps.JSON', 'w') as outfile: 
#     json.dump(data, outfile)

In [None]:
# raw = run(angle=40, n1=1.5, n2=1., frames=10, interferance=False, graph=False)

In [None]:
# import json

# with open('rawHeatData.JSON', 'w') as outfile: 
#     json.dump(raw, outfile)