# Identifying YSOs with a neural network

In [1]:
#ML imports
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
from sklearn import model_selection
from sklearn.ensemble import RandomForestRegressor

#System/general imports
import math
import imf
import random
import sys
import warnings
if not sys.warnoptions:
    warnings.simplefilter('ignore')
import time

#Data imports
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pickle
# import tables

#Astro imports
from astropy.io import fits
from isochrones import parsec
from isochrones.mist import MIST_Isochrone
from isochrones.parsec import Parsec_Isochrone
mist = Parsec_Isochrone()

def getmagerror(flux,eflux):
    return (-2.5*np.log10(flux-eflux)+2.5*np.log10(flux+eflux))/2


# device = torch.device('cuda:0')

# import pymultinest



In [2]:
clusterx ,clustery =pickle.load(open('ysotrainf1ms.pickle', 'rb'))
clusterx1,clustery1=pickle.load(open('ysotrainf2ms.pickle', 'rb'))
clusterx2,clustery2=pickle.load(open('ysotrainf3ms.pickle', 'rb'))

print(clustery.size())
print(clustery1.size())
print(clustery2.size())

# print(-2.5*np.log10(10**(-mist.mag['G'](mass, 8, feh, distance, AV)/2.5)))
# print(-2.5*np.log10(10**(-mist.mag['G'](mass, 8.3, feh, distance, AV)/2.5)))
# print(-2.5*np.log10(10**(-mist.mag['G'](mass, 7.5, feh, distance, AV)/2.5)))

torch.Size([88089, 6])
torch.Size([87564, 6])
torch.Size([88112, 6])


The following code is used to generate the clusters

Here's our neural network, taking twelve inputs

In [3]:
class Net(nn.Module):
    def __init__(self, input_shape = (1,12)):
        super(Net,self).__init__()
        self.conv1 = nn.Conv1d(1,8,3, padding=1)
        self.conv2 = nn.Conv1d(8,16,3, padding = 1)
        self.conv3 = nn.Conv1d(16,32,3,padding =1)
        n_size =self._get_conv_output(input_shape)
        self.fc1 = nn.Linear(n_size, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 6)
    
    #Generate input sample and forward to get shape
    def _get_conv_output(self, shape):
        bs = 1
        input = torch.rand(bs, *shape)
        output_feat = self._forward_features(input)
        n_size = output_feat.data.view(bs, -1).size(1)
        return n_size
    
    def _forward_features(self, x):
        x = F.max_pool1d(F.relu(self.conv2(F.relu(self.conv1(x)))),2)
        x = F.relu(F.max_pool1d(self.conv3(x), 2))
        return x
    
    def forward(self, x):
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
#Initialize model
BAD_LOSS= 100000000
model = Net()#initialize our network

In [4]:
device = torch.device("cuda")
model=model.to(device)
clusterx = clusterx.to(device)
clustery = clustery.to(device)
clusterx1 = clusterx1.to(device)
clustery1 = clustery1.to(device)
clusterx2 = clusterx2.to(device)
clustery2 = clustery2.to(device)

In [5]:
inputs = clusterx
target = clustery

criterion = nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr = 1e-5, momentum = .9)

best_loss = BAD_LOSS
running_dev_loss = 0
badcount = 0
t = time.time()

a = torch.rand(23500, 1, 12)
print(clusterx.size())
print(a.size())
print(model(clusterx))

torch.Size([88089, 1, 12])
torch.Size([23500, 1, 12])
tensor([[ 0.0303,  0.0283, -0.0191,  0.0382, -0.0018, -0.0234],
        [ 0.0312,  0.0301, -0.0180,  0.0387, -0.0037, -0.0237],
        [ 0.0314,  0.0283, -0.0198,  0.0389, -0.0017, -0.0232],
        ...,
        [ 0.0307,  0.0282, -0.0189,  0.0379, -0.0024, -0.0235],
        [ 0.0305,  0.0283, -0.0194,  0.0386, -0.0016, -0.0233],
        [ 0.0306,  0.0284, -0.0190,  0.0383, -0.0020, -0.0235]],
       device='cuda:0', grad_fn=<AddmmBackward>)


The following would be used to start the process, but we want to resume instead

In [6]:
def resume(savept, n=0):
    model = Net()
    model.load_state_dict(torch.load(savept))
    model=model.to(device)    
    
    clusterx.requires_grad=True
    clustery.requires_grad=True
    optimizer = torch.optim.SGD(model.parameters(), lr = 1e-5, momentum = .9)
    start_time = time.time()
    ran = range(85000)
    
    for epoch in range(1000):
        model.train()
        t = time.time()
        k=random.sample(ran, len(ran))
        running_loss = 0
        for i in range(850):
            inputs = clusterx[k[i*100:(i+1)*100]]
            target = clustery[k[i*100:(i+1)*100]]
            optimizer.zero_grad()#zero parameter gradients

            #forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()
            getloss = loss.item()
            running_loss = getloss + running_loss

        n=n+1
        if n % 1 == 0:
            model.eval()
            optimizer.zero_grad()
            with torch.no_grad():
                dev_outputs = model.forward(clusterx)
                dev_loss = criterion(dev_outputs, clustery)
            print( (n, time.time()-t, running_loss, dev_loss.item()))
        if n % 10 == 0:
            torch.save(model.state_dict(), 'intermediatemodel.pt')
            pickle.dump(n, open('ncount.pickle','wb'))

In [7]:
N = pickle.load(open('ncount.pickle', 'rb'))
resume('intermediatemodel.pt', n=N)

(1351, 9.488616943359375, 10100.341459274292, 10296.072265625)
(1352, 9.385923862457275, 10075.572836875916, 10472.673828125)
(1353, 9.662151098251343, 9941.562607765198, 10336.3359375)
(1354, 10.552769660949707, 9960.607882976532, 10095.548828125)
(1355, 10.004202842712402, 9978.80999469757, 10500.771484375)
(1356, 9.648190259933472, 10036.41046667099, 11183.5966796875)
(1357, 9.958317279815674, 10031.898025035858, 10199.650390625)
(1358, 9.457664489746094, 10052.014582633972, 10062.576171875)
(1359, 9.807851314544678, 9998.059167385101, 10583.271484375)
(1360, 9.784822225570679, 9984.48024559021, 11372.08203125)
(1361, 9.36790132522583, 9996.163826942444, 10802.9609375)
(1362, 9.37591552734375, 9926.866962909698, 11231.490234375)
(1363, 10.189717531204224, 10130.645646095276, 10222.90625)
(1364, 9.340978860855103, 9965.798392772675, 10167.283203125)
(1365, 9.353975296020508, 10013.18553352356, 10388.587890625)
(1366, 9.387884140014648, 9969.673046588898, 10172.70703125)
(1367, 9.6192