In [1]:
import os
from skimage import io, transform
import numpy as np
from tqdm import tqdm
from model import FireNet
from data_augmentation import prepare_dataset

import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

from torch.utils.tensorboard import SummaryWriter

In [None]:
prepare_dataset()

DOWNLOADING TRAINING DATASET

Downloading: 0% [0 / 510779994] bytes
Downloading: 0% [8192 / 510779994] bytes
Downloading: 0% [16384 / 510779994] bytes
Downloading: 0% [24576 / 510779994] bytes
Downloading: 0% [32768 / 510779994] bytes
Downloading: 0% [40960 / 510779994] bytes
Downloading: 0% [49152 / 510779994] bytes
Downloading: 0% [57344 / 510779994] bytes
Downloading: 0% [65536 / 510779994] bytes
Downloading: 0% [73728 / 510779994] bytes
Downloading: 0% [81920 / 510779994] bytes
Downloading: 0% [90112 / 510779994] bytes
Downloading: 0% [98304 / 510779994] bytes
Downloading: 0% [106496 / 510779994] bytes
Downloading: 0% [114688 / 510779994] bytes
Downloading: 0% [122880 / 510779994] bytes
Downloading: 0% [131072 / 510779994] bytes
Downloading: 0% [139264 / 510779994] bytes
Downloading: 0% [147456 / 510779994] bytes
Downloading: 0% [155648 / 510779994] bytes
Downloading: 0% [163840 / 510779994] bytes
Downloading: 0% [172032 / 510779994] bytes
Downloading: 0% [180224 / 510779994] bytes

In [2]:
writer = SummaryWriter('runs/firenet_experiment_1')

In [3]:
print(torch.cuda.is_available())
net = FireNet()
device = "cuda" if torch.cuda.is_available() else "cpu"
net.to(device)

True


FireNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (drop1): Dropout(p=0.5, inplace=False)
  (drop2): Dropout(p=0.2, inplace=False)
  (dense1): Linear(in_features=2304, out_features=256, bias=True)
  (dense2): Linear(in_features=256, out_features=128, bias=True)
  (dense3): Linear(in_features=128, out_features=2, bias=True)
)

In [4]:
TRAINING_PATH='training_dataset'
CATEGORIES = ['Fire', 'NoFire']

In [5]:
class TrainingSet(Dataset):
    def __init__(self, transform=None):
        self.transform = transform
        classes = []
        images = []
        for category in CATEGORIES:
            path = os.path.join(TRAINING_PATH, category)
            class_num = CATEGORIES.index(category)
            
            for img in tqdm(os.listdir(path)):
                try:
                    image = io.imread(os.path.join(path, img))
                    if (image.shape[2] == 3):
                        if (self.transform is not None):
                            image = self.transform(image)
                            images.append(image)
                            classes.append(class_num)
                except Exception as e:
                    pass
        self.set = {'image': images, 'class': classes}
    def __len__(self):
        return len(self.set['class'])
    def __getitem__(self, idx):
        image = self.set['image'][idx]
        classe = self.set['class'][idx]
        sample = {'image': image, 'class': classe}
        return sample
            
        

In [6]:
training_set= TrainingSet(transform=transforms.ToTensor())
print(len(training_set))

100%|██████████| 4444/4444 [00:01<00:00, 2573.05it/s]
100%|██████████| 5172/5172 [00:01<00:00, 2672.82it/s]

9616





In [7]:
class TestingSet(Dataset):
    def __init__(self, transform=None):
        self.transform = transform
        classes = []
        images = []
        for category in CATEGORIES:
            path = os.path.join('testing_dataset', category)
            class_num = CATEGORIES.index(category)
            
            for img in tqdm(os.listdir(path)):
                try:
                    image = io.imread(os.path.join(path, img))
                    if (image.shape[2] == 3):
                        if (self.transform):
                            image = self.transform(image)
                            images.append(image)
                            classes.append(class_num)
                except Exception as e:
                    pass
        self.set = {'image': images, 'class': classes}
    def __len__(self):
        return len(self.set['class'])
    def __getitem__(self, idx):
        image = self.set['image'][idx]
        classe = self.set['class'][idx]
        sample = {'image': image, 'class': classe}
        return sample
            
        

In [8]:
test_set = TestingSet(transform=transforms.ToTensor())
print(len(test_set))

