In [None]:
import os
import sys
import time 
from random import randrange
import json
import copy
from typing import List
from typing import Tuple


import numpy as np
import cv2
from matplotlib import pyplot as plt
import torch
from torch import nn, optim
import torch.nn.functional as F


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cpu_device = torch.device('cpu')
device  # device to use for calculations

# Data files path and dataset definitions

In [None]:
PROJECT_DIR_PATH = '.'
DATA_DIR_PATH = os.path.join(PROJECT_DIR_PATH, 'data')
LABELS_DIR = os.path.join(PROJECT_DIR_PATH, 'labels')

In [None]:
# Directories with data to be used for each dataset

TRAINING_DIRS_1 = [
    '006__11_44_59',
    '007__11_48_59',
    '008__11_52_59',
    '009__11_57_00',
     
    '000__14_15_19',
    '001__14_19_19',
    '002__14_23_19',
    '003__14_27_20',
    '004__14_31_20',
    
    '012__15_03_21',
    '013__15_07_21',
    '014__15_11_21',
    '015__15_15_21',
    '016__15_19_21',
    
    '011__13_38_20',
    '012__13_42_20',
    '013__13_46_21',
    
    '007__13_22_20',
    ]


VALIDATION_DIRS_1 = [
    '004__13_10_20',

    '014__13_50_21',
    
    '005__14_35_20',
    '006__14_39_20',
    '007__14_43_20',
    '008__14_47_20',
]


TEST_DIRS_1 = [
    '008__13_26_20',
    
    '009__14_51_20',
    '010__14_55_20',
    '011__14_59_20',

    '015__13_54_21',
]

# Miscellaneous definitions and data reading functions

(nothing really interesting in this section)

In [None]:
IR_CAMERA_RESOLUTION_X = 32
IR_CAMERA_RESOLUTION_Y = 24

IR_CAMERA_RESOLUTION = (IR_CAMERA_RESOLUTION_Y, IR_CAMERA_RESOLUTION_X)

IR_CAMERA_RESOLUTION_XY = (IR_CAMERA_RESOLUTION_X, IR_CAMERA_RESOLUTION_Y)  # for opencv functions

# for frames normalization
TEMPERATURE_NORMALIZATION__MIN = 20
TEMPERATURE_NORMALIZATION__MAX = 35

IR_FRAME_RESIZE_MULTIPLIER = 1

MIN_TEMPERATURE_ON_PLOT = 20  # None for auto range
MAX_TEMPERATURE_ON_PLOT = 30  # None for auto range
IR_FRAME_INTERPOLATION_METHOD = cv2.INTER_CUBIC
RGB_FRAME_RESIZE_MULTIPLIER = 2

In [None]:
class IrDataCsvReader:
    def __init__(self, file_path):
        self._file_path = file_path
        self._frames, self._raw_frame_data = self.get_frames_from_file(file_path)

    @staticmethod
    def get_frames_from_file(file_path):
        frames = []
        raw_frame_data = []
        with open(file_path, 'r') as file:
            raw_lines = file.readlines()
        lines = [x.strip() for x in raw_lines]
        frame_lines = lines[1:]
        for frame_line in frame_lines:
            raw_frame_data.append(frame_line)
            line_parts = frame_line.split(',')
            frame_data_str = line_parts[2:]
            frame_data_1d = [float(x) for x in frame_data_str]
            frame_2d = np.reshape(frame_data_1d, IR_CAMERA_RESOLUTION)
            frames.append(frame_2d)
        return frames, raw_frame_data

    def get_number_of_frames(self):
        return len(self._frames)

    def get_frame(self, n):
        return self._frames[n]

    def get_raw_frame_data(self, n) -> str:
        return self._raw_frame_data[n]



class BatchTrainingData:
    """
    Stores training data for one batch
    """
    def __init__(self, min_temperature=TEMPERATURE_NORMALIZATION__MIN, max_temperature=TEMPERATURE_NORMALIZATION__MAX):
        self.centre_points = []  # type: List[List[tuple]]
        self.raw_ir_data = []  # type: List[np.ndarray]
        self.normalized_ir_data = []  # type: List[np.ndarray]  # same data as raw_ir_data, but normalized

        self.min_temperature = min_temperature
        self.max_temperature = max_temperature

    def append_frame_data(self, centre_points, raw_ir_data):
        self.centre_points.append(centre_points)
        self.raw_ir_data.append(raw_ir_data)

        ir_data_normalized = (raw_ir_data - self.min_temperature) * (1 / (self.max_temperature - self.min_temperature))
        self.normalized_ir_data.append(ir_data_normalized)

    def flip_horizontally(self):
        for i in range(len(self.raw_ir_data)):
            self.raw_ir_data[i] = np.flip(self.raw_ir_data[i], 1)
            self.normalized_ir_data[i] = np.flip(self.normalized_ir_data[i], 1)
            for j in range(len(self.centre_points[i])):
                x_flipped = IR_CAMERA_RESOLUTION_X - self.centre_points[i][j][0]
                self.centre_points[i][j] = (x_flipped, self.centre_points[i][j][1])

    def flip_vertically(self):
        for i in range(len(self.raw_ir_data)):
            self.raw_ir_data[i] = np.flip(self.raw_ir_data[i], 0)
            self.normalized_ir_data[i] = np.flip(self.normalized_ir_data[i], 0)
            for j in range(len(self.centre_points[i])):
                y_flipped = IR_CAMERA_RESOLUTION_Y - self.centre_points[i][j][1]
                self.centre_points[i][j] = (self.centre_points[i][j][0], y_flipped)


class AugmentedBatchesTrainingData:
    """
    Stores training data for all batches, with data augmentation
    """
    def __init__(self):
        self.batches = []  # Type: List[BatchTrainingData]

    def add_training_batch(self, batch: BatchTrainingData, flip_and_rotate=True):
        self.batches.append(copy.deepcopy(batch))  # plain data

        if flip_and_rotate:
            batch.flip_horizontally()
            self.batches.append(copy.deepcopy(batch))  # flipped horizontally

            batch.flip_vertically()
            self.batches.append(copy.deepcopy(batch))  # rotated 180 degrees

            batch.flip_horizontally()
            self.batches.append(copy.deepcopy(batch))  # rotated vertically

    def print_stats(self):
        total_number_of_frames = 0
        number_of_frames_with_n_persons = {}
        for batch in self.batches:
            total_number_of_frames += len(batch.centre_points)
            for centre_points in batch.centre_points:
                number_of_persons = len(centre_points)
                number_of_frames_with_n_persons[number_of_persons] = \
                    number_of_frames_with_n_persons.get(number_of_persons, 0) + 1
        frames_persons_details_msg = '\n'.join([f'   {count} frames with {no} persons'
                                               for no, count in number_of_frames_with_n_persons.items()])
        msg = f"AugmentedBatchesTrainingData with {len(self.batches)} BatchTrainingData batches.\n" \
              f"Total number of frames after augmentation: {total_number_of_frames}, with:\n" \
              f"{frames_persons_details_msg}"
        print(msg)


