In [1]:
import os
import sys
import glob
import time
import yaml
import math
import torch as t
import numpy as np
import pandas as pd
from torch import nn
import multiprocessing
from torch.optim import Adam
from astropy import units as u
from astropy import constants as c
from torch.autograd import Variable
from scipy.interpolate import interp1d

In [2]:
def MP(taskid_lst=None, func=None, Nprocs=24):
    def worker(taskid_lst, out_q):
        outdict={}
        for tid in taskid_lst:
            outdict[tid]=func(tid)
        out_q.put(outdict)
    out_q=multiprocessing.Queue()
    chunksize=int(math.ceil(len(taskid_lst)/float(Nprocs)))
    procs=[]
    for i in range(Nprocs):
        p=multiprocessing.Process(target=worker,\
        args=(taskid_lst[chunksize*i:chunksize*(i+1)],out_q))
        procs.append(p)
        p.start()
    resultdict = {}
    for i in range(Nprocs):
        resultdict.update(out_q.get())
    for p in procs:
        p.join()
    return resultdict

In [3]:
sigmaT=c.sigma_T.to('cm^2').value #thompson cross section in cm^2
cli=c.c.to('cm/s').value #speed of light in meter/second
kbc=c.k_B.to('erg/K').value
hli=c.h.to('erg*s').value
mele=c.m_e.to('g').value
stfBlz=c.sigma_sb.to('erg/cm^2/K^4/s').value #stefan boltzmann constant. 
secondsInADay=(1*u.day).to('s').value
v2d07e16=1/2*(hli**2/2/np.pi/kbc/mele)**(3/2)
oneEv=(1*u.eV).to('erg').value



In [None]:
yamlName='N14_Model/runner.yml'
with open(yamlName,'r') as reader:
    howToRun=yaml.safe_load(reader)
ModelSaver=howToRun['files']['save_dir']
linesFile=howToRun['files']['line_file']
levelFile=howToRun['files']['level_file']
lineWide=float(howToRun['material']['lines']['line_wide'])
veloMin=float(howToRun['material']['velocity']['min'])*cli
veloMax=float(howToRun['material']['velocity']['max'])*cli
expTime=float(howToRun['material']['density']['exp_time'])*secondsInADay
densDataType=howToRun['material']['density']['data_type']
if densDataType=='file':
    densT0=float(howToRun['material']['density']['t_0'])
    densFile=howToRun['material']['density']['file_name']
    densData=np.loadtxt(densFile)
    veloData=densData[:,0]*100000#Convert from km/s to cm/s. 
    densData=densData[:,1]*(densT0/expTime)**3
    if veloMax>=veloData.max():
        print('Gesa Warning: The upper boundary is too high. ')
        veloMax=veloData.max()
    if veloMin<=veloData.min():
        print('Gesa Error: The lower boundary is too low. ')
        assert False
    radiData=veloData*expTime
elif densDataType=='power_law':
    dSlope=float(howToRun['material']['density']['slope'])
    N0=float(howToRun['material']['density']['pivot_dens'])
    V0=float(howToRun['material']['density']['pivot_velo'])*cli
    R0=V0*expTime
elemDataType=howToRun['material']['element']['data_type']
if elemDataType=='file':
    elemFile=howToRun['material']['element']['file_name']
    elemData=np.loadtxt(elemFile)
    veloForElem=elemData[:,0]*100000
    radiForElem=veloForElem*expTime
    elemData=elemData[:,1:]
    maxElem=elemData.shape[1]
    elemSum=elemData.sum(axis=0)
elif elemDataType=='one_zone':
    elemList=t.tensor(np.array(howToRun['material']['element']['number_ratio'])).cuda()
    maxElem=len(elemList)
    elemSum=elemList
extraEngType=howToRun['material']['extra_energy']['data_type']
if extraEngType=='file':
    extraBounFile=howToRun['material']['extra_energy']['boun_file']
    extraSourFile=howToRun['material']['extra_energy']['source_file']
    extraBounVal=howToRun['material']['extra_energy']['boun_value']
    extraSourVal=howToRun['material']['extra_energy']['source_value']
    exBounData=np.loadtxt(extraBounFile)
    exSourData=np.loadtxt(extraSourFile)
    exEnVelo=exBounData[:,0]*100000
    exEnRadi=exEnVelo*expTime
    exEnData=exBounData[:,1]*extraBounVal+exSourData[:,1]*extraSourVal
    exEnData[exEnData<0]=0
elif extraEngType=='zero':
    print('Gesa Info: There is no extra energy other than the boundary condition. ')
lineLimit=float(howToRun['material']['lines']['line_limit'])
allowShrink=howToRun['material']['lines']['allow_shrink']
BBBoun=float(howToRun['boundary']['T_inner'])
freqMinLog=float(howToRun['boundary']['freq_min'])
freqMaxLog=float(howToRun['boundary']['freq_max'])
freqSampler=int(howToRun['boundary']['freq_sample'])
learnRateSchedule=np.array(howToRun['train']['rate_list'])
pdeWeight=float(howToRun['train']['pde_weight'])
bo1Weight=float(howToRun['train']['boun1_weight'])
bo2Weight=float(howToRun['train']['boun2_weight'])
tinyEpochNum=int(howToRun['train']['tiny_epoch'])
tempMultiRate=float(howToRun['train']['temp_rate'])
l1Norm=float(howToRun['train']['l1_norm'])
goodIterate=howToRun['train']['good_iterate']
zmin=expTime*veloMin
zmax=expTime*veloMax

In [None]:
if os.path.exists(ModelSaver):
    os.popen('rm -r '+ModelSaver)
    time.sleep(0.1)
    print('Gesa Warning: Wiping out an old PINN model. ')
