In [None]:
import os
import sys
import time
import yaml
import math
import torch as t
import numpy as np
import pandas as pd
from torch import nn
import multiprocessing
from torch.optim import Adam
from astropy import units as u
from astropy.table import Table
import matplotlib.pyplot as plt
from astropy import constants as c
from torch.autograd import Variable
from scipy.interpolate import interp1d,interp2d

In [None]:
sigmaT=c.sigma_T.to('cm^2').value #thompson cross section in cm^2
cli=c.c.to('cm/s').value #speed of light in meter/second
kbc=c.k_B.to('erg/K').value
hli=c.h.to('erg*s').value
mele=c.m_e.to('g').value
mpro=c.m_p.to('g').value
stfBlz=c.sigma_sb.to('erg/cm^2/K^4/s').value #stefan boltzmann constant. 
secondsInADay=(1*u.day).to('s').value
v2d07e16=1/2*(hli**2/2/np.pi/kbc/mele)**(3/2)
Zelem=1
croSecConst=2.815e29
oneEv=(1*u.eV).to('erg').value
alphaRefine=float(c.alpha)

In [None]:
yamlName='GammaModel1/runner.yml'
with open(yamlName,'r') as reader:
    howToRun=yaml.safe_load(reader)
ModelSaver=howToRun['files']['save_dir']
learnRateSchedule=np.array(howToRun['train']['rate_list'])
pdeWeight=float(howToRun['train']['pde_weight'])
bo1Weight=float(howToRun['train']['boun1_weight'])
bo2Weight=float(howToRun['train']['boun2_weight'])
negaLim=float(howToRun['train']['negative_limit'])
engMax=(float(howToRun['boundary']['energy_max'])*u.MeV).to('erg').value
gammaSample=int(howToRun['boundary']['gamma_sample'])
innerPhoton=float(howToRun['boundary']['inner_phot_rate'])
photonMax=float(howToRun['boundary']['photon_max'])
veloMin=float(howToRun['material']['velocity']['min'])*cli
veloMax=float(howToRun['material']['velocity']['max'])*cli
expTime=float(howToRun['material']['density']['exp_time'])*secondsInADay
densDataType=howToRun['material']['density']['data_type']
if densDataType=='file':
    densT0=float(howToRun['material']['density']['t_0'])
    densFile=howToRun['material']['density']['file_name']
    densData=np.loadtxt(densFile)
    veloData=densData[:,0]*100000#Convert from km/s to cm/s. 
    densData=densData[:,1]*(densT0/expTime)**3
    if veloMax>=veloData.max():
        print('Gesa Warning: The upper boundary is too high. ')
        veloMax=veloData.max()
    if veloMin<=veloData.min():
        print('Gesa Error: The lower boundary is too low. ')
        assert False
    radiData=veloData*expTime
elif densDataType=='power_law':
    dSlope=float(howToRun['material']['density']['slope'])
    N0=float(howToRun['material']['density']['pivot_dens'])
    V0=float(howToRun['material']['density']['pivot_velo'])*cli
    R0=V0*expTime
elemDataType=howToRun['material']['element']['data_type']
if elemDataType=='file':
    elemFile=howToRun['material']['element']['file_name']
    elemData=np.loadtxt(elemFile)
    veloForElem=elemData[:,0]*100000#Convert from km/s to cm/s. 
    radiForElem=veloForElem*expTime
    elemData=elemData[:,1:]
    maxElem=elemData.shape[1]
    elemSum=elemData.sum(axis=0)
elif elemDataType=='one_zone':
    elemList=t.tensor(np.array(howToRun['material']['element']['number_ratio'])).cuda()
    maxElem=len(elemList)
    elemSum=elemList
gammaDataType=howToRun['material']['gamma_source']['data_type']
if gammaDataType=='file':
    gammaFile=howToRun['material']['gamma_source']['file_name']
    gammaData=np.loadtxt(gammaFile)
    veloForGamma=gammaData[:,0]*100000#Convert from km/s to cm/s. 
    radiForGamma=veloForGamma*expTime
    gammaData=gammaData[:,1]/photonMax
elif gammaDataType=='zero':
    print('Maybe There is Something I Did Not Write. ')
zmin=expTime*veloMin
zmax=expTime*veloMax
elemZList=t.tensor(np.arange(1,maxElem+1)).cuda()

In [None]:
engxList=[]
xIn=engMax/(mele*cli**2)
engxList.append(xIn)
for i in range(gammaSample-1):
    xIn=engxList[-1]
    phiBetween=np.linspace(0,np.pi,num=2333333)
    xOutList=xIn/(1+xIn*(1-np.cos(phiBetween)))
    xOut=np.sum(xOutList*np.sin(phiBetween))*2*np.pi*np.pi/len(phiBetween)/4/np.pi
    engxList.append(xOut)