def load_data_for_labeled_batches(labeled_batch_dirs) -> BatchTrainingData:

    training_data = BatchTrainingData()
    for batch_subdir in labeled_batch_dirs:
        data_batch_dir_path = os.path.join(DATA_DIR_PATH, batch_subdir)
        raw_ir_data_csv_file_path = os.path.join(data_batch_dir_path, 'ir.csv')
        output_file_with_labels_name = batch_subdir.replace('/', '--') + '.csv'
        annotation_data_file_path = os.path.join(LABELS_DIR, output_file_with_labels_name)

        raw_ir_data_csv_reader = IrDataCsvReader(file_path=raw_ir_data_csv_file_path)
        annotations_collector = AnnotationCollector.load_from_file(

            file_path=annotation_data_file_path, do_not_scale_and_reverse=True)

        for frame_index in range(raw_ir_data_csv_reader.get_number_of_frames()):
            raw_frame_data = raw_ir_data_csv_reader.get_frame(frame_index)
            frame_annotations = annotations_collector.get_annotation(frame_index)
            if not frame_annotations.accepted:
                print(f"Frame index {frame_index} from batch '{batch_subdir}' not annotated!")
                continue
            training_data.append_frame_data(
                centre_points=frame_annotations.centre_points,
                raw_ir_data=raw_frame_data)

    return training_data



class AnnotationCollector:
    ANNOTATIONS_BETWEEN_AUTOSAVE = 10

    def __init__(self, output_file_path, data_batch_dir_path):
        self._output_file_path = output_file_path
        self._annotations = {}  # ir_frame_index: FrameAnnotation
        self._data_batch_dir_path = data_batch_dir_path
        self._annotations_to_autosave = 1

    def get_annotation(self, ir_frame_index):
        return self._annotations.get(ir_frame_index, FrameAnnotation())

    def set_annotation(self, ir_frame_index, annotation):
        self._annotations[ir_frame_index] = annotation
        if annotation.accepted or annotation.discarded:
            self._annotations_to_autosave -= 1
            if self._annotations_to_autosave == 0:
                self._annotations_to_autosave = self.ANNOTATIONS_BETWEEN_AUTOSAVE
                self.save()

    def save(self):
        data_dict = {
            'output_file_path': self._output_file_path,
            'data_batch_dir_path': self._data_batch_dir_path,
            'annotations': {index: annotation.as_dict()
                            for index, annotation in self._annotations.items()}
        }
        with open(self._output_file_path, 'w') as file:
            file.write(json.dumps(data_dict, indent=2))
            file.flush()

    @classmethod
    def load_from_file(cls, file_path, do_not_scale_and_reverse=False):
        item = cls(output_file_path=file_path, data_batch_dir_path=None)
        with open(file_path, 'r') as file:
            data = file.read()
        data_dict = json.loads(data)
        item._data_batch_dir_path = data_dict['data_batch_dir_path']
        item._annotations = {int(index): FrameAnnotation.from_dict(annotation_dict, do_not_scale_and_reverse) for index, annotation_dict
                             in data_dict['annotations'].items()}
        return item


def x_on_interpolated_image_to_raw_x(x):
    x_raw_flipped = x / IR_FRAME_RESIZE_MULTIPLIER
    x_raw = IR_CAMERA_RESOLUTION_X - x_raw_flipped
    return x_raw


def y_on_interpolated_image_to_raw_y(y):
    return y / IR_FRAME_RESIZE_MULTIPLIER


def xy_on_interpolated_image_to_raw_xy(xy: tuple) -> tuple:
    return (x_on_interpolated_image_to_raw_x(xy[0]),
            y_on_interpolated_image_to_raw_y(xy[1]))


class FrameAnnotation:
    def __init__(self):
        self.accepted = False  # whether frame was marked as annotated successfully
        self.discarded = False  # whether frame was marked as discarded (ignored)
        self.centre_points = []  # type: List[tuple]  # x, y
        self.rectangles = []  # type: List[Tuple[tuple, tuple]]  # (x_left, y_top), (x_right, y_bottom)

        self.raw_frame_data = None  # Not an annotation, but write it to the result file, just in case

    def as_dict(self):
        data_dict = copy.copy(self.__dict__)
        data_dict['centre_points'] = []
        data_dict['rectangles'] = []

        for i, point in enumerate(self.centre_points):
            data_dict['centre_points'].append(xy_on_interpolated_image_to_raw_xy(point))
        for i, rectangle in enumerate(self.rectangles):
            data_dict['rectangles'].append((xy_on_interpolated_image_to_raw_xy(rectangle[0]),
                                            xy_on_interpolated_image_to_raw_xy(rectangle[1])))
        return data_dict

    @classmethod
    def from_dict(cls, data_dict, do_not_scale_and_reverse=False):
        item = cls()
        item.__dict__.update(data_dict)
        for i, point in enumerate(item.centre_points):
            item.centre_points[i] = xy_on_raw_image_to_xy_on_interpolated_image(point, do_not_scale_and_reverse)
        for i, rectangle in enumerate(item.rectangles):
            item.rectangles[i] = (xy_on_raw_image_to_xy_on_interpolated_image(rectangle[0], do_not_scale_and_reverse),
                                  xy_on_raw_image_to_xy_on_interpolated_image(rectangle[1], do_not_scale_and_reverse))
        return item


def x_on_interpolated_image_to_raw_x(x):
    x_raw_flipped = x / IR_FRAME_RESIZE_MULTIPLIER
    x_raw = IR_CAMERA_RESOLUTION_X - x_raw_flipped
    return x_raw


def y_on_interpolated_image_to_raw_y(y):
    return y / IR_FRAME_RESIZE_MULTIPLIER


def xy_on_interpolated_image_to_raw_xy(xy: tuple) -> tuple:
    return (x_on_interpolated_image_to_raw_x(xy[0]),
            y_on_interpolated_image_to_raw_y(xy[1]))


def x_on_raw_image_to_x_on_interpolated_image(x):
    x_flipped = IR_CAMERA_RESOLUTION_X - x
    return round(x_flipped * IR_FRAME_RESIZE_MULTIPLIER)


def y_on_raw_image_to_y_on_interpolated_image(y):
    return round(y * IR_FRAME_RESIZE_MULTIPLIER)