os.mkdir(ModelSaver)
os.mkdir(ModelSaver+'StepModels/')
os.mkdir(ModelSaver+'Metrics/')
os.popen('cp '+yamlName+' '+ModelSaver+'HyperData.yml')

In [None]:
def BBSpec(freqHz,BBBoun=BBBoun):
    return 2*hli*freqHz**3/cli**2/(np.e**(hli*freqHz/kbc/BBBoun)-1)
freqGrid=np.linspace(freqMinLog,freqMaxLog,num=freqSampler)
freqGrid=10**freqGrid
freqGrid=t.tensor(freqGrid).cuda()
freqWidth=freqGrid*np.log(10)*(freqMaxLog-freqMinLog)/freqSampler
freqGridCpu=freqGrid.cpu()
freqWidthCpu=freqWidth.cpu()
intenMax=BBSpec(freqGrid).max()
tempMax=BBBoun
threeDimStep=200

In [None]:
levelData={}
elemMask=[]
for i in range(maxElem):
    elemNum=i+1
    levelDataElem={}
    ionLimEngs=[]
    ionDataMask=[]
    if len(glob.glob(levelFile+str(elemNum)+'_'+'*'))==0:
        print('Gesa Warning: Element '+str(elemNum)+' has no data. ')
        elemMask.append(False)
        continue
    elif elemSum[i]==0:
        print('Gesa Info: Element '+str(elemNum)+' not in the material, not reading the element data. ')
        elemMask.append(False)
        continue
    else:elemMask.append(True)
    for j in range(elemNum):
        ionNum=j+1
        shrinkFile=levelFile+str(elemNum)+'_'+str(ionNum)+'_shrink.csv'
        normFile=levelFile+str(elemNum)+'_'+str(ionNum)+'.csv'
        if os.path.exists(normFile)==False and os.path.exists(shrinkFile)==False:
            print('Gesa Warning: Element '+str(elemNum)+' Ion '+str(ionNum)+' has no level data. ')
            ionDataMask.append(False)
            continue
        elif os.path.exists(normFile) and os.path.exists(shrinkFile)==False:fileName=normFile
        elif os.path.exists(shrinkFile) and os.path.exists(normFile)==False:fileName=shrinkFile
        elif os.path.exists(shrinkFile) and os.path.exists(normFile):
            if allowShrink:fileName=shrinkFile
            else:fileName=normFile
        ionDataMask.append(True)
        tableIon=pd.read_csv(fileName,index_col=0)
        limEng=tableIon[tableIon['Configuration']=='Limit']['Level (eV)'].iloc[0]*oneEv
        tableIon=tableIon[tableIon['Configuration']!='Limit']
        levelDataIon={}
        levelDataIon['Config']=np.array(tableIon['Configuration'])
        levelDataIon['Term']=np.array(tableIon['Term'].astype('str'))
        if tableIon['J'].dtype==float:
            if tableIon['J'].isna().mean()==1:levelDataIon['J']=np.array(tableIon['J'].astype('str'))
            else:levelDataIon['J']=np.array(tableIon['J'].astype('int').astype('str'))
        else:levelDataIon['J']=np.array(tableIon['J'].astype('str'))
        levelDataIon['g']=np.array(tableIon['g'])
        levelDataIon['Level']=np.array(tableIon['Level (eV)'])*oneEv
        levelDataIon['nQua']=np.array(tableIon['nQua'].astype('int'))
        ionLimEngs.append(limEng)
        levelDataElem[ionNum-1]=levelDataIon
    levelDataElem['Limit']=ionLimEngs
    levelDataElem['DataMask']=ionDataMask
    levelData[elemNum]=levelDataElem
levelData['ElemMask']=elemMask

In [None]:
lineData={}
for i in range(maxElem):
    if levelData['ElemMask'][i]==False:continue
    elemNum=i+1
    lineDataElem={}
    for j in range(elemNum):
        ionNum=j+1
        shrinkFile=linesFile+str(elemNum)+'_'+str(ionNum)+'_shrink.csv'
        normFile=linesFile+str(elemNum)+'_'+str(ionNum)+'.csv'
        if os.path.exists(normFile)==False and os.path.exists(shrinkFile)==False:
            print('Gesa Warning: Element '+str(elemNum)+' Ion '+str(ionNum)+' has no line data. ')
            continue
        elif os.path.exists(normFile) and os.path.exists(shrinkFile)==False:fileName=normFile
        elif os.path.exists(shrinkFile) and os.path.exists(normFile)==False:fileName=shrinkFile
        elif os.path.exists(shrinkFile) and os.path.exists(normFile):
            if allowShrink:fileName=shrinkFile
            else:fileName=normFile
        tableIon=pd.read_csv(fileName,index_col=0)
        mask=(tableIon['ritz_wl_vac(nm)']<1000)&(tableIon['ritz_wl_vac(nm)']>100)
        tableIon=tableIon[mask].reset_index(drop=True)
        if len(tableIon)==0:continue
        freqIon=(c.c/(np.array(tableIon['ritz_wl_vac(nm)'])*u.nm)).to('Hz').value
        mask=(freqIon<10**freqMaxLog)&(freqIon>10**freqMinLog)
        tableIon=tableIon[mask]
        if len(tableIon)==0:
            print('Gesa Warning: Element',elemNum,'Ion',ionNum,' has no line in the wavelength region. ')
            continue
        lineDataIon={}
        lineDataIon['Freq']=(c.c/(np.array(tableIon['ritz_wl_vac(nm)'])*u.nm)).to('Hz').value
        lineDataIon['Aki']=np.array(tableIon['Aki(s^-1)'])
        lineDataIon['Conf_l']=np.array(tableIon['conf_i'].astype('str'))
        lineDataIon['Term_l']=np.array(tableIon['term_i'].astype('str'))
        if tableIon['J_i'].dtype==float and np.sum(np.isnan(tableIon['J_i']))==0:lineDataIon['J_l']=np.array(tableIon['J_i'].astype('int').astype('str'))
        else:lineDataIon['J_l']=np.array(tableIon['J_i'].astype('str'))
        lineDataIon['Conf_u']=np.array(tableIon['conf_k'].astype('str'))
        lineDataIon['Term_u']=np.array(tableIon['term_k'].astype('str'))
        if tableIon['J_k'].dtype==float and np.sum(np.isnan(tableIon['J_k']))==0:lineDataIon['J_u']=np.array(tableIon['J_k'].astype('int').astype('str'))
        else:lineDataIon['J_u']=np.array(tableIon['J_k'].astype('str'))
        lineDataIon['Bul']=lineDataIon['Aki']/(2*hli*lineDataIon['Freq']**3/cli**2)
        lineDataIon['Blu']=lineDataIon['Bul']*np.array(tableIon['g_k'])/np.array(tableIon['g_i'])
        lineDataIon['Mask']=np.ones(len(lineDataIon['Freq']),dtype=bool)
        lineDataElem[ionNum-1]=lineDataIon
    lineData[elemNum]=lineDataElem

