In [1]:
import torch
import torch.nn as nn

import sys
from os import path
sys.path.append( "../website/apis/models/" )

from WildfireModel import WildfireModel

In [2]:
# define constants
IMG_SZ = 100
TRAIN_NEW_MODEL = True
MODEL_PATH = "../website/apis/models/wfm.pt"

In [3]:
# determine the linear size of the output layer of a convolutional layer
def conv2d_out_sz(in_size, kernel_size, pool_size, padding=0, stride=1):
    return ((in_size - kernel_size + 2*padding)/stride + 1)/pool_size

In [4]:
# use the function we just defined
# the value of fcin explicitly calculated here is used in the definition of the model in 

ks = 5 # kernel size
ps = 2 # pool size
out_chan = 64 # number of output channels 

os = conv2d_out_sz(IMG_SZ,ks,ps) # outputs to 2nd conv2d layer
os = conv2d_out_sz(os,ks,ps) # output size of 2nd conv2d layer
os = conv2d_out_sz(os,ks,ps) # output size 3

fcin = int((int(os)**2)*out_chan)
fcin

5184

In [5]:
import torch.optim as optim
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt

In [6]:
# get the device and the tensor datatype that we will be using, giving preference to NVIDIA GPUs over CPUs
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
torch.set_default_tensor_type(dtype)

In [7]:
# must run the Preprocess.ipynb Jupyter Notebook to generate data before running this notebook
# load the balanced data array to train and test on
data = np.load("balanced_data.npy", allow_pickle=True)
extra_data = np.load("extra_data.npy", allow_pickle=True)

# format x the way torch wants to see it
x = torch.tensor([d[0] for d in data])
x = (x/255.0).view(-1,3,100,100)
# and the extra x
extra_x = torch.tensor([ed[0] for ed in extra_data])
extra_x = (extra_x/255.0).view(-1,3,100,100)

# format y the way torch wants to see it
y = torch.tensor([float(d[1]) for d in data])
# and the extra y
extra_y = torch.tensor([float(ed[1]) for ed in extra_data])

train_x, test_x, train_y, test_y = train_test_split(x, y, test_size = 0.1)
(train_x.shape, train_y.shape), (test_x.shape, test_y.shape)

((torch.Size([3479, 3, 100, 100]), torch.Size([3479])),
 (torch.Size([387, 3, 100, 100]), torch.Size([387])))

In [8]:
def evaluate_accuracy(test_x, test_y, wfm):
    with torch.no_grad():
        y_pred = wfm(test_x).round()
        test_y = test_y.unsqueeze(1)
        correct = torch.tensor(y_pred == test_y)
        return sum(correct).item()/len(correct)

In [9]:
# bounded buffer class, keeps track of 'size' most recent insertions
class BoundedNumericList():
    
    def __init__(self, size):
        self.size = size
        self.nums = []
        self.next_insertion = 0
        
    def insert(self, item):
        if not isinstance(item, (int, float, complex)) or isinstance(item, bool):
            return False
        if len(self.nums) < self.size:
            self.nums += [item]
        else:
            self.nums[self.next_insertion % self.size] = item
        self.next_insertion += 1
        return True

    def average(self):
        if len(self.nums) == 0: return None
        return sum(self.nums) / len(self.nums)

