# Identifying YSOs with a neural network

In [18]:
#ML imports
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
from sklearn import model_selection
from sklearn.ensemble import RandomForestRegressor

#System/general imports
import math
import imf
import random
import sys
import warnings
if not sys.warnoptions:
    warnings.simplefilter('ignore')
import time

#Data imports
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pickle
# import tables

#Astro imports
from astropy.io import fits
from astropy.table import Table
from isochrones import parsec
from isochrones.mist import MIST_Isochrone
from isochrones.parsec import Parsec_Isochrone
mist = Parsec_Isochrone()

def getmagerror(flux,eflux):
    return (-2.5*np.log10(flux-eflux)+2.5*np.log10(flux+eflux))/2


# device = torch.device('cuda:0')

# import pymultinest

In [2]:
clusterx ,clustery =pickle.load(open('ysotrainf1ms.pickle', 'rb'))
clusterx1,clustery1=pickle.load(open('ysotrainf2ms.pickle', 'rb'))
clusterx2,clustery2=pickle.load(open('ysotrainf3ms.pickle', 'rb'))

print(clustery.size())
print(clustery1.size())
print(clustery2.size())

# print(-2.5*np.log10(10**(-mist.mag['G'](mass, 8, feh, distance, AV)/2.5)))
# print(-2.5*np.log10(10**(-mist.mag['G'](mass, 8.3, feh, distance, AV)/2.5)))
# print(-2.5*np.log10(10**(-mist.mag['G'](mass, 7.5, feh, distance, AV)/2.5)))

torch.Size([88089, 6])
torch.Size([87564, 6])
torch.Size([88112, 6])


The following code is used to generate the clusters

Here's our neural network, taking twelve inputs

In [3]:
class Net(nn.Module):
    def __init__(self, input_shape = (1,12)):
        super(Net,self).__init__()
        self.conv1 = nn.Conv1d(1,8,3, padding=1)
        self.conv2 = nn.Conv1d(8,16,3, padding = 1)
        self.conv3 = nn.Conv1d(16,32,3,padding =1)
        n_size =self._get_conv_output(input_shape)
        self.fc1 = nn.Linear(n_size, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 6)
    
    #Generate input sample and forward to get shape
    def _get_conv_output(self, shape):
        bs = 1
        input = torch.rand(bs, *shape)
        output_feat = self._forward_features(input)
        n_size = output_feat.data.view(bs, -1).size(1)
        return n_size
    
    def _forward_features(self, x):
        x = F.max_pool1d(F.relu(self.conv2(F.relu(self.conv1(x)))),2)
        x = F.relu(F.max_pool1d(self.conv3(x), 2))
        return x
    
    def forward(self, x):
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
#Initialize model
BAD_LOSS= 100000000
model = Net()#initialize our network

In [4]:
device = torch.device("cuda")
model=model.to(device)
clusterx = clusterx.to(device)
clustery = clustery.to(device)
clusterx1 = clusterx1.to(device)
clustery1 = clustery1.to(device)
clusterx2 = clusterx2.to(device)
clustery2 = clustery2.to(device)

In [5]:
inputs = clusterx
target = clustery

criterion = nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr = 1e-5, momentum = .9)

best_loss = BAD_LOSS
running_dev_loss = 0
badcount = 0
t = time.time()

a = torch.rand(23500, 1, 12)
print(clusterx.size())
print(a.size())
print(model(clusterx))

torch.Size([88089, 1, 12])
torch.Size([23500, 1, 12])
tensor([[ 0.0303,  0.0283, -0.0191,  0.0382, -0.0018, -0.0234],
        [ 0.0312,  0.0301, -0.0180,  0.0387, -0.0037, -0.0237],
        [ 0.0314,  0.0283, -0.0198,  0.0389, -0.0017, -0.0232],
        ...,
        [ 0.0307,  0.0282, -0.0189,  0.0379, -0.0024, -0.0235],
        [ 0.0305,  0.0283, -0.0194,  0.0386, -0.0016, -0.0233],
        [ 0.0306,  0.0284, -0.0190,  0.0383, -0.0020, -0.0235]],
       device='cuda:0', grad_fn=<AddmmBackward>)


The following would be used to start the process, but we want to resume instead