In [None]:
collBatch=1500
bounBatch=1500
sourSubBa=1500
tempBatch=200
batchNum=400
epochNum=8

In [None]:
class IntenW(nn.Module):
    def __init__(self):
        super(IntenW, self).__init__()
        self.fc1=nn.Linear(2,256)
        self.fc2=nn.Linear(256,256)
        self.fc3=nn.Linear(256,256)
        self.fc4=nn.Linear(256,256)
        self.fc5=nn.Linear(256,256)
        self.fc6=nn.Linear(256,512)
        self.fc7=nn.Linear(512,512)
        self.fc8=nn.Linear(512,512)
        self.fc9=nn.Linear(512,512)
        self.fc10=nn.Linear(512,512)
        self.fc11=nn.Linear(512,2048)
        self.fc12=nn.Linear(2048,2048)
        self.fc13=nn.Linear(2048,2048)
        self.fc14=nn.Linear(2048,freqSampler)
    def forward(self,Xin):
        self.Xin=Xin
        self.out1=self.fc1(Xin)
        self.out2=self.fc2(nn.Tanh()(self.out1))
        self.out3=self.fc3(nn.Tanh()(self.out2))
        self.out4=self.fc4(nn.Tanh()(self.out3))
        self.out5=self.fc5(nn.Tanh()(self.out4))
        self.out6=self.fc6(nn.Tanh()(self.out5))
        self.out7=self.fc7(nn.Tanh()(self.out6))
        self.out8=self.fc8(nn.Tanh()(self.out7))
        self.out9=self.fc9(nn.Tanh()(self.out8))
        self.out10=self.fc10(nn.Tanh()(self.out9))
        self.out11=self.fc11(nn.Tanh()(self.out10))
        self.out12=self.fc12(nn.Tanh()(self.out11))
        self.out13=self.fc13(nn.Tanh()(self.out12))
        self.out14=self.fc14(nn.Tanh()(self.out13))
        return self.out14
    def partialInput(self):
        pout=t.cosh(t.unsqueeze(self.out1,2))**-2*(self.fc1.weight)
        pout=t.cosh(t.unsqueeze(self.out2,2))**-2*(self.fc2.weight@pout)
        pout=t.cosh(t.unsqueeze(self.out3,2))**-2*(self.fc3.weight@pout)
        pout=t.cosh(t.unsqueeze(self.out4,2))**-2*(self.fc4.weight@pout)
        pout=t.cosh(t.unsqueeze(self.out5,2))**-2*(self.fc5.weight@pout)
        pout=t.cosh(t.unsqueeze(self.out6,2))**-2*(self.fc6.weight@pout)
        pout=t.cosh(t.unsqueeze(self.out7,2))**-2*(self.fc7.weight@pout)
        pout=t.cosh(t.unsqueeze(self.out8,2))**-2*(self.fc8.weight@pout)
        pout=t.cosh(t.unsqueeze(self.out9,2))**-2*(self.fc9.weight@pout)
        pout=t.cosh(t.unsqueeze(self.out10,2))**-2*(self.fc10.weight@pout)
        pout=t.cosh(t.unsqueeze(self.out11,2))**-2*(self.fc11.weight@pout)
        pout=t.cosh(t.unsqueeze(self.out12,2))**-2*(self.fc12.weight@pout)
        pout=t.cosh(t.unsqueeze(self.out13,2))**-2*(self.fc13.weight@pout)
        pout=self.fc14.weight@pout
        return pout
class Tempe(nn.Module):
    def __init__(self):
        super(Tempe,self).__init__()
        self.fc1=nn.Linear(1,64)
        self.fc2=nn.Linear(64,64)
        self.fc3=nn.Linear(64,64)
        self.fc4=nn.Linear(64,64)
        self.fc5=nn.Linear(64,1)
    def forward(self,zin):
        out=nn.SELU()(self.fc1(zin))
        out=nn.SELU()(self.fc2(out))
        out=nn.SELU()(self.fc3(out))
        out=nn.SELU()(self.fc4(out))
        out=-nn.LogSigmoid()(self.fc5(out))
        return out