def xy_on_raw_image_to_xy_on_interpolated_image(xy: tuple, do_not_scale_and_reverse=False) -> tuple:
    if do_not_scale_and_reverse:
        return xy

    return (x_on_raw_image_to_x_on_interpolated_image(xy[0]),
            y_on_raw_image_to_y_on_interpolated_image(xy[1]))


def get_extrapolated_ir_frame_heatmap_flipped(
        frame_2d, multiplier, interpolation, min_temp, max_temp, colormap):
    new_size = (frame_2d.shape[1] * multiplier, frame_2d.shape[0] * multiplier)
    frame_resized_not_clipped = cv2.resize(
        src=frame_2d, dsize=new_size, interpolation=interpolation)

    if min_temp is None:
        min_temp = min(frame_resized_not_clipped.reshape(-1))
    if max_temp is None:
        max_temp = max(frame_resized_not_clipped.reshape(-1))

    frame_resized = np.clip(frame_resized_not_clipped, min_temp, max_temp)
    frame_resized_normalized = (frame_resized - min_temp) * (255 / (max_temp - min_temp))
    frame_resized_normalized_u8 = frame_resized_normalized.astype(np.uint8)
    heatmap_u8 = (colormap(frame_resized_normalized_u8) * 2 ** 8).astype(np.uint8)[:, :, :3]
    heatmap_u8_bgr = cv2.cvtColor(heatmap_u8, cv2.COLOR_RGB2BGR)
    heatmap_u8_bgr_flipped = cv2.flip(heatmap_u8_bgr, 1)  # horizontal flip
    return heatmap_u8_bgr_flipped


In [None]:
def draw_airbrush_circle(img, centre, radius):
    for x in range(max(0, round(centre[0]-radius)), min(img.shape[0], round(centre[0]+radius+1))):
        for y in range(max(0, round(centre[1]-radius)), min(img.shape[1], round(centre[1]+radius+1))):
            point = (x, y)
            distance_to_centre = cv2.norm((centre[0] - x, centre[1] - y))
            if distance_to_centre > radius:
                continue
            img[point] += 1 - distance_to_centre / radius
            

def draw_cross(img, centre, cross_width, cross_height):
    for x in range(max(0, round(centre[0]) - cross_width), min(img.shape[0], round(centre[0]) + cross_width + 1)):
        for y in range(max(0, round(centre[1]) - cross_height), min(img.shape[1], round(centre[1]) + cross_height + 1)):
            point = (x, y)
            img[point] = 1
    
    for x in range(max(0, round(centre[0] - cross_height)), min(img.shape[0], round(centre[0] + cross_height + 1))):
        for y in range(max(0, round(centre[1] - cross_width)), min(img.shape[1], round(centre[1] + cross_width + 1))):
            point = (x, y)
            img[point] = 1
            

def gauss_1d(x, sig):
    return np.exp(-np.power(x, 2.) / (2 * np.power(sig, 2.)))
    
    
def draw_gauss(img, sig, centre):
    radius = 3 * sig
    for x in range(max(0, round(centre[0]-radius)), min(img.shape[0], round(centre[0]+radius+1))):
        for y in range(max(0, round(centre[1]-radius)), min(img.shape[1], round(centre[1]+radius+1))):
            point = (x, y)
            distance_to_centre = cv2.norm((centre[0] - x, centre[1] - y))
            if distance_to_centre > radius:
                continue
            gauss_value = gauss_1d(distance_to_centre, sig)
            img[point] += gauss_value
    
    
def get_img_reconstructed_from_labels(centre_points):
    """ Function used to create training data basing on the labels """
    
    img_reconstructed = np.zeros(shape=(IR_CAMERA_RESOLUTION[0], 
                                        IR_CAMERA_RESOLUTION[1]))

    for centre_point in centre_points:
        centre_point = centre_point[::-1]  # reversed x and y in 
        draw_gauss(img=img_reconstructed, 
                   centre=[c for c in centre_point], 
                   sig=3)
    
    #img_int = (img_reconstructed * (NUMBER_OF_OUPUT_CLASSES-1)).astype('int')
    #return img_int
    return img_reconstructed


# For Gauss function with sig=3, sum of values for one person is between 45 - 55, depending whether people are on the edge. 
# To predict the number of people on a frame as a sum of pixels of the predicted image, we need to divide every pixel by a constant   
sum_of_values_for_one_person = 51.35  # total sum of pixels for one person on the reconstructed image, of course it changes with gaussian function parameters. Calculated as average from all training data


In [None]:
_training_data_1 = load_data_for_labeled_batches(labeled_batch_dirs=TRAINING_DIRS_1)
_validation_data_1 = load_data_for_labeled_batches(labeled_batch_dirs=VALIDATION_DIRS_1)
_test_data_1 = load_data_for_labeled_batches(labeled_batch_dirs=TEST_DIRS_1)

augmented_data_training = AugmentedBatchesTrainingData()
augmented_data_training.add_training_batch(_training_data_1)

augmented_data_validation = AugmentedBatchesTrainingData()
augmented_data_validation.add_training_batch(_validation_data_1, flip_and_rotate=False)

augmented_data_test = AugmentedBatchesTrainingData()
augmented_data_test.add_training_batch(_test_data_1, flip_and_rotate=False)

# Datasets statistics

In [None]:
augmented_data_training.print_stats()

In [None]:
augmented_data_validation.print_stats()

In [None]:
augmented_data_test.print_stats()


# Training and validation datasets loaders

In [None]:
class IrPersonsUnetTrainDataset(torch.utils.data.Dataset):
    def __init__(self, augmented_data: AugmentedBatchesTrainingData, transform=None):
        self.augmented_data = AugmentedBatchesTrainingData
        self.transform = transform
        self._index_to_batch_and_subindex_map = {}
        
        self._cache = {}
        
        i = 0
        for batch in augmented_data.batches:
            for j in range(len(batch.raw_ir_data)):
                self._index_to_batch_and_subindex_map[i] = (batch, j) 
                i += 1
        
    def __len__(self):
        return len (self._index_to_batch_and_subindex_map)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            raise Exception("Not supported")
        
        if idx not in self._cache:
            batch, subindex = self._index_to_batch_and_subindex_map[idx]
            frame = batch.normalized_ir_data[subindex][np.newaxis, :, :]

            batch, subindex = self._index_to_batch_and_subindex_map[idx]
            centre_points = batch.centre_points[subindex]
            img_reconstructed = get_img_reconstructed_from_labels(centre_points)
            img_reconstructed_3d = img_reconstructed

            result = frame, img_reconstructed_3d
            self._cache[idx] = result
            
        return self._cache[idx]

    def get_number_of_persons_for_frame(self, idx):
        batch, subindex = self._index_to_batch_and_subindex_map[idx]
        return len(batch.centre_points[subindex])
        
    