In [6]:
def resume(savept, n=0):
    model = Net()
    model.load_state_dict(torch.load(savept))
    model=model.to(device)    
    
    clusterx.requires_grad=True
    clustery.requires_grad=True
    optimizer = torch.optim.SGD(model.parameters(), lr = 1e-5, momentum = .9)
    start_time = time.time()
    ran = range(85000)
    
    for epoch in range(1000):
        model.train()
        t = time.time()
        k=random.sample(ran, len(ran))
        running_loss = 0
        for i in range(850):
            inputs = clusterx[k[i*100:(i+1)*100]]
            target = clustery[k[i*100:(i+1)*100]]
            optimizer.zero_grad()#zero parameter gradients

            #forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()
            getloss = loss.item()
            running_loss = getloss + running_loss

        n=n+1
        if n % 1 == 0:
            model.eval()
            optimizer.zero_grad()
            with torch.no_grad():
                dev_outputs = model.forward(clusterx)
                dev_loss = criterion(dev_outputs, clustery)
            print( (n, time.time()-t, running_loss, dev_loss.item()))
        if n % 10 == 0:
            torch.save(model.state_dict(), 'intermediatemodel.pt')
            pickle.dump(n, open('ncount.pickle','wb'))

In [27]:
mwTab = Table.read('.\\data\\mwclustering-061319.fits')
mwTab