engxList=t.tensor(np.array(engxList)).cuda()

In [None]:
if os.path.exists(ModelSaver):
    os.popen('rm -r '+ModelSaver)
    time.sleep(0.1)
    print('Gesa Warning: Wiping out an old PINN model. ')
os.mkdir(ModelSaver)
os.mkdir(ModelSaver+'StepModels/')
os.mkdir(ModelSaver+'Metrics/')
os.mkdir(ModelSaver+'StepEnergy/')
os.popen('cp '+yamlName+' '+ModelSaver+'HyperData.yml')

In [None]:
threeDimStep=200
collBatch=2000
bounBatch=2000
batchNum=200
epochNum=10

In [None]:
class IntenW(nn.Module):
    def __init__(self):
        super(IntenW, self).__init__()
        self.fc1=nn.Linear(2,128)
        self.fc2=nn.Linear(128,128)
        self.fc3=nn.Linear(128,128)
        self.fc4=nn.Linear(128,128)
        self.fc5=nn.Linear(128,128)
        self.fc6=nn.Linear(128,128)
        self.fc7=nn.Linear(128,128)
        self.fc8=nn.Linear(128,256)
        self.fc9=nn.Linear(256,256)
        self.fc10=nn.Linear(256,gammaSample)
        
    def forward(self,Xin):
        self.Xin=Xin
        out1=self.fc1(Xin-t.tensor([0.5,np.pi/2]).cuda())
        out2=self.fc2(nn.Tanh()(out1)+1)
        out3=self.fc3(nn.Tanh()(out2)-1)
        out4=self.fc4(nn.Tanh()(out3))
        out5=self.fc5(nn.Tanh()(out4))
        out6=self.fc6(nn.Tanh()(out5))
        out7=self.fc7(nn.Tanh()(out6))
        out8=self.fc8(nn.Tanh()(out7))
        out9=self.fc9(nn.Tanh()(out8))
        out10=self.fc10(nn.Tanh()(out9))
        
        self.out1=out1
        self.out2=out2
        self.out3=out3
        self.out4=out4
        self.out5=out5
        self.out6=out6
        self.out7=out7
        self.out8=out8
        self.out9=out9
        self.out10=out10
        
        return out10
    def partialInput(self):
        pout=t.cosh(t.unsqueeze(self.out1,2))**-2*(self.fc1.weight)
        pout=t.cosh(t.unsqueeze(self.out2,2))**-2*(self.fc2.weight@pout)
        pout=t.cosh(t.unsqueeze(self.out3,2))**-2*(self.fc3.weight@pout)
        pout=t.cosh(t.unsqueeze(self.out4,2))**-2*(self.fc4.weight@pout)
        pout=t.cosh(t.unsqueeze(self.out5,2))**-2*(self.fc5.weight@pout)
        pout=t.cosh(t.unsqueeze(self.out6,2))**-2*(self.fc6.weight@pout)
        pout=t.cosh(t.unsqueeze(self.out7,2))**-2*(self.fc7.weight@pout)
        pout=t.cosh(t.unsqueeze(self.out8,2))**-2*(self.fc8.weight@pout)
        pout=t.cosh(t.unsqueeze(self.out9,2))**-2*(self.fc9.weight@pout)
        
        pout=self.fc10.weight@pout
        return pout
IntenWNet=IntenW().double().cuda()

In [None]:
def integGridAssigner(randomGridNumber):
    phiGrid=t.tensor(np.random.random(randomGridNumber)*np.pi).cuda()
    return phiGrid
def gammaVeloMaker(zList):
    veloList=(veloMax-veloMin)*zList+veloMin
    return veloList
def KCroSecAdj(x):
    return 3/4*((1+x)/x**3*(2*x*(1+x)/(1+2*x)-t.log(1+2*x))+1/2/x*t.log(1+2*x)-(1+3*x)/(1+2*x)**2)
if densDataType=='file':
    densCalc=interp1d(radiData,densData)
elif densDataType=='power_law':
    def densCalc(radiList):
        densList=N0*(radiList/R0)**dSlope
        return densList
if elemDataType=='file':
    elemInterp=interp2d(radiForElem,np.arange(1,maxElem+1),elemData.T)
    def elemDensCalc(zList):
        radiList=zList.clone().detach().cpu().numpy()*(zmax-zmin)+zmin
        densList=densCalc(radiList)
        elemList=elemInterp(radiList,np.arange(1,maxElem+1))*densList
        return t.tensor(elemList).cuda().T
elif elemDataType=='one_zone':
    def elemDensCalc(zList):
        elemRatioList=t.ones([len(zList),len(elemList)]).cuda()*elemList
        radiList=(zList*(zmax-zmin)+zmin).clone().detach().cpu().numpy()
        elemDensList=t.tensor(densCalc(radiList).reshape(-1,1)).cuda()*elemRatioList
        return elemDensList