TempNet=Tempe().double().cuda()
IntenWNetFo=IntenW().double().cuda()
IntenWNetFoRec=IntenW().double().cuda()
IntenWNetFoRec.load_state_dict(IntenWNetFo.state_dict())
IntenWNetBa=IntenW().double().cuda()
IntenWNetBaRec=IntenW().double().cuda()
IntenWNetBaRec.load_state_dict(IntenWNetBa.state_dict())

In [None]:
def integGridAssigner(randomGridNumber):
    phiGrid=t.tensor(np.random.random(randomGridNumber)*np.pi).cuda()
    return phiGrid
def veloMaker(zList):
    veloList=(veloMax-veloMin)*zList+veloMin
    return veloList
def DopplerFreqChanger(VeloBeta,phiAngleIn):
    GammaRela=1/(1-VeloBeta**2)**0.5
    Gamma=GammaRela.reshape([-1,1])
    Beta=VeloBeta.reshape([-1,1])
    phiAngle=phiAngleIn.reshape([-1,1])
    nuBarDivNu=Gamma*(1-t.cos(phiAngle)*Beta)
    return nuBarDivNu
if densDataType=='file':
    densCalc=interp1d(radiData,densData)
elif densDataType=='power_law':
    def densCalc(radiList):
        densList=N0*(radiList/R0)**dSlope
        return densList
if elemDataType=='file':
    elemInterDict={}
    for i in range(1,maxElem+1):
        elemInterDict[i]=interp1d(radiForElem,elemData[:,i-1])
    def elemInterp(radiList,elemNumbers):
        elemReturnMat=np.zeros([len(radiList),maxElem])
        for i in range(len(elemNumbers)):
            elemReturnMat[:,elemNumbers[i]-1]=elemInterDict[elemNumbers[i]](radiList)
        return elemReturnMat.T
    def elemDensCalc(zList):
        radiList=zList.clone().detach().cpu().numpy()*(zmax-zmin)+zmin
        densList=densCalc(radiList)
        elemList=elemInterp(radiList,np.arange(1,maxElem+1))*densList
        return t.tensor(elemList).cuda().T
elif elemDataType=='one_zone':
    def elemDensCalc(zList):
        elemRatioList=t.ones([len(zList),len(elemList)]).cuda()*elemList
        radiList=(zList*(zmax-zmin)+zmin).clone().detach().cpu().numpy()
        elemDensList=t.tensor(densCalc(radiList).reshape(-1,1)).cuda()*elemRatioList
        return elemDensList
if extraEngType=='file':
    exEnCalc=interp1d(exEnRadi,exEnData)
elif extraEngType=='zero':
    def exEnCalc(radi):return 0

In [None]:
def NiiCalcul(TEmiList,NiList,Ne):
    NiiList=[]
    for elem in range(1,maxElem+1):
        if levelData['ElemMask'][elem-1]==False:
            NiiList.append(t.zeros(len(TEmiList)).cuda())
            continue
        sumHere2=0
        for ionJ in range(0,elem+1):
            if ionJ<elem:
                if levelData[elem]['DataMask'][ionJ]==False:continue
            prodHere=1
            for ionL in range(ionJ,elem):
                if levelData[elem]['DataMask'][ionL]==False:continue
                if elem>ionL+1 and levelData[elem]['DataMask'][ionL+1]:gUpper=levelData[elem][ionL+1]['g'][0]
                else:gUpper=1
                gLower=t.tensor(levelData[elem][ionL]['g']).cuda()
                engLevel=t.tensor(levelData[elem]['Limit'][ionL]-levelData[elem][ionL]['Level']).cuda()
                sumHere=t.sum(v2d07e16/gUpper*gLower*TEmiList**(-1.5)*np.e**(engLevel/kbc/TEmiList),axis=1)
                prodHere=(prodHere*sumHere).detach()*Ne
            prodHere=prodHere
            sumHere2=sumHere2+prodHere
        Nii=NiList[:,elem-1]/sumHere2
        NiiList.append(Nii)
    return NiiList
def NeCalcul(TEmiList,NiiList,neInput):
    Ne=0
    for elem in range(1,maxElem+1):
        if levelData['ElemMask'][elem-1]==False:continue
        sumHere2=0
        Nii=NiiList[elem-1]
        for ionJ in range(0,elem+1):
            if ionJ<elem:
                if levelData[elem]['DataMask'][ionJ]==False:continue
            prodHere=1
            for ionL in range(ionJ,elem):
                if levelData[elem]['DataMask'][ionL]==False:continue
                if elem>ionL+1 and levelData[elem]['DataMask'][ionL+1]:gUpper=levelData[elem][ionL+1]['g'][0]
                else:gUpper=1
                gLower=t.tensor(levelData[elem][ionL]['g']).cuda()
                engLevel=t.tensor(levelData[elem]['Limit'][ionL]-levelData[elem][ionL]['Level']).cuda()
                sumHere=t.sum(v2d07e16/gUpper*gLower*TEmiList**(-1.5)*np.e**(engLevel/kbc/TEmiList),axis=1)
                prodHere=(prodHere*sumHere).detach()*neInput
            prodHere=prodHere*ionJ
            sumHere2=sumHere2+prodHere
        Ne=Ne+sumHere2*Nii
    return Ne
