# The ActionFinder 
## (Ibata, Diakogiannis, Famaey & Monari, 2021)

# An unsupervised network to find the canonical transformation
# from x,v to J,theta

In [None]:
%pylab inline
import torch

In [None]:
# Probably don't want to fiddle with these:
dtype = torch.float64 # data type to use
np_dtype = np.float64 # same for numpy

num_workers=1         # most recent version of pytorch complains if we ask for more data workers !!

NPhase =  6           # number of phase-space coordinates

## The scale factors 

In [None]:
pos_scale =  20.0   # will scale all input distances  by this factor
vel_scale = 200.0   # will scale all input velocities by this factor
J_scale   = pos_scale*vel_scale
T_scale   = vel_scale/pos_scale

# loss function inverse scales:
scale_Ts      =  10.0 # (rough) between Js and Ts
scale_Js_same =   1.0 # to increase importance of identical Js
scale_direc   = 0.001 # for the acceleration direction constraint
scale_Os_same = 100.0 # to increase importance of identical Omegas
scale_dJdt    = 100.0 # for dJdt
scale_LinAl   =1000.0 # to increase importance of linear algebra solution for acceleration


# The reference isochrone model (these values are a good starting point for the Milky Way):
M_iso_ref=1500000.0/(pos_scale*vel_scale**2) # masses in N-body units (i.e. 1.5e6*2.222e5=3.333e11 Msun)
b_iso_ref=5.0/pos_scale                      # scale in N-body units (i.e. 5 kpc)


torch.pi = torch.acos(torch.zeros(1)).item() * 2
torch.J_scale = torch.tensor(J_scale)
torch.W_scale=torch.zeros(6)
torch.W_scale[:3]=pos_scale
torch.W_scale[3:]=vel_scale
W_scale=np.zeros(6)
W_scale[:3]=pos_scale
W_scale[3:]=vel_scale

# Set up the choices for this run:

In [None]:
Epochs = 1024*2                  # Number of epochs to run for

batch_size=128*4
ndata_select=batch_size*2        # training   dataset size
ndata_select_test=batch_size*2   # validation dataset size


genname_root = 'generate_orbits0'

fitsdata_name = '../data/'+genname_root+'.fits'
flname_save = r'../saved_models/ActionFinder0.params'


# file to make predictions for, and output file with predictions:
fitspred_name =     r'../data/generate_orbits0.fits'
fitspred_name_out = r'../data/generate_orbits0_withPreds.fits'


NStars_per_stream = 8 # NOTE: The fits file has NPhase*NStars_per_stream columns

# number of stars that we wish to treat as one stream group (these will be randomly drawn from NStars_per_stream)
NStars = 8 # The idea is that NStars << NStars_per_stream, so 

assert NStars <= NStars_per_stream, "We cannot have NStars > NStars_per_stream"


# Half-depths of the neural networks:
depth_GF_G = 7
depth_GF_P = 5
depth_A    = 7


# Switches:
Fit_Acceleration = False             # set to True on second-pass to get the acceleration field.
use_point_transformation = True      # do we use the optional point-transformation network?
use_validation_set = True            # the network is unsupervised, so validation is not critical if we don't want to propagate solutions
use_simple_loss = False              # if True, the only objective in first pass is to minimize spread of J''
iterate_to_find_Jd = True            # False ==> CHEAT in training of GF using ground-truth J' values (currently also requires use_point_transformation = False)
use_dropout = True                   # for better regularization
write_predictions_on_the_fly = False # do we want to write to fitspred_name_out at each best solution?

## Some astrophysical units and coordinate conversions

In [None]:
yr_unit=31.5576e6          # in s
Gyr_unit=yr_unit * 1.0e9   # in s
AU_unit=149.597870700e6    # in km
kpc_unit=3.0857e16         # in km
vel_unit=kpc_unit/Gyr_unit # i.e. 1 vel_unit=0.97803 km/s
PM_unit=AU_unit/yr_unit    # in km/s (=4.74047)
PM_conv_masyr_radGyr=0.20626

xr_G=192.85948402*torch.pi/180.0 # Poleski's 2013 alpha_G  NGP in equatorial coordinates
xd_G= 27.12829637*torch.pi/180.0 # his delta_G
sin_xd_G=np.sin(xd_G)
cos_xd_G=np.cos(xd_G)

xl_NP=122.93193212*torch.pi/180.0 # North pole in Galactic coordinates
xb_NP= 27.12835496*torch.pi/180.0 
sin_xb_NP=np.sin(xb_NP)
cos_xb_NP=np.cos(xb_NP)

RMAT_EqGal = torch.tensor([
[-0.054875539726,-0.873437108010,-0.483834985808],
[ 0.494109453312,-0.444829589425, 0.746982251810],
[-0.867666135858,-0.198076386122, 0.455983795705]],dtype=dtype) # for use in Equatorial to Galactic coordinate conversion.

## The Solar motion, and corrections

In [None]:
vSun_pec = torch.tensor([11.1, 7.2, 7.25],dtype=dtype) # Schoenrich et al (2010) for U,W; V from Bovy (2020)
vc_Sun   = torch.tensor([0.0, 243.0, 0.0],dtype=dtype) # +/- 8 km/s  Bovy (2020)
Rsun = 8.224 # +/- 8 km/s      Bovy (2020)
zsun = 0.000 # 2.8 pc Widmark et al. (2021), rounded down to zero for mu_l and mu_b to be simple in R and z
vSun=vc_Sun+vSun_pec

Rsun = Rsun/pos_scale
zsun = zsun/pos_scale
vc_Sun = vc_Sun/vel_scale
vSun_pec = vSun_pec/vel_scale
vSun = vSun/vel_scale
vc_Sun = vc_Sun/vel_unit # convert into units with G=1
vSun_pec = vSun_pec/vel_unit
vSun = vSun/vel_unit

# Construct a Dataset class for training
## The data is input as a fits file