if gammaDataType=='file':
    gammaCalc=interp1d(radiForGamma,gammaData)
elif gammaDataType=='zero':
    def gammaCalc(zList):
        return 0

In [None]:
def sourceFuncMaker(IntenWNet,zList,zRadi):
    NiList=elemDensCalc(zList)
    kAbsComp=sigmaT*KCroSecAdj(engxList)*t.sum(NiList*elemZList,axis=1).reshape(-1,1)
    kAbsPhot=sigmaT*alphaRefine**4*8*2**0.5*engxList**(-7/2)*t.sum(NiList*elemZList**5,axis=1).reshape(-1,1)
    kAbsAll=kAbsComp+kAbsPhot
    phiGrid=integGridAssigner(threeDimStep)
    phiIn,zbig=t.meshgrid(phiGrid,zftp[:,0])
    zftpBig=t.vstack([zbig.flatten(),phiIn.flatten()]).T
    IntenBig=IntenWNet(zftpBig).reshape(threeDimStep,-1,gammaSample)
    compPhoton=t.sum(IntenBig*t.unsqueeze(t.sin(phiIn),2),axis=0)*kAbsComp*2*np.pi**2/threeDimStep
    jEmi=t.zeros_like(compPhoton).cuda()
    jEmi[:,1:]=nn.ReLU()(compPhoton[:,:-1]/np.pi/4).clone().detach()
    jEmi[:,[0]]=jEmi[:,[0]]+t.tensor(gammaCalc(zRadi.clone().detach().cpu().numpy())).cuda()
    return kAbsAll,jEmi

In [None]:
zftpOrig=np.random.random(collBatch*batchNum*2).reshape(-1,2)
zftpOrig=zftpOrig*np.array([1,np.pi])+np.array([0,0])
zftpOrig=t.tensor(zftpOrig).cuda()
zftpOrig=zftpOrig.split(collBatch)
zftpOrig=[i for i in zftpOrig]

zftpTradOrig=np.random.random(collBatch*batchNum*2).reshape(-1,2)
zftpTradOrig=zftpTradOrig*np.array([1,np.pi])+np.array([0,0])
zftpTradOrig=t.tensor(zftpTradOrig).cuda()
zftpTradOrig=zftpTradOrig.split(collBatch)
zftpTradOrig=[i for i in zftpTradOrig]

zftpBoun1Orig=np.random.random(bounBatch*batchNum*2).reshape(-1,2)
zftpBoun1Orig=zftpBoun1Orig*np.array([0,np.pi*0.5])+np.array([0,0])
zftpBoun1Orig=t.tensor(zftpBoun1Orig).cuda()
zftpBoun1Orig=zftpBoun1Orig.split(bounBatch)
zftpBoun1Orig=[i for i in zftpBoun1Orig]

zftpBoun2Orig=np.random.random(bounBatch*batchNum*2).reshape(-1,2)
zftpBoun2Orig=zftpBoun2Orig*np.array([0,np.pi*0.5])+np.array([1,np.pi*0.5])
zftpBoun2Orig=t.tensor(zftpBoun2Orig).cuda()
zftpBoun2Orig=zftpBoun2Orig.split(bounBatch)
zftpBoun2Orig=[i for i in zftpBoun2Orig]

In [None]:
lossSumAll=[]
lossPdeAll=[]
lossBo1All=[]
lossBo2All=[]
lossTraAll=[]