class IrPersonsTestDataset(torch.utils.data.Dataset):
    def __init__(self, augmented_data: AugmentedBatchesTrainingData, transform=None):
        self.augmented_data = AugmentedBatchesTrainingData
        self.transform = transform
        self._index_to_batch_and_subindex_map = {}
        
        i = 0
        for batch in augmented_data.batches:
            for j in range(len(batch.raw_ir_data)):
                self._index_to_batch_and_subindex_map[i] = (batch, j) 
                i += 1
        
    def __len__(self):
        return len (self._index_to_batch_and_subindex_map)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            raise Exception("Not supported")

        batch, subindex = self._index_to_batch_and_subindex_map[idx]
        frame = batch.normalized_ir_data[subindex][np.newaxis, :, :]
        return frame, len(batch.centre_points[subindex])

    
    
    
training_dataset = IrPersonsUnetTrainDataset(augmented_data_training)
validation_dataset = IrPersonsUnetTrainDataset(augmented_data_validation)


# it makes no sense to split all data, as most of the frames are almost identical. Instead every dataset consists of distinctive sequences
# training_dataset, validation_dataset = torch.utils.data.random_split(all_data_dataset, [training_data_len, validation_data_len])


trainloader = torch.utils.data.DataLoader(training_dataset, batch_size=16, shuffle=True)
valloader = torch.utils.data.DataLoader(validation_dataset, batch_size=16, shuffle=False)



dataset_with_real_people_count_valid = IrPersonsTestDataset(augmented_data_validation)
loader_with_real_people_count_valid = torch.utils.data.DataLoader(dataset_with_real_people_count_valid, batch_size=16, shuffle=False)


dataset_with_real_people_count_test = IrPersonsTestDataset(augmented_data_test)
loader_with_real_people_count_test = torch.utils.data.DataLoader(dataset_with_real_people_count_test, batch_size=16, shuffle=False)





print('Example training frame:')
xb, yb = next(iter(trainloader))
xb.shape, yb.shape
plt.imshow(yb[0].numpy().squeeze())
print(f'number of persosn based on sum: {np.sum(yb[0].numpy()) / sum_of_values_for_one_person:.3f}')


In [None]:
# Calculate sum of pixels for one person on the image

# total_sum_of_pixels = 0
# total_number_of_people = 0

# for i in range(len(training_dataset)):
#     reconstructed_frame = training_dataset[i][1]
    
#     sum_of_pixels = np.sum(reconstructed_frame)
#     numbe_of_people = training_dataset.get_number_of_persons_for_frame(i)
    
#     plt.imshow(reconstructed_frame)

#     total_sum_of_pixels += sum_of_pixels
#     total_number_of_people += numbe_of_people

# average_pixels_per_person = total_sum_of_pixels / total_number_of_people
# print(f'average_pixels_per_person={average_pixels_per_person}')  # 51.35

# UNET model definition

In [None]:
class AutoEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.encoder = Encoder(in_channels)
        self.conv = DoubleConv(32, 64, 3, 1)
        self.upconv1 = ExpandBlock(64, 32, 3, 1)
        self.upconv2 = ExpandBlock(32, 16, 3, 1)
        self.out_conv = nn.Conv2d(16, out_channels, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # downsampling part
        x, conv2, conv1 = self.encoder(x)
        x = self.conv(x)
        x = self.upconv1(x, conv2)
        x = self.upconv2(x, conv1)
        
        x = self.out_conv(x)
        x = self.relu(x)
        
        x = x[:, 0, :, :]  # get rid of one dimension
        return x


class ExpandBlock(nn.Sequential):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: int):
        super().__init__()
        self.conv_transpose = nn.ConvTranspose2d(
            in_channels, in_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1
        )
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x: torch.Tensor, encoder_features: torch.Tensor) -> torch.Tensor:
        x = self.conv_transpose(x)
        x = torch.cat((x, encoder_features), dim=1)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x


class Encoder(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.conv1 = ContractBlock(in_channels, 16, 3, 1)
        self.conv2 = ContractBlock(16, 32, 3, 1)
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x, conv1 = self.conv1(x)
        x, conv2 = self.conv2(x)
        return x, conv2, conv1
    def forward_simple(self, x: torch.Tensor) -> torch.Tensor:
        x = self.input_conv(x)
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class ContractBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels, kernel_size, padding=padding)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        features = self.conv(x)
        x = self.pool(features)
        return x, features


class DoubleConv(nn.Sequential):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: int = 0):
        super().__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )


unet = AutoEncoder(1, 1).double()
unet = unet.to(device)



print('Number of trainable model parameters:')
print(sum([np.prod(p.size()) for p in filter(lambda p: p.requires_grad, unet.parameters())]))
print(sum(p.numel() for p in unet.parameters() if p.requires_grad))


In [None]:
# optionally - load a pre-trained encoder

# unet.encoder.load_state_dict(torch.load(pretrained_encoder_file_path))
# unet.encoder.eval()

In [None]:
# optionally - load trained model instead of training...

# trained_model_file_path = '/home/przemek/Desktop/thermo_presence_article_files/trained_model_pytorch/unet_v2_cpu2'
# unet.load_state_dict(torch.load(trained_model_file_path))
# unet.eval()
# unet.to(device)

# Model validation function and metrics calculation