In [None]:
def get_data_as_pandas(fits_flname_read):
    tab = fits.open(fits_flname_read)

    
    labels = tab[2].header
    raw_data = tab[2].data
        
    labels_list = []
    for i in range(0,len(labels)):
        if 'TTYPE'+str(i) in labels:
            labels_list.append(labels['TTYPE'+str(i)])
         
    table = Table(raw_data,names=labels_list)
    nv = len(table)
        
    orig_datav = table.copy()


    labels = tab[3].header
    raw_data = tab[3].data
        
    labels_list = []
    for i in range(0,len(labels)):
        if 'TTYPE'+str(i) in labels:
            labels_list.append(labels['TTYPE'+str(i)])
         
    table = Table(raw_data,names=labels_list)
    np = len(table)
        
    orig_datap = table.copy()


    
    labels = tab[1].header
    raw_data = tab[1].data
        
    labels_list = []
    for i in range(0,len(labels)):
        if 'TTYPE'+str(i) in labels:
            labels_list.append(labels['TTYPE'+str(i)])
         
    table = Table(raw_data,names=labels_list)
    n = len(table)
        
    orig_data = table.copy()
    
    assert nv == n, "The fits file must be sensible"
    assert np == n, "The fits file must be sensible"
    

    physical_data = table.copy()
    Js_data = table.copy()
    thetas_data = table.copy()
    dJs_data = table.copy()
    dthetas_data = table.copy()


    del physical_data[[labels_list[:]][0]]
    del Js_data[[labels_list[:]][0]]
    del thetas_data[[labels_list[:]][0]]
    del dJs_data[[labels_list[:]][0]]
    del dthetas_data[[labels_list[:]][0]]


    # Expected inputs:
    # dis (distance in kpc), xl (Galactic longitude in radians), xb (Galactic latitude in radians)
    # vr (Heliocentric radial velocity in km/s),
    # ld (mu_l*cos(b), proper motion along Galactic longitude in mas/yr)
    # bd (mu_b, proper motion along Galactic latitude in mas/yr)
    for i in range(0,NStars_per_stream):
        ii = "%03d" % (i,)
        
        colname='dis' + ii
        var = Column(orig_data[colname], name = colname)
        var = var/pos_scale
        physical_data.add_column(var)

        colname='xl' + ii
        var = Column(orig_data[colname], name = colname)
        physical_data.add_column(var)

        colname='xb' + ii
        var = Column(orig_data[colname], name = colname)
        physical_data.add_column(var)

        colname='vr' + ii
        var = Column(orig_datav[colname], name = colname)
        var = var/vel_unit # put into units with G=1
        var = var/vel_scale
        physical_data.add_column(var)

        colname='ld' + ii
        var = Column(orig_datav[colname], name = colname)
        var = var/PM_conv_masyr_radGyr # put into units with G=1 (radians/Gyr)
        var = var/T_scale
        physical_data.add_column(var)

        colname='bd' + ii
        var = Column(orig_datav[colname], name = colname)
        var = var/PM_conv_masyr_radGyr # put into units with G=1 (radians/Gyr)
        var = var/T_scale
        physical_data.add_column(var)


    # The reference actions J1,J2,J3 (useful for checks, if not available give dummy values!)
    colname='J1'
    var = Column(orig_data[colname], name = colname)
    var = var/J_scale
    Js_data.add_column(var)

    colname='J2'
    var = Column(orig_data[colname], name = colname)
    var = var/J_scale
    Js_data.add_column(var)

    colname='J3'
    var = Column(orig_data[colname], name = colname)
    var = var/J_scale
    Js_data.add_column(var)

    # The reference angles op,oz,or conjugate to J1,J2,J3 (useful for checks, if not available give dummy values!)
    for i in range(0,NStars_per_stream):
        ii = "%03d" % (i,)

        colname='op' + ii
        var = Column(orig_datap[colname], name = colname)
        thetas_data.add_column(var)

        colname='oz' + ii
        var = Column(orig_datap[colname], name = colname)
        thetas_data.add_column(var)

        colname='or' + ii
        var = Column(orig_datap[colname], name = colname)
        thetas_data.add_column(var)


    physical = physical_data.to_pandas()
    Js       = Js_data.to_pandas()
    thetas   = thetas_data.to_pandas()

    tab.close()
    
    return physical, Js, thetas

In [None]:
from torch.utils.data import dataset

from astropy.table import Column
import astropy.io.fits as fits 
from astropy.table import Table
import numpy as np
import pandas as pd