In [10]:
if TRAIN_NEW_MODEL:

    # instantiate our model, initial optimizer, and loss function
    wfm = WildfireModel().to(device)
    optimizer = optim.AdamW(wfm.parameters(), lr=1e-4)
    loss_function = nn.BCELoss()

    # declare constants controlling the training process
    BATCH_SIZE = 32
    EPOCHS = 5000
    ROLLING_ACCURACY_SIZE = 4

    # instantiate our bounded buffer to keep track of the last ROLLING_ACCURACY_SIZE epochh accuracies
    past_epoch_accuracies = BoundedNumericList(ROLLING_ACCURACY_SIZE)
    highest_rolling_accuracy = -1

    for epoch in range(EPOCHS): # loop over all of our data EPOCH times
        for i in range(0, len(train_x), BATCH_SIZE): # iterate over our batches
            # grab the ith batch
            batch_x = train_x[i : i+BATCH_SIZE]
            batch_y = train_y[i : i+BATCH_SIZE]
            batch_y = torch.unsqueeze(batch_y, 1)
            
            # zero our gradient
            wfm.zero_grad()

            # pass the batch through the model
            outputs = wfm(batch_x)
            
            # compute the loss between the outputs and the expected
            loss = loss_function(outputs, batch_y)
            
            # update the model's weights
            loss.backward()
            optimizer.step()

        # calculate the most recent epoch's accuracy
        epoch_accuracy = evaluate_accuracy(test_x, test_y, wfm)
        
        # add it to the list of past accuracies
        past_epoch_accuracies.insert(epoch_accuracy)
        
        # calculate the rolling average
        rolling_accuracy = past_epoch_accuracies.average()

        print(f"Epoch: {epoch}. Loss: {loss}. Rolling Accuracy: {round(rolling_accuracy,3)}")

        # save the model if the most recent updates have been beneficial
        if rolling_accuracy > highest_rolling_accuracy:
            print(f"Saving model at epoch {epoch}.")
            torch.save(wfm, MODEL_PATH)
            highest_rolling_accuracy = rolling_accuracy

  correct = torch.tensor(y_pred == test_y)


Epoch: 0. Loss: 0.5216521620750427. Rolling Accuracy: 0.693
Saving model at epoch 0.
Epoch: 1. Loss: 0.5007863640785217. Rolling Accuracy: 0.708
Saving model at epoch 1.
Epoch: 2. Loss: 0.4932693839073181. Rolling Accuracy: 0.713
Saving model at epoch 2.
Epoch: 3. Loss: 0.49051037430763245. Rolling Accuracy: 0.714
Saving model at epoch 3.
Epoch: 4. Loss: 0.4838408827781677. Rolling Accuracy: 0.72
Saving model at epoch 4.
Epoch: 5. Loss: 0.4464470148086548. Rolling Accuracy: 0.72
Saving model at epoch 5.
Epoch: 6. Loss: 0.47642460465431213. Rolling Accuracy: 0.717
Epoch: 7. Loss: 0.44293412566185. Rolling Accuracy: 0.717
Epoch: 8. Loss: 0.4491131901741028. Rolling Accuracy: 0.716
Epoch: 9. Loss: 0.4557952582836151. Rolling Accuracy: 0.717
Epoch: 10. Loss: 0.4580579996109009. Rolling Accuracy: 0.724
Saving model at epoch 10.
Epoch: 11. Loss: 0.48614469170570374. Rolling Accuracy: 0.737
Saving model at epoch 11.
Epoch: 12. Loss: 0.49574014544487. Rolling Accuracy: 0.758
Saving model at ep

Epoch: 116. Loss: 0.005764631554484367. Rolling Accuracy: 0.92
Epoch: 117. Loss: 0.008570926263928413. Rolling Accuracy: 0.923
Saving model at epoch 117.
Epoch: 118. Loss: 0.0172397643327713. Rolling Accuracy: 0.922
Epoch: 119. Loss: 0.00382635067217052. Rolling Accuracy: 0.929
Saving model at epoch 119.
Epoch: 120. Loss: 0.00781748816370964. Rolling Accuracy: 0.935
Saving model at epoch 120.
Epoch: 121. Loss: 0.006375241558998823. Rolling Accuracy: 0.935
Epoch: 122. Loss: 0.0018237588228657842. Rolling Accuracy: 0.936
Saving model at epoch 122.
Epoch: 123. Loss: 0.004078424070030451. Rolling Accuracy: 0.928
Epoch: 124. Loss: 0.004775104112923145. Rolling Accuracy: 0.926
Epoch: 125. Loss: 0.002644350752234459. Rolling Accuracy: 0.925
Epoch: 126. Loss: 0.004019938874989748. Rolling Accuracy: 0.926
Epoch: 127. Loss: 0.0016985032707452774. Rolling Accuracy: 0.93
Epoch: 128. Loss: 0.004800158552825451. Rolling Accuracy: 0.933
Epoch: 129. Loss: 0.003527397522702813. Rolling Accuracy: 0.936