In [None]:
def validate_model_with_real_number_of_persons(loader, model, data_plotting_interval=3000, skip_confusion_matrix=False):
    """ Validate the model on data from the loader, calculate and print the results and metrics """
    correct_count = 0
    tested_frames = 0
    number_of_frames_with_n_persons = {}
    number_of_frames_with_n_persons_predicted_correctly = {}

    MAX_PEOPLE_COUNT = 6
    confusion_matrix = np.zeros(shape=(MAX_PEOPLE_COUNT+1, MAX_PEOPLE_COUNT+1), dtype=int)

    mae_sum = 0
    mse_sum = 0

    mae_rounded_sum = 0
    mse_rounded_sum = 0
    
    vec_real_number_of_persons = []
    vec_predicted_number_of_persons = []

    for frame, labels in loader:
        with torch.no_grad():
            outputs = model(frame.to(device)).to(cpu_device)

        for i in range(len(labels)):
            predicted_img = outputs[i].numpy()

            pred_people = np.sum(predicted_img) / sum_of_values_for_one_person
            pred_label = round(pred_people)

            true_label = labels.numpy()[i]

            if not skip_confusion_matrix:
                confusion_matrix[true_label][pred_label] += 1

            error = abs(pred_people - true_label)
            mae_sum += error


            mse_sum += error*error
            
            rounded_error = abs(pred_label - true_label)
            mae_rounded_sum += rounded_error
            mse_rounded_sum += rounded_error*rounded_error

            number_of_frames_with_n_persons[pred_label] = \
                number_of_frames_with_n_persons.get(pred_label, 0) + 1

            if true_label == pred_label:
                correct_count += 1
                number_of_frames_with_n_persons_predicted_correctly[pred_label] = \
                    number_of_frames_with_n_persons_predicted_correctly.get(pred_label, 0) + 1

            tested_frames += 1
            
            vec_real_number_of_persons.append(true_label)
            vec_predicted_number_of_persons.append(pred_people)
            
            if tested_frames % data_plotting_interval == 0 and data_plotting_interval > 0:
                #fig=plt.figure(figsize=(6,4), dpi= 150, facecolor='w', edgecolor='k')
                plt.imshow(frame[i, 0, :, :])
                plt.show()
                print(f'true_label={true_label}, pred_people={pred_people}')
                #fig=plt.figure(figsize=(6,4), dpi= 150, facecolor='w', edgecolor='k')
                plt.imshow(predicted_img)
                plt.show()
                print('#'*30)


    mae = mae_sum / tested_frames
    mse = mse_sum / tested_frames
    mae_rounded = mae_rounded_sum / tested_frames
    mse_rounded = mse_rounded_sum / tested_frames

    model_accuracy = correct_count / tested_frames
    
    print(f"Number of tested frames: {tested_frames}")
    print(f"Model Accuracy = {model_accuracy}")
    print('Predicted:\n' + '\n'.join([f'   {count} frames with {no} persons' for no, count in number_of_frames_with_n_persons.items()]))
    print('Predicted correctly:\n' + '\n'.join([f'   {count} frames with {no} persons' for no, count in number_of_frames_with_n_persons_predicted_correctly.items()]))
    print(f'mae: {mae}')
    print(f'mse: {mse}')
    print(f'mae_rounded: {mae_rounded}')
    print(f'mse_rounded: {mse_rounded}')
    
    
    return model_accuracy, mae, vec_real_number_of_persons, vec_predicted_number_of_persons, confusion_matrix



# Train the model

In [None]:
def train(model, train_dl, valid_dl_train, valid_dl_real, loss_fn, optimizer, epochs=1):
    best_mae_model = None
    best_mae = None
    
    start_time = time.time()
    train_loss, valid_loss, mae_vec, accuracy_vec = [], [], [], []

    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch, epochs - 1))
        print('-' * 10)

        model.train(True)
        
        running_loss = 0.0
        step = 0
        for x, y in train_dl:
            if step != 0:
                if randrange(100) != 1:  # do not train on every frame in each epoch, just one in 100 frames
                    continue
            
            step += 1
            
            x = x.to(device)
            y = y.to(device)
            

            optimizer.zero_grad()
            outputs = model(x)
            loss = loss_fn(outputs, y)
            loss.backward()
            optimizer.step()
                            
            running_loss += loss.item()
        
        
        epoch_loss = running_loss / step
        lr = optimizer.param_groups[0]['lr']
        print(f'Train Loss: {epoch_loss:.4f}. lr={lr}')
        train_loss.append(epoch_loss)
        
        model.train(False)
        
        
        # Plot one frame
        with torch.no_grad():
            example_frames = next(iter(trainloader))[0].to(device)
            model_ouput_frames = model(example_frames).to(cpu_device)
            model_ouput_frame = model_ouput_frames[0]
            plt.imshow(model_ouput_frame)
            plt.show()
            pred = np.sum(model_ouput_frame.numpy()) / sum_of_values_for_one_person
            print(pred)


        # full validation of the model
        accuracy, mae, _, _, _,  = validate_model_with_real_number_of_persons(loader=valid_dl_real, model=unet, skip_confusion_matrix=True, data_plotting_interval=-1)
        mae_vec.append(mae)
        accuracy_vec.append(accuracy)

        if best_mae is None or mae < best_mae:
            print('New best model saved!')
            best_mae = mae
            best_mae_model = copy.deepcopy(model)
        

        step = 0
        running_loss = 0
        for frame, labels in valid_dl_train:
          with torch.no_grad():
              step += 1
              outputs = model(frame.to(device))
              if loss_fn:
                  loss = loss_fn(outputs, labels.to(device))
                  running_loss += loss.item()
        
        single_valid_loss = running_loss / step
        valid_loss.append(single_valid_loss)
        print(f'Valid loss: {single_valid_loss}')
            
            
    time_elapsed = time.time() - start_time
    print(f'Training time: {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')    
    return train_loss, valid_loss, mae_vec, accuracy_vec, best_mae_model


loss_fn = nn.L1Loss()
# loss_fn = nn.MSELoss()  # sometimes learns strangely


opt = torch.optim.Adam(unet.parameters(), lr=0.00001)
train_loss, valid_loss, mae_vec, accuracy_vec, unet = train(
    model=unet, 
    train_dl=trainloader, 
    valid_dl_train=valloader, 
    valid_dl_real=loader_with_real_people_count_valid, 
    loss_fn=loss_fn, 
    optimizer=opt, 
    epochs=400)


# Plot training progress

In [None]:
plt.plot(train_loss)
plt.plot(valid_loss)
plt.legend(['train', 'valid'])
plt.xlabel('epoch')
plt.ylabel('loss')
plt.grid()
ax = plt.gca()
ax.set_yscale('log')

In [None]:
plt.plot(mae_vec)
plt.xlabel('epoch')
plt.ylabel('MAE')
plt.grid()

In [None]:
plt.plot(accuracy_vec)
plt.xlabel('epoch')
plt.ylabel('accuracy_vec')
plt.grid()

# Test model prediction on test dataset

In [None]:
# _, _, real_vec, predicted_vec, confusion_matrix = validate_model_with_real_number_of_persons(loader=loader_with_real_people_count_valid, model=unet)

_, _, real_vec, predicted_vec, confusion_matrix = validate_model_with_real_number_of_persons(loader=loader_with_real_people_count_test, model=unet, data_plotting_interval=1000)

In [None]:
confusion_matrix

txt = ''
for i in range(confusion_matrix.shape[0]):
    for j in range(confusion_matrix.shape[1]):
        val = confusion_matrix[j][i]
        txt += str(val) + '\t'
    txt += '\n'

print(txt)

## Plot prediction graphs (test data)

In [None]:
predicted_vec_round = [round(v) for v in predicted_vec]

fig=plt.figure(figsize=(11,4), dpi= 150, facecolor='w', edgecolor='k')
t = [i * 0.5 for i in range(len(predicted_vec_round))]

plt.grid()
plt.plot(t, real_vec, '--', linewidth=2.5)
#plt.plot(predicted_vec, linewidth=1)

