In [None]:
https://medium.com/dair-ai/a-simple-neural-network-from-scratch-with-pytorch-and-google-colab-c7f3830618e0

In [2]:
import torch
import torch.nn as nn

In [3]:
X = torch.tensor(([2, 9], [1, 5], [3, 6]), dtype=torch.float) # 3 X 2 tensor
y = torch.tensor(([92], [100], [89]), dtype=torch.float) # 3 X 1 tensor
xPredicted = torch.tensor(([4, 8]), dtype=torch.float) # 1 X 2 tensor

In [4]:
print(X.size())
print(y.size())

torch.Size([3, 2])
torch.Size([3, 1])


In [5]:
# scale units
X_max, _ = torch.max(X, 0)
xPredicted_max, _ = torch.max(xPredicted, 0)

X = torch.div(X, X_max)
xPredicted = torch.div(xPredicted, xPredicted_max)
y = y / 100  # max test score is 100

In [6]:
class Neural_Network(nn.Module):
    def __init__(self, ):
        super(Neural_Network, self).__init__()
        # parameters
        # TODO: parameters can be parameterized instead of declaring them here
        self.inputSize = 2
        self.outputSize = 1
        self.hiddenSize = 3
        
        # weights
        self.W1 = torch.randn(self.inputSize, self.hiddenSize) # 3 X 2 tensor
        self.W2 = torch.randn(self.hiddenSize, self.outputSize) # 3 X 1 tensor
        
    def forward(self, X):
        self.z = torch.matmul(X, self.W1) # 3 X 3 ".dot" does not broadcast in PyTorch
        self.z2 = self.sigmoid(self.z) # activation function
        self.z3 = torch.matmul(self.z2, self.W2)
        o = self.sigmoid(self.z3) # final activation function
        return o
        
    def sigmoid(self, s):
        return 1 / (1 + torch.exp(-s))
    
    def sigmoidPrime(self, s):
        # derivative of sigmoid
        return s * (1 - s)
    
    def backward(self, X, y, o):
        self.o_error = y - o # error in output
        self.o_delta = self.o_error * self.sigmoidPrime(o) # derivative of sig to error
        self.z2_error = torch.matmul(self.o_delta, torch.t(self.W2))
        self.z2_delta = self.z2_error * self.sigmoidPrime(self.z2)
        self.W1 += torch.matmul(torch.t(X), self.z2_delta)
        self.W2 += torch.matmul(torch.t(self.z2), self.o_delta)
        
    def train(self, X, y):
        # forward + backward pass for training
        o = self.forward(X)
        self.backward(X, y, o)
        
    def saveWeights(self, model):
        # we will use the PyTorch internal storage functions
        torch.save(model, "NN")
        # you can reload model with all the weights and so forth with:
        # torch.load("NN")
        
    def predict(self):
        print ("Predicted data based on trained weights: ")
        print ("Input (scaled): \n" + str(xPredicted))
        print ("Output: \n" + str(self.forward(xPredicted)))

In [7]:
NN = Neural_Network()
for i in range(1000):  # trains the NN 1,000 times
    print ("#" + str(i) + " Loss: " + str(torch.mean((y - NN(X))**2).detach().item()))  # mean sum squared loss
    NN.train(X, y)
NN.saveWeights(NN)
NN.predict()

#0 Loss: 0.15037821233272552
#1 Loss: 0.10608363896608353
#2 Loss: 0.07739616185426712
#3 Loss: 0.05857440456748009
#4 Loss: 0.04585493728518486
#5 Loss: 0.03696621581912041
#6 Loss: 0.030551893636584282
#7 Loss: 0.025787657126784325
#8 Loss: 0.022158406674861908
#9 Loss: 0.01933235675096512
#10 Loss: 0.017089400440454483
#11 Loss: 0.015279454179108143
#12 Loss: 0.013797607272863388
#13 Loss: 0.012568830512464046
#14 Loss: 0.011538361199200153
#15 Loss: 0.01066552009433508
#16 Loss: 0.009919575415551662
#17 Loss: 0.009276955388486385
#18 Loss: 0.008719311095774174
#19 Loss: 0.00823226198554039
#20 Loss: 0.007804302033036947
#21 Loss: 0.00742624094709754
#22 Loss: 0.007090594619512558
#23 Loss: 0.00679122656583786
#24 Loss: 0.006523087155073881
#25 Loss: 0.00628199428319931
#26 Loss: 0.006064425688236952
#27 Loss: 0.005867434665560722
#28 Loss: 0.005688503384590149
#29 Loss: 0.005525509361177683
#30 Loss: 0.005376615095883608
#31 Loss: 0.0052402629517018795
#32 Loss: 0.00511508807539939

#529 Loss: 0.002801816212013364
#530 Loss: 0.002800900721922517
#531 Loss: 0.002799984300509095
#532 Loss: 0.0027990664821118116
#533 Loss: 0.002798146568238735
#534 Loss: 0.0027972266543656588
#535 Loss: 0.00279630976729095
#536 Loss: 0.002795388223603368
#537 Loss: 0.002794469939544797
#538 Loss: 0.0027935465332120657
#539 Loss: 0.0027926235925406218
#540 Loss: 0.0027917048428207636
#541 Loss: 0.002790778875350952
#542 Loss: 0.0027898543048650026
#543 Loss: 0.002788927173241973
#544 Loss: 0.002788003534078598
#545 Loss: 0.0027870796620845795
#546 Loss: 0.0027861499693244696
#547 Loss: 0.0027852163184434175
#548 Loss: 0.002784289186820388
#549 Loss: 0.0027833671774715185
#550 Loss: 0.0027824293356388807
#551 Loss: 0.002781500108540058
#552 Loss: 0.0027805715799331665
#553 Loss: 0.002779637696221471
#554 Loss: 0.0027787021826952696
#555 Loss: 0.0027777664363384247
#556 Loss: 0.0027768334839493036
#557 Loss: 0.002775897504761815
#558 Loss: 0.002774960594251752
#559 Loss: 0.0027740199584

  "type " + obj.__name__ + ". It won't be checked "