Epoch: 239. Loss: 8.4826506281388e-06. Rolling Accuracy: 0.941
Epoch: 240. Loss: 8.184657417587005e-06. Rolling Accuracy: 0.941
Epoch: 241. Loss: 5.413968210632447e-06. Rolling Accuracy: 0.945
Epoch: 242. Loss: 5.4917181842029095e-06. Rolling Accuracy: 0.944
Epoch: 243. Loss: 3.2187615488510346e-06. Rolling Accuracy: 0.944
Epoch: 244. Loss: 1.689692794570874e-06. Rolling Accuracy: 0.944
Epoch: 245. Loss: 8.785281693235447e-07. Rolling Accuracy: 0.944
Epoch: 246. Loss: 2.4904693418648094e-06. Rolling Accuracy: 0.946
Epoch: 247. Loss: 9.951456831913674e-07. Rolling Accuracy: 0.946
Epoch: 248. Loss: 4.975705110155104e-07. Rolling Accuracy: 0.946
Epoch: 249. Loss: 3.368966190464562e-07. Rolling Accuracy: 0.945
Epoch: 250. Loss: 0.011716227047145367. Rolling Accuracy: 0.934
Epoch: 251. Loss: 0.0002649376692716032. Rolling Accuracy: 0.934
Epoch: 252. Loss: 0.0005973384249955416. Rolling Accuracy: 0.931
Epoch: 253. Loss: 0.00014718247985001653. Rolling Accuracy: 0.93
Epoch: 254. Loss: 1.94825

Epoch: 363. Loss: 2.0239792775100796e-06. Rolling Accuracy: 0.941
Epoch: 364. Loss: 4.35804613516666e-05. Rolling Accuracy: 0.938
Epoch: 365. Loss: 3.7467412767000496e-05. Rolling Accuracy: 0.941
Epoch: 366. Loss: 3.602285005399608e-06. Rolling Accuracy: 0.941
Epoch: 367. Loss: 3.9651108636462595e-06. Rolling Accuracy: 0.944
Epoch: 368. Loss: 2.736685473792022e-06. Rolling Accuracy: 0.946
Epoch: 369. Loss: 2.280563876411179e-06. Rolling Accuracy: 0.948
Epoch: 370. Loss: 2.285745040353504e-06. Rolling Accuracy: 0.948
Epoch: 371. Loss: 7.670913078072772e-07. Rolling Accuracy: 0.948
Epoch: 372. Loss: 1.2905829862575047e-06. Rolling Accuracy: 0.947
Epoch: 373. Loss: 1.09880818399688e-06. Rolling Accuracy: 0.946
Epoch: 374. Loss: 4.820224148716079e-07. Rolling Accuracy: 0.946
Epoch: 375. Loss: 7.826397450116929e-07. Rolling Accuracy: 0.946
Epoch: 376. Loss: 3.213477839381085e-07. Rolling Accuracy: 0.946
Epoch: 377. Loss: 6.012323865434155e-07. Rolling Accuracy: 0.946
Epoch: 378. Loss: 3.203

Epoch: 488. Loss: 9.405440323462244e-06. Rolling Accuracy: 0.941
Epoch: 489. Loss: 5.198829967412166e-06. Rolling Accuracy: 0.939
Epoch: 490. Loss: 4.151767370785819e-06. Rolling Accuracy: 0.937
Epoch: 491. Loss: 1.7622546693019103e-06. Rolling Accuracy: 0.938
Epoch: 492. Loss: 7.256222289697689e-08. Rolling Accuracy: 0.938
Epoch: 493. Loss: 1.969548151237177e-07. Rolling Accuracy: 0.939
Epoch: 494. Loss: 6.805807515775086e-06. Rolling Accuracy: 0.939
Epoch: 495. Loss: 2.1613625449390383e-06. Rolling Accuracy: 0.941
Epoch: 496. Loss: 1.5704770248703426e-06. Rolling Accuracy: 0.943
Epoch: 497. Loss: 9.64049263529887e-07. Rolling Accuracy: 0.943
Epoch: 498. Loss: 1.1817409131253953e-06. Rolling Accuracy: 0.943
Epoch: 499. Loss: 9.145815420197323e-05. Rolling Accuracy: 0.939
Epoch: 500. Loss: 1.0890315934375394e-05. Rolling Accuracy: 0.935
Epoch: 501. Loss: 1.5893099771346897e-05. Rolling Accuracy: 0.935
Epoch: 502. Loss: 1.3853223208570853e-05. Rolling Accuracy: 0.935
Epoch: 503. Loss: 2

