In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# define constants
IMG_SZ = 100
TRAIN_NEW_MODEL = False

In [3]:
# determine the linear size of the output layer of a convolutional layer
def conv2d_out_sz(in_size, kernel_size, pool_size, padding=0, stride=1):
    return ((in_size - kernel_size + 2*padding)/stride + 1)/pool_size

In [4]:
ks = 5 # kernel size
ps = 2 # pool size
out_chan = 64 # number of output channels 

os = conv2d_out_sz(IMG_SZ,ks,ps) # outputs to 2nd conv2d layer
os = conv2d_out_sz(os,ks,ps) # output size of 2nd conv2d layer
os = conv2d_out_sz(os,ks,ps) # output size 3

fcin = int((int(os)**2)*out_chan)
fcin

5184

In [5]:
class WildfireModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.kernel_size = 5

        self.conv1 = nn.Conv2d(in_channels=3,  out_channels=16,  kernel_size=self.kernel_size)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32,  kernel_size=self.kernel_size)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=self.kernel_size)

        self.fc1 = nn.Linear(fcin, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 32)
        self.fc4 = nn.Linear(32, 8)
        self.fc5 = nn.Linear(8, 1)

    def forward(self, x):
        x = F.max_pool2d(F.leaky_relu(self.conv1(x)), (2,2))
        x = F.max_pool2d(F.leaky_relu(self.conv2(x)), (2,2))
        x = F.max_pool2d(F.leaky_relu(self.conv3(x)), (2,2))
        
        x = x.view(-1, fcin) # flatten to input to the linear layer
        
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))
        x = F.leaky_relu(self.fc3(x))
        x = F.leaky_relu(self.fc4(x))
        x = F.leaky_relu(self.fc5(x))
                    
        x = F.sigmoid(x)
        return x

In [6]:
import torch.optim as optim
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt

In [7]:
# load the balanced data array to train and test on
data = np.load("balanced_data.npy", allow_pickle=True)
extra_data = np.load("extra_data.npy", allow_pickle=True)

# format x the way torch wants to see it
x = torch.tensor([d[0] for d in data])
x = (x/255.0).view(-1,3,100,100)
# and the extra x
extra_x = torch.tensor([ed[0] for ed in extra_data])
extra_x = (extra_x/255.0).view(-1,3,100,100)

# format y the way torch wants to see it
y = torch.tensor([float(d[1]) for d in data])
# and the extra y
extra_y = torch.tensor([float(ed[1]) for ed in extra_data])

train_x, test_x, train_y, test_y = train_test_split(x, y, test_size = 0.1)
(train_x.shape, train_y.shape), (test_x.shape, test_y.shape)

((torch.Size([514, 3, 100, 100]), torch.Size([514])),
 (torch.Size([58, 3, 100, 100]), torch.Size([58])))

In [8]:
def evaluate_accuracy(test_x, test_y, wfm):
    correct = []
    incorrect = []

    with torch.no_grad():
        for i in range(len(test_x)):
            real_class = test_y[i]
            output = wfm(test_x[i:i+1])[0] # returns a list, grab the 0th 
            predicted_class = np.round(output.detach())
            if predicted_class == real_class:
                correct += [i]
            else:
                incorrect += [i]
                
    accuracy = round(len(correct)/(len(correct) + len(incorrect)), 3)
                
    return (accuracy, correct, incorrect)

In [9]:
# bounded buffer class, keeps track of 'size' most recent insertions
class BoundedNumericList():
    
    def __init__(self, size):
        self.size = size
        self.nums = []
        self.next_insertion = 0
        
    def insert(self, item):
        if not isinstance(item, (int, float, complex)) or isinstance(item, bool):
            return False
        if len(self.nums) < self.size:
            self.nums += [item]
        else:
            self.nums[self.next_insertion % self.size] = item
        self.next_insertion += 1
        return True

    def average(self):
        if len(self.nums) == 0: return None
        return sum(self.nums) / len(self.nums)

