In [1]:
import pandas as pd
import numpy as np
import torch
import json
from normalizer import counter
from torch.utils.data import DataLoader


In [2]:
# All of our fields have names, but we need them in numbers, so define a dictionary to convert.
global magDict
magDict = {
    'TOTUSJH': 0,
    'TOTBSQ': 1,
    'TOTPOT': 2,
    'TOTUSJZ': 3,
    'ABSNJZH': 4,
    'SAVNCPP': 5,
    'USFLUX': 6,
    'TOTFZ': 7,
    'MEANPOT': 8,
    'EPSZ': 9,
    'SHRGT45': 10,
    'MEANSHR': 11,
    'MEANGAM': 12,
    'MEANGBT': 13,
    'MEANGBZ': 14,
    'MEANGBH': 15,
    'MEANJZH': 16,
    'TOTFY': 17,
    'MEANJZD': 18,
    'MEANALP': 19,
    'TOTFX': 20,
    'EPSY': 21,
    'EPSX': 22,
    'R_VALUE': 23,
    'RBZ_VALUE': 24,
    'RBT_VALUE': 25,
    'RBP_VALUE': 26,
    'FDIM': 27,
    'BZ_FDIM': 28,
    'BT_FDIM': 29,
    'BP_FDIM': 30,
    'PIL_LEN': 31,
    'XR_MAX': 32
}

In [3]:
# Get the data from the JSON file, then return it as a tensor of input data and a list of labels
def getDataFromJSON(path="data/train_partition1_data.json", device='cpu', earlyStop=-1):
    # path is the path to the files, device is where to store it (CUDA), earlyStop is how many lines to 
    # read if you don't want the entire file read.
    
    # Get the dictionary to assign names to numbers
    global magDict
    
    # This dataset is heavily skewed, so we need to get the number of each type of flare.
    # This also lets us get the number of lines in the file with a sum.
    # This function also ignores any lines with a value of NaN.
    weights = counter(path, earlyStop)
    lines = np.sum(weights)
    # Check when we want to stop - the end of the file or earlier.
    if earlyStop < 0: length = lines
    else: length = min(earlyStop, lines)
    
    # Get the file and open it. 
    file = open(path)    
    
    # Declare a tensor to hold the data, and a list to hold the labels.
    # Dimensions: 0: number of entries we want. 1: the 33 fields in the data. 2: the 60 observations in each field. 
    tnsr = torch.Tensor().new_empty((length, 33, 60), device=device)
    labels = []
    flares = {'X':0, 'M':1, 'C':2, 'B':3, 'Q':4}
        
    row = -1
    for line in file:
        if 'nan' in line or "NaN" in line:
            continue
        # Load the line as a dictionary. Row is an integer place and v is a smaller dictionary.
        d: dict = json.loads(line)
        row += 1
        for _, v in d.items(): # we use the _ because we don't want the ID.
            if earlyStop > 0 and row >= earlyStop:
                # If we don't want the entire dataset, stop loading more than we want
                return tnsr, labels, weights
            if row % 100 == 0:
                print(f'Now loading event {row}/{length}')
            # append the label to our list
            labels.append(flares[v['label']])
            
            # Break each individual dictionary into dictionaries of observations
            # Key is the string in magDict, and timeDict is a dictionary of observations over time
            for key, timeDict in v['values'].items():
                # Turn our name string into a numeric value
                location = magDict[key]
                # Get the measurements out of the time series dictionary
                for timeStamp, measurement in timeDict.items():
                    tnsr[row][location][int(timeStamp)] = measurement
    print(f'{row} lines loaded.')
    # Close the file. I'm not a heathen                    
    file.close()
    # This might be a good place to perform some post processing, but that's a question for another day.
    # Famous last words.
    return tnsr, labels, weights



In [4]:
# This file has 77270 data points. 
%time train1, labels1, weights1 = getDataFromJSON(path="data/train_partition1_data.json", earlyStop=-1)
print(np.sum(weights1))