class AstroDataset(dataset.Dataset):
    """
    Class for feeding data within a DataLoader object.     
    """
    
    def __init__(self, fitsfile = 'dummy.fits',   mode='train', randomize=True ):
        super().__init__()
        
    
        df, df_ref, dfo_ref = get_data_as_pandas(fitsfile)
        self.randomize = randomize
        
        assert len(df) >= ndata_select,      "ndata_select      must not be larger than the dataset"
        assert len(df) >= ndata_select_test, "ndata_select_test must not be larger than the dataset"
        
        train_df = df[:ndata_select]
        test_df  = df[len(df)-ndata_select_test:]

        train_df_ref = df_ref[:ndata_select]
        test_df_ref  = df_ref[len(df)-ndata_select_test:]

        train_dfo_ref = dfo_ref[:ndata_select]
        test_dfo_ref  = dfo_ref[len(df)-ndata_select_test:]


        if mode =='train':
            nlen=min(len(train_df),ndata_select)
            self.df = train_df[:nlen]
            self.df_ref = train_df_ref[:nlen]
            self.dfo_ref = train_dfo_ref[:nlen]
        elif mode =='validation':
            nlen=min(len(test_df),ndata_select_test)
            self.df = test_df[:nlen]
            self.df_ref = test_df_ref[:nlen]
            self.dfo_ref = test_dfo_ref[:nlen]
        elif mode =='all':
            self.df = df
            self.df_ref = df_ref
            self.dfo_ref = dfo_ref
        else:
            raise ValueError("Cannot understand mode for training, available choices::{train,validation,all}, aborting ...")
    
        self.mode = mode
       

        nelem=len(self.df)
    
        self.variables_arr        = np.zeros((nelem,NStars_per_stream*NPhase),dtype=np_dtype)
        self.variables_arr[:,:]   = self.df.iloc[:,:].astype(np_dtype)
        self.variables_arr        = np.reshape(self.variables_arr,(nelem,NStars_per_stream,NPhase))
        
        self.Js_ref_arr           = np.zeros((nelem,1*NPhase//2),dtype=np_dtype)
        self.Js_ref_arr[:,:]      = self.df_ref.iloc[:,:].values.astype(np_dtype)
        self.Js_ref_arr           = np.reshape(self.Js_ref_arr,(nelem,1,NPhase//2))

        self.thetas_ref_arr       = np.zeros((nelem,NStars_per_stream*NPhase//2),dtype=np_dtype)
        self.thetas_ref_arr[:,:]  = self.dfo_ref.iloc[:,:].values.astype(np_dtype)
        self.thetas_ref_arr       = np.reshape(self.thetas_ref_arr,(nelem,NStars_per_stream,NPhase//2))
        self.thetas_ref_arr       = ( (self.thetas_ref_arr) % (2 * np.pi) ) # i.e. [0,2*pi]


        dis    = np.cos(self.variables_arr[:,:,0:1])
        cos_xl = np.cos(self.variables_arr[:,:,1:2])
        sin_xl = np.sin(self.variables_arr[:,:,1:2])
        cos_xb = np.cos(self.variables_arr[:,:,2:3])
        sin_xb = np.sin(self.variables_arr[:,:,2:3])
        
        rvec=np.concatenate(( cos_xl*cos_xb, sin_xl*cos_xb, sin_xb),axis=2)
        lvec=np.concatenate((-sin_xl       , cos_xl       , np.zeros_like(cos_xl)),axis=2)
        bvec=np.concatenate((-cos_xl*sin_xb,-sin_xl*sin_xb, cos_xb),axis=2)
        
        vr_vcorr = np.einsum('bci,i->bc',rvec,vSun)[:,:,None]
        vl_vcorr = np.einsum('bci,i->bc',lvec,vSun)[:,:,None]
        vb_vcorr = np.einsum('bci,i->bc',bvec,vSun)[:,:,None]

        self.vcorr_arr=np.concatenate((vr_vcorr,vl_vcorr,vb_vcorr,vr_vcorr/dis,vl_vcorr/dis,vb_vcorr/dis),axis=2)

        # remove the Solar motion first.
        self.variables_arr[:,:,3:4]=self.variables_arr[:,:,3:4]+self.vcorr_arr[:,:,0:1]
        self.variables_arr[:,:,4:5]=self.variables_arr[:,:,4:5]+self.vcorr_arr[:,:,1:2]/self.variables_arr[:,:,0:1]
        self.variables_arr[:,:,5:6]=self.variables_arr[:,:,5:6]+self.vcorr_arr[:,:,2:3]/self.variables_arr[:,:,0:1]
        
        
    def __getitem__(self,idx):
        """
        This function returns the the input to the network, and ground truth pairs for a particular index. 
        """
        
        pick_timestep=np.random.permutation(NStars_per_stream) # NOTE the random permutation of the entries

        
        variables  = self.variables_arr[idx,pick_timestep[0:NStars],:]
        Js_ref     = self.Js_ref_arr[idx,0:1,:]
        thetas_ref = self.thetas_ref_arr[idx,pick_timestep[0:NStars],:]
        vcorr_arr  = self.vcorr_arr[idx,pick_timestep[0:NStars],:]

        return variables, Js_ref, thetas_ref
    
    
    def __len__(self):
        """
        This function returns the length of the dataset, total elements. 
        """
        return len(self.df)

## Check the data visually to make sure all is OK:

In [None]:
dataset_check  = AstroDataset(fitsfile=fitsdata_name,mode='all')
variables = dataset_check.df.iloc[0,:]

print("Length of full dataset:",len(dataset_check))

dataset_check  = AstroDataset(fitsfile=fitsdata_name,mode='validation')
variables = dataset_check.df.iloc[0,:]

print("Length of validation dataset:",len(dataset_check))

dataset_check  = AstroDataset(fitsfile=fitsdata_name,mode='train')
variables = dataset_check.df.iloc[0,:]

print("Length of training dataset:",len(dataset_check))


for i in range(0,1):
    net_input, Js_ref, thetas_ref = dataset_check[10]

net_input.shape, thetas_ref.shape

# Define the network

In [None]:
import torch

class ResDense_block(torch.nn.Module):
    """
    Basic building block, using the philosophy of resnets but for Linear blocks.
    This layer does not change the number of features. 
    """
    
    def __init__(self, nunits, minwidth_forDropout=256, pDropout=0.5):
        super().__init__()
        """
        We DECLARE the layers we want to use. These are the elements that will go into the nodes of 
        the computational graph. The graph (i.e. how the layers are connected) is defined in the forward function. 
        """
        
        self.nunits = nunits
        self.minwidth_forDropout = minwidth_forDropout
            
        self.DO1    = torch.nn.Dropout(p=pDropout)
        self.dense1 = torch.nn.Linear(in_features=self.nunits, out_features= self.nunits, bias=False)
        self.WN1    = torch.nn.utils.weight_norm(self.dense1)
        self.DO2    = torch.nn.Dropout(p=pDropout)
        self.dense2 = torch.nn.Linear(in_features=self.nunits, out_features= self.nunits)
        self.WN2    = torch.nn.utils.weight_norm(self.dense2)
    
    def forward(self, input):
        """
        This function is the definition of the computational graph, i.e. the connection of inputs to layers, and layers interconnection.  
        """
        
        xx = torch.relu(input)
        if (use_dropout and self.nunits>=self.minwidth_forDropout):
            xx = self.DO1(xx)
        xx = self.dense1(xx)
        xx = self.WN1(xx)

        xx = torch.relu(xx)
        if (use_dropout and self.nunits>=self.minwidth_forDropout):
            xx = self.DO2(xx)
        xx = self.dense2(xx)
        xx = self.WN2(xx)

        xx = xx+input
            
        return xx

In [None]:
import torch

class CapNet(torch.nn.Module):
    """
    Rodrigo's lobotomization of Foivos' fancy CapNet
    """
    def __init__(self, depth, in_features=NPhase//2, out_features=1, st_features=64, verbose=False):
        super().__init__()
        
        self.depth = depth
        width_max=(st_features)*2**4
        
        assert width_max >= out_features, "width_max too small"
        assert st_features >= in_features, "st_features too small"
        assert st_features >= out_features, "st_features too small"
        
        verbose=True
        if (verbose):
            print ("***** CapNet structure *****")

        encoder = []

        encoder.append(torch.nn.Linear(in_features=in_features, out_features=st_features))

        for i in range(0,depth):
            if (verbose):
                print ("depth:= {0}, width: {1}, target: {2}".format(i,min(st_features*2**i,width_max),out_features))
            encoder.append(ResDense_block(nunits=min(st_features*2**i,width_max)))
            if (min(st_features*2**i,width_max) != min(st_features*2**(i+1),width_max)):
                encoder.append(torch.nn.Linear(in_features=min(st_features*2**i,width_max), 
                                            out_features=min(st_features*2**(i+1),width_max)))

        current_features=min(st_features*2**depth,width_max)

        if (verbose):
            print ("depth:= {0}, width: {1}, target: {2}".format(depth,current_features,out_features))
        encoder.append(ResDense_block(nunits=current_features))

        assert current_features >= out_features, "depth too small"

        
        for i in range(depth-1,-1,-1):
            if (min(st_features*2**i,width_max)>=out_features):
                current_features=min(st_features*2**i,width_max)

                assert min(st_features*2**(i+1),width_max) >= current_features, "something weird"
                
                if (min(st_features*2**(i+1),width_max) != current_features):
                    encoder.append(torch.nn.Linear(in_features=min(st_features*2**(i+1),width_max), 
                                           out_features=current_features))
                if (verbose):
                    print ("depth:= {0}, width: {1}, target: {2}".format(i,current_features,out_features))
                encoder.append(ResDense_block(nunits=current_features))

        
        encoder.append(torch.nn.Linear(in_features=current_features, out_features=out_features))
        self.encoder = torch.nn.Sequential(*encoder)

        
    def forward(self, input ):
            
        x = self.encoder(input)
                
        return x    


# The analytic toy model
## output order is J_phi, J_z, J_r

In [None]:
class Isochrone_Analytic(torch.nn.Module):
    def __init__(self, verbose=False, tiny=1.0e-10):
        super().__init__()

        self.tiny = tiny

    def forward(self,SetOfStars,M_iso,b_iso):
        # largely a pytorch translation (with some cleaning to get back-propagation to work) of J. Bovy's actionAngleIsochrone.py 
        # Given its location in the network, it would pay off to spend some effort rationalizing this routine !!!
        
        x=SetOfStars[...,0:1]
        y=SetOfStars[...,1:2]
        z=SetOfStars[...,2:3]
        vx=SetOfStars[...,3:4]
        vy=SetOfStars[...,4:5]
        vz=SetOfStars[...,5:6]

        r2 = x**2+y**2+z**2
        r  = torch.sqrt(r2)
        r_soft = torch.sqrt(b_iso**2+r2)
        R2 = x**2+y**2
        R  = torch.sqrt(R2)
        phi= torch.atan2(y,x)
        v2 = vx**2+vy**2+vz**2
        v  = torch.sqrt(v2)
        vR = (x*vx+y*vy)/R

        Lx = y*vz-z*vy
        Ly = z*vx-x*vz
        Lz = x*vy-y*vx
    
        L2 = Lx**2+Ly**2+Lz**2
        L  = torch.sqrt(L2)
        
        pot= -M_iso/(b_iso+r_soft)
        E  = pot+0.5*v2
        E[E>-self.tiny]=-self.tiny  # NOTE!!! energy of unbound orbits is fudged here to avoid numerical problems below
        
        J2 = L - torch.abs(Lz)
        J3 = M_iso/torch.sqrt(-2.0*E)-0.5*(L+torch.sqrt(L2+4*M_iso*b_iso))
        J3[torch.isnan(J3)]=self.tiny

  
        #Frequencies
        Omegar= (-2.0*E)**1.5/M_iso
        Omegar[torch.isnan(Omegar)]=0.0
        Omegaz= 0.5*(1.0+L/torch.sqrt(L2+4.0*M_iso*b_iso))*Omegar
        Omegap= Omegaz.clone()
        
        indx= (Lz < 0.0)
        Omegap[indx] *= -1.0

        #Angles
        c  = -M_iso/2.0/E-b_iso
        e  = torch.sqrt( 1.0 - L2/M_iso/c*(1.0+b_iso/c) )

        if b_iso == 0.0:
            coseta = 1.0/e*(1.0-torch.sqrt(r2)/c)
        else:
            s= 1.0+torch.sqrt(1.0+r2/b_iso**2)
            coseta= 1.0/e*(1.0-b_iso/c*(s-2.0))

        torch.clamp(coseta, min=-1.0,max=1.0)
        eta= torch.acos(coseta)
        costheta= z/r
        sintheta= R/r
        vrindx= ((vR*sintheta+vz*costheta) < 0.0)
        eta[vrindx] = 2.0*torch.pi-eta[vrindx]
        angler= eta-e*c/(c+b_iso)*torch.sin(eta)
        tan11= torch.atan(torch.sqrt((1.0+e)/(1.0-e))*torch.tan(0.5*eta))
        tan12= torch.atan(torch.sqrt((1.0+e+2.0*b_iso/c)/(1.0-e+2.0*b_iso/c))*torch.tan(0.5*eta))

        vzindx= ((-vz*sintheta+vR*costheta) > 0.0)

        tan11[tan11 < 0.0] += torch.pi
        tan12[tan12 < 0.0] += torch.pi

        pindx1= (Lz/L >  1.0)
        pindx2= (Lz/L < -1.0)
        Lz_clamped = Lz.clone()
        Lz_clamped[pindx1]=  L[pindx1]
        Lz_clamped[pindx2]= -L[pindx2]
        
        sini= torch.sqrt(L2-Lz_clamped**2)/L
        tani= torch.sqrt(L2-Lz_clamped**2)/Lz_clamped
        sinpsi= costheta/sini
    
        pindx1= (sinpsi > 1.0)*torch.isfinite(sinpsi)
        sinpsi[pindx1]= 1.0
        pindx2= (sinpsi < -1.0)*torch.isfinite(sinpsi)
        sinpsi[pindx2]= -1.0           
    
        psi= torch.asin(sinpsi)
        psi[vzindx]= torch.pi-psi[vzindx]
    
        # For non-inclined orbits, set Omega=0 by convention
        psi[~torch.isfinite(psi)] = phi[~torch.isfinite(psi)]
        psi= psi % (2.0*torch.pi)
        anglez= psi+Omegaz/Omegar*angler-tan11-1.0/torch.sqrt(1.0+4*M_iso*b_iso/L2)*tan12
        sinu= z/R/tani
    
        pindx3= (sinu > 1.0)*torch.isfinite(sinu)
        sinu[pindx3]= 1.0
        pindx4= (sinu < -1.0)*torch.isfinite(sinu)
        sinu[pindx4]= -1.0           

        u= torch.asin(sinu)
        u[vzindx]= torch.pi-u[vzindx]

        # For non-inclined orbits, set Omega=0 by convention
        u[~torch.isfinite(u)]= phi[~torch.isfinite(u)]
        Omega = phi-u
        anglep= Omega.clone()
    
        anglep[indx]  -= anglez[indx]
        anglep[~indx] += anglez[~indx]

        indxp = torch.isnan(anglep)
        indxz = torch.isnan(anglez)
        indxr = torch.isnan(angler)

        Anglep= anglep % (2*torch.pi)
        Anglez= anglez % (2*torch.pi)
        Angler= angler % (2*torch.pi)

        Anglep[indxp] = 0.
        Anglez[indxz] = 0.
        Angler[indxr] = 0.
        
        acc_fac = M_iso/( r_soft * (b_iso+r_soft)**2 )
        ax = -x*acc_fac
        ay = -y*acc_fac
        az = -z*acc_fac        
        
        return (torch.cat([Lz,J2,J3],dim=-1),
                torch.cat([Anglep,Anglez,Angler],dim=-1),
                torch.cat([Omegap,Omegaz,Omegar],dim=-1),
                torch.cat([ax,ay,az],dim=-1) )

## Utility to convert: distance,l,b,vr,mu_l,mu_b -> x,y,z,vx,vy,vz

In [None]:
class conv_Input2xv(torch.nn.Module):
    def __init__(self, verbose=False):
        super().__init__()

    def forward(self,SetOfStars):
        
        dis           = SetOfStars[...,0:1]
        xl            = SetOfStars[...,1:2]
        xb            = SetOfStars[...,2:3]
        vr            = SetOfStars[...,3:4]
        ldot_CB_vcorr = SetOfStars[...,4:5]
        bdot_vcorr    = SetOfStars[...,5:6]
        
        cos_xl = torch.cos(xl)
        sin_xl = torch.sin(xl)
        cos_xb = torch.cos(xb)
        sin_xb = torch.sin(xb)
        
        rvec=torch.cat(( cos_xl*cos_xb, sin_xl*cos_xb, sin_xb),axis=2)
        lvec=torch.cat((-sin_xl       , cos_xl       , torch.zeros_like(cos_xl)),axis=2)
        bvec=torch.cat((-cos_xl*sin_xb,-sin_xl*sin_xb, cos_xb),axis=2)
        rlbRMAT=torch.stack([rvec,lvec,bvec],dim=-1)
        
        vl = ldot_CB_vcorr*dis
        vb = bdot_vcorr*dis
        vec_vcorr=torch.cat([vr,vl,vb],dim=-1)

        x=dis*rlbRMAT[:,:,0:1,0]-Rsun         # the position vector from the GC
        y=dis*rlbRMAT[:,:,1:2,0]
        z=dis*rlbRMAT[:,:,2:3,0]+zsun

        vx=torch.einsum('bci,bci->bc',vec_vcorr,rlbRMAT[:,:,0,:])[:,:,None]
        vy=torch.einsum('bci,bci->bc',vec_vcorr,rlbRMAT[:,:,1,:])[:,:,None]
        vz=torch.einsum('bci,bci->bc',vec_vcorr,rlbRMAT[:,:,2,:])[:,:,None]

        SetOfStars_xv=torch.cat([x,y,z,vx,vy,vz],dim=2)
            
        return SetOfStars_xv

## Nice and fast Jacobian

In [None]:
def jacobian(inputs, outputs,create_graph=True,retain_graph=True):
    # This calculates gradients for each J coordinate - index i
    return  torch.stack([ torch.autograd.grad(outputs[:,:,i].sum(), inputs, 
                                    retain_graph=retain_graph, 
                                    create_graph=create_graph, 
                                    only_inputs=True)[0] 
                        for i in range(outputs.shape[-1])], dim=2) 
# (it might be possible to make this completely general by using pytorch's "..." notation)

## Next come the encoders for:
### Encoder_GF_G: canonical generating function G
### Encoder_GF_P: canonical point-transformation P
### Encoder_A   : the acceleration

In [None]:
class Encoder_A(torch.nn.Module):
    def __init__(self, verbose=False):
        super().__init__()
        self.ultra  = CapNet(depth=depth_A, in_features=NPhase, out_features=NPhase//2)

    def forward(self,SetOfStars_wacc):

        acc_iso = SetOfStars_wacc[...,3:]
        
        acc = acc_iso + self.ultra(SetOfStars_wacc)
        
        return acc


class Encoder_GF_P(torch.nn.Module):
    def __init__(self, verbose=False):
        super().__init__()
        self.ultra  = CapNet(depth=depth_GF_P, in_features=NPhase, out_features=NPhase)

    def forward(self,SetOfStars_T):
        
        SetOfStars = torch.cat([torch.cos(SetOfStars_T),torch.sin(SetOfStars_T)],dim=-1)

        Tout = self.ultra(SetOfStars)
        
        GF_P = (SetOfStars_T + torch.atan2(Tout[...,:3],Tout[...,3:])) % (2*torch.pi)
        
        return GF_P


class Encoder_GF_G(torch.nn.Module):
    def __init__(self, verbose=False):
        super().__init__()
        self.ultra = CapNet(depth=depth_GF_G, in_features=NPhase+NPhase//2, out_features=3)

    def forward(self,SetOfStars_TJ):

        SetOfStars_Ts = SetOfStars_TJ[...,:3]
        SetOfStars_Js = SetOfStars_TJ[...,3:]
        SetOfStars = torch.cat([torch.cos(SetOfStars_Ts),torch.sin(SetOfStars_Ts),SetOfStars_Js],dim=-1)
        
        GFU = self.ultra(SetOfStars)

        GF_G = ( torch.atan2(GFU[...,1:2],GFU[...,0:1]) + GFU[...,2:3] )
        
        return GF_G

# Main ActionFinder network

In [None]:
class ActionFinder(torch.nn.Module):
    """
    Inputs: 
           d,l,b,vh,mu_l,mu_b astrometric phase space coordinates
    Output: 
           loss due to the spread of the J'' values (and others)
    """
    def __init__(self,verbose=False,iter_max=10):
        super().__init__()

        self.conv_input2xv = conv_Input2xv(verbose=verbose)
        self.isochrone_analytic = Isochrone_Analytic(verbose=verbose)
        self.encoder_GF_G = Encoder_GF_G(verbose=verbose)
        self.encoder_GF_P = Encoder_GF_P(verbose=verbose)
        self.encoder_A    = Encoder_A(verbose=verbose)
        self.Mb_param     = torch.nn.Parameter(torch.ones((2,1),dtype=torch.float64))

        self.iter_max = iter_max
        
    def forward(self,inputs,Jlabels,Tlabels,i,epoch): # Note that we're feeding in the J,theta labels, but purely for debugging. The algorithm can work completely unsupervised.
    
        M_iso_toy  = np.abs(M_iso_ref)*self.Mb_param[0] # mass and scale of the toy isochrone model
        b_iso_toy  = np.abs(b_iso_ref)*self.Mb_param[1] # Mb_param[] is optimized by the network
                    
        # ***NOTE*** double dashed variables (Tdd, Jdd) i.e. (Theta'', J'') are in the ***target system***
        #                   dashed variables (Td,  Jd)  i.e. (Theta', J')   result from applying G transformations

        inputs_xv = self.conv_input2xv(inputs)        # convert d,l,b,vh,mu_l,mu_b to xv
        
        J_iso,  T_iso,  O_iso,  a_iso = self.isochrone_analytic(inputs_xv, M_iso_toy,b_iso_toy)  # (J,Theta, freq & acceleration) of isochrone toy model 

        if (Fit_Acceleration):
            inputs_x = inputs_xv[...,:3]
            Accxv_model = self.encoder_A( torch.cat([inputs_x,a_iso],dim=-1) )

        
        if (iterate_to_find_Jd):
            # We now iterate to find J'. Start with J' = mean(J_iso for each stream)
            Jd_mean = torch.mean( J_iso, dim=1, keepdim=True)

            for iter in range(self.iter_max):
                Jd_trial = Jd_mean.clone()                # trial J'
                Jd_mean_fill = torch.cat([Jd_mean]*NStars,dim=1)

                TJd  = torch.cat( [T_iso,Jd_mean_fill],dim=-1 ) # (T,J')
                GF_G = self.encoder_GF_G(TJd)
                JmJd = jacobian(T_iso,GF_G)[:,:,0,:].data # remove from computational graph
                GF_G.detach()                             # remove from computational graph

                # update 
                Jd_mean = torch.mean( J_iso - JmJd, dim=1, keepdim=True)

                if ( torch.max( torch.abs(Jd_trial - Jd_mean)) < 5.e-5 and 
                     torch.mean(torch.abs(Jd_trial - Jd_mean)) < 1.e-6):
                    break
        else:
            # CHEAT by initializing with ground-truth values!!! (useful for exploring what the GF is capable of)
            Jd_mean = Jlabels.clone() # currently, for this to work, also set use_point_transformation=False
                
                
        # Find Jd_mean once more, but now retain the computational graph 
        Jd_mean_fill = torch.cat([Jd_mean]*NStars,dim=1)
        TJd  = torch.cat( [T_iso,Jd_mean_fill],dim=-1 )
        GF_G = self.encoder_GF_G(TJd)
        JmJd = jacobian(T_iso,GF_G)[:,:,0,:]
        Jd_vals = J_iso - JmJd                       # J' given by generating function G
        Jd_mean = torch.mean( Jd_vals, dim=1, keepdim=True)
                
        # Calculate T' from GF derivative
        Jd_mean_fill = torch.cat([Jd_mean]*NStars,dim=1)
        TJd  = torch.cat( [T_iso,Jd_mean_fill],dim=-1 )
        GF_G = self.encoder_GF_G(TJd)

        # T' = T + d(G)/d(J'):
        Td_vals = (T_iso + jacobian(Jd_mean_fill,GF_G)[:,:,0,:]) % (2*torch.pi)

        
        if (use_point_transformation):
            # Add in the extra freedom of a point-transformation
            Tdd_vals = self.encoder_GF_P( Td_vals )      # final T'' from generating function P
            dTdd_dTd = jacobian(Td_vals,Tdd_vals)        # d(T'')/d(T')
            dTd_dTdd = torch.inverse( dTdd_dTd )         # d(T')/d(T'')
            # J_i'' = dT'/dT_i'' . J'
            Jdd_vals = torch.einsum('bsji,bsj->bsi',dTd_dTdd,Jd_vals) # final J'' values
        else:
            Tdd_vals = Td_vals.clone()
            Jdd_vals = Jd_vals.clone()

        
        Jdd_mean = torch.mean( Jdd_vals, dim=1)
                
        Jdd_spread = torch.mean(torch.abs(Jdd_mean[:,None,:] - Jdd_vals), dim=1)
        Jdd_diff   = Jdd_mean[:,None,:] - Jlabels

        Tdd_diff   = (Tdd_vals-Tlabels + torch.pi) % (2*torch.pi) - torch.pi # [-pi,pi]
        T_diff_iso = (T_iso-Tlabels + torch.pi) % (2*torch.pi) - torch.pi
        
        
        # create Jdd_vals_pos, which we will use later to penalise unphysical -ve J2,J3:
        mask_neg = [torch.cat([torch.abs(Jdd_vals[...,0:1]),Jdd_vals[...,1:3]],dim=-1)<0]
        Jdd_vals_pos = Jdd_vals.clone()
        Jdd_vals_pos[mask_neg] = torch.abs(Jdd_vals[mask_neg])

        
        # T'=0 should correspond to T=0
        TJd0 = torch.cat( [torch.zeros_like(T_iso),Jd_vals],dim=-1 )
        GF_G = self.encoder_GF_G(TJd0)
        Td_vals0  = (jacobian(Jd_vals,GF_G)[:,:,0,:]) % (2*torch.pi)
        if (use_point_transformation):
            Tdd_vals0 = (self.encoder_GF_P( Td_vals0 ) + torch.pi) % (2*torch.pi) - torch.pi # [-pi,pi] network for T' -> T''
        else:
            Tdd_vals0 = Td_vals0 - torch.pi # [-pi,pi]
    
        
        if (Fit_Acceleration):
            # Jacobian of the new coordinates wrt xv:
            dJdW = jacobian(inputs_xv, Jdd_vals)
            dTdW = jacobian(inputs_xv, Tdd_vals)
            dQdW = torch.cat([dTdW,dJdW],dim=2)
        
            gradrJ = dJdW[:,:,:,:3] # grad_r of J
            gradvJ = dJdW[:,:,:,3:] # grad_v of J
            mask = torch.ByteTensor(batch_size, NStars).zero_().type(torch.bool).to(device)
            mask += True
            mask_indx = torch.min(torch.abs( gradvJ.det() ),1)[1]
            mask[np.arange(batch_size),mask_indx]=0

            gradrJ_dot_v = torch.einsum('bsij,bsj->bsi',gradrJ,inputs_xv[:,:,3:])
            gradvJ_dot_a = torch.einsum('bsij,bsj->bsi',gradvJ,Accxv_model[:,:,:])
            dJdt = gradrJ_dot_v + gradvJ_dot_a  # total derivative of J

            gradrPhi = torch.solve(gradrJ_dot_v[:,:,:,None], gradvJ)[0]
            Accxv_LinAl = -gradrPhi[:,:,:,0] # the linear algebra solution for the acceleration 

            gradrT_dot_v = torch.einsum('bsij,bsj->bsi',dTdW[:,:,:,:3],inputs_xv[:,:,3:])
            gradvT_dot_a = torch.einsum('bsij,bsj->bsi',dTdW[:,:,:,3:],Accxv_model[:,:,:])
            Omegas = gradrT_dot_v + gradvT_dot_a # Omegas == dtheta/dt

            Omegas_mean = torch.mean( Omegas, dim=1)
            Omegas_spread = torch.mean(torch.abs(Omegas_mean[:,None,:] - Omegas), dim=1) # order consistent with Jdd_spread above

            

        loss_Jdd = torch.mean( torch.abs( Jdd_diff ) )
        loss_Jdd_spread = torch.mean( Jdd_spread )
        loss_Jdd_neg = torch.mean( torch.abs( Jdd_vals-Jdd_vals_pos ) )
        loss_J_iso = torch.mean( torch.abs( torch.mean( J_iso, dim=1, keepdim=True)-Jlabels ) )
    
        loss_Tdd = torch.mean( torch.abs( Tdd_diff ) )/(scale_Ts)
        loss_Tdd_vals0 = torch.mean(torch.abs( Tdd_vals0 ))/(scale_Ts)
        loss_T_iso = torch.mean( torch.abs( T_diff_iso ))/(scale_Ts)
        
        
        if (Fit_Acceleration):
            loss_dJdt = torch.mean( torch.abs( dJdt ) )/(scale_dJdt)
            loss_Os_spread = torch.mean( Omegas_spread )/(scale_Os_same)
            loss_direc = torch.mean( 1 - cossimilarity(-Accxv_model[:,:,0:2],inputs_xv[:,:,0:2]) )/(scale_direc) # cylindrical symmetry
            loss_LinAl = torch.mean(torch.abs( (Accxv_model-Accxv_LinAl)[mask] ) )/(scale_LinAl)
        else:
            loss_dJdt = torch.zeros(1)
            loss_Os_spread = torch.zeros(1)
            loss_direc = torch.zeros(1)
            loss_LinAl = torch.zeros(1)
        
        
        if (i==0): # print out the first point in the (randomly shuffled) list, to check all is ok
            for k in range(0,1): 

                torch.set_printoptions(sci_mode=False,linewidth=120)
                print(" ")
                print(" *********** ")
                print(" ")
                print("Jdd_vals     :",(Jdd_vals[k,0,:]).data)
                print("Jlabels      :",(Jlabels[k,0,:]).data)
                print("Jdd_diff     :",(Jdd_diff[k,0,:]).data)
                print("Jdd_spread   :",(Jdd_spread[k,:]).data)
                print(" ")
                print("Tdd_diff     :",(Tdd_diff[k,0,:]).data)
                print(" ")
                print("loss_Jdd_neg :",loss_Jdd_neg)
                print(" ")
                print("Isochrone M,b:",self.Mb_param.data)

                torch.set_printoptions(sci_mode=None,linewidth=None)

            
        if (Fit_Acceleration):
            loss = ( loss_Jdd_spread + loss_Jdd_neg + loss_Tdd_vals0 +
                     loss_dJdt + loss_Os_spread + loss_direc + loss_LinAl ) 
        else:
            if (use_simple_loss):
                loss = loss_Jdd_spread
            else:
                loss = loss_Jdd_spread + loss_Jdd_neg + loss_Tdd_vals0
            
        
        return (loss, loss_Jdd, loss_Jdd_spread, loss_Jdd_neg, loss_J_iso,
                loss_Tdd, loss_Tdd_vals0, loss_T_iso, 
                loss_dJdt, loss_Os_spread, loss_direc, loss_LinAl, 
                Jdd_mean, Jdd_spread)

## Some Convenience functions

In [None]:
# Convenience function for learning rate reduction: 
def reduce_learning_rate(new_lr, YourOptimizer):
    for param_group in YourOptimizer.param_groups:
        param_group['lr'] = new_lr

        
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# On the fly visualization
from IPython import display
def generate_image(fig, history,best_epoch):
    """
    Generates a plot during runtime 
    """
    
    #fig = figure(figsize=(16,4))  
    fig.clf()
    
    ax = fig.add_subplot(1,1,1)
    ymin1=min(history[:,1])
    ymin2=min(history[:,2])
    ymax1=np.median(history[:,1])
    ymax2=np.median(history[:,2])
    ymin=min(ymin1,ymin2)
    ymax=max(ymax1,ymax2)
    dy=ymax-ymin
    ymin=ymin-dy*0.05
    ymax=ymax+dy*0.05
    ax.set_ylim([0.0,1.05*max(history[:,1])])
    ax.plot(history[:,0], history[:,1],'.-',label=r'train loss')
    ax.plot(history[:,0], history[:,2],'--',label=r'validation loss')
    ax.set_xlabel(r'epoch')
    ax.set_ylabel(r'Loss')
    
    ax.scatter(history[best_epoch-1,0], history[best_epoch-1,2],c='r',s=40)
    
    ax.legend()
    
    ax2 = ax.twinx()  # instantiate a second axes that shares the same x-axis
    ax2.set_ylim([ymin,ymax])
    ax2.plot(history[:,0], history[:,1],'.-',c='blue')
    ax2.plot(history[:,0], history[:,2],'.-',c='red')

    ax2.set_ylabel('Loss', color='blue')  # we already handled the x-label with ax1
    ax2.tick_params(axis='y', labelcolor='blue')
    ax2.scatter(history[best_epoch-1,0], history[best_epoch-1,2],c='k',s=40)
    
    
    fig.suptitle('Epoch:: {}, best valLoss::{:.6f}'.format(epoch,history[:,2].min()), size=20)
    
    
    
    display.clear_output(wait=True)
    display.display(gcf())

In [None]:
# Convenience function to get test loss 
def eval_test_loss(tnet, batch_size, epoch, randomize=False, training=False):
    tnet.eval() # Make evaluation mode, for correct batchnorm behaviour 

    if training:
        dataset_val = AstroDataset(fitsfile=fitsdata_name,mode='train',randomize=randomize)
    else:
        dataset_val = AstroDataset(fitsfile=fitsdata_name,mode='validation',randomize=randomize)
        
    datagen_val = torch.utils.data.DataLoader(dataset_val,batch_size=batch_size,shuffle=False,drop_last=True,num_workers=num_workers)


    loss_list = []
    loss_list_Jd = []
    loss_list_Td = []
    for i, data in enumerate(datagen_val):
        # get the inputs
        inputs, Jlabels, Tlabels = data
        inputs = inputs.to(device)
        Jlabels = Jlabels.to(device)
        Tlabels = Tlabels.to(device)

        
        optimizer.zero_grad() 
        inputs.requires_grad=True
        Jlabels.requires_grad=True

        
        (loss, loss_Jd, loss_Jd_spread, loss_Jd_neg, loss_J_iso,
         loss_Td, loss_Td_vals0, loss_T_iso, 
         loss_dJdt, loss_Os_spread, loss_direc, loss_LinAl, 
         Jdd_mean, Jdd_spread) = net(inputs,Jlabels,Tlabels,i,epoch)

        loss_list.append(loss.item())
        loss_list_Jd.append(loss_Jd.item())
        loss_list_Td.append(loss_Td.item())

    loss_list_np = np.array(loss_list)
    loss_list_np[np.isnan(loss_list_np)]=1.0e6

    loss_list_Jd_np = np.array(loss_list_Jd)
    loss_list_Td_np = np.array(loss_list_Td)

    return np.mean(loss_list_np), np.mean(loss_list_Jd_np), np.mean(loss_list_Td_np)



# Convenience function to get predictions     
def get_preds(tnet,batch_size,mode='all',randomize=False):
    tnet.eval() # Make evaluation mode, for correct batchnorm behaviour 

    dataset_val = AstroDataset(fitsfile=fitspred_name,mode=mode,randomize=randomize)
    datagen_val = torch.utils.data.DataLoader(dataset_val,batch_size=batch_size,shuffle=False)
    
    
    preds = []
    preds_spread = []
    epoch = 0
    for i, data in enumerate(datagen_val):
        # get the inputs
        inputs, Jlabels, Tlabels = data
        inputs = inputs.to(device)
        Jlabels=Jlabels.to(device)
        Tlabels=Tlabels.to(device)

        (loss, loss_Jd, loss_Jd_spread, loss_Jd_neg, loss_J_iso,
         loss_Td, loss_Td_vals0, loss_T_iso, 
         loss_dJdt, loss_Os_spread, loss_direc, loss_LinAl, 
         Jdd_mean, Jdd_spread) = net(inputs,Jlabels,Tlabels,i,epoch)

        preds += [Jdd_mean.data.cpu().numpy()]
        preds_spread += [Jdd_spread.data.cpu().numpy()]

    preds = np.concatenate(preds,axis=0)
    preds_spread = np.concatenate(preds_spread,axis=0)

    return preds, preds_spread



# Convenience function to make and write out the predictions     
def make_and_write_preds(tnet, batch_size):
    import os

    preds, preds_spread = get_preds(tnet,batch_size)

    tab = fits.open(fitspred_name)

    labels = tab[1].header
    raw_data = tab[1].data
        
    labels_list = []
    for i in range(0,len(labels)):
        if 'TTYPE'+str(i) in labels:
            labels_list.append(labels['TTYPE'+str(i)])

        
    table = Table(raw_data,names=labels_list)
    n = len(table)
        
    copy_data = table.copy()


    J1_preds=preds[:,0]*J_scale
    J2_preds=preds[:,1]*J_scale
    J3_preds=preds[:,2]*J_scale

    J1_spread=preds_spread[:,0]*J_scale
    J2_spread=preds_spread[:,1]*J_scale
    J3_spread=preds_spread[:,2]*J_scale

    colJ1_preds = Column(J1_preds, name = 'J1_pred')
    colJ2_preds = Column(J2_preds, name = 'J2_pred')
    colJ3_preds = Column(J3_preds, name = 'J3_pred')

    colJ1_spread = Column(J1_spread, name = 'J1_spread')
    colJ2_spread = Column(J2_spread, name = 'J2_spread')
    colJ3_spread = Column(J3_spread, name = 'J3_spread')

    copy_data.add_column(colJ1_preds)
    copy_data.add_column(colJ2_preds)
    copy_data.add_column(colJ3_preds)

    copy_data.add_column(colJ1_spread)
    copy_data.add_column(colJ2_spread)
    copy_data.add_column(colJ3_spread)
    
    if os.path.exists(fitspred_name_out):
        os.remove(fitspred_name_out)

    fits.writeto(fitspred_name_out, np.array(copy_data))

## Some more convenience functions

In [None]:
def freeze_network_GF_G_part(m):
    for name, p in m.named_parameters():
        if "encoder_GF_G.ultra.encoder" in name:
            #print("freezing GF_G", name)
            p.requires_grad = False

            
def freeze_network_GF_P_part(m):
    for name, p in m.named_parameters():
        if "encoder_GF_P.ultra.encoder" in name:
            #print("freezing GF_P", name)
            p.requires_grad = False

            
def freeze_network_A_part(m):
    for name, p in m.named_parameters():
        if "encoder_A.ultra.encoder" in name:
            #print("freezing A", name)
            p.requires_grad = False


## HyperParameters of the run and network definitions 

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'

net = ActionFinder().to(device)
net.to(dtype)


if (Fit_Acceleration):
    net.apply(freeze_network_GF_G_part)
    net.apply(freeze_network_GF_P_part)

    print("Fitting the acceleration field")
else:
    net.apply(freeze_network_A_part)
    if (use_point_transformation):
        print("also fitting a canonical point-transformation")
    else:
        print("NOT using a point-transformation")
        net.apply(freeze_network_GF_P_part)
    

dataset_train  = AstroDataset(fitsfile=fitsdata_name,mode='train')
datagen_train = torch.utils.data.DataLoader(dataset_train,batch_size=batch_size,shuffle=True,drop_last=True,num_workers=num_workers)


criterion = torch.nn.L1Loss()
cossimilarity = torch.nn.CosineSimilarity(dim=2)


wd=0.0

# Define Optimizer 
lr = 3.e-3
optimizer  = torch.optim.Adam(net.parameters(), lr=lr,weight_decay=wd)

# Actual training loop

In [None]:
print(count_parameters(net))

In [None]:
#net.load_state_dict(torch.load(flname_save))# Load best model first if necessary

In [None]:
import subprocess
import os.path
import timeit
import datetime

fig = figure(figsize=(8,6)) # Change this for larger image 

history = [] # monitoring
val_loss_criterion = 1.0e30 # Checkpointing: Something very large so it will be picked up upon first epoch 

lr_new = lr/10.
reduce_learning_rate(lr_new,optimizer)


loss_Jd_best = 0.0
loss_Td_best = 0.0

time_init = datetime.datetime.now()
start_time_init = timeit.default_timer()
for epoch in range(1,Epochs+1):  # loop over the dataset multiple times
    start_time = timeit.default_timer()
             


    net.train() # Make in train mode 
    train_loss = 0.0
    train_loss_Jd = 0.0
    train_loss_Jd_spread = 0.0
    train_loss_Jd_neg = 0.0
    train_loss_J_iso = 0.0
    train_loss_Td = 0.0
    train_loss_Td_vals0 = 0.0
    train_loss_T_iso = 0.0
    train_loss_dJdt = 0.0
    train_loss_Os_spread = 0.0
    train_loss_direc = 0.0
    train_loss_LinAl = 0.0

    loss_list_epoch = []
    
    datagen_train = torch.utils.data.DataLoader(dataset_train,batch_size=batch_size,shuffle=True,drop_last=True,num_workers=num_workers)

    for i, data in enumerate(datagen_train):
        # get the inputs
        inputs, Jlabels, Tlabels = data
        inputs = inputs.to(device) # pass to GPU if available 
        Jlabels=Jlabels.to(device)
        Tlabels=Tlabels.to(device)

        
        optimizer.zero_grad() 
        inputs.requires_grad=True
        Jlabels.requires_grad=True

        
        (loss, loss_Jd, loss_Jd_spread, loss_Jd_neg, loss_J_iso,
         loss_Td, loss_Td_vals0, loss_T_iso, 
         loss_dJdt, loss_Os_spread, loss_direc, loss_LinAl, 
         Jdd_mean, Jdd_spreadmad) = net(inputs,Jlabels,Tlabels,i,epoch)


        loss.backward() # Calculate gradients 
                
        optimizer.step() # update weights 
                
        # print statistics
        train_loss += loss.item()
        train_loss_Jd += loss_Jd.item()
        train_loss_Jd_spread += loss_Jd_spread.item()
        train_loss_Jd_neg += loss_Jd_neg.item()
        train_loss_J_iso += loss_J_iso.item()
        train_loss_Td += loss_Td.item()
        train_loss_Td_vals0 += loss_Td_vals0.item()
        train_loss_T_iso += loss_T_iso.item()
        train_loss_dJdt += loss_dJdt.item()
        train_loss_Os_spread += loss_Os_spread.item()
        train_loss_direc += loss_direc.item()
        train_loss_LinAl += loss_LinAl.item()

        
        loss_list_epoch.append(loss.item())

                                
    # Evaluate accuracy after each training epoch 
    net.eval() # make network in evaluation mode
    if (use_validation_set):
        acc, loss_Jd_eval, loss_Td_eval = eval_test_loss(net, batch_size=batch_size, epoch=epoch)
    else:
        acc=train_loss/len(datagen_train)
        loss_Jd_eval=train_loss_Jd/len(datagen_train)
        loss_Td_eval=train_loss_Td/len(datagen_train)
    
    # ========== Checkpointing ===========================
    if acc < val_loss_criterion:
        best_epoch = copy(epoch)
        print(best_epoch, flname_save)
        val_loss_criterion = acc        
        torch.save(net.state_dict(), flname_save)
        # see https://pytorch.org/tutorials/beginner/saving_loading_models.html
        loss_Jd_best = np.copy(loss_Jd_eval)
        loss_Td_best = np.copy(loss_Td_eval)
        if (write_predictions_on_the_fly and epoch>100): # this might be expensive!
            make_and_write_preds(net, batch_size=batch_size)        
    # ======================================================
        
    history += [[epoch, train_loss/len(datagen_train),acc,loss_Jd_eval,loss_Td_eval]]
    generate_image(fig,np.array(history),best_epoch)
    print('Epoch::{} \t best epoch::{} \t train_loss::{:.6f} \t val_loss::{:.6f}'.format(epoch, best_epoch, train_loss / len(datagen_train), acc))
    print('train_loss_Jd        ::{:.6f} train_loss_Td        ::{:.6f}'.format(train_loss_Jd/len(datagen_train),train_loss_Td/len(datagen_train)))
    print('train_loss_J_iso     ::{:.6f} train_loss_T_iso     ::{:.6f}'.format(train_loss_J_iso/len(datagen_train),train_loss_T_iso/len(datagen_train)))
    print('train_loss_Jd_spread ::{:.6f} train_loss_Jd_neg    ::{:.6f}'.format(train_loss_Jd_spread/len(datagen_train),train_loss_Jd_neg/len(datagen_train)))
    print('train_loss_Td_vals0  ::{:.6f}'.format(train_loss_Td_vals0/len(datagen_train)))
    print('train_loss_dJdt      ::{:.6f} train_loss_direc     ::{:.6f}'.format(train_loss_dJdt/len(datagen_train),train_loss_direc/len(datagen_train)))
    print('train_loss_Os_spread ::{:.6f} train_loss_LinAl     ::{:.6f}'.format(train_loss_Os_spread/len(datagen_train),train_loss_LinAl/len(datagen_train)))
    print(' ')
    print('loss_Jd_best         ::{:.6f} loss_Td_best         ::{:.6f}'.format(loss_Jd_best,loss_Td_best))

    torch.set_printoptions(sci_mode=True,linewidth=None)
    print('lr::{:.6f}'.format(optimizer.param_groups[0]['lr']))
    torch.set_printoptions(sci_mode=None,linewidth=None)

    
    print()
    time_now = timeit.default_timer()
    elapsed_time = time_now - start_time
    total_elapsed_time = time_now - start_time_init
    print("elapsed time",elapsed_time)
    print("total   time",total_elapsed_time/3600.0)
    print()
    
    time_now = datetime.datetime.now()
    time_next = time_now + datetime.timedelta(0,elapsed_time)

    current_time = time_now.strftime("%H:%M:%S")
    next_time = time_next.strftime("%H:%M:%S")

    print("Starting time    :", time_init)
    print("Current  time    :", current_time)
    print("Expect   next at :", next_time)


print('Finished Training')
history_pd = pd.DataFrame(history,columns=['Epoch','trainLoss','acc','loss_Jd_eval','loss_Td_eval'])

# End of Training

# Make predictions, concatenate predictions with original catalog, and write out to a fits file:

In [None]:
if (write_predictions_on_the_fly):
    print("The predictions were already made during the running of the code.")
else:
    net.load_state_dict(torch.load(flname_save))# Load best model first
    make_and_write_preds(net, batch_size=batch_size)

## Store the history

In [None]:
t = Table.from_pandas(history_pd)
t.write('history_DB_1024.fits', overwrite=True)