In [10]:
if TRAIN_NEW_MODEL:

    # instantiate our model, initial optimizer, and loss function
    wfm = WildfireModel()
    optimizer = optim.AdamW(wfm.parameters(), lr=1e-4)
    loss_function = nn.BCELoss()

    # declare constants controlling the training process
    BATCH_SIZE = 100
    EPOCHS = 1000
    ROLLING_ACCURACY_SIZE = 4

    # instantiate our bounded buffer to keep track of the last ROLLING_ACCURACY_SIZE epochh accuracies
    past_epoch_accuracies = BoundedNumericList(ROLLING_ACCURACY_SIZE)
    highest_rolling_accuracy = -1

    for epoch in range(EPOCHS): # loop over all of our data EPOCH times
        for i in range(0, len(train_x), BATCH_SIZE): # iterate over our batches
            # grab the ith batch
            batch_x = train_x[i : i+BATCH_SIZE]
            batch_y = train_y[i : i+BATCH_SIZE]
            batch_y = torch.unsqueeze(batch_y, 1)

            # zero our gradient
            wfm.zero_grad()

            # pass the batch through the model
            outputs = wfm(batch_x)
            
            # compute the loss between the outputs and the expected
            loss = loss_function(outputs, batch_y)
            
            # update the model's weights
            loss.backward()
            optimizer.step()

        # calculate the most recent epoch's accuracy
        epoch_accuracy,_,_ = evaluate_accuracy(test_x, test_y, wfm)
        
        # add it to the list of past accuracies
        past_epoch_accuracies.insert(epoch_accuracy)
        
        # calculate the rolling average
        rolling_accuracy = past_epoch_accuracies.average()

        print(f"Epoch: {epoch}. Loss: {loss}. Rolling Accuracy: {round(rolling_accuracy,3)}")

        # save the model if the most recent updates have been beneficial
        if rolling_accuracy > highest_rolling_accuracy:
            print(f"Saving model at epoch {epoch}.")
            torch.save(wfm, "wfm.pt")
            highest_rolling_accuracy = rolling_accuracy



Epoch: 0. Loss: 0.693611204624176. Rolling Accuracy: 0.534
Saving model at epoch 0.


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 1. Loss: 0.6936047673225403. Rolling Accuracy: 0.534
Epoch: 2. Loss: 0.693601131439209. Rolling Accuracy: 0.534
Epoch: 3. Loss: 0.6935982704162598. Rolling Accuracy: 0.534
Epoch: 4. Loss: 0.6935939788818359. Rolling Accuracy: 0.534
Epoch: 5. Loss: 0.6935860514640808. Rolling Accuracy: 0.534
Epoch: 6. Loss: 0.6935043334960938. Rolling Accuracy: 0.534
Epoch: 7. Loss: 0.6934460997581482. Rolling Accuracy: 0.534
Epoch: 8. Loss: 0.6923684477806091. Rolling Accuracy: 0.564
Saving model at epoch 8.
Epoch: 9. Loss: 0.6572192311286926. Rolling Accuracy: 0.577
Saving model at epoch 9.
Epoch: 10. Loss: 0.6141074299812317. Rolling Accuracy: 0.595
Saving model at epoch 10.
Epoch: 11. Loss: 0.5698967576026917. Rolling Accuracy: 0.612
Saving model at epoch 11.
Epoch: 12. Loss: 0.5315399765968323. Rolling Accuracy: 0.607
Epoch: 13. Loss: 0.49154940247535706. Rolling Accuracy: 0.642
Saving model at epoch 13.
Epoch: 14. Loss: 0.4753129780292511. Rolling Accuracy: 0.664
Saving model at epoch 14.
E

