In [None]:
import numpy as np
import torch
from torch import nn

import matplotlib.pyplot as plt

def ProbCorrPlot(Prob_Guess_All, Answer):
    nBin = round(np.sqrt(len(Answer)))
    edges = np.linspace(0.0, 1.0, nBin+1)
    X = (edges[:-1] + edges[1:])/2
    Prob_Guess_Type = Prob_Guess_All
    Prob_Guess_Cor  = Prob_Guess_Type[Answer]
    Count_All = np.histogram(Prob_Guess_Type, bins=edges)[0]
    Count_Cor = np.histogram(Prob_Guess_Cor , bins=edges)[0]

    X2         = X[Count_All != 0]
    Count_All2 = Count_All[Count_All != 0]
    Count_Cor2 = Count_Cor[Count_All != 0]
    Prob_Ans_Cor = Count_Cor2/Count_All2
    Prob_Ans_Cor_SD = 1.0 / np.sqrt(Count_All2)

    # Plotting:
    plt.figure()
    plt.errorbar(X2,Prob_Ans_Cor,Prob_Ans_Cor_SD,capsize=3,ls='none',c='k')
    plt.scatter(X2,Prob_Ans_Cor,marker='.',s=14,c='k')
    plt.plot([0,1],[0,1],':b',lw=2,label='Ideal Correlation')
    Norm = 0.25/np.max(Count_All)
    plt.fill_between(X,Norm*Count_All,alpha=0.8,color=(0.1,0.5,0.1),label='Probability Density')
    
    plt.xlim(0,1)
    plt.ylim(0,1)
    plt.title('Probability Accuracy',fontsize=18)
    plt.xlabel('Predicted Probability',fontsize=15)
    plt.ylabel('Fraction of Correct Assignments',fontsize=15)
    plt.legend(fontsize=10)

    plt.show()

# class MaskedLinear(nn.Linear):
#     def __init__(self, mask, **kwargs):
#         super().__init__(*mask.shape, **kwargs)
#         self.mask = mask

#     def forward(self, x):
#         return nn.linear(x, self.weight, self.bias)*self.mask

class NN(nn.Module):
    def __init__(self, *layers):
        super(NN, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(layers[0], layers[1]),
            nn.ReLU(),
            nn.Linear(layers[1], layers[2]),
            nn.ReLU(),
            nn.Linear(layers[2], layers[3]),
            nn.SoftMax()
        )
    def forward(self, x):
        return self.layers(x)

class ThisNeuralNet:
    def __init__(self, layers:tuple, learn_rate:float, num_epochs:int, batchsize:int):
        """
        Initializes hyperparameters and the neural network
        """
        
        # Hyperparameters
        self.layers     = layers
        self.learn_rate = learn_rate
        self.num_epochs = num_epochs
        self.batchsize  = batchsize

        # Neural Network:
        self.device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model     = NN(*layers).to(self.device)
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learn_rate)

    def train(self, data, answers):
        """
        Trains the neural network
        """

        if len(data) != len(answers):
            raise ValueError('The data and answers must be iterables of the same length.')

        dataset = [(data[idx], answers[idx]) for idx in range(len(answers))]
        train_loader  = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
        for epoch in range(self.num_epochs):
            for idx, (data, ans) in enumerate(train_loader):
                data = data.to(self.device)
                ans = ans.to(self.device)

                # Forward pass:
                guess = self.model(data)
                loss  = self.criterion(guess, ans)

                # Backward pass:
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

    def run(self, data):
        """
        Runs the neural network
        """

        with torch.no_grad():
            loader = torch.utils.data.DataLoader(dataset=data,  batch_size=self.batch_size, shuffle=True)
            for batch in loader:
                outputs = self.model(batch)

    def test(self, data, answers):
        """
        Tests the neural network
        """

        if len(data) != len(answers):
            raise ValueError('The data and answers must be iterables of the same length.')

        with torch.no_grad():
            probs   = self.model(data)
        guesses = (probs >= 0.5)

        # Trivial Score:
        triv  = np.mean(answers)
        triv  = max(triv, 1-triv)

        # Neural Network Score:
        score = np.mean(~(guesses ^ answers))

        print('\nRESULTS:')
        print(f'Trivial score = {100*triv:.3f}%')
        print(f'NN score      = {100*score:.3f}%')

        ProbCorrPlot(probs, answers)

# # Save the model checkpoint
# torch.save(model.state_dict(), 'model.ckpt')


In [None]:
layers = (15, 10, 3, 1)
learn_rate = 1e-3
num_epochs = 3
batch_size = 50

thisNN = ThisNeuralNet(layers, learn_rate, num_epochs, batch_size)

# Training:
train_data = 'foo'
train_ans  = 'bar'
thisNN.train(train_data, train_ans)

In [None]:
# Testing:
test_data = 'bar'
test_ans  = 'foo'
thisNN.test(test_data, test_ans)