SOURCE_ID,LABELS,PROB,RA,DEC,PARALLAX,VLSRRA,VLSRDEC,BP_RP,ABSG,L,L1,B,AGE,av,PHOT_G_MEAN_FLUX,PHOT_G_MEAN_FLUX_ERROR,PHOT_G_MEAN_MAG,PHOT_BP_MEAN_FLUX,PHOT_BP_MEAN_FLUX_ERROR,PHOT_BP_MEAN_MAG,PHOT_RP_MEAN_FLUX,PHOT_RP_MEAN_FLUX_ERROR,PHOT_RP_MEAN_MAG,PARALLAX_ERROR,VLSRL,VLSRB,SLABEL,RADIAL_VELOCITY,RADIAL_VELOCITY_ERROR,VLSRV,J_M,J_MSIGCOM,H_M,H_MSIGCOM,KS_M,KS_MSIGCOM,W1MPRO,W1MPRO_ERROR,W2MPRO,W2MPRO_ERROR,W3MPRO,W3MPRO_ERROR,W4MPRO,W4MPRO_ERROR,NAME,PLOTNAME,ID
int64,int32,float32,float64,float64,float32,float32,float32,float32,float32,float64,float64,float64,float32,float32,float64,float64,float32,float64,float64,float32,float64,float64,float32,float64,float32,float32,int32,float64,float64,float64,float32,float32,float32,float32,float32,float32,float64,float64,float64,float64,float64,float64,float64,float64,bytes17,bytes23,int32
2170296628789233152,682,0.76858854,314.70689194180073,52.329566621578444,1.7047493,-2.7944367,-5.3487525,2.1167336,6.8206105,91.56317025918347,91.56317025918347,4.240555592195001,6.1448393,1.673481,10242.902467121192,53.884176247419276,15.662308,2899.306110216472,85.09694543121097,16.695654,11835.86540024834,172.1274497043808,14.57892,0.036259825414414304,-5.8693695,-1.4030428,-1,,,,12.784,0.029,11.77,0.024,11.113,0.026,,,,,,,,,LDN_988e,LDN_988e (1),1
2168939865797736576,682,0.68571967,314.90886441228935,50.377704156902475,1.7571619,-4.0396943,-5.9315767,0.67275333,1.5655601,90.16603181336296,90.16603181336296,2.868959494636923,6.1448393,1.673481,1376402.9974730948,533.215140925598,10.341501,785318.5834472583,920.777542516155,10.613773,847929.943294906,847.7494691680856,9.94102,0.041739590794115974,-7.1230206,-0.87481856,-1,,,,9.453,0.021,9.37,0.016,9.298,0.022,9.239,0.023,9.265,0.02,9.437,0.04,9.374,,LDN_988e,LDN_988e (1),1
2168944298210166016,682,0.7721724,315.44785979176584,50.30612989441009,1.5677166,-2.5051455,-5.2310324,2.4937096,7.102591,90.33870042353638,90.33870042353638,2.562848847906288,6.1448393,1.673481,6681.055925424997,40.578218399358654,16.126253,1420.5292939731612,44.561939754954835,17.470263,8206.290766804868,167.7014557175414,14.976553,0.05041082139407983,-5.5732546,-1.6057041,-1,,,,13.155,0.022,12.211,0.021,11.848,0.027,11.35,0.023,10.978,0.022,9.295,0.042,7.173,0.095,LDN_988e,LDN_988e (1),1
2168946875190510464,682,0.7110925,315.3347786968243,50.33555337422311,1.599027,-3.3392544,-5.5588126,1.7517471,5.0354056,90.31310955312907,90.31310955312907,2.6365046305879867,6.1448393,1.673481,46654.71890709048,512.4988737009248,14.016127,15542.928811589301,712.3594820746923,14.872556,45336.14439523501,1429.2217123928838,13.120809,0.017662063467430107,-6.374369,-1.1909817,-1,,,,11.68,0.021,10.856,0.017,10.289,0.022,9.515,0.023,8.917,0.02,7.286,0.018,6.005,0.046,LDN_988e,LDN_988e (1),1
2168950581742335616,682,0.9102174,315.55931904844766,50.36047407882485,1.7437564,-2.5445793,-5.3477407,2.306903,6.7165947,90.42660116133291,90.42660116133291,2.545414396134979,6.1448393,1.673481,11794.513234228565,27.279547570041885,15.509166,3159.781784059732,44.24357739733449,16.602245,15368.476959480466,270.24554763297573,14.295342,0.0490561929855524,-5.684296,-1.6619238,-1,,,,,,,,,,,,,,,,,,LDN_988e,LDN_988e (1),1
2168955151587411072,682,0.8176937,315.5117231752189,50.4949443486175,1.6295545,-2.6513894,-5.3867946,1.145791,2.7946603,90.5074663110749,90.5074663110749,2.6571101455766852,6.1448393,1.673481,381610.3384178202,87.08550100776637,11.734316,176594.57183657322,166.39372863169547,12.233945,294785.8544708823,206.6661188395169,11.088154,0.02551326077642788,-5.7856493,-1.6042697,-1,-12.552853360072566,1.0701986299633248,-0.0886239466723161,10.281,0.021,9.978,0.019,9.876,0.022,9.775,0.023,9.771,0.021,9.626,0.048,7.666,0.116,LDN_988e,LDN_988e (1),1
2168955533844548096,682,1.0,315.51420614155546,50.49953999162996,1.7215203,-3.5272715,-5.1660175,2.8698845,7.734706,90.51196296343467,90.51196296343467,2.6589643954102584,6.1448393,1.673481,4500.783642873504,8.35927402939735,16.555145,754.2740678570366,11.679365435111956,18.157566,6161.625812617854,48.986687810535386,15.287682,0.061560148132033324,-6.203509,-0.8036534,-1,,,,13.237,0.06,12.289,0.065,11.879,0.094,,,,,,,,,LDN_988e,LDN_988e (1),1
2168958901098918528,682,0.9973055,315.2883636636069,50.36242411610956,1.6759953,-3.396986,-5.196422,1.4857779,2.689424,90.31378288294434,90.31378288294434,2.676489697375864,6.1448393,1.673481,444756.4990598864,1713.3330020358874,11.56806,176721.81104279603,1689.542832569218,12.233163,403472.43576449493,4194.378879136757,10.747385,0.02367974726143494,-6.1421475,-0.90351456,-1,,,,9.88,0.02,9.288,0.017,8.556,0.02,7.142,0.038,6.127,0.035,3.883,0.015,1.44,0.02,LDN_988e,LDN_988e (1),1
2170293639492240384,682,0.7361647,315.2581602989653,52.48528253646384,1.9271299,-3.4742708,-5.2473803,,,91.90197776253439,91.90197776253439,4.088380294032973,6.1448393,1.673481,1060.7517405232609,2.043976324540662,18.12433,67.06986320649928,5.80847238886815,20.78507,1739.6474225945149,8.591603655901668,16.660767,0.1404486702039244,-6.2331066,-0.8682965,-1,,,,14.148,0.029,13.252,0.032,12.923,0.041,12.684,0.033,12.484,0.032,11.611,0.288,9.267,0.492,LDN_988e,LDN_988e (1),1
2169124068363304192,682,1.0,316.29874557139937,50.388655164270794,1.4513856,-2.9227438,-4.198964,,,90.76172955110428,90.76172955110428,2.2121515704382713,6.1448393,1.673481,1286.2793548268678,1.8106998699385632,17.915028,152.1651594656194,8.56914806227995,19.8956,1970.5452685983835,9.873948030893834,16.525454,0.13981249771318333,-5.072781,-0.6637926,-1,,,,14.339,0.037,13.544,0.043,13.177,0.034,13.002,0.026,12.896,0.027,12.803,,9.333,,LDN_988e,LDN_988e (1),1


 Volume in drive C is Windows
 Volume Serial Number is 9A64-F859

 Directory of C:\Users\sahal\Desktop\YSO ML\Code\data

06/26/2019  01:46 PM    <DIR>          .
06/26/2019  01:46 PM    <DIR>          ..
06/26/2019  12:14 PM        92,298,240 mwclustering-061319.fits
06/26/2019  01:45 PM         1,880,640 Orion+Starhorse.fits
               2 File(s)     94,178,880 bytes
               2 Dir(s)  196,772,282,368 bytes free