def populCalcul(TEmiList,NiList,iterNum=30,initDens=1e4,cudaOutput=True):
    if initDens=='1-ion':
        neOld=NiList.sum(axis=1).clone().detach()
        neNew=NiList.sum(axis=1).clone().detach()
    else:
        neOld=initDens*t.ones(len(TEmiList)).cuda()
        neNew=initDens*t.ones(len(TEmiList)).cuda()
    for neRound in range(iterNum):
        NiiList=NiiCalcul(TEmiList,NiList,neNew)
        Ne=NeCalcul(TEmiList,NiiList,neNew)
        neOld=neNew.detach()
        neNew=((Ne+neOld)/2).detach()
    NiiList=NiiCalcul(TEmiList,NiList,neNew)
    neNew=neNew.reshape(-1,1)
    popuList={}
    for elem in range(1,maxElem+1):
        if levelData['ElemMask'][elem-1]==False:continue
        Nii=NiiList[elem-1]
        popuListIon={}
        for ionJ in range(elem,-1,-1):
            highlyIonized=np.where(levelData[elem]['DataMask'])[0].max()+1
            if ionJ==highlyIonized:popuListIon[ionJ]=NiiList[elem-1].reshape(-1,1)
            elif ionJ>highlyIonized:continue
            else:
                if ionJ+1==elem:gUpper=1
                elif levelData[elem]['DataMask'][ionJ+1]==False:gUpper=1
                else:gUpper=levelData[elem][ionJ+1]['g'][0]
                gLower=t.tensor(levelData[elem][ionJ]['g']).cuda()
                engLevel=t.tensor(levelData[elem]['Limit'][ionJ]-levelData[elem][ionJ]['Level']).cuda()
                popuListIon[ionJ]=popuListIon[ionJ+1][:,[0]]*neNew*v2d07e16/gUpper*gLower*TEmiList**(-1.5)*np.e**(engLevel/kbc/TEmiList)
        popuList[elem]=popuListIon
    if cudaOutput==False:
        neNew=neNew.detach().cpu().numpy()
        for elem in popuList.keys():
            for ionJ in popuList[elem].keys():
                popuList[elem][ionJ]=popuList[elem][ionJ].detach().cpu().numpy()
    return neNew,popuList

In [None]:
def BBsourceMakerSparse(lineWide,nuBarDivNu,popuDens):
    freqList=freqGrid*nuBarDivNu
    jEmiLine=0
    kAbsLine=0
    for elem in lineData.keys():
        if levelData['ElemMask'][elem-1]==False:continue
        for ionJ in lineData[elem].keys():
            if levelData[elem]['DataMask'][ionJ]==False:continue
            lineMask=lineData[elem][ionJ]['Mask']
            if np.sum(lineMask)==0:continue
            lineFreq=t.tensor(lineData[elem][ionJ]['Freq'][lineMask]).cuda().reshape(-1,1,1)
            Aul=t.tensor(lineData[elem][ionJ]['Aki'][lineMask]).cuda().reshape(-1,1,1)
            Blu=t.tensor(lineData[elem][ionJ]['Blu'][lineMask]).cuda().reshape(-1,1,1)
            Bul=t.tensor(lineData[elem][ionJ]['Bul'][lineMask]).cuda().reshape(-1,1,1)
            popuDensOne=popuDens[elem][ionJ]
            levelFormat=levelData[elem][ionJ]['Config']+levelData[elem][ionJ]['Term']+levelData[elem][ionJ]['J']
            levelUpper=lineData[elem][ionJ]['Conf_u'][lineMask]+lineData[elem][ionJ]['Term_u'][lineMask]+lineData[elem][ionJ]['J_u'][lineMask]
            levelLower=lineData[elem][ionJ]['Conf_l'][lineMask]+lineData[elem][ionJ]['Term_l'][lineMask]+lineData[elem][ionJ]['J_l'][lineMask]
            upperMask=np.where(levelUpper.reshape(-1,1)==levelFormat)[1]
            lowerMask=np.where(levelLower.reshape(-1,1)==levelFormat)[1]
            maskHere=(t.abs(freqList-lineFreq)<=freqWidth*lineWide)
            if t.sum(maskHere)==0:continue
            maskHere=maskHere.to_sparse()
            gausLineSha=t.sparse.FloatTensor(maskHere._indices(),maskHere._values()/(lineWide*2)/freqWidth[maskHere._indices()[2]],maskHere.size())
            gaInd=gausLineSha._indices()
            gaVal=gausLineSha._values()
            jEmiLineOne=t.sparse.FloatTensor(gaInd,gaVal*Aul[gaInd[0],0,0]*lineFreq[gaInd[0],0,0]*popuDensOne[:,upperMask].T[gaInd[0],gaInd[1]],gausLineSha.size())
            try:jEmiLineOne=t.sparse.sum(jEmiLineOne*hli/4/np.pi,0).to_dense()/nuBarDivNu**2
            except:print(jEmiLineOne,maskHere,freqList,lineFreq)
            jEmiLine=jEmiLine+jEmiLineOne.detach().cpu().numpy()
            del jEmiLineOne
            kAbsLineOne=(popuDensOne[:,lowerMask].T.unsqueeze(2)*Blu-popuDensOne[:,upperMask].T.unsqueeze(2)*Bul)
            kAbsLineOne=t.sparse.FloatTensor(gaInd,gaVal*lineFreq[gaInd[0],0,0]*kAbsLineOne[gaInd[0],gaInd[1],0],gausLineSha.size())
            kAbsLineOne=t.sparse.sum(kAbsLineOne*hli/4/np.pi,0).to_dense()*nuBarDivNu
            kAbsLine=kAbsLine+kAbsLineOne.detach().cpu().numpy()
            del maskHere
            del gausLineSha
            del kAbsLineOne
            del gaInd
            del gaVal
    return jEmiLine,kAbsLine