Epoch: 124. Loss: 0.12733305990695953. Rolling Accuracy: 0.784
Epoch: 125. Loss: 0.12367036193609238. Rolling Accuracy: 0.788
Epoch: 126. Loss: 0.12471633404493332. Rolling Accuracy: 0.793
Epoch: 127. Loss: 0.12367693334817886. Rolling Accuracy: 0.823
Epoch: 128. Loss: 0.11546500772237778. Rolling Accuracy: 0.784
Epoch: 129. Loss: 0.10786253213882446. Rolling Accuracy: 0.78
Epoch: 130. Loss: 0.10478372871875763. Rolling Accuracy: 0.763
Epoch: 131. Loss: 0.09846310317516327. Rolling Accuracy: 0.737
Epoch: 132. Loss: 0.09582126885652542. Rolling Accuracy: 0.75
Epoch: 133. Loss: 0.09249485284090042. Rolling Accuracy: 0.763
Epoch: 134. Loss: 0.20422758162021637. Rolling Accuracy: 0.754
Epoch: 135. Loss: 0.10763001441955566. Rolling Accuracy: 0.78
Epoch: 136. Loss: 0.10526105016469955. Rolling Accuracy: 0.797
Epoch: 137. Loss: 0.09014960378408432. Rolling Accuracy: 0.789
Epoch: 138. Loss: 0.09227077662944794. Rolling Accuracy: 0.797
Epoch: 139. Loss: 0.10138881951570511. Rolling Accuracy: 0

Epoch: 253. Loss: 0.03802249953150749. Rolling Accuracy: 0.828
Epoch: 254. Loss: 0.04003690928220749. Rolling Accuracy: 0.823
Epoch: 255. Loss: 0.05978149175643921. Rolling Accuracy: 0.841
Epoch: 256. Loss: 0.3902130722999573. Rolling Accuracy: 0.819
Epoch: 257. Loss: 0.07612423598766327. Rolling Accuracy: 0.823
Epoch: 258. Loss: 0.05957796424627304. Rolling Accuracy: 0.806
Epoch: 259. Loss: 0.0499989315867424. Rolling Accuracy: 0.797
Epoch: 260. Loss: 0.05417848378419876. Rolling Accuracy: 0.845
Epoch: 261. Loss: 0.035605404525995255. Rolling Accuracy: 0.845
Epoch: 262. Loss: 0.03921739012002945. Rolling Accuracy: 0.866
Epoch: 263. Loss: 0.03606672212481499. Rolling Accuracy: 0.879
Epoch: 264. Loss: 0.03756803274154663. Rolling Accuracy: 0.875
Epoch: 265. Loss: 0.03382088243961334. Rolling Accuracy: 0.879
Epoch: 266. Loss: 0.04644882678985596. Rolling Accuracy: 0.87
Epoch: 267. Loss: 0.06394635885953903. Rolling Accuracy: 0.853
Epoch: 268. Loss: 0.04874451085925102. Rolling Accuracy: 

Epoch: 381. Loss: 0.04581793397665024. Rolling Accuracy: 0.871
Epoch: 382. Loss: 0.030584771186113358. Rolling Accuracy: 0.871
Epoch: 383. Loss: 0.0230748001486063. Rolling Accuracy: 0.884
Epoch: 384. Loss: 0.036182235926389694. Rolling Accuracy: 0.884
Epoch: 385. Loss: 0.018608959391713142. Rolling Accuracy: 0.875
Epoch: 386. Loss: 0.023121122270822525. Rolling Accuracy: 0.879
Epoch: 387. Loss: 0.021892575547099113. Rolling Accuracy: 0.866
Epoch: 388. Loss: 0.024748992174863815. Rolling Accuracy: 0.871
Epoch: 389. Loss: 0.01962689496576786. Rolling Accuracy: 0.871
Epoch: 390. Loss: 0.021180594339966774. Rolling Accuracy: 0.879
Epoch: 391. Loss: 0.02124510332942009. Rolling Accuracy: 0.893
Epoch: 392. Loss: 0.024761736392974854. Rolling Accuracy: 0.888
Epoch: 393. Loss: 0.02269885316491127. Rolling Accuracy: 0.901
Epoch: 394. Loss: 0.020775992423295975. Rolling Accuracy: 0.897
Epoch: 395. Loss: 0.022947072982788086. Rolling Accuracy: 0.901
Saving model at epoch 395.
Epoch: 396. Loss: 0