plt.plot(t, predicted_vec, linewidth=1)

plt.ylim(-0.2, 5.75)
plt.vlines([470/2, 470*4/2],  -0.2, 5.75, linestyle='dashed', colors='k')
plt.xlim(left=0, right=470*5/2)

# plt.title('Number of people')
plt.legend(['ground truth', 'prediction'])
plt.xlabel('time [s]')
plt.ylabel('person count [-]')


In [None]:
predicted_vec_round = [round(v) for v in predicted_vec]

fig=plt.figure(figsize=(11,4), dpi= 150, facecolor='w', edgecolor='k')
t = [i * 0.5 for i in range(len(predicted_vec_round))]

plt.grid()
plt.plot(t, real_vec, '--', linewidth=2.5)

#plt.plot(predicted_vec, linewidth=1)

plt.ylim(-0.2, 6.1)
plt.vlines([470/2, 470*4/2],  -0.2, 6.1, linestyle='dashed', colors='k')
plt.xlim(left=0, right=470*5/2)


plt.plot(t, predicted_vec_round, linewidth=1)

# plt.title('Number of people')
plt.legend(['ground truth', 'rounded prediction '])
plt.xlabel('time [s]')
plt.ylabel('person count [-]')


In [None]:


predicted_error_round = [real_vec[i] - predicted_vec_round[i] for i in range(len(predicted_vec_round))]

prediction_error = np.array(real_vec) - np.array(predicted_vec)
relative_error = [prediction_error[i] / max(1, real_vec[i]) for i in range(len(real_vec))]

fig=plt.figure(figsize=(11,4), dpi= 150, facecolor='w', edgecolor='k')
t = [i * 0.5 for i in range(len(predicted_vec_round))]
plt.grid()
plt.plot(t, predicted_error_round, linewidth=1)
plt.plot(t, prediction_error, linewidth=1)

plt.ylim(-2.3, 2.1)
plt.vlines([470/2, 470*4/2],  -0.2, 6.1, linestyle='dashed', colors='k')
plt.xlim(left=0, right=470*5/2)

plt.legend(['rounded prediction error', 'prediction error'])

plt.xlabel('time [s]')
plt.ylabel('person count prediction error [-]')

In [None]:
# optionally - save the model

# unet_cpu = unet.to(cpu_device)
# torch.save(unet_cpu.state_dict(), 'model2')
# unet.to(device)

# Plot example frame from raw IR data