Now loading event 0/71633
Now loading event 100/71633
Now loading event 200/71633
Now loading event 300/71633
Now loading event 400/71633
Now loading event 500/71633
Now loading event 600/71633
Now loading event 700/71633
Now loading event 800/71633
Now loading event 900/71633
Now loading event 1000/71633
Now loading event 1100/71633
Now loading event 1200/71633
Now loading event 1300/71633
Now loading event 1400/71633
Now loading event 1500/71633
Now loading event 1600/71633
Now loading event 1700/71633
Now loading event 1800/71633
Now loading event 1900/71633
Now loading event 2000/71633
Now loading event 2100/71633
Now loading event 2200/71633
Now loading event 2300/71633
Now loading event 2400/71633
Now loading event 2500/71633
Now loading event 2600/71633
Now loading event 2700/71633
Now loading event 2800/71633
Now loading event 2900/71633
Now loading event 3000/71633
Now loading event 3100/71633
Now loading event 3200/71633
Now loading event 3300/71633
Now loading event 3400/716

Now loading event 27700/71633
Now loading event 27800/71633
Now loading event 27900/71633
Now loading event 28000/71633
Now loading event 28100/71633
Now loading event 28200/71633
Now loading event 28300/71633
Now loading event 28400/71633
Now loading event 28500/71633
Now loading event 28600/71633
Now loading event 28700/71633
Now loading event 28800/71633
Now loading event 28900/71633
Now loading event 29000/71633
Now loading event 29100/71633
Now loading event 29200/71633
Now loading event 29300/71633
Now loading event 29400/71633
Now loading event 29500/71633
Now loading event 29600/71633
Now loading event 29700/71633
Now loading event 29800/71633
Now loading event 29900/71633
Now loading event 30000/71633
Now loading event 30100/71633
Now loading event 30200/71633
Now loading event 30300/71633
Now loading event 30400/71633
Now loading event 30500/71633
Now loading event 30600/71633
Now loading event 30700/71633
Now loading event 30800/71633
Now loading event 30900/71633
Now loadin

Now loading event 55100/71633
Now loading event 55200/71633
Now loading event 55300/71633
Now loading event 55400/71633
Now loading event 55500/71633
Now loading event 55600/71633
Now loading event 55700/71633
Now loading event 55800/71633
Now loading event 55900/71633
Now loading event 56000/71633
Now loading event 56100/71633
Now loading event 56200/71633
Now loading event 56300/71633
Now loading event 56400/71633
Now loading event 56500/71633
Now loading event 56600/71633
Now loading event 56700/71633
Now loading event 56800/71633
Now loading event 56900/71633
Now loading event 57000/71633
Now loading event 57100/71633
Now loading event 57200/71633
Now loading event 57300/71633
Now loading event 57400/71633
Now loading event 57500/71633
Now loading event 57600/71633
Now loading event 57700/71633
Now loading event 57800/71633
Now loading event 57900/71633
Now loading event 58000/71633
Now loading event 58100/71633
Now loading event 58200/71633
Now loading event 58300/71633
Now loadin

In [5]:
# This file has 93767 data points. 
%time train2, labels2, weights2 = getDataFromJSON(path="data/train_partition2_data.json", earlyStop=-1)
print(np.sum(weights2))

Now loading event 0/82425
Now loading event 100/82425
Now loading event 200/82425
Now loading event 300/82425
Now loading event 400/82425
Now loading event 500/82425
Now loading event 600/82425
Now loading event 700/82425
Now loading event 800/82425
Now loading event 900/82425
Now loading event 1000/82425
Now loading event 1100/82425
Now loading event 1200/82425
Now loading event 1300/82425
Now loading event 1400/82425
Now loading event 1500/82425
Now loading event 1600/82425
Now loading event 1700/82425
Now loading event 1800/82425
Now loading event 1900/82425
Now loading event 2000/82425
Now loading event 2100/82425
Now loading event 2200/82425
Now loading event 2300/82425
Now loading event 2400/82425
Now loading event 2500/82425
Now loading event 2600/82425
Now loading event 2700/82425
Now loading event 2800/82425
Now loading event 2900/82425
Now loading event 3000/82425
Now loading event 3100/82425
Now loading event 3200/82425
Now loading event 3300/82425
Now loading event 3400/824