In [None]:
def makeAllTheSource(zftp,TempNetIn):
    TEmiList=TempNetIn(zftp[:,[0]]).detach()*tempMax
    VeloBeta=veloMaker(zftp[:,0])/cli
    nuBarDivNu=DopplerFreqChanger(VeloBeta,zftp[:,1])
    NiList=elemDensCalc(zftp[:,0])
    elecDens,popuDens=populCalcul(TEmiList,NiList,initDens='1-ion')
    sigList=elecDens*sigmaT
    thermSourc=sigList.reshape([-1,1])*BBSpec(freqGrid*nuBarDivNu,TEmiList.reshape([-1,1]))
    thermSourc=thermSourc.clone().detach()
    thermSourc[t.isnan(thermSourc)]=0
    thermSourc=thermSourc/nuBarDivNu**2
    jEmiLine,kAbsLine=BBsourceMakerSparse(lineWide,nuBarDivNu,popuDens)
    jEmiLine=t.tensor(jEmiLine).cuda()
    kAbsLine=t.tensor(kAbsLine).cuda()
    totalAbs=sigList+kAbsLine
    totalEmi=thermSourc+jEmiLine
    return totalAbs,totalEmi

In [None]:
def subBatchMakeSource(zftp,TempNetIn):
    totalAbs,totalEmi=[],[]
    zftpSplit=zftp.split(sourSubBa)
    for subBaNum in range(len(zftpSplit)):
        subBaAbs,subBaEmi=makeAllTheSource(zftpSplit[subBaNum],TempNetIn)
        totalAbs.append(subBaAbs)
        totalEmi.append(subBaEmi)
    totalAbs=t.cat(totalAbs)
    totalEmi=t.cat(totalEmi)
    return totalAbs,totalEmi

In [None]:
def lineMaskMakeOne(elem,ionJ,popuDensOne,lineLimit):
    nuBarDivNu=t.ones([len(popuDensOne),1])
    freqList=freqGridCpu*nuBarDivNu
    try:lineFreq=t.tensor(lineData[elem][ionJ]['Freq']).reshape(-1,1,1)
    except:print(elem,ionJ)
    Aul=t.tensor(lineData[elem][ionJ]['Aki']).reshape(-1,1,1)
    Blu=t.tensor(lineData[elem][ionJ]['Blu']).reshape(-1,1,1)
    Bul=t.tensor(lineData[elem][ionJ]['Bul']).reshape(-1,1,1)
    levelFormat=levelData[elem][ionJ]['Config']+levelData[elem][ionJ]['Term']+levelData[elem][ionJ]['J']
    levelUpper=lineData[elem][ionJ]['Conf_u']+lineData[elem][ionJ]['Term_u']+lineData[elem][ionJ]['J_u']
    levelLower=lineData[elem][ionJ]['Conf_l']+lineData[elem][ionJ]['Term_l']+lineData[elem][ionJ]['J_l']
    upperMask=np.where(levelUpper.reshape(-1,1)==levelFormat)[1]
    lowerMask=np.where(levelLower.reshape(-1,1)==levelFormat)[1]
    maskHere=(t.abs(freqList-lineFreq)<=freqWidthCpu*lineWide).to_sparse()
    gausLineSha=t.sparse.FloatTensor(maskHere._indices(),maskHere._values()/(lineWide*2)/freqWidthCpu[maskHere._indices()[2]],maskHere.size())
    gaInd=gausLineSha._indices()
    gaVal=gausLineSha._values()
    kAbsLineOne=(popuDensOne[:,lowerMask].T.unsqueeze(2)*Blu-popuDensOne[:,upperMask].T.unsqueeze(2)*Bul)
    kAbsLineOne=t.sparse.FloatTensor(gaInd,gaVal*lineFreq[gaInd[0],0,0]*kAbsLineOne[gaInd[0],gaInd[1],0],gausLineSha.size())
    kAbsLineOne=kAbsLineOne.to_dense()*hli/4/np.pi
    kAbsLineOne=(kAbsLineOne.sum(axis=1)*(zmax-zmin)/len(popuDensOne)).cpu().numpy()
    maskOpac=(kAbsLineOne.sum(axis=0)>lineLimit)
    maskTrans=(kAbsLineOne[:,maskOpac].sum(axis=1)>0)
    return maskTrans
def lineMaskMake(lineDataIn,temperNet,sampleNum=3500):
    lineDataOut=lineDataIn.copy()
    zftp=np.zeros([sampleNum,2])
    zftp[:,0]=np.linspace(0,1,num=sampleNum)
    zftp=Variable(t.tensor(zftp).cuda(),requires_grad=True)
    TEmiList=temperNet(zftp[:,[0]]).detach()*tempMax
    TEmiList=t.tensor(TEmiList.cpu().detach().numpy()).cuda()
    NiList=elemDensCalc(zftp[:,0])
    elecDens,popuDens=populCalcul(TEmiList,NiList,initDens='1-ion',cudaOutput=False)
    TEmiList=TEmiList.cpu()
    def mpInRun(inKey):
        elem=int(inKey.split('_')[0])
        ionJ=int(inKey.split('_')[1])
        popuDensOne=t.tensor(popuDens[elem][ionJ])
        maskTrans=lineMaskMakeOne(elem,ionJ,popuDensOne,lineLimit)
        return maskTrans
    inKeyList=[]
    for elem in lineDataIn.keys():
        if levelData['ElemMask'][elem-1]==False:continue
        for ionJ in lineDataIn[elem].keys():
            if levelData[elem]['DataMask'][ionJ]==False:continue
            inKeyList.append(str(elem)+'_'+str(ionJ))
    mpOut=MP(inKeyList,mpInRun)
    for inKey in inKeyList:
        elem,ionJ=inKey.split('_')
        elem,ionJ=int(elem),int(ionJ)
        lineDataOut[elem][ionJ]['Mask']=mpOut[str(elem)+'_'+str(ionJ)]
    return lineDataOut