Epoch: 615. Loss: 4.5869748532822996e-07. Rolling Accuracy: 0.931
Epoch: 616. Loss: 2.384188348969474e-07. Rolling Accuracy: 0.933
Epoch: 617. Loss: 2.539678973789705e-07. Rolling Accuracy: 0.936
Epoch: 618. Loss: 1.7881409064557374e-07. Rolling Accuracy: 0.937
Epoch: 619. Loss: 8.811125695729061e-08. Rolling Accuracy: 0.937
Epoch: 620. Loss: 1.1661785492833587e-07. Rolling Accuracy: 0.937
Epoch: 621. Loss: 8.29282456038527e-08. Rolling Accuracy: 0.936
Epoch: 622. Loss: 6.219617176839165e-08. Rolling Accuracy: 0.936
Epoch: 623. Loss: 5.1830138403374804e-08. Rolling Accuracy: 0.935
Epoch: 624. Loss: 2.8506574878406354e-08. Rolling Accuracy: 0.936
Epoch: 625. Loss: 9.329427541615587e-08. Rolling Accuracy: 0.937
Epoch: 626. Loss: 6.997070300940322e-08. Rolling Accuracy: 0.937
Epoch: 627. Loss: 5.96046660916727e-08. Rolling Accuracy: 0.937
Epoch: 628. Loss: 3.887260646706636e-08. Rolling Accuracy: 0.936
Epoch: 629. Loss: 3.887260646706636e-08. Rolling Accuracy: 0.935
Epoch: 630. Loss: 2.85

Epoch: 742. Loss: 1.036602625958949e-08. Rolling Accuracy: 0.941
Epoch: 743. Loss: 1.036602625958949e-08. Rolling Accuracy: 0.942
Epoch: 744. Loss: 1.036602625958949e-08. Rolling Accuracy: 0.942
Epoch: 745. Loss: 0.0. Rolling Accuracy: 0.942
Epoch: 746. Loss: 0.0. Rolling Accuracy: 0.942
Epoch: 747. Loss: 0.0. Rolling Accuracy: 0.942
Epoch: 748. Loss: 0.0. Rolling Accuracy: 0.943
Epoch: 749. Loss: 0.0. Rolling Accuracy: 0.943
Epoch: 750. Loss: 0.0. Rolling Accuracy: 0.943
Epoch: 751. Loss: 0.0. Rolling Accuracy: 0.943
Epoch: 752. Loss: 7.774524135584215e-08. Rolling Accuracy: 0.94
Epoch: 753. Loss: 5.7013160414953745e-08. Rolling Accuracy: 0.94
Epoch: 754. Loss: 2.073205251917898e-08. Rolling Accuracy: 0.939
Epoch: 755. Loss: 1.5549041165741073e-08. Rolling Accuracy: 0.939
Epoch: 756. Loss: 1.036602625958949e-08. Rolling Accuracy: 0.941
Epoch: 757. Loss: 5.183013129794745e-09. Rolling Accuracy: 0.941
Epoch: 758. Loss: 0.0. Rolling Accuracy: 0.941
Epoch: 759. Loss: 0.0. Rolling Accuracy