In [None]:
ir_frame = np.array([[24.92, 26.03, 24.15, 24.76, 25.48, 27.32, 27.36, 26.77, 25.12,
        25.47, 24.15, 24.28, 26.81, 28.4 , 28.01, 28.71, 28.57, 28.56,
        26.56, 25.54, 23.49, 23.73, 22.84, 22.81, 23.02, 23.32, 23.05,
        22.93, 22.48, 23.81, 23.06, 23.18],
       [26.87, 25.44, 24.83, 25.28, 26.23, 27.69, 28.57, 28.07, 27.63,
        27.38, 25.5 , 24.76, 26.29, 28.1 , 28.13, 28.02, 28.46, 27.89,
        27.  , 26.71, 24.43, 23.7 , 23.15, 22.89, 23.67, 23.69, 22.92,
        22.82, 23.99, 24.24, 23.25, 23.47],
       [23.72, 24.56, 23.98, 24.44, 24.67, 26.34, 26.76, 27.65, 28.08,
        27.69, 24.86, 25.58, 25.84, 26.9 , 27.28, 27.23, 27.38, 27.26,
        25.97, 25.48, 23.21, 23.34, 22.8 , 23.53, 23.91, 23.98, 23.49,
        23.79, 23.44, 24.18, 22.76, 23.64],
       [25.08, 24.97, 24.8 , 24.41, 24.52, 25.08, 26.24, 26.34, 27.85,
        27.4 , 25.92, 24.55, 25.4 , 26.22, 27.  , 26.78, 26.48, 25.62,
        24.28, 23.67, 23.63, 23.66, 23.08, 23.24, 24.84, 24.73, 24.08,
        23.97, 23.46, 24.48, 23.48, 23.42],
       [24.29, 24.68, 23.43, 24.22, 24.41, 24.78, 24.37, 25.16, 28.24,
        27.88, 25.88, 27.26, 24.42, 25.76, 26.6 , 26.69, 25.35, 24.22,
        23.34, 23.54, 23.12, 23.5 , 22.87, 23.21, 23.77, 24.03, 22.81,
        23.36, 23.07, 23.86, 22.84, 23.41],
       [24.94, 24.44, 24.64, 24.14, 26.21, 26.99, 25.45, 25.62, 29.93,
        27.7 , 28.21, 26.01, 27.03, 25.55, 27.24, 26.03, 25.41, 24.36,
        23.88, 23.68, 23.59, 23.43, 23.44, 23.11, 23.19, 23.44, 22.74,
        23.21, 23.5 , 23.86, 23.29, 23.17],
       [23.91, 24.39, 23.67, 24.53, 27.17, 28.45, 24.57, 24.15, 24.82,
        25.1 , 27.73, 29.65, 26.85, 29.25, 25.81, 25.55, 23.88, 23.79,
        23.53, 23.59, 23.22, 23.34, 23.15, 23.36, 22.85, 23.14, 22.63,
        23.43, 23.01, 23.52, 22.7 , 23.4 ],
       [24.87, 24.42, 24.75, 24.55, 25.5 , 25.9 , 24.19, 24.07, 24.11,
        25.8 , 28.89, 28.46, 29.9 , 28.31, 27.03, 25.82, 23.87, 24.08,
        23.9 , 23.77, 23.72, 23.64, 23.54, 23.15, 23.2 , 23.14, 23.04,
        23.13, 22.96, 23.34, 23.07, 23.13],
       [24.35, 24.58, 24.04, 24.59, 23.9 , 24.13, 23.28, 23.63, 23.9 ,
        24.83, 28.4 , 28.89, 29.88, 27.84, 27.97, 24.5 , 24.09, 24.18,
        23.35, 23.62, 23.49, 23.54, 23.02, 23.55, 23.12, 23.38, 22.66,
        23.16, 22.84, 23.39, 22.82, 23.27],
       [25.06, 25.57, 26.22, 26.  , 24.57, 24.42, 23.73, 23.44, 23.92,
        25.16, 27.18, 29.98, 26.84, 29.15, 25.89, 25.81, 23.84, 23.65,
        23.94, 23.61, 23.57, 23.69, 23.56, 23.19, 23.26, 23.31, 23.47,
        23.25, 23.27, 23.4 , 23.14, 23.69],
       [24.35, 24.41, 24.39, 24.14, 23.74, 23.64, 23.63, 23.63, 23.47,
        23.29, 26.78, 23.66, 27.  , 24.  , 25.18, 24.02, 24.  , 23.73,
        23.55, 23.83, 23.43, 23.88, 23.33, 23.57, 23.09, 23.49, 23.  ,
        23.35, 23.34, 24.03, 23.67, 24.16],
       [24.61, 24.81, 24.15, 23.93, 24.46, 24.1 , 24.25, 23.87, 23.69,
        24.21, 23.79, 25.2 , 23.9 , 25.06, 24.36, 24.  , 23.92, 23.99,
        23.9 , 23.77, 23.65, 23.87, 23.72, 23.49, 23.2 , 23.2 , 22.74,
        23.13, 23.77, 24.07, 24.13, 23.46],
       [24.09, 24.12, 24.31, 23.84, 24.05, 24.1 , 23.95, 24.04, 23.7 ,
        23.72, 23.25, 23.45, 23.73, 23.91, 24.25, 24.06, 23.99, 24.02,
        23.65, 24.03, 23.78, 24.19, 23.6 , 24.  , 23.07, 23.33, 22.88,
        23.67, 24.01, 24.63, 25.33, 24.86],
       [25.21, 25.31, 24.46, 24.29, 24.48, 24.74, 25.13, 24.63, 24.05,
        24.17, 24.05, 23.89, 24.07, 24.15, 24.11, 23.85, 24.02, 24.07,
        24.24, 24.06, 24.37, 24.31, 23.81, 23.86, 23.53, 23.24, 23.78,
        24.04, 24.82, 24.7 , 25.37, 25.25],
       [24.86, 24.43, 24.18, 24.34, 24.18, 24.85, 26.59, 28.64, 25.67,
        24.7 , 24.33, 24.51, 23.94, 24.14, 24.2 , 24.16, 24.13, 24.28,
        24.13, 24.36, 24.27, 24.38, 23.88, 23.97, 23.37, 23.59, 24.82,
        28.25, 25.72, 25.77, 24.01, 24.46],
       [25.4 , 25.  , 25.06, 24.46, 25.01, 25.07, 25.41, 25.96, 24.75,
        24.99, 24.83, 24.9 , 24.66, 24.45, 24.58, 24.15, 24.63, 24.47,
        24.57, 24.18, 24.51, 24.59, 24.11, 24.2 , 24.06, 23.76, 24.73,
        25.52, 25.24, 24.99, 24.37, 24.07],
       [25.03, 24.6 , 24.43, 24.55, 26.85, 27.28, 25.86, 25.75, 24.17,
        24.73, 25.01, 25.23, 24.65, 24.88, 24.46, 25.03, 24.63, 24.75,
        24.52, 24.66, 24.56, 24.74, 24.13, 24.06, 23.71, 24.25, 25.18,
        25.58, 25.35, 25.47, 24.25, 24.64],
       [25.37, 25.37, 25.57, 25.4 , 28.33, 28.9 , 27.46, 26.55, 25.14,
        25.23, 25.4 , 25.29, 25.23, 24.94, 25.14, 24.88, 25.19, 24.92,
        25.08, 24.65, 24.78, 25.03, 24.73, 24.3 , 24.66, 25.24, 26.23,
        26.25, 26.72, 26.8 , 25.13, 24.64],
       [24.89, 24.85, 24.97, 24.95, 27.41, 27.94, 27.82, 27.49, 24.66,
        24.95, 24.55, 25.12, 24.78, 24.76, 24.44, 24.8 , 24.55, 24.72,
        24.66, 24.77, 24.33, 24.9 , 24.74, 24.95, 24.31, 25.36, 25.94,
        26.79, 27.76, 28.84, 25.35, 25.43],
       [25.48, 25.23, 25.37, 24.99, 26.87, 27.71, 28.55, 27.02, 25.34,
        25.07, 24.91, 24.62, 24.86, 24.86, 24.81, 24.84, 24.71, 24.56,
        24.6 , 24.58, 24.86, 24.48, 25.01, 24.75, 24.54, 24.98, 25.93,
        26.78, 26.92, 26.99, 25.18, 24.89],
       [25.51, 24.57, 24.97, 24.85, 25.98, 26.69, 26.61, 25.87, 24.41,
        24.79, 24.35, 24.35, 24.11, 24.5 , 24.16, 24.31, 24.23, 24.16,
        24.09, 24.08, 24.21, 24.51, 24.16, 24.91, 24.62, 24.79, 24.74,
        25.91, 25.02, 25.07, 24.09, 24.79],
       [25.94, 25.45, 25.77, 25.37, 26.86, 27.43, 27.79, 26.72, 25.13,
        24.65, 24.83, 24.3 , 24.62, 24.45, 24.59, 24.07, 24.4 , 24.12,
        24.12, 23.81, 24.17, 24.15, 24.71, 24.39, 24.48, 24.54, 24.81,
        24.31, 24.9 , 24.83, 25.02, 24.33],
       [27.3 , 25.43, 25.73, 26.34, 27.5 , 28.89, 28.77, 28.7 , 27.52,
        27.18, 25.55, 25.42, 24.73, 24.87, 24.42, 24.21, 24.32, 23.97,
        23.67, 24.11, 23.77, 24.4 , 24.78, 25.29, 25.47, 25.95, 25.32,
        25.33, 24.28, 24.62, 25.11, 24.76],
       [28.99, 27.23, 27.41, 26.89, 28.01, 28.28, 29.35, 28.76, 28.5 ,
        27.6 , 26.96, 26.8 , 26.83, 26.22, 26.  , 24.83, 24.03, 23.92,
        23.81, 23.57, 24.  , 24.14, 25.55, 26.06, 26.67, 26.77, 26.19,
        25.02, 24.79, 24.84, 24.02, 23.8 ]])

# frame 416 from video '05_05_2021__0to5_people/004__13_10_20'
centre_points = [
        [
          16.125,
          0.9375
        ],
        [
          28.1875,
          18.6875
        ],
        [
          7.6875,
          2.25
        ],
        [
          6.375,
          18.125
        ],
        [
          12.75,
          7.9375
        ]
      ]


flipped_frame = np.flip(ir_frame, 1) 

plt.imshow(flipped_frame)
plt.show()


fig=plt.figure(figsize=(6,4), dpi= 150, facecolor='w', edgecolor='k')
plt.imshow(flipped_frame, 'inferno'); cbar = plt.colorbar(); cbar.set_label('Temperature [°C]')
plt.show()


print(f'{"#"*60} \n Reconstructed frame (training data):')
fig=plt.figure(figsize=(6,4), dpi= 150, facecolor='w', edgecolor='k')
recontructed_image = get_img_reconstructed_from_labels(centre_points)
recontructed_image_flipped = np.flip(recontructed_image, 1) 
plt.imshow(recontructed_image_flipped)
plt.show()