100%|██████████| 400/400 [00:00<00:00, 2595.77it/s]
100%|██████████| 278/278 [00:00<00:00, 2600.97it/s]

678





In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), eps=1e-7)

In [10]:
trainloader = DataLoader(training_set, batch_size=32, shuffle=True, num_workers=4)
net = net.float()
net.train()

FireNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (drop1): Dropout(p=0.5, inplace=False)
  (drop2): Dropout(p=0.2, inplace=False)
  (dense1): Linear(in_features=2304, out_features=256, bias=True)
  (dense2): Linear(in_features=256, out_features=128, bias=True)
  (dense3): Linear(in_features=128, out_features=2, bias=True)
)

In [11]:
len(trainloader)

301

In [12]:
for epoch in range(100):
    global_loss = 0.0
    for data in tqdm(trainloader):
        inputs, labels = data['image'].to(device), data['class'].to(device)

        optimizer.zero_grad()

        outputs = net(inputs.float())
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        global_loss += loss.item()
    print('global loss for epoch %d : %.3f' % (epoch + 1, global_loss/ len(trainloader)))
    writer.add_scalar('Training loss 100', global_loss / len(trainloader), epoch + 1)

100%|██████████| 301/301 [00:02<00:00, 144.17it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 1 : 0.515


100%|██████████| 301/301 [00:02<00:00, 145.98it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 2 : 0.432


100%|██████████| 301/301 [00:02<00:00, 142.84it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 3 : 0.379


100%|██████████| 301/301 [00:02<00:00, 144.66it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 4 : 0.329


100%|██████████| 301/301 [00:02<00:00, 144.85it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 5 : 0.293


100%|██████████| 301/301 [00:02<00:00, 145.33it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 6 : 0.257


100%|██████████| 301/301 [00:02<00:00, 144.38it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 7 : 0.242


100%|██████████| 301/301 [00:02<00:00, 144.34it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 8 : 0.214


100%|██████████| 301/301 [00:02<00:00, 142.79it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 9 : 0.188


100%|██████████| 301/301 [00:02<00:00, 145.19it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 10 : 0.182


100%|██████████| 301/301 [00:02<00:00, 146.89it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 11 : 0.165


100%|██████████| 301/301 [00:02<00:00, 144.45it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 12 : 0.148


100%|██████████| 301/301 [00:02<00:00, 145.33it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 13 : 0.145


100%|██████████| 301/301 [00:02<00:00, 144.52it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 14 : 0.119


100%|██████████| 301/301 [00:02<00:00, 143.30it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 15 : 0.126


100%|██████████| 301/301 [00:02<00:00, 144.65it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 16 : 0.110


100%|██████████| 301/301 [00:02<00:00, 144.27it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 17 : 0.100


100%|██████████| 301/301 [00:02<00:00, 143.81it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 18 : 0.088


100%|██████████| 301/301 [00:02<00:00, 142.91it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 19 : 0.090


100%|██████████| 301/301 [00:02<00:00, 140.42it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 20 : 0.085


100%|██████████| 301/301 [00:02<00:00, 147.99it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 21 : 0.072


100%|██████████| 301/301 [00:01<00:00, 153.19it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 22 : 0.077


100%|██████████| 301/301 [00:02<00:00, 147.64it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 23 : 0.073


100%|██████████| 301/301 [00:01<00:00, 155.41it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 24 : 0.064


100%|██████████| 301/301 [00:01<00:00, 157.58it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 25 : 0.059


100%|██████████| 301/301 [00:01<00:00, 157.34it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 26 : 0.067


100%|██████████| 301/301 [00:01<00:00, 156.43it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 27 : 0.056


100%|██████████| 301/301 [00:01<00:00, 158.78it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 28 : 0.054


100%|██████████| 301/301 [00:01<00:00, 157.30it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 29 : 0.064


100%|██████████| 301/301 [00:01<00:00, 157.96it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 30 : 0.052


100%|██████████| 301/301 [00:01<00:00, 155.24it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 31 : 0.045


100%|██████████| 301/301 [00:01<00:00, 155.78it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 32 : 0.048


100%|██████████| 301/301 [00:01<00:00, 156.15it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 33 : 0.059


100%|██████████| 301/301 [00:01<00:00, 155.12it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 34 : 0.046


100%|██████████| 301/301 [00:01<00:00, 157.22it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 35 : 0.043


100%|██████████| 301/301 [00:02<00:00, 146.62it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 36 : 0.041


100%|██████████| 301/301 [00:02<00:00, 148.89it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 37 : 0.042


100%|██████████| 301/301 [00:02<00:00, 149.54it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 38 : 0.043


100%|██████████| 301/301 [00:02<00:00, 149.61it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 39 : 0.043


100%|██████████| 301/301 [00:01<00:00, 155.66it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 40 : 0.042


100%|██████████| 301/301 [00:01<00:00, 151.42it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 41 : 0.039


100%|██████████| 301/301 [00:01<00:00, 153.79it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 42 : 0.042


100%|██████████| 301/301 [00:01<00:00, 154.86it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 43 : 0.045


100%|██████████| 301/301 [00:01<00:00, 157.37it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 44 : 0.048


100%|██████████| 301/301 [00:01<00:00, 151.63it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 45 : 0.036


100%|██████████| 301/301 [00:02<00:00, 146.06it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 46 : 0.031


100%|██████████| 301/301 [00:02<00:00, 145.58it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 47 : 0.037


100%|██████████| 301/301 [00:01<00:00, 154.04it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 48 : 0.035


100%|██████████| 301/301 [00:01<00:00, 155.72it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 49 : 0.035


100%|██████████| 301/301 [00:01<00:00, 157.41it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 50 : 0.037


100%|██████████| 301/301 [00:01<00:00, 156.81it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 51 : 0.044


100%|██████████| 301/301 [00:01<00:00, 156.36it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 52 : 0.024


100%|██████████| 301/301 [00:01<00:00, 154.91it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 53 : 0.028


100%|██████████| 301/301 [00:01<00:00, 154.75it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 54 : 0.035


100%|██████████| 301/301 [00:01<00:00, 158.23it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 55 : 0.035


100%|██████████| 301/301 [00:01<00:00, 160.51it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 56 : 0.024


100%|██████████| 301/301 [00:01<00:00, 159.21it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 57 : 0.028


100%|██████████| 301/301 [00:01<00:00, 157.94it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 58 : 0.030


100%|██████████| 301/301 [00:01<00:00, 157.97it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 59 : 0.026


100%|██████████| 301/301 [00:01<00:00, 155.69it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 60 : 0.028


100%|██████████| 301/301 [00:01<00:00, 154.74it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 61 : 0.036


100%|██████████| 301/301 [00:02<00:00, 144.67it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 62 : 0.032


100%|██████████| 301/301 [00:01<00:00, 153.07it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 63 : 0.034


100%|██████████| 301/301 [00:01<00:00, 152.73it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 64 : 0.028


100%|██████████| 301/301 [00:02<00:00, 149.15it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 65 : 0.024


100%|██████████| 301/301 [00:01<00:00, 154.46it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 66 : 0.027


100%|██████████| 301/301 [00:01<00:00, 155.04it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 67 : 0.022


100%|██████████| 301/301 [00:01<00:00, 155.86it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 68 : 0.024


100%|██████████| 301/301 [00:01<00:00, 155.35it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 69 : 0.023


100%|██████████| 301/301 [00:02<00:00, 148.20it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 70 : 0.021


100%|██████████| 301/301 [00:02<00:00, 148.52it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 71 : 0.034


100%|██████████| 301/301 [00:01<00:00, 156.69it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 72 : 0.026


100%|██████████| 301/301 [00:01<00:00, 156.24it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 73 : 0.024


100%|██████████| 301/301 [00:01<00:00, 153.39it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 74 : 0.022


100%|██████████| 301/301 [00:02<00:00, 132.70it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 75 : 0.020


100%|██████████| 301/301 [00:02<00:00, 146.65it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 76 : 0.023


100%|██████████| 301/301 [00:01<00:00, 157.32it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 77 : 0.021


100%|██████████| 301/301 [00:01<00:00, 157.91it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 78 : 0.026


100%|██████████| 301/301 [00:01<00:00, 160.73it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 79 : 0.019


100%|██████████| 301/301 [00:01<00:00, 154.35it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 80 : 0.026


100%|██████████| 301/301 [00:01<00:00, 157.10it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 81 : 0.027


100%|██████████| 301/301 [00:01<00:00, 155.82it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 82 : 0.027


100%|██████████| 301/301 [00:01<00:00, 156.08it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 83 : 0.018


100%|██████████| 301/301 [00:01<00:00, 157.66it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 84 : 0.020


100%|██████████| 301/301 [00:01<00:00, 154.43it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 85 : 0.018


100%|██████████| 301/301 [00:02<00:00, 148.32it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 86 : 0.028


100%|██████████| 301/301 [00:01<00:00, 156.30it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 87 : 0.029


100%|██████████| 301/301 [00:01<00:00, 156.10it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 88 : 0.021


100%|██████████| 301/301 [00:02<00:00, 149.60it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 89 : 0.013


100%|██████████| 301/301 [00:01<00:00, 151.35it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 90 : 0.022


100%|██████████| 301/301 [00:01<00:00, 157.85it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 91 : 0.018


100%|██████████| 301/301 [00:01<00:00, 153.55it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 92 : 0.023


100%|██████████| 301/301 [00:02<00:00, 149.53it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 93 : 0.017


100%|██████████| 301/301 [00:01<00:00, 155.74it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 94 : 0.014


100%|██████████| 301/301 [00:01<00:00, 157.69it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 95 : 0.028


100%|██████████| 301/301 [00:01<00:00, 158.96it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 96 : 0.023


100%|██████████| 301/301 [00:01<00:00, 156.06it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 97 : 0.019


100%|██████████| 301/301 [00:01<00:00, 157.18it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 98 : 0.020


100%|██████████| 301/301 [00:01<00:00, 154.53it/s]
  0%|          | 0/301 [00:00<?, ?it/s]

global loss for epoch 99 : 0.013


100%|██████████| 301/301 [00:01<00:00, 154.50it/s]

global loss for epoch 100 : 0.023





In [13]:
torch.save(net.state_dict(), './trained_weights100.pth')

In [14]:
net = FireNet()
net.float()
net.cuda()
net.load_state_dict(torch.load('./trained_weights100.pth'))
net.eval()

FireNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (drop1): Dropout(p=0.5, inplace=False)
  (drop2): Dropout(p=0.2, inplace=False)
  (dense1): Linear(in_features=2304, out_features=256, bias=True)
  (dense2): Linear(in_features=256, out_features=128, bias=True)
  (dense3): Linear(in_features=128, out_features=2, bias=True)
)

In [15]:
trainloader = DataLoader(training_set, batch_size=4, shuffle=False, num_workers=0)
net = net.float()
net.eval()

FireNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (drop1): Dropout(p=0.5, inplace=False)
  (drop2): Dropout(p=0.2, inplace=False)
  (dense1): Linear(in_features=2304, out_features=256, bias=True)
  (dense2): Linear(in_features=256, out_features=128, bias=True)
  (dense3): Linear(in_features=128, out_features=2, bias=True)
)

In [16]:
correct = 0
total = 0
with torch.no_grad():
    print('evaluate accuracy on training set:')
    for data in tqdm(trainloader):
        images, labels = data['image'].to(device), data['class'].to(device)
        outputs = net(images.float())
        _, predicted = torch.max(outputs.data, 1)
        total += len(labels)
        correct += (predicted == labels).sum().item()
print ('accuracy: %d' % (100*correct/total))

  6%|▌         | 133/2404 [00:00<00:01, 1328.40it/s]

evaluate accuracy on training set:


100%|██████████| 2404/2404 [00:01<00:00, 1342.64it/s]

accuracy: 100





In [17]:
testloader = DataLoader(test_set, batch_size=4, shuffle=False, num_workers=0)
net = net.float()
net.eval()

FireNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (drop1): Dropout(p=0.5, inplace=False)
  (drop2): Dropout(p=0.2, inplace=False)
  (dense1): Linear(in_features=2304, out_features=256, bias=True)
  (dense2): Linear(in_features=256, out_features=128, bias=True)
  (dense3): Linear(in_features=128, out_features=2, bias=True)
)

In [18]:
len(testloader)

170

In [19]:
correct = 0
total = 0
with torch.no_grad():
    print('evaluate accuracy on training set:')
    for data in tqdm(testloader):
        images, labels = data['image'].to(device), data['class'].to(device)
        outputs = net(images.float())
        _, predicted = torch.max(outputs.data, 1)
        total += len(labels)
        correct += (predicted == labels).sum().item()
print ('accuracy: %d' % (100*correct/total))

100%|██████████| 170/170 [00:00<00:00, 1256.80it/s]

evaluate accuracy on training set:
accuracy: 81