Now loading event 27700/82425
Now loading event 27800/82425
Now loading event 27900/82425
Now loading event 28000/82425
Now loading event 28100/82425
Now loading event 28200/82425
Now loading event 28300/82425
Now loading event 28400/82425
Now loading event 28500/82425
Now loading event 28600/82425
Now loading event 28700/82425
Now loading event 28800/82425
Now loading event 28900/82425
Now loading event 29000/82425
Now loading event 29100/82425
Now loading event 29200/82425
Now loading event 29300/82425
Now loading event 29400/82425
Now loading event 29500/82425
Now loading event 29600/82425
Now loading event 29700/82425
Now loading event 29800/82425
Now loading event 29900/82425
Now loading event 30000/82425
Now loading event 30100/82425
Now loading event 30200/82425
Now loading event 30300/82425
Now loading event 30400/82425
Now loading event 30500/82425
Now loading event 30600/82425
Now loading event 30700/82425
Now loading event 30800/82425
Now loading event 30900/82425
Now loadin

Now loading event 55100/82425
Now loading event 55200/82425
Now loading event 55300/82425
Now loading event 55400/82425
Now loading event 55500/82425
Now loading event 55600/82425
Now loading event 55700/82425
Now loading event 55800/82425
Now loading event 55900/82425
Now loading event 56000/82425
Now loading event 56100/82425
Now loading event 56200/82425
Now loading event 56300/82425
Now loading event 56400/82425
Now loading event 56500/82425
Now loading event 56600/82425
Now loading event 56700/82425
Now loading event 56800/82425
Now loading event 56900/82425
Now loading event 57000/82425
Now loading event 57100/82425
Now loading event 57200/82425
Now loading event 57300/82425
Now loading event 57400/82425
Now loading event 57500/82425
Now loading event 57600/82425
Now loading event 57700/82425
Now loading event 57800/82425
Now loading event 57900/82425
Now loading event 58000/82425
Now loading event 58100/82425
Now loading event 58200/82425
Now loading event 58300/82425
Now loadin

82424 lines loaded.
CPU times: user 15min 28s, sys: 2.32 s, total: 15min 30s
Wall time: 15min 30s
82425


In [6]:
# This file has 42986 data points. 
%time train3, labels3, weights3 = getDataFromJSON(path="data/train_partition3_data.json", earlyStop=-1)
print(np.sum(weights3))

Now loading event 0/37759
Now loading event 100/37759
Now loading event 200/37759
Now loading event 300/37759
Now loading event 400/37759
Now loading event 500/37759
Now loading event 600/37759
Now loading event 700/37759
Now loading event 800/37759
Now loading event 900/37759
Now loading event 1000/37759
Now loading event 1100/37759
Now loading event 1200/37759
Now loading event 1300/37759
Now loading event 1400/37759
Now loading event 1500/37759
Now loading event 1600/37759
Now loading event 1700/37759
Now loading event 1800/37759
Now loading event 1900/37759
Now loading event 2000/37759
Now loading event 2100/37759
Now loading event 2200/37759
Now loading event 2300/37759
Now loading event 2400/37759
Now loading event 2500/37759
Now loading event 2600/37759
Now loading event 2700/37759
Now loading event 2800/37759
Now loading event 2900/37759
Now loading event 3000/37759
Now loading event 3100/37759
Now loading event 3200/37759
Now loading event 3300/37759
Now loading event 3400/377

Now loading event 27700/37759
Now loading event 27800/37759
Now loading event 27900/37759
Now loading event 28000/37759
Now loading event 28100/37759
Now loading event 28200/37759
Now loading event 28300/37759
Now loading event 28400/37759
Now loading event 28500/37759
Now loading event 28600/37759
Now loading event 28700/37759
Now loading event 28800/37759
Now loading event 28900/37759
Now loading event 29000/37759
Now loading event 29100/37759
Now loading event 29200/37759
Now loading event 29300/37759
Now loading event 29400/37759
Now loading event 29500/37759
Now loading event 29600/37759
Now loading event 29700/37759
Now loading event 29800/37759
Now loading event 29900/37759
Now loading event 30000/37759
Now loading event 30100/37759
Now loading event 30200/37759
Now loading event 30300/37759
Now loading event 30400/37759
Now loading event 30500/37759
Now loading event 30600/37759
Now loading event 30700/37759
Now loading event 30800/37759
Now loading event 30900/37759
Now loadin