In [None]:
zftpOrig=np.random.random(collBatch*batchNum*2).reshape(-1,2)
zftpOrig=zftpOrig*np.array([1,np.pi])+np.array([0,0])
zftpOrig=t.tensor(zftpOrig).cuda()
zftpOrig=zftpOrig.split(collBatch)
zftpOrig=[i for i in zftpOrig]

zOrig=np.random.random(tempBatch*batchNum*1).reshape(-1,1)
zOrig=zOrig*np.array([1])
zOrig=t.tensor(zOrig).cuda()
zOrig=zOrig.split(tempBatch)
zOrig=[i for i in zOrig]

zftpBoun1Orig=np.random.random(bounBatch*batchNum*2).reshape(-1,2)
zftpBoun1Orig=zftpBoun1Orig*np.array([0,np.pi*0.5])+np.array([0,0])
zftpBoun1Orig=t.tensor(zftpBoun1Orig).cuda()
zftpBoun1Orig=zftpBoun1Orig.split(bounBatch)
zftpBoun1Orig=[i for i in zftpBoun1Orig]

zftpBoun2Orig=np.random.random(bounBatch*batchNum*2).reshape(-1,2)
zftpBoun2Orig=zftpBoun2Orig*np.array([0,np.pi*0.5])+np.array([1,np.pi*0.5])
zftpBoun2Orig=t.tensor(zftpBoun2Orig).cuda()
zftpBoun2Orig=zftpBoun2Orig.split(bounBatch)
zftpBoun2Orig=[i for i in zftpBoun2Orig]

In [None]:
lossPFoAll=[]
lossPBaAll=[]
lossBoFAll=[]
lossBoBAll=[]
lossTemAll=[]
lossL1FAll=[]
lossL1BAll=[]