In [None]:
for bigStep in range(len(learnRateSchedule)):
    optim=Adam(IntenWNet.parameters(),lr=learnRateSchedule[bigStep])
    for epoch in range(epochNum):
        for batch in range(batchNum):
            zftp=Variable(zftpTradOrig[batch].clone().detach(),requires_grad=True)
            IntenWNet.zero_grad()
            IntenOut=IntenWNet(zftp)
            XinGrad=IntenW.partialInput(IntenWNet)
            XinGrad[:,:,0]=XinGrad[:,:,0]/(zmax-zmin)
            zRadi=zftp[:,[0]]*(zmax-zmin)+zmin
            kAbsAll,jEmi=sourceFuncMaker(IntenWNet,zftp[:,0],zRadi)
            lossEqu=t.cos(zftp[:,[1]])*XinGrad[:,:,0]-t.sin(zftp[:,[1]])*XinGrad[:,:,1]/zRadi+kAbsAll*IntenOut-jEmi
            lossEqu=(lossEqu/kAbsAll.mean())**2
            lossEqu=t.mean(lossEqu)
            lossTraAll.append(lossEqu.item())
            
            zftp=Variable(zftpOrig[batch].clone().detach(),requires_grad=True)
            IntenWNet.zero_grad()
            IntenOut=IntenWNet(zftp)
            XinGrad=IntenW.partialInput(IntenWNet)
            XinGrad[:,:,0]=XinGrad[:,:,0]/(zmax-zmin)
            zRadi=zftp[:,[0]]*(zmax-zmin)+zmin
            kAbsAll,jEmi=sourceFuncMaker(IntenWNet,zftp[:,0],zRadi)
            lossEqu=t.cos(zftp[:,[1]])*XinGrad[:,:,0]-t.sin(zftp[:,[1]])*XinGrad[:,:,1]/zRadi+kAbsAll*IntenOut-jEmi
            lossEqu=(lossEqu/kAbsAll.mean())**2
            lossEqu=t.mean(lossEqu)

            lossEqu=t.sum(lossEqu)*pdeWeight#(9)
            lossPdeAll.append(lossEqu.item())
            zftpBoun1=Variable(zftpBoun1Orig[batch].clone().detach(),requires_grad=True)        
            IntenOutBoun1=IntenWNet(zftpBoun1)
            realBoun=t.zeros(gammaSample)
            realBoun[0]=innerPhoton/photonMax
            realBoun=realBoun.cuda()
            lossBoun1=(IntenOutBoun1-realBoun)**2
            lossBoun1=t.mean(lossBoun1)*bo1Weight
            lossBo1All.append(lossBoun1.item())
            zftpBoun2=Variable(zftpBoun2Orig[batch].clone().detach(),requires_grad=True)
            IntenOutBoun2=IntenWNet(zftpBoun2)
            lossBoun2=(IntenOutBoun2-0)**2
            lossBoun2=t.mean(lossBoun2)*bo2Weight
            lossBo2All.append(lossBoun2.item())
            lossSym=t.mean((t.abs(IntenOut+negaLim)-IntenOut-negaLim))*10
            loss=lossEqu+lossBoun1+lossBoun2+lossSym
            lossSumAll.append(loss.item())
            assert np.isnan(loss.item())==False

            loss.backward()
            optim.step()
            if batch%20==0:
                print(loss.item(),lossEqu.item(),lossBoun1.item(),lossBoun2.item(),lossSym.item())

        print('epoch '+str(epoch))
        
        t.save(IntenWNet.state_dict(),ModelSaver+'CurrentInten.to')
        np.save(ModelSaver+'Metrics/lossSumAll.npy',np.array(lossSumAll))
        np.save(ModelSaver+'Metrics/lossPdeAll.npy',np.array(lossPdeAll))
        np.save(ModelSaver+'Metrics/lossBo1All.npy',np.array(lossBo1All))
        np.save(ModelSaver+'Metrics/lossBo2All.npy',np.array(lossBo2All))
        np.save(ModelSaver+'Metrics/lossTraAll.npy',np.array(lossTraAll))
        
        zftp=np.zeros([collBatch,2])
        zftp[:,0]=np.linspace(0,1,num=collBatch)
        zftp=t.tensor(zftp).cuda()
        zRadi=zftp[:,[0]]*(zmax-zmin)+zmin
        NiList=elemDensCalc(zftp[:,0])
        elemZList=t.tensor(np.arange(1,maxElem+1)).cuda()
        kAbsComp=sigmaT*KCroSecAdj(engxList)*t.sum(NiList*elemZList,axis=1).reshape(-1,1)
        kAbsPhot=sigmaT*alphaRefine**4*8*2**0.5*engxList**(-7/2)*t.sum(NiList*elemZList**5,axis=1).reshape(-1,1)
        kAbsAll=kAbsComp+kAbsPhot
        phiGrid=integGridAssigner(threeDimStep)
        phiIn,zbig=t.meshgrid(phiGrid,zftp[:,0])
        zftpBig=t.vstack([zbig.flatten(),phiIn.flatten()]).T
        IntenBig=IntenWNet(zftpBig).reshape(threeDimStep,-1,gammaSample)*photonMax*engxList*(mele*cli**2)
        IntenBigSum=t.sum(IntenBig*t.unsqueeze(t.sin(phiIn),2),axis=0)
        compPhoton=IntenBigSum*kAbsComp*2*np.pi**2/threeDimStep
        jEmiComp=t.zeros_like(compPhoton).cuda()
        jEmiComp[:,1:]=nn.ReLU()(compPhoton[:,:-1]/np.pi/4).clone().detach()
        absorbEnergy=(IntenBigSum*kAbsAll-jEmiComp).sum(axis=1).clone().detach().cpu().numpy()
        absorbEnergy=np.array([np.linspace(veloMin,veloMax,num=collBatch)*1e-5,absorbEnergy]).T
        np.savetxt(ModelSaver+'CurrentEnergy.txt',absorbEnergy)
        
    t.save(IntenWNet.state_dict(),ModelSaver+'StepModels/Inten_'+str(bigStep)+'.to')
    np.savetxt(ModelSaver+'StepEnergy/Energy_'+str(bigStep)+'.to',absorbEnergy)