In [7]:
train2[0,0]

tensor([19.0601, 18.4611, 25.6917, 17.3464, 18.1081, 18.0577, 18.7077, 17.4748,
        17.8627, 16.9230, 22.6015, 22.5883, 18.8607, 19.9172, 23.7632, 21.3895,
        27.7625, 26.1885, 26.1678, 27.3103, 27.4704, 26.1131, 29.2025, 31.6294,
        28.8686, 27.6369, 26.0752, 20.9289, 19.3941, 17.8339, 23.0833, 19.8953,
        20.2367, 20.7759, 21.6280, 22.9786, 20.4450, 20.1092, 19.7991, 21.1444,
        16.3790, 20.6794, 16.9210, 18.0858, 16.4018, 16.2757, 17.2050, 19.5720,
        19.2548, 19.9467, 18.0555, 18.8806, 18.5460, 20.5339, 18.9813, 16.3843,
        18.5569, 17.3046, 17.9197, 14.8065])

In [8]:
# Define the network. Make sure to end with nn.Softmax activation
import torch.nn as nn
from skorch import NeuralNet

class logRegWithHidden(nn.Module):
    def __init__(self, hidden_size1, hidden_size2, num_classes=5, drop1=.5, input_size=1980):
        super().__init__() 
        self.layer1 = nn.Linear(input_size, hidden_size1)
        self.layer2 = nn.Linear(hidden_size1, hidden_size2)
        self.layerout = nn.Linear(hidden_size2, num_classes)
        #Define a RELU Activation unit
        self.relu = nn.ReLU()  
        self.smax = nn.Softmax(dim=1)
        self.drop = nn.Dropout(p=drop1)

    def forward(self, x):
        #Forward Propagate through the layers as defined above
        y = self.drop(x.reshape(-1, 1980))
        y = self.drop(self.relu(self.layer1(y)))
        y = self.relu(self.layer2(y))
        y = self.smax(self.layerout(y))
        return y





In [22]:
def train(model, inputs, labels, weight, valSets, valLabels, valweight, lr=0.01):
    # TODO: Is this right? How do I determine the weights here?
    weight = torch.Tensor(weight)
    lfc = nn.CrossEntropyLoss(weight=100/weight)
#     valLabels = torch.tensor(valLabels, dtype=torch.int)
    #ideas
    # 1-(weight/np.sum(weight))
    # .2/weight - this one normalizes so that each class is responsible for 20% of the loss
    # 1/weight - this is a bit naive, but the classes with fewer items are weighted more.
    # 1/(weight+1) - makes sure we don't have any pesky zeroes
    # np.sum(weight)/weight if your learning rate is too low.
    
    # Hyperparameters
    batch = 256
    epochs = 10
    
    # Start a dataloader object
    data = list(zip(inputs,labels))
    val = list(zip(valSets,valLabels))
    loader = DataLoader(data, batch_size = batch, num_workers=4)
    valLoader = DataLoader(val, batch_size = int(len(val)/4), num_workers=4)
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        batch_loss = []
        for (xtrain,ytrain) in loader:
            output = model(xtrain)
            loss = lfc(output,ytrain)
            opt.zero_grad()
            loss.backward()
            opt.step()
            batch_loss.append(loss.item())
        print(f'The training loss for epoch {epoch+1}/{epochs} was {np.mean(batch_loss)}')
        
        
        model.eval()
        balanced = [[],[],[],[],[]]
        batchLoss = []
        unbalanced = []
        
        for (xval,yval) in valLoader:
            output = model(xval)
            loss = lfc(output,yval)
            batchLoss.append(loss.item())
            corrects = yval.clone().detach() == torch.argmax(output)
            unbalanced.append(np.mean([1 if correct else 0 for correct in corrects.detach()]))
            
            # TODO: figure this one out.
            for i, ans in enumerate(yval):
                balanced[ans].append(corrects[i])
        
        balanced = [np.mean(i) for i in balanced]
        balancedAccuracy = np.mean(balanced)
        
        print(f'The total balanced accuracy for validation was {balancedAccuracy}')
        print(f'The unbalanced validation accuracy is {np.mean(unbalanced)}')
        print(f'The accuracy for each is {balanced}')           
        print(f'The validation loss was :   {epoch+1}/{epochs} was {np.mean(batchLoss)}')

            
        print(f'The validation loss was :   {epoch+1}/{epochs} was {np.mean(batchLoss)}')
    return model