Epoch: 873. Loss: 3.783615341035329e-07. Rolling Accuracy: 0.942
Epoch: 874. Loss: 2.643344316766161e-07. Rolling Accuracy: 0.941
Epoch: 875. Loss: 1.6585671858138085e-07. Rolling Accuracy: 0.941
Epoch: 876. Loss: 1.1402642741131785e-07. Rolling Accuracy: 0.941
Epoch: 877. Loss: 8.292828113098949e-08. Rolling Accuracy: 0.941
Epoch: 878. Loss: 5.18301561669432e-08. Rolling Accuracy: 0.941
Epoch: 879. Loss: 4.146412280192635e-08. Rolling Accuracy: 0.943
Epoch: 880. Loss: 3.1098089436909504e-08. Rolling Accuracy: 0.944
Epoch: 881. Loss: 2.591507097804424e-08. Rolling Accuracy: 0.944
Epoch: 882. Loss: 2.0732056071892657e-08. Rolling Accuracy: 0.946
Epoch: 883. Loss: 1.5549041165741073e-08. Rolling Accuracy: 0.946
Epoch: 884. Loss: 1.036602625958949e-08. Rolling Accuracy: 0.946
Epoch: 885. Loss: 1.036602625958949e-08. Rolling Accuracy: 0.946
Epoch: 886. Loss: 5.183013129794745e-09. Rolling Accuracy: 0.946
Epoch: 887. Loss: 5.183013129794745e-09. Rolling Accuracy: 0.946
Epoch: 888. Loss: 5.1

Epoch: 1010. Loss: 1.3786891486233799e-06. Rolling Accuracy: 0.943
Epoch: 1011. Loss: 1.0573392046353547e-06. Rolling Accuracy: 0.943
Epoch: 1012. Loss: 8.08552726994094e-07. Rolling Accuracy: 0.943
Epoch: 1013. Loss: 6.893427553222864e-07. Rolling Accuracy: 0.943
Epoch: 1014. Loss: 5.338515620678663e-07. Rolling Accuracy: 0.942
Epoch: 1015. Loss: 4.612890904809319e-07. Rolling Accuracy: 0.941
Epoch: 1016. Loss: 3.887266188939975e-07. Rolling Accuracy: 0.941
Epoch: 1017. Loss: 3.2134727234733873e-07. Rolling Accuracy: 0.941
Epoch: 1018. Loss: 2.954321303150209e-07. Rolling Accuracy: 0.941
Epoch: 1019. Loss: 2.5396792580067995e-07. Rolling Accuracy: 0.941
Epoch: 1020. Loss: 2.0732070993290108e-07. Rolling Accuracy: 0.941
Epoch: 1021. Loss: 1.9177164745087794e-07. Rolling Accuracy: 0.941
Epoch: 1022. Loss: 1.762225849688548e-07. Rolling Accuracy: 0.941
Epoch: 1023. Loss: 1.3475842308707797e-07. Rolling Accuracy: 0.941
Epoch: 1024. Loss: 1.2439238616934745e-07. Rolling Accuracy: 0.941
Epo

Epoch: 1134. Loss: 1.0884333789817902e-07. Rolling Accuracy: 0.939
Epoch: 1135. Loss: 5.753164487032336e-07. Rolling Accuracy: 0.94
Epoch: 1136. Loss: 3.368964485161996e-07. Rolling Accuracy: 0.94
Epoch: 1137. Loss: 2.1768680369405047e-07. Rolling Accuracy: 0.941
Epoch: 1138. Loss: 5.338519599717984e-07. Rolling Accuracy: 0.94
Epoch: 1139. Loss: 1.036602625958949e-08. Rolling Accuracy: 0.941
Epoch: 1140. Loss: 9.847783530858578e-07. Rolling Accuracy: 0.941
Epoch: 1141. Loss: 1.580833441039431e-06. Rolling Accuracy: 0.941
Epoch: 1142. Loss: 5.183013129794745e-09. Rolling Accuracy: 0.941
Epoch: 1143. Loss: 6.54122459309292e-06. Rolling Accuracy: 0.941
Epoch: 1144. Loss: 0.004483403638005257. Rolling Accuracy: 0.935
Epoch: 1145. Loss: 0.0001402161142323166. Rolling Accuracy: 0.937
Epoch: 1146. Loss: 1.7085363651858643e-05. Rolling Accuracy: 0.939
Epoch: 1147. Loss: 7.72314888308756e-06. Rolling Accuracy: 0.941
Epoch: 1148. Loss: 4.493830147112021e-06. Rolling Accuracy: 0.95
Epoch: 1149. L