Epoch: 508. Loss: 0.009857363067567348. Rolling Accuracy: 0.888
Epoch: 509. Loss: 0.006505575031042099. Rolling Accuracy: 0.888
Epoch: 510. Loss: 0.008207350037992. Rolling Accuracy: 0.888
Epoch: 511. Loss: 0.009927055798470974. Rolling Accuracy: 0.884
Epoch: 512. Loss: 0.008205131627619267. Rolling Accuracy: 0.892
Epoch: 513. Loss: 0.006350950337946415. Rolling Accuracy: 0.892
Epoch: 514. Loss: 0.007535705808550119. Rolling Accuracy: 0.892
Epoch: 515. Loss: 0.010248412378132343. Rolling Accuracy: 0.892
Epoch: 516. Loss: 0.006442480720579624. Rolling Accuracy: 0.884
Epoch: 517. Loss: 0.006856301333755255. Rolling Accuracy: 0.888
Epoch: 518. Loss: 0.00923363771289587. Rolling Accuracy: 0.892
Epoch: 519. Loss: 0.00902984756976366. Rolling Accuracy: 0.892
Epoch: 520. Loss: 0.00747970724478364. Rolling Accuracy: 0.892
Epoch: 521. Loss: 0.006776456255465746. Rolling Accuracy: 0.897
Epoch: 522. Loss: 0.006859208457171917. Rolling Accuracy: 0.892
Epoch: 523. Loss: 0.006673215422779322. Rollin

Epoch: 637. Loss: 0.002760716248303652. Rolling Accuracy: 0.897
Epoch: 638. Loss: 0.0026690897066146135. Rolling Accuracy: 0.897
Epoch: 639. Loss: 0.00256914971396327. Rolling Accuracy: 0.897
Epoch: 640. Loss: 0.0024015307426452637. Rolling Accuracy: 0.888
Epoch: 641. Loss: 0.0035642858128994703. Rolling Accuracy: 0.884
Epoch: 642. Loss: 0.010062525048851967. Rolling Accuracy: 0.892
Epoch: 643. Loss: 0.008472559042274952. Rolling Accuracy: 0.888
Epoch: 644. Loss: 0.007695671170949936. Rolling Accuracy: 0.896
Epoch: 645. Loss: 0.008950620889663696. Rolling Accuracy: 0.91
Epoch: 646. Loss: 0.008067918941378593. Rolling Accuracy: 0.905
Epoch: 647. Loss: 0.006543498486280441. Rolling Accuracy: 0.914
Saving model at epoch 647.
Epoch: 648. Loss: 0.00457717152312398. Rolling Accuracy: 0.918
Saving model at epoch 648.
Epoch: 649. Loss: 0.004047691822052002. Rolling Accuracy: 0.914
Epoch: 650. Loss: 0.003933355212211609. Rolling Accuracy: 0.91
Epoch: 651. Loss: 0.003404794028028846. Rolling Acc

