In [5]:
import astropy.io.fits as pf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torch
from torchsummary import summary

In [6]:
training = pf.open('SDSS_QSO_star_training.fits')
test = pf.open('SDSS_QSO_star_testing.fits')

In [7]:
color_train = torch.from_numpy(np.array(training[0].data,dtype=np.float32))
color_test = torch.from_numpy(np.array(test[0].data,dtype=np.float32))

In [8]:
test_labels = np.zeros((len(test[0].data),2))
for i in range(0,len(test_labels)):
    test_labels[i,test[1].data[i]]=1
test_labels = torch.from_numpy(np.array(test_labels,dtype=np.float32))

In [9]:
train_labels = np.zeros((len(training[0].data),2))
for i in range(0,len(train_labels)):
    train_labels[i,training[1].data[i]]=1
train_labels = torch.from_numpy(np.array(train_labels,dtype=np.float32))

In [20]:
class fully_connected_NN_QSO_star(torch.nn.Module):
    def __init__(self):
        super(fully_connected_NN_QSO_star, self).__init__()
        self.fc1 = torch.nn.Sequential(
            torch.nn.Linear(4, 2),
            #torch.nn.ReLU()
            )
      
    def forward(self, x):
        x = self.fc1(x)
        out = x
        return out

In [21]:
model = fully_connected_NN_QSO_star()

In [22]:
learningRate = 0.1
epochs = 1000
class_weights = torch.FloatTensor([1,1])

criterion = torch.nn.CrossEntropyLoss(class_weights)
# Just the loss function : here we use the default MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learningRate)

In [23]:
accuracy_array = []
accuracy_test_array = []
loss_array = []
loss_test_array = []
epoch_array = []

In [24]:
batch_size = 1024
N_total_train = len(color_train[:,:])
print("epoch","accuracy","loss", "accuracy_test","loss_test")
for epoch in range(epochs):

    ''' simple mini-batch calculation'''
    for start_index_batch in range(0,N_total_train,batch_size):
        # Clear gradient buffers because we don't want any gradient from previous epoch to carry forward, dont want to cummulate gradients
        optimizer.zero_grad()
        end_index = min(start_index_batch + batch_size, N_total_train)
        # get output from the model, given the inputs
        #print(len(color_train[start_index_batch:end_index,:]))
        outputs = model(color_train[start_index_batch:end_index,:])
        
        # get loss for the predicted output
        loss = criterion(outputs, train_labels[start_index_batch:end_index,:])
        # get gradients w.r.t to parameters
        loss.backward()
        # update parameters
        optimizer.step()
    
    ### Below is for recording the performance metrics
    #------------------------------------------------------------------   
    outputs_all = model(color_train)
    pred_y = torch.max(outputs_all, 1)[1].data.squeeze()
    accuracy = torch.sum(pred_y == torch.from_numpy(np.array(training[1].data,dtype=np.float32))).type(torch.FloatTensor) / pred_y.size(0)

    epoch_array.append(epoch)

    loss_array.append(float(loss.detach().numpy()))
    accuracy_array.append(float(accuracy.numpy()))


    outputs_test = model(color_test) 
    loss_test = criterion(outputs_test, test_labels)

    pred_y_test = torch.max(outputs_test, 1)[1].data.squeeze()
    accuracy_test = torch.sum(pred_y_test == torch.from_numpy(np.array(test[1].data,dtype=np.float64))).type(torch.FloatTensor) / pred_y_test.size(0)
    accuracy_test_array.append(float(accuracy_test.numpy()))
    loss_test_array.append(float(loss_test.detach().numpy()))

    if epoch % 50 ==0:
        print(epoch,accuracy.numpy(),loss.detach().numpy(), accuracy_test.numpy(),loss_test.detach().numpy())

epoch accuracy loss accuracy_test loss_test
0 0.8450506 0.43430966 0.84066504 0.40165943
50 0.96256477 0.33628348 0.9597044 0.34213924
100 0.96351415 0.333279 0.9620136 0.34000406
150 0.96298814 0.3331649 0.96166724 0.33949295
200 0.9626802 0.3332175 0.96166724 0.33934605
250 0.96257764 0.33325067 0.9615518 0.33929572
300 0.9625263 0.33326566 0.9615518 0.3392758
350 0.9624237 0.3332718 0.96166724 0.33926722
400 0.96237236 0.33327407 0.96166724 0.3392633
450 0.9623467 0.3332749 0.9615518 0.3392615
500 0.9623467 0.3332752 0.96166724 0.33926064
550 0.96233386 0.33327496 0.96166724 0.33926028
600 0.96233386 0.33327502 0.96166724 0.3392601
650 0.96232104 0.33327514 0.96166724 0.33925995
700 0.96232104 0.33327505 0.96166724 0.33925992
750 0.96232104 0.33327484 0.96166724 0.33925992
800 0.96232104 0.33327478 0.96166724 0.33925992
850 0.96232104 0.33327472 0.96166724 0.33925992
900 0.96232104 0.33327472 0.96166724 0.33925992
950 0.96232104 0.33327475 0.96166724 0.33925992


In [None]:
import 