In [23]:
model = logRegWithHidden(16384, 4096)
    # print(weights1+weights2, 1.0/np.array(weights1))

In [26]:
print(model)
%time _ = train(model, torch.cat((train1, train2), dim=0), labels1 + labels2, [weights1[i] + weights2[i] for i in range(5)], train3, labels3, weights3, lr = 0.01)

logRegWithHidden(
  (layer1): Linear(in_features=1980, out_features=16384, bias=True)
  (layer2): Linear(in_features=16384, out_features=4096, bias=True)
  (layerout): Linear(in_features=4096, out_features=5, bias=True)
  (relu): ReLU()
  (smax): Softmax(dim=1)
  (drop): Dropout(p=0.5, inplace=False)
)
The training loss for epoch 1/10 was 1.7211452997800123
The total balanced accuracy for validation was 0.2
The unbalanced validation accuracy is 0.002479076173323445
The accuracy for each is [1.0, 0.0, 0.0, 0.0, 0.0]
The validation loss was :   1/10 was 1.649339461326599
The validation loss was :   1/10 was 1.649339461326599
The training loss for epoch 2/10 was 1.7142055123747386
The total balanced accuracy for validation was 0.2
The unbalanced validation accuracy is 0.002479076173323445
The accuracy for each is [1.0, 0.0, 0.0, 0.0, 0.0]
The validation loss was :   2/10 was 1.649339461326599
The validation loss was :   2/10 was 1.649339461326599
The training loss for epoch 3/10 was 1.717

In [12]:
print(model.parameters())
feature_extraction = [child for child in model.children()]
for line in feature_extraction:
    print(line)
    print(f'weights: {line.weight}')

<generator object Module.parameters at 0x7fdc3d1e1890>
Linear(in_features=1980, out_features=60, bias=True)
weights: Parameter containing:
tensor([[-1.6114e-02,  4.9551e-03,  1.9623e-04,  ..., -1.7681e-02,
          6.6794e-03, -9.9341e-03],
        [-9.9802e-05, -9.1554e-03, -1.6190e-02,  ..., -8.0461e-03,
          6.0411e-03,  9.4299e-03],
        [-2.5013e-03,  4.5977e-03,  7.3765e-03,  ...,  1.2560e-02,
         -6.2866e-03,  4.5736e-03],
        ...,
        [ 2.6844e-03,  2.0730e-02,  2.1617e-02,  ...,  7.1570e-03,
         -5.1240e-03, -1.4599e-02],
        [ 1.4485e-02, -1.7104e-02, -8.4965e-03,  ...,  1.8946e-02,
         -5.2009e-03, -7.9523e-03],
        [ 1.4149e-02, -1.3996e-02,  2.1623e-02,  ..., -2.3521e-03,
          1.7077e-03, -1.7793e-02]], requires_grad=True)
Linear(in_features=60, out_features=30, bias=True)
weights: Parameter containing:
tensor([[-0.0104,  0.0486,  0.0563,  ...,  0.0831, -0.0599,  0.0043],
        [ 0.0979, -0.1259, -0.0226,  ...,  0.1205,  0.072

ModuleAttributeError: 'ReLU' object has no attribute 'weight'

In [None]:
_ = train(
    model,
    torch.cat((train1, train2), dim=0),
    labels1 + labels2,
    [weights1[i] + weights2[i] for i in range(5)],
    train3,
    labels3,
    weights3,
    lr = 0.1
)

In [None]:
_ = train(
    model,
    torch.cat((train1, train2), dim=0),
    labels1 + labels2,
    [weights1[i] + weights2[i] for i in range(5)],
    train3,
    labels3,
    weights3,
    lr = 0.001
)