Epoch: 764. Loss: 0.0013345389161258936. Rolling Accuracy: 0.914
Epoch: 765. Loss: 0.0012117968872189522. Rolling Accuracy: 0.918
Epoch: 766. Loss: 0.000891120929736644. Rolling Accuracy: 0.918
Epoch: 767. Loss: 0.001162125961855054. Rolling Accuracy: 0.918
Epoch: 768. Loss: 0.0015862485161051154. Rolling Accuracy: 0.914
Epoch: 769. Loss: 0.0016406329814344645. Rolling Accuracy: 0.906
Epoch: 770. Loss: 0.0015745424898341298. Rolling Accuracy: 0.91
Epoch: 771. Loss: 0.0012118435697630048. Rolling Accuracy: 0.91
Epoch: 772. Loss: 0.0008953322540037334. Rolling Accuracy: 0.914
Epoch: 773. Loss: 0.0018783779814839363. Rolling Accuracy: 0.923
Epoch: 774. Loss: 0.0018575765425339341. Rolling Accuracy: 0.918
Epoch: 775. Loss: 0.001374130486510694. Rolling Accuracy: 0.914
Epoch: 776. Loss: 0.0014806464314460754. Rolling Accuracy: 0.91
Epoch: 777. Loss: 0.00185579271055758. Rolling Accuracy: 0.901
Epoch: 778. Loss: 0.001892650849185884. Rolling Accuracy: 0.897
Epoch: 779. Loss: 0.00172662921249

Epoch: 892. Loss: 0.005475573241710663. Rolling Accuracy: 0.897
Epoch: 893. Loss: 0.005630459636449814. Rolling Accuracy: 0.897
Epoch: 894. Loss: 0.0036775260232388973. Rolling Accuracy: 0.897
Epoch: 895. Loss: 0.0024385633878409863. Rolling Accuracy: 0.897
Epoch: 896. Loss: 0.0019204235868528485. Rolling Accuracy: 0.901
Epoch: 897. Loss: 0.0013921783538535237. Rolling Accuracy: 0.905
Epoch: 898. Loss: 0.0009997597662732005. Rolling Accuracy: 0.91
Epoch: 899. Loss: 0.0013093978632241488. Rolling Accuracy: 0.91
Epoch: 900. Loss: 0.002462293254211545. Rolling Accuracy: 0.905
Epoch: 901. Loss: 0.0026146459858864546. Rolling Accuracy: 0.901
Epoch: 902. Loss: 0.0020729186944663525. Rolling Accuracy: 0.897
Epoch: 903. Loss: 0.0016028985846787691. Rolling Accuracy: 0.897
Epoch: 904. Loss: 0.0011649971129372716. Rolling Accuracy: 0.897
Epoch: 905. Loss: 0.0008536564419046044. Rolling Accuracy: 0.897
Epoch: 906. Loss: 0.0006923811160959303. Rolling Accuracy: 0.897
Epoch: 907. Loss: 0.0005823411

In [11]:
wfm = torch.load('wfm.pt')

acc,correct,incorrect = evaluate_accuracy(test_x, test_y, wfm)
incorrect_fire = int(sum(test_y[incorrect])) # fires incorrectly identified as no fire
incorrect_no_fire = len(test_y[incorrect]) - incorrect_fire # no fires incorrectly identified as fire

print(f"Model accuracy on test data: {acc}")
print(f"{incorrect_fire} 'fires' identified as 'no fires' on test data.")
print(f"{incorrect_no_fire} 'no fires' identified as 'fires' on test data.")

extra_acc,extra_correct,extra_incorrect = evaluate_accuracy(extra_x, extra_y, wfm)
extra_incorrect_fire = int(sum(extra_y[extra_incorrect])) # fires incorrectly identified as no fire
extra_incorrect_no_fire = len(extra_y[extra_incorrect]) - extra_incorrect_fire # no fires incorrectly identified as fire

print(f"Model accuracy on extra data: {extra_acc}")
print(f"{extra_incorrect_fire} 'fires' identified as 'no fires' on extra data.")
print(f"{extra_incorrect_no_fire} 'no fires' identified as 'fires' on extra data.")

Model accuracy on test data: 0.948
2 'fires' identified as 'no fires' on test data.
1 'no fires' identified as 'fires' on test data.
Model accuracy on extra data: 0.899
0 'fires' identified as 'no fires' on extra data.
42 'no fires' identified as 'fires' on extra data.