Epoch: 1266. Loss: 0.0. Rolling Accuracy: 0.943
Epoch: 1267. Loss: 0.0. Rolling Accuracy: 0.943
Epoch: 1268. Loss: 0.0. Rolling Accuracy: 0.943
Epoch: 1269. Loss: 0.0. Rolling Accuracy: 0.943
Epoch: 1270. Loss: 0.0. Rolling Accuracy: 0.943
Epoch: 1271. Loss: 7.774519694692117e-09. Rolling Accuracy: 0.943
Epoch: 1272. Loss: 0.0. Rolling Accuracy: 0.943
Epoch: 1273. Loss: 0.0. Rolling Accuracy: 0.943
Epoch: 1274. Loss: 0.0. Rolling Accuracy: 0.944
Epoch: 1275. Loss: 0.0. Rolling Accuracy: 0.944
Epoch: 1276. Loss: 0.0. Rolling Accuracy: 0.944
Epoch: 1277. Loss: 0.0. Rolling Accuracy: 0.944
Epoch: 1278. Loss: 0.0. Rolling Accuracy: 0.943
Epoch: 1279. Loss: 0.0. Rolling Accuracy: 0.942
Epoch: 1280. Loss: 0.0. Rolling Accuracy: 0.94
Epoch: 1281. Loss: 0.0016030673868954182. Rolling Accuracy: 0.935
Epoch: 1282. Loss: 3.7058694601910247e-07. Rolling Accuracy: 0.937
Epoch: 1283. Loss: 2.384192328008794e-07. Rolling Accuracy: 0.938
Epoch: 1284. Loss: 1.6067369301708823e-07. Rolling Accuracy: 0.9

Epoch: 1403. Loss: 2.2287005663201853e-07. Rolling Accuracy: 0.939
Epoch: 1404. Loss: 1.4512458790250093e-07. Rolling Accuracy: 0.939
Epoch: 1405. Loss: 1.2439247143447574e-07. Rolling Accuracy: 0.939
Epoch: 1406. Loss: 1.399415481273536e-07. Rolling Accuracy: 0.94
Epoch: 1407. Loss: 9.847733650758528e-08. Rolling Accuracy: 0.94
Epoch: 1408. Loss: 7.774524846126951e-08. Rolling Accuracy: 0.941
Epoch: 1409. Loss: 6.219620019010108e-08. Rolling Accuracy: 0.941
Epoch: 1410. Loss: 5.18301561669432e-08. Rolling Accuracy: 0.941
Epoch: 1411. Loss: 4.6647137708077935e-08. Rolling Accuracy: 0.941
Epoch: 1412. Loss: 5.18301561669432e-08. Rolling Accuracy: 0.941
Epoch: 1413. Loss: 3.628110434306109e-08. Rolling Accuracy: 0.941
Epoch: 1414. Loss: 3.1098089436909504e-08. Rolling Accuracy: 0.941
Epoch: 1415. Loss: 3.628110434306109e-08. Rolling Accuracy: 0.94
Epoch: 1416. Loss: 2.591507097804424e-08. Rolling Accuracy: 0.939
Epoch: 1417. Loss: 2.0732056071892657e-08. Rolling Accuracy: 0.939
Epoch: 14

Epoch: 1528. Loss: 5.483387940330431e-05. Rolling Accuracy: 0.935
Epoch: 1529. Loss: 2.0499057882261695e-06. Rolling Accuracy: 0.941
Epoch: 1530. Loss: 2.498269168427214e-06. Rolling Accuracy: 0.942
Epoch: 1531. Loss: 1.575659211994207e-06. Rolling Accuracy: 0.948
Epoch: 1532. Loss: 1.1506416512929718e-06. Rolling Accuracy: 0.953
Epoch: 1533. Loss: 8.96669178018783e-07. Rolling Accuracy: 0.953
Epoch: 1534. Loss: 7.334018050642044e-07. Rolling Accuracy: 0.953
Epoch: 1535. Loss: 6.297401000665559e-07. Rolling Accuracy: 0.952
Epoch: 1536. Loss: 5.312617759045679e-07. Rolling Accuracy: 0.952
Epoch: 1537. Loss: 4.6906498596399615e-07. Rolling Accuracy: 0.951
Epoch: 1538. Loss: 4.094597727544169e-07. Rolling Accuracy: 0.951
Epoch: 1539. Loss: 3.6281227266954374e-07. Rolling Accuracy: 0.951
Epoch: 1540. Loss: 3.2393938909081044e-07. Rolling Accuracy: 0.951
Epoch: 1541. Loss: 2.928411504399264e-07. Rolling Accuracy: 0.951
Epoch: 1542. Loss: 2.617429117890424e-07. Rolling Accuracy: 0.951
Epoch:

Epoch: 1657. Loss: 6.16782926954329e-07. Rolling Accuracy: 0.943
Epoch: 1658. Loss: 5.856844040863507e-07. Rolling Accuracy: 0.943
Epoch: 1659. Loss: 5.390367050495115e-07. Rolling Accuracy: 0.943
Epoch: 1660. Loss: 2.4309010768774897e-06. Rolling Accuracy: 0.943
Epoch: 1661. Loss: 2.539680679092271e-07. Rolling Accuracy: 0.941
Epoch: 1662. Loss: 0.00014431244926527143. Rolling Accuracy: 0.936
Epoch: 1663. Loss: 0.00014617046690545976. Rolling Accuracy: 0.935
Epoch: 1664. Loss: 7.2773650572344195e-06. Rolling Accuracy: 0.935
Epoch: 1665. Loss: 2.4671628580108518e-06. Rolling Accuracy: 0.937
Epoch: 1666. Loss: 1.420161424903199e-06. Rolling Accuracy: 0.941
Epoch: 1667. Loss: 7.074851851029962e-07. Rolling Accuracy: 0.941
Epoch: 1668. Loss: 4.1205086631634913e-07. Rolling Accuracy: 0.941
Epoch: 1669. Loss: 2.798832952066732e-07. Rolling Accuracy: 0.94
Epoch: 1670. Loss: 1.8658873557342304e-07. Rolling Accuracy: 0.94
Epoch: 1671. Loss: 1.89180241250142e-07. Rolling Accuracy: 0.941
Epoch: 

Epoch: 1805. Loss: 5.234873583503941e-07. Rolling Accuracy: 0.929
Epoch: 1806. Loss: 2.8506661919891485e-07. Rolling Accuracy: 0.93
Epoch: 1807. Loss: 1.8140582369596814e-07. Rolling Accuracy: 0.932
Epoch: 1808. Loss: 1.347585367739157e-07. Rolling Accuracy: 0.933
Epoch: 1809. Loss: 9.329432515414737e-08. Rolling Accuracy: 0.935
Epoch: 1810. Loss: 8.292828113098949e-08. Rolling Accuracy: 0.935
Epoch: 1811. Loss: 4.6647137708077935e-08. Rolling Accuracy: 0.934
Epoch: 1812. Loss: 3.628110434306109e-08. Rolling Accuracy: 0.933
Epoch: 1813. Loss: 3.628110434306109e-08. Rolling Accuracy: 0.933
Epoch: 1814. Loss: 2.591507097804424e-08. Rolling Accuracy: 0.933
Epoch: 1815. Loss: 2.0732056071892657e-08. Rolling Accuracy: 0.933
Epoch: 1816. Loss: 1.5549041165741073e-08. Rolling Accuracy: 0.933
Epoch: 1817. Loss: 1.5549041165741073e-08. Rolling Accuracy: 0.933
Epoch: 1818. Loss: 1.036602625958949e-08. Rolling Accuracy: 0.933
Epoch: 1819. Loss: 1.036602625958949e-08. Rolling Accuracy: 0.933
Epoch

KeyboardInterrupt: 

In [11]:
wfm = torch.load(MODEL_PATH).to(device)

acc = evaluate_accuracy(test_x, test_y, wfm)
print(f"Model accuracy on test data: {acc}")

extra_acc = evaluate_accuracy(extra_x, extra_y, wfm)
print(f"Model accuracy on extra data: {extra_acc}")

Model accuracy on test data: 0.9534883720930233
Model accuracy on extra data: 0.9274193548387096


  correct = torch.tensor(y_pred == test_y)