In [None]:
for bigStep in range(len(learnRateSchedule)):
    lineData=lineMaskMake(lineData,TempNet,100)
    optimFo=Adam(IntenWNetFo.parameters(),lr=learnRateSchedule[bigStep])
    optimBa=Adam(IntenWNetBa.parameters(),lr=learnRateSchedule[bigStep])
    optimTem=Adam(TempNet.parameters(),lr=learnRateSchedule[bigStep]*tempMultiRate)#
    for epoch in range(epochNum):
        for batch in range(batchNum):
            zftpTwo=zftpOrig[batch].clone().detach()
            zRadiTwo=zftpTwo[:,[0]]*(zmax-zmin)+zmin
            touchBoun=((zRadiTwo*t.sin(zftpTwo[:,[1]]))<zmin)&(zftpTwo[:,[1]]<np.pi/2)
            zftp=zftpTwo[touchBoun.flatten()]
            zRadi=zRadiTwo[touchBoun.flatten()]
            totalAbs,totalEmi=subBatchMakeSource(zftp,TempNet)
            for tinyStep in range(tinyEpochNum):
                IntenWNetFo.zero_grad()
                IntenOut=IntenWNetFo(zftp)*intenMax
                XinGrad=IntenW.partialInput(IntenWNetFo)*intenMax
                XinGrad[:,:,0]=XinGrad[:,:,0]/(zmax-zmin)
                lossEqu=t.cos(zftp[:,[1]])*XinGrad[:,:,0]-t.sin(zftp[:,[1]])*XinGrad[:,:,1]/zRadi+totalAbs*IntenOut-totalEmi
                lossEqu=(lossEqu/intenMax/totalAbs.mean())**2
                lossEqu=t.mean(lossEqu)*pdeWeight
                lossPFoAll.append(lossEqu.item())
                zftpBoun1=zftpBoun1Orig[batch].clone().detach()
                IntenOutBoun1=IntenWNetFo(zftpBoun1)*intenMax
                realBoun=t.tensor(BBSpec(freqGrid)).cuda()
                lossBoun1=((IntenOutBoun1-realBoun)/intenMax)**2
                lossBoun1=t.mean(lossBoun1)*bo1Weight#(10)
                lossBoFAll.append(lossBoun1.item())
                
                l1Loss=0
                for parOne in IntenWNetFo.parameters():
                    l1Loss=l1Loss+t.mean(t.abs(parOne))
                l1Loss=l1Loss*l1Norm
                lossL1FAll.append(l1Loss.item())
                #lossSym=t.mean((t.abs(IntenOut)-IntenOut)/intenMax)*100000
                loss=lossEqu+lossBoun1+l1Loss#+lossSym
                assert np.isnan(loss.item())==False
                if goodIterate:
                    if tinyStep==0:lossGood=loss.item()
                    else:
                        if loss.item()<lossGood:
                            IntenWNetFoRec.load_state_dict(IntenWNetFo.state_dict())
                            lossGood=loss.item()
                    loss.backward()
                    optimFo.step()
                else:
                    loss.backward()
                    optimFo.step()
                    if tinyStep==tinyEpochNum-1:
                        IntenWNetFoRec.load_state_dict(IntenWNetFo.state_dict())
            
            zftp=zftpTwo[(touchBoun==False).flatten()]
            zRadi=zRadiTwo[(touchBoun==False).flatten()]
            totalAbs,totalEmi=subBatchMakeSource(zftp,TempNet)
            for tinyStep in range(tinyEpochNum):
                IntenWNetBa.zero_grad()
                IntenOut=IntenWNetBa(zftp)*intenMax
                XinGrad=IntenW.partialInput(IntenWNetBa)*intenMax
                XinGrad[:,:,0]=XinGrad[:,:,0]/(zmax-zmin)
                lossEqu=t.cos(zftp[:,[1]])*XinGrad[:,:,0]-t.sin(zftp[:,[1]])*XinGrad[:,:,1]/zRadi+totalAbs*IntenOut-totalEmi
                lossEqu=(lossEqu/intenMax/totalAbs.mean())**2
                lossEqu=t.mean(lossEqu)*pdeWeight
                lossPBaAll.append(lossEqu.item())
            
                zftpBoun2=zftpBoun2Orig[batch].clone().detach()
                IntenOutBoun2=IntenWNetBa(zftpBoun2)*intenMax
                lossBoun2=((IntenOutBoun2-0)/intenMax)**2
                lossBoun2=t.mean(lossBoun2)*bo2Weight
                lossBoBAll.append(lossBoun2.item())
                
                l1Loss=0
                for parOne in IntenWNetBa.parameters():
                    l1Loss=l1Loss+t.mean(t.abs(parOne))
                l1Loss=l1Loss*l1Norm
                lossL1BAll.append(l1Loss.item())
                #lossSym=t.mean((t.abs(IntenOut)-IntenOut)/intenMax)*100000
                loss=lossEqu+lossBoun2+l1Loss#+lossSym
                assert np.isnan(loss.item())==False
                if goodIterate:
                    if tinyStep==0:lossGood=loss.item()
                    else:
                        if loss.item()<lossGood:
                            IntenWNetBaRec.load_state_dict(IntenWNetBa.state_dict())
                            lossGood=loss.item()
                    loss.backward()
                    optimBa.step()
                else:
                    loss.backward()
                    optimBa.step()
                    if tinyStep==tinyEpochNum-1:
                        IntenWNetBaRec.load_state_dict(IntenWNetBa.state_dict())
            
            IntenWNetFo.load_state_dict(IntenWNetFoRec.state_dict())
            IntenWNetBa.load_state_dict(IntenWNetBaRec.state_dict())
            zTemp=zOrig[batch].clone().detach()
            radiTemp=zTemp*(zmax-zmin)+zmin
            TempNet.zero_grad()
            phiGrid=integGridAssigner(threeDimStep)
            phiIn,zbig=t.meshgrid(phiGrid,zTemp[:,0])
            zftpBig=t.vstack([zbig.flatten(),phiIn.flatten()]).T
            zRadiBig=zftpBig[:,[0]]*(zmax-zmin)+zmin
            touchBoun=((zRadiBig*t.sin(zftpBig[:,[1]]))<zmin)&(zftpBig[:,[1]]<np.pi/2)

            IntenBig=(IntenWNetFo(zftpBig)*touchBoun+IntenWNetBa(zftpBig)*(1-1*touchBoun)).reshape(threeDimStep,-1,freqSampler)*intenMax
            IntenBig=IntenBig.clone().detach()
            EnergyAbsoSum=t.sum(IntenBig*t.unsqueeze(t.sin(phiIn),2)*freqGrid,axis=(0,2))*np.log(10)*(freqMaxLog-freqMinLog)*2*np.pi**2/threeDimStep/freqSampler
            EnergyAbsoSum=nn.ReLU()(EnergyAbsoSum.reshape([-1,1])-0.001)+0.001
            addExtraEnergy=t.tensor(exEnCalc(radiTemp.clone().detach().cpu().numpy())).cuda()

            TEmiPred=TempNet(zTemp)*tempMax
            NiList=elemDensCalc(zTemp[:,0])
            elecDens,popuDens=populCalcul(TEmiPred,NiList,initDens='1-ion')
            sigList=elecDens*sigmaT
            TEmiTrue=((addExtraEnergy/sigList+EnergyAbsoSum)/4/stfBlz)**0.25
            TEmiTrue=t.tensor(TEmiTrue.clone().detach().cpu().numpy()).cuda()

            lossTem=t.sum((TEmiTrue-TEmiPred)**2)*1e-6
            lossTemAll.append(lossTem.item())
            lossTem.backward()
            optimTem.step()
            print(loss.item(),lossEqu.item(),lossBoun1.item(),lossBoun2.item(),lossTem.item(),l1Loss.item())
            
        print('epoch '+str(epoch))
        t.save(IntenWNetFoRec.state_dict(),ModelSaver+'CurrentIntenForward.to')
        t.save(IntenWNetBaRec.state_dict(),ModelSaver+'CurrentIntenBackward.to')
        t.save(TempNet.state_dict(),ModelSaver+'CurrentTemp.to')
        np.save(ModelSaver+'Metrics/lossPFoAll.npy',np.array(lossPFoAll))
        np.save(ModelSaver+'Metrics/lossPBaAll.npy',np.array(lossPBaAll))
        np.save(ModelSaver+'Metrics/lossBoFAll.npy',np.array(lossBoFAll))
        np.save(ModelSaver+'Metrics/lossBoBAll.npy',np.array(lossBoBAll))
        np.save(ModelSaver+'Metrics/lossTemAll.npy',np.array(lossTemAll))
        np.save(ModelSaver+'Metrics/lossL1FAll.npy',np.array(lossL1FAll))
        np.save(ModelSaver+'Metrics/lossL1BAll.npy',np.array(lossL1BAll))
    t.save(IntenWNetFoRec.state_dict(),ModelSaver+'StepModels/IntenForward_'+str(bigStep)+'.to')
    t.save(IntenWNetBaRec.state_dict(),ModelSaver+'StepModels/IntenBackward_'+str(bigStep)+'.to')
    t.save(TempNet.state_dict(),ModelSaver+'StepModels/Temp_'+str(bigStep)+'.to')