# Comparison with the fully connected network from another article

In [None]:
class FNNClassifier(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.Flatten(),
            nn.Linear(in_features=768, out_features=512),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=512, out_features=6),
            )
        

model_fnn = FNNClassifier().double()
model_fnn.to(device)
print('Number of patameters:', sum(p.numel() for p in model_fnn.parameters() if p.requires_grad))


class IrPersonsDatasetForFnn(torch.utils.data.Dataset):
    def __init__(self, augmented_data: AugmentedBatchesTrainingData, transform=None):
        self.augmented_data = AugmentedBatchesTrainingData
        self.transform = transform
        self._index_to_batch_and_subindex_map = {}
        
        i = 0
        for batch in augmented_data.batches:
            for j in range(len(batch.raw_ir_data)):
                self._index_to_batch_and_subindex_map[i] = (batch, j) 
                i += 1
        
    def __len__(self):
        return len (self._index_to_batch_and_subindex_map)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            raise Exception("Not supported")
            #idx = idx.tolist()

        batch, subindex = self._index_to_batch_and_subindex_map[idx]
        frame = batch.normalized_ir_data[subindex][np.newaxis, :, :]
        return frame, len(batch.centre_points[subindex])


training_dataset_fnn = IrPersonsDatasetForFnn(augmented_data_training)
validation_dataset_fnn = IrPersonsDatasetForFnn(augmented_data_validation)
test_dataset_fnn = IrPersonsDatasetForFnn(augmented_data_test)


trainloader_fnn = torch.utils.data.DataLoader(training_dataset_fnn, batch_size=16, shuffle=True)
valloader_fnn = torch.utils.data.DataLoader(validation_dataset_fnn, batch_size=16, shuffle=False)
testloader_fnn = torch.utils.data.DataLoader(test_dataset_fnn, batch_size=16, shuffle=False)

In [None]:
dataiter = iter(trainloader_fnn)
ir_frmaes, labels = dataiter.next()

ir_frame_normalized_0 = ir_frmaes[0].numpy().squeeze()
plt.imshow(ir_frame_normalized_0)
print(f'Persons: {labels[0]}')

In [None]:
def validate_model_with_real_number_of_persons_fnn(loader, model, data_plotting_interval=3000, skip_confusion_matrix=False, loss_fn=None):
    correct_count = 0
    tested_frames = 0
    number_of_frames_with_n_persons = {}
    number_of_frames_with_n_persons_predicted_correctly = {}

    MAX_PEOPLE_COUNT = 5
    confusion_matrix = np.zeros(shape=(MAX_PEOPLE_COUNT+1, MAX_PEOPLE_COUNT+1), dtype=int)

    mae_rounded_sum = 0
    mse_rounded_sum = 0
    
    vec_real_number_of_persons = []
    vec_predicted_number_of_persons = []

    running_loss = 0
    steps = 0

    for frame, labels in loader:
        with torch.no_grad():
            outputs = model(frame.to(device))
            if loss_fn:
                loss = loss_fn(outputs, labels.to(device))
                running_loss += loss.item()
            outputs = outputs.to(cpu_device)
            steps += 1

        for i in range(len(labels)):
            predicted_vals = outputs[i].numpy()

            pred_label = np.argmax(predicted_vals)
            true_label = labels.numpy()[i]

            if not skip_confusion_matrix:
                confusion_matrix[true_label][pred_label] += 1

            rounded_error = abs(pred_label - true_label)
            mae_rounded_sum += rounded_error
            mse_rounded_sum += rounded_error*rounded_error

            number_of_frames_with_n_persons[pred_label] = \
                number_of_frames_with_n_persons.get(pred_label, 0) + 1

            if true_label == pred_label:
                correct_count += 1
                number_of_frames_with_n_persons_predicted_correctly[pred_label] = \
                    number_of_frames_with_n_persons_predicted_correctly.get(pred_label, 0) + 1

            tested_frames += 1
            
            vec_real_number_of_persons.append(true_label)
            vec_predicted_number_of_persons.append(pred_label)
            

    mae_rounded = mae_rounded_sum / tested_frames
    mse_rounded = mse_rounded_sum / tested_frames

    model_accuracy = correct_count / tested_frames

    total_loss = running_loss / steps
    
    print(f"Number of tested frames: {tested_frames}")
    print(f"Model Accuracy = {model_accuracy}")
    print('Predicted:\n' + '\n'.join([f'   {count} frames with {no} persons' for no, count in number_of_frames_with_n_persons.items()]))
    print('Predicted correctly:\n' + '\n'.join([f'   {count} frames with {no} persons' for no, count in number_of_frames_with_n_persons_predicted_correctly.items()]))
    print(f'mae_rounded: {mae_rounded}')
    print(f'mse_rounded: {mse_rounded}')
    print(f'total_loss: {total_loss}')
    
    
    return model_accuracy, mae_rounded, vec_real_number_of_persons, vec_predicted_number_of_persons, confusion_matrix, total_loss



In [None]:
optimizer = optim.SGD(model_fnn.parameters(), lr=0.0001, momentum=0.9)
time0 = time.time()
epochs = 200
# criterion = nn.NLLLoss()
criterion = nn.CrossEntropyLoss()


train_loss = []
valid_loss = []
mae_vec = []
accuracy_vec = []


best_mae = None
best_mae_model = None

for e in range(epochs):
    running_loss = 0
    step = 0

    model_fnn.train(True)
    for images, labels in trainloader_fnn:    
        optimizer.zero_grad()
        
        output = model_fnn(images.to(device))
        loss = criterion(output, labels.to(device))
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        step += 1

    epoch_loss = running_loss / step
    train_loss.append(epoch_loss)
    print("Epoch {} - Training loss: {}".format(e, running_loss/len(trainloader_fnn)))
    print("\nTraining Time (in minutes) =",(time.time()-time0)/60)


    model_fnn.train(False)
    accuracy, mae, _, _, _, loss = validate_model_with_real_number_of_persons_fnn(loader=valloader_fnn, model=model_fnn, skip_confusion_matrix=True, loss_fn=criterion)

    if best_mae is None or mae < best_mae:
        print('New best MAE model!')
        best_mae = mae
        best_mae_model = copy.deepcopy(model_fnn)
    
    valid_loss.append(loss)
    mae_vec.append(mae)
    accuracy_vec.append(accuracy)

model_fnn = best_mae_model
        

In [None]:
validate_model_with_real_number_of_persons(loader=testloader_fnn, model=model_fnn, skip_confusion_matrix=False)