In [None]:
import os
import sys
import time 
import random
from random import randrange
import json


import numpy as np
import cv2
from matplotlib import pyplot as plt
import importlib
import torch
import torchvision

from torchvision import datasets, transforms
from torch import nn, optim
import torch.nn.functional as F
from torch import nn


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cpu_device = torch.device('cpu')
device

In [None]:
seed = 43

torch.backends.cudnn.deterministic = True
torch.manual_seed(seed)
if torch.cuda.is_available(): 
    torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(0)
torch.use_deterministic_algorithms(True)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

In [None]:
#from google.colab import drive

#drive.mount('/content/drive/')

# with open('/content/drive/My Drive/wspoldzielone/ir/05_05_2021__0to5_people--013__13_46_21.csv', 'r') as f:
#   print(f.read())


ROOT_DATA_DIR_PATH = '/media/data/temporary/thermo-presence/'
OUTPUT_LABELS_DIR = '/home/przemek/Projects/thermo-presence/thermo-presence/data_processing/data_labeling/labeling_output'

In [None]:
IR_CAMERA_RESOLUTION_X = 32
IR_CAMERA_RESOLUTION_Y = 24

IR_CAMERA_RESOLUTION = (IR_CAMERA_RESOLUTION_Y, IR_CAMERA_RESOLUTION_X)

IR_CAMERA_RESOLUTION_XY = (IR_CAMERA_RESOLUTION_X, IR_CAMERA_RESOLUTION_Y)  # for opencv functions

# for frames normalization
TEMPERATURE_NORMALIZATION__MIN = 20
TEMPERATURE_NORMALIZATION__MAX = 35

IR_FRAME_RESIZE_MULTIPLIER = 1

MIN_TEMPERATURE_ON_PLOT = 20  # None for auto range
MAX_TEMPERATURE_ON_PLOT = 30  # None for auto range
IR_FRAME_INTERPOLATION_METHOD = cv2.INTER_CUBIC
RGB_FRAME_RESIZE_MULTIPLIER = 2

In [None]:
# pretrained_encoder_file_path = '/content/drive/My Drive/wspoldzielone/ir/cifar10_encder_weights_v1.p'
pretrained_encoder_file_path = '/home/przemek/Downloads/cifar10_encoder_weights_v1.p'


In [None]:
class IrDataCsvReader:
    def __init__(self, file_path):
        self._file_path = file_path
        self._frames, self._raw_frame_data = self.get_frames_from_file(file_path)

    @staticmethod
    def get_frames_from_file(file_path):
        frames = []
        raw_frame_data = []
        with open(file_path, 'r') as file:
            raw_lines = file.readlines()
        lines = [x.strip() for x in raw_lines]
        frame_lines = lines[1:]
        for frame_line in frame_lines:
            raw_frame_data.append(frame_line)
            line_parts = frame_line.split(',')
            frame_data_str = line_parts[2:]
            frame_data_1d = [float(x) for x in frame_data_str]
            frame_2d = np.reshape(frame_data_1d, IR_CAMERA_RESOLUTION)
            frames.append(frame_2d)
        return frames, raw_frame_data

    def get_number_of_frames(self):
        return len(self._frames)

    def get_frame(self, n):
        return self._frames[n]

    def get_raw_frame_data(self, n) -> str:
        return self._raw_frame_data[n]



class BatchTrainingData:
    """
    Stores training data for one batch
    """
    def __init__(self, min_temperature=TEMPERATURE_NORMALIZATION__MIN, max_temperature=TEMPERATURE_NORMALIZATION__MAX):
        self.centre_points = []  # type: List[List[tuple]]
        self.raw_ir_data = []  # type: List[np.ndarray]
        self.normalized_ir_data = []  # type: List[np.ndarray]  # same data as raw_ir_data, but normalized

        self.min_temperature = min_temperature
        self.max_temperature = max_temperature

    def append_frame_data(self, centre_points, raw_ir_data):
        self.centre_points.append(centre_points)
        self.raw_ir_data.append(raw_ir_data)

        ir_data_normalized = (raw_ir_data - self.min_temperature) * (1 / (self.max_temperature - self.min_temperature))
        self.normalized_ir_data.append(ir_data_normalized)

    def flip_horizontally(self):
        for i in range(len(self.raw_ir_data)):
            self.raw_ir_data[i] = np.flip(self.raw_ir_data[i], 1)
            self.normalized_ir_data[i] = np.flip(self.normalized_ir_data[i], 1)
            for j in range(len(self.centre_points[i])):
                x_flipped = IR_CAMERA_RESOLUTION_X - self.centre_points[i][j][0]
                self.centre_points[i][j] = (x_flipped, self.centre_points[i][j][1])

    def flip_vertically(self):
        for i in range(len(self.raw_ir_data)):
            self.raw_ir_data[i] = np.flip(self.raw_ir_data[i], 0)
            self.normalized_ir_data[i] = np.flip(self.normalized_ir_data[i], 0)
            for j in range(len(self.centre_points[i])):
                y_flipped = IR_CAMERA_RESOLUTION_Y - self.centre_points[i][j][1]
                self.centre_points[i][j] = (self.centre_points[i][j][0], y_flipped)


class AugmentedBatchesTrainingData:
    """
    Stores training data for all batches, with data augmentation
    """
    def __init__(self):
        self.batches = []  # Type: List[BatchTrainingData]

    def add_training_batch(self, batch: BatchTrainingData, flip_and_rotate=True):
        self.batches.append(copy.deepcopy(batch))  # plain data

        if flip_and_rotate:
            batch.flip_horizontally()
            self.batches.append(copy.deepcopy(batch))  # flipped horizontally

            batch.flip_vertically()
            self.batches.append(copy.deepcopy(batch))  # rotated 180 degrees

            batch.flip_horizontally()
            self.batches.append(copy.deepcopy(batch))  # rotated vertically

    def print_stats(self):
        total_number_of_frames = 0
        number_of_frames_with_n_persons = {}
        for batch in self.batches:
            total_number_of_frames += len(batch.centre_points)
            for centre_points in batch.centre_points:
                number_of_persons = len(centre_points)
                number_of_frames_with_n_persons[number_of_persons] = \
                    number_of_frames_with_n_persons.get(number_of_persons, 0) + 1
        frames_persons_details_msg = '\n'.join([f'   {count} frames with {no} persons'
                                               for no, count in number_of_frames_with_n_persons.items()])
        msg = f"AugmentedBatchesTrainingData with {len(self.batches)} BatchTrainingData batches.\n" \
              f"Total number of frames after augmentation: {total_number_of_frames}, with:\n" \
              f"{frames_persons_details_msg}"
        print(msg)


def load_data_for_labeled_batches(labeled_batch_dirs) -> BatchTrainingData:

    training_data = BatchTrainingData()
    for batch_subdir in labeled_batch_dirs:
        data_batch_dir_path = os.path.join(ROOT_DATA_DIR_PATH, batch_subdir)
        raw_ir_data_csv_file_path = os.path.join(data_batch_dir_path, 'ir.csv')
        output_file_with_labels_name = batch_subdir.replace('/', '--') + '.csv'
        annotation_data_file_path = os.path.join(OUTPUT_LABELS_DIR, output_file_with_labels_name)

        raw_ir_data_csv_reader = IrDataCsvReader(file_path=raw_ir_data_csv_file_path)
        annotations_collector = AnnotationCollector.load_from_file(

            file_path=annotation_data_file_path, do_not_scale_and_reverse=True)

        for frame_index in range(raw_ir_data_csv_reader.get_number_of_frames()):
            raw_frame_data = raw_ir_data_csv_reader.get_frame(frame_index)
            frame_annotations = annotations_collector.get_annotation(frame_index)
            if not frame_annotations.accepted:
                print(f"Frame index {frame_index} from batch '{batch_subdir}' not annotated!")
                continue
            training_data.append_frame_data(
                centre_points=frame_annotations.centre_points,
                raw_ir_data=raw_frame_data)

    return training_data



class AnnotationCollector:
    ANNOTATIONS_BETWEEN_AUTOSAVE = 10

    def __init__(self, output_file_path, data_batch_dir_path):
        self._output_file_path = output_file_path
        self._annotations = {}  # ir_frame_index: FrameAnnotation
        self._data_batch_dir_path = data_batch_dir_path
        self._annotations_to_autosave = 1

    def get_annotation(self, ir_frame_index):
        return self._annotations.get(ir_frame_index, FrameAnnotation())

    def set_annotation(self, ir_frame_index, annotation):
        self._annotations[ir_frame_index] = annotation
        if annotation.accepted or annotation.discarded:
            self._annotations_to_autosave -= 1
            if self._annotations_to_autosave == 0:
                self._annotations_to_autosave = self.ANNOTATIONS_BETWEEN_AUTOSAVE
                self.save()

    def save(self):
        data_dict = {
            'output_file_path': self._output_file_path,
            'data_batch_dir_path': self._data_batch_dir_path,
            'annotations': {index: annotation.as_dict()
                            for index, annotation in self._annotations.items()}
        }
        with open(self._output_file_path, 'w') as file:
            file.write(json.dumps(data_dict, indent=2))
            file.flush()

    @classmethod
    def load_from_file(cls, file_path, do_not_scale_and_reverse=False):
        item = cls(output_file_path=file_path, data_batch_dir_path=None)
        with open(file_path, 'r') as file:
            data = file.read()
        data_dict = json.loads(data)
        item._data_batch_dir_path = data_dict['data_batch_dir_path']
        item._annotations = {int(index): FrameAnnotation.from_dict(annotation_dict, do_not_scale_and_reverse) for index, annotation_dict
                             in data_dict['annotations'].items()}
        return item








import copy
from typing import List
from typing import Tuple

def x_on_interpolated_image_to_raw_x(x):
    x_raw_flipped = x / IR_FRAME_RESIZE_MULTIPLIER
    x_raw = IR_CAMERA_RESOLUTION_X - x_raw_flipped
    return x_raw


def y_on_interpolated_image_to_raw_y(y):
    return y / IR_FRAME_RESIZE_MULTIPLIER


def xy_on_interpolated_image_to_raw_xy(xy: tuple) -> tuple:
    return (x_on_interpolated_image_to_raw_x(xy[0]),
            y_on_interpolated_image_to_raw_y(xy[1]))

class FrameAnnotation:
    def __init__(self):
        self.accepted = False  # whether frame was marked as annotated successfully
        self.discarded = False  # whether frame was marked as discarded (ignored)
        self.centre_points = []  # type: List[tuple]  # x, y
        self.rectangles = []  # type: List[Tuple[tuple, tuple]]  # (x_left, y_top), (x_right, y_bottom)

        self.raw_frame_data = None  # Not an annotation, but write it to the result file, just in case

    def as_dict(self):
        data_dict = copy.copy(self.__dict__)
        data_dict['centre_points'] = []
        data_dict['rectangles'] = []

        for i, point in enumerate(self.centre_points):
            data_dict['centre_points'].append(xy_on_interpolated_image_to_raw_xy(point))
        for i, rectangle in enumerate(self.rectangles):
            data_dict['rectangles'].append((xy_on_interpolated_image_to_raw_xy(rectangle[0]),
                                            xy_on_interpolated_image_to_raw_xy(rectangle[1])))
        return data_dict

    @classmethod
    def from_dict(cls, data_dict, do_not_scale_and_reverse=False):
        item = cls()
        item.__dict__.update(data_dict)
        for i, point in enumerate(item.centre_points):
            item.centre_points[i] = xy_on_raw_image_to_xy_on_interpolated_image(point, do_not_scale_and_reverse)
        for i, rectangle in enumerate(item.rectangles):
            item.rectangles[i] = (xy_on_raw_image_to_xy_on_interpolated_image(rectangle[0], do_not_scale_and_reverse),
                                  xy_on_raw_image_to_xy_on_interpolated_image(rectangle[1], do_not_scale_and_reverse))
        return item




def x_on_interpolated_image_to_raw_x(x):
    x_raw_flipped = x / IR_FRAME_RESIZE_MULTIPLIER
    x_raw = IR_CAMERA_RESOLUTION_X - x_raw_flipped
    return x_raw


def y_on_interpolated_image_to_raw_y(y):
    return y / IR_FRAME_RESIZE_MULTIPLIER


def xy_on_interpolated_image_to_raw_xy(xy: tuple) -> tuple:
    return (x_on_interpolated_image_to_raw_x(xy[0]),
            y_on_interpolated_image_to_raw_y(xy[1]))


def x_on_raw_image_to_x_on_interpolated_image(x):
    x_flipped = IR_CAMERA_RESOLUTION_X - x
    return round(x_flipped * IR_FRAME_RESIZE_MULTIPLIER)


def y_on_raw_image_to_y_on_interpolated_image(y):
    return round(y * IR_FRAME_RESIZE_MULTIPLIER)


def xy_on_raw_image_to_xy_on_interpolated_image(xy: tuple, do_not_scale_and_reverse=False) -> tuple:
    if do_not_scale_and_reverse:
        return xy

    return (x_on_raw_image_to_x_on_interpolated_image(xy[0]),
            y_on_raw_image_to_y_on_interpolated_image(xy[1]))


def get_extrapolated_ir_frame_heatmap_flipped(
        frame_2d, multiplier, interpolation, min_temp, max_temp, colormap):
    new_size = (frame_2d.shape[1] * multiplier, frame_2d.shape[0] * multiplier)
    frame_resized_not_clipped = cv2.resize(
        src=frame_2d, dsize=new_size, interpolation=interpolation)

    if min_temp is None:
        min_temp = min(frame_resized_not_clipped.reshape(-1))
    if max_temp is None:
        max_temp = max(frame_resized_not_clipped.reshape(-1))

    frame_resized = np.clip(frame_resized_not_clipped, min_temp, max_temp)
    frame_resized_normalized = (frame_resized - min_temp) * (255 / (max_temp - min_temp))
    frame_resized_normalized_u8 = frame_resized_normalized.astype(np.uint8)
    heatmap_u8 = (colormap(frame_resized_normalized_u8) * 2 ** 8).astype(np.uint8)[:, :, :3]
    heatmap_u8_bgr = cv2.cvtColor(heatmap_u8, cv2.COLOR_RGB2BGR)
    heatmap_u8_bgr_flipped = cv2.flip(heatmap_u8_bgr, 1)  # horizontal flip
    return heatmap_u8_bgr_flipped


In [None]:
TRAINING_DIRS_1 = [
    '31_03_21__318__3or4_people/1/006__11_44_59',
    '31_03_21__318__3or4_people/1/007__11_48_59',
    '31_03_21__318__3or4_people/1/008__11_52_59',
    '31_03_21__318__3or4_people/1/009__11_57_00',
     
    
    '31_03_21__318__3or4_people/2/000__14_15_19',
    '31_03_21__318__3or4_people/2/001__14_19_19',
    '31_03_21__318__3or4_people/2/002__14_23_19',
    '31_03_21__318__3or4_people/2/003__14_27_20',
    '31_03_21__318__3or4_people/2/004__14_31_20',
    
    '31_03_21__318__3or4_people/2/012__15_03_21',
    '31_03_21__318__3or4_people/2/013__15_07_21',
    '31_03_21__318__3or4_people/2/014__15_11_21',
    '31_03_21__318__3or4_people/2/015__15_15_21',
    '31_03_21__318__3or4_people/2/016__15_19_21',
    
    
    '05_05_2021__0to5_people/011__13_38_20',
    '05_05_2021__0to5_people/012__13_42_20',
    '05_05_2021__0to5_people/013__13_46_21',
    
    '05_05_2021__0to5_people/007__13_22_20',
    '05_05_2021__0to5_people/008__13_26_20',
    ]


VALIDATION_DIRS_1 = [
    '05_05_2021__0to5_people/004__13_10_20',
    
    '05_05_2021__0to5_people/014__13_50_21',
    '05_05_2021__0to5_people/015__13_54_21',
    
    '31_03_21__318__3or4_people/2/005__14_35_20',
    '31_03_21__318__3or4_people/2/006__14_39_20',
    '31_03_21__318__3or4_people/2/007__14_43_20',
    '31_03_21__318__3or4_people/2/008__14_47_20',
    '31_03_21__318__3or4_people/2/009__14_51_20',
    '31_03_21__318__3or4_people/2/010__14_55_20',
    '31_03_21__318__3or4_people/2/011__14_59_20',
    
]

_training_data_1 = load_data_for_labeled_batches(labeled_batch_dirs=TRAINING_DIRS_1)
_validation_data_1 = load_data_for_labeled_batches(labeled_batch_dirs=VALIDATION_DIRS_1)

augmented_data_training = AugmentedBatchesTrainingData()
augmented_data_training.add_training_batch(_training_data_1)

augmented_data_validation = AugmentedBatchesTrainingData()
augmented_data_validation.add_training_batch(_validation_data_1, flip_and_rotate=False)

In [None]:
augmented_data_training.print_stats()

def draw_airbrush_circle(img, centre, radius):
    for x in range(max(0, round(centre[0]-radius)), min(img.shape[0], round(centre[0]+radius+1))):
        for y in range(max(0, round(centre[1]-radius)), min(img.shape[1], round(centre[1]+radius+1))):
            point = (x, y)
            distance_to_centre = cv2.norm((centre[0] - x, centre[1] - y))
            if distance_to_centre > radius:
                continue
            img[point] += 1 - distance_to_centre / radius
            

def draw_cross(img, centre, cross_width, cross_height):
    for x in range(max(0, round(centre[0]) - cross_width), min(img.shape[0], round(centre[0]) + cross_width + 1)):
        for y in range(max(0, round(centre[1]) - cross_height), min(img.shape[1], round(centre[1]) + cross_height + 1)):
            point = (x, y)
            img[point] = 1
    
    for x in range(max(0, round(centre[0] - cross_height)), min(img.shape[0], round(centre[0] + cross_height + 1))):
        for y in range(max(0, round(centre[1] - cross_width)), min(img.shape[1], round(centre[1] + cross_width + 1))):
            point = (x, y)
            img[point] = 1
            

def gauss_1d(x, sig):
    return np.exp(-np.power(x, 2.) / (2 * np.power(sig, 2.)))
    
    
def draw_gauss(img, sig, centre):
    radius = 3 * sig
    for x in range(max(0, round(centre[0]-radius)), min(img.shape[0], round(centre[0]+radius+1))):
        for y in range(max(0, round(centre[1]-radius)), min(img.shape[1], round(centre[1]+radius+1))):
            point = (x, y)
            distance_to_centre = cv2.norm((centre[0] - x, centre[1] - y))
            if distance_to_centre > radius:
                continue
            gauss_value = gauss_1d(distance_to_centre, sig)
            img[point] += gauss_value
    
    
def get_img_reconstructed_from_labels(centre_points):
    img_reconstructed = np.zeros(shape=(IR_CAMERA_RESOLUTION[0], 
                                        IR_CAMERA_RESOLUTION[1]))

    for centre_point in centre_points:
        centre_point = centre_point[::-1]  # reversed x and y in 
        draw_gauss(img=img_reconstructed, 
                   centre=[c for c in centre_point], 
                   sig=3)
    
    #img_int = (img_reconstructed * (NUMBER_OF_OUPUT_CLASSES-1)).astype('int')
    #return img_int
    return img_reconstructed



cp = _validation_data_1.centre_points[777]
img = get_img_reconstructed_from_labels(cp)
plt.imshow(img)
print(np.sum(img)/1)

# for sig == 3, it is between 45 - 55, depending whether people are on the edge 
sum_of_values_for_one_person = 51.35  # total sum of pixels for one person on the reconstructed image, of course it changes with circle radius, etc. Calculated as average from all training data








class IrPersonsUnetTrainDataset(torch.utils.data.Dataset):
    def __init__(self, augmented_data: AugmentedBatchesTrainingData, transform=None):
        self.augmented_data = AugmentedBatchesTrainingData
        self.transform = transform
        self._index_to_batch_and_subindex_map = {}
        
        self._cache = {}
        
        i = 0
        for batch in augmented_data.batches:
            for j in range(len(batch.raw_ir_data)):
                self._index_to_batch_and_subindex_map[i] = (batch, j) 
                i += 1
        
    def __len__(self):
        return len (self._index_to_batch_and_subindex_map)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            raise Exception("Not supported")
        
        if idx not in self._cache:
            batch, subindex = self._index_to_batch_and_subindex_map[idx]
            frame = batch.normalized_ir_data[subindex][np.newaxis, :, :]

            batch, subindex = self._index_to_batch_and_subindex_map[idx]
            centre_points = batch.centre_points[subindex]
            img_reconstructed = get_img_reconstructed_from_labels(centre_points)
            img_reconstructed_3d = img_reconstructed

            result = frame, img_reconstructed_3d
            self._cache[idx] = result
            
        return self._cache[idx]

    def get_number_of_persons_for_frame(self, idx):
        batch, subindex = self._index_to_batch_and_subindex_map[idx]
        return len(batch.centre_points[subindex])
        
    
    
    
training_dataset = IrPersonsUnetTrainDataset(augmented_data_training)
validation_dataset = IrPersonsUnetTrainDataset(augmented_data_validation)


# it makes no sense to split all data, as most of the frames are almost identical
# training_dataset, validation_dataset = torch.utils.data.random_split(all_data_dataset, [training_data_len, validation_data_len])


trainloader = torch.utils.data.DataLoader(training_dataset, batch_size=16, shuffle=True, worker_init_fn=seed_worker)
valloader = torch.utils.data.DataLoader(validation_dataset, batch_size=16, shuffle=True, worker_init_fn=seed_worker)


print(len(trainloader))
print(len(valloader))



xb, yb = next(iter(trainloader))
xb.shape, yb.shape


plt.imshow(yb[0].numpy().squeeze())

print(f'number of persosn based on sum: {np.sum(yb[0].numpy()) / sum_of_values_for_one_person:.3f}')









In [None]:
# Calculate sum of pixels for one person on the image

# total_sum_of_pixels = 0
# total_number_of_people = 0

# for i in range(len(training_dataset)):
#     reconstructed_frame = training_dataset[i][1]
    
#     sum_of_pixels = np.sum(reconstructed_frame)
#     numbe_of_people = training_dataset.get_number_of_persons_for_frame(i)
    
#     plt.imshow(reconstructed_frame)

#     total_sum_of_pixels += sum_of_pixels
#     total_number_of_people += numbe_of_people


# average_pixels_per_person = total_sum_of_pixels / total_number_of_people
# print(f'average_pixels_per_person={average_pixels_per_person}')  # 51.35

In [None]:
def validate_model(loader, model):
    correct_count = 0
    tested_frames = 0
    number_of_frames_with_n_persons = {}
    number_of_frames_with_n_persons_predicted_correctly = {}

    persons_error_sum = 0

    for frame, labels in loader:
        with torch.no_grad():
            outputs = model(frame.to(device)).to(cpu_device)
            
        for i in range(len(labels)):
            predicted_img = outputs[i].numpy()

            pred_people = np.sum(predicted_img) / sum_of_values_for_one_person
            pred_label = round(pred_people)

            true_label = round(np.sum(labels.numpy()[i]) / sum_of_values_for_one_person)  # not true entirely, but good enough for testing. One would need to obtain real number of people on image

            persons_error_sum += abs(pred_people - true_label)

            number_of_frames_with_n_persons[pred_label] = \
                number_of_frames_with_n_persons.get(pred_label, 0) + 1

            if true_label == pred_label:
                correct_count += 1
                number_of_frames_with_n_persons_predicted_correctly[pred_label] = \
                    number_of_frames_with_n_persons_predicted_correctly.get(pred_label, 0) + 1

            tested_frames += 1

    average_prediction_error = persons_error_sum / tested_frames
    model_accuracy = correct_count / tested_frames
    
    print(f"Number of tested frames: {tested_frames}")
    print(f"Model Accuracy = {model_accuracy}")
    print('Predicted:\n' + '\n'.join([f'   {count} frames with {no} persons' for no, count in number_of_frames_with_n_persons.items()]))
    print('Predicted correctly:\n' + '\n'.join([f'   {count} frames with {no} persons' for no, count in number_of_frames_with_n_persons_predicted_correctly.items()]))
    print(f'average_prediction_error: {average_prediction_error}')
    
    return model_accuracy, average_prediction_error


In [None]:
xb, yb = next(iter(trainloader))
xb = xb.to(device)

xb.shape


# unet(xb).shape

In [None]:
from typing import Tuple
import torch
from torch import nn
class AutoEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.encoder = Encoder(in_channels)
        self.conv = DoubleConv(64, 128, 3, 1)
        self.upconv1 = ExpandBlock(128, 64, 3, 1)
        self.upconv2 = ExpandBlock(64, 32, 3, 1)
        self.out_conv = nn.Conv2d(32, out_channels, kernel_size=1)
    def forward(self, x):
        # downsampling part
        x, conv2, conv1 = self.encoder(x)
        x = self.conv(x)
        x = self.upconv1(x, conv2)
        x = self.upconv2(x, conv1)
        x = self.out_conv(x)
        
        x = x[:, 0, :, :]  # get rid of one dimension
        return x


class ExpandBlock(nn.Sequential):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: int):
        super().__init__()
        self.conv_transpose = nn.ConvTranspose2d(
            in_channels, in_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1
        )
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x: torch.Tensor, encoder_features: torch.Tensor) -> torch.Tensor:
        x = self.conv_transpose(x)
        x = torch.cat((x, encoder_features), dim=1)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x


class Encoder(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.conv1 = ContractBlock(in_channels, 32, 3, 1)
        self.conv2 = ContractBlock(32, 64, 3, 1)
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x, conv1 = self.conv1(x)
        x, conv2 = self.conv2(x)
        return x, conv2, conv1
    def forward_simple(self, x: torch.Tensor) -> torch.Tensor:
        x = self.input_conv(x)
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class ContractBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels, kernel_size, padding=padding)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        features = self.conv(x)
        x = self.pool(features)
        return x, features


class DoubleConv(nn.Sequential):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: int = 0):
        super().__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )


unet = AutoEncoder(1, 1).double()
unet = unet.to(device)


unet.encoder.load_state_dict(torch.load(pretrained_encoder_file_path))

unet.encoder.eval()


In [None]:
def train(model, train_dl, valid_dl, loss_fn, optimizer, epochs=1):
    start = time.time()
    train_loss, valid_loss, valid_error, accuracy_vec = [], [], [], []

    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch, epochs - 1))
        print('-' * 10)

        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train(True)  # Set trainind mode = true
                dataloader = train_dl
            else:
                model.train(False)  # Set model to evaluate mode
                dataloader = valid_dl

            running_loss = 0.0

            step = 0
            for x, y in dataloader:
                if step != 0:
                    if randrange(200) != 1:  # do not train on every frame in each epoch
                        continue
                
                step += 1
                
                x = x.to(device)
                y = y.to(device)
                

                if phase == 'train':
                    optimizer.zero_grad()
                    outputs = model(x)
                    loss = loss_fn(outputs, y)
                    loss.backward()
                    optimizer.step()
                else:
                    with torch.no_grad():
                        outputs = model(x)
                        loss = loss_fn(outputs, y)
                                

                running_loss += loss.item()
                
            epoch_loss = running_loss / step
            lr = optimizer.param_groups[0]['lr']
            print(f'{phase} Loss: {epoch_loss:.4f}. lr={lr}')
            train_loss.append(epoch_loss) if phase=='train' else valid_loss.append(epoch_loss)
            
        
        # Plot one frame
        with torch.no_grad():
            example_frames = next(iter(trainloader))[0].to(device)
            model_ouput_frames = model(example_frames).to(cpu_device)
            model_ouput_frame = model_ouput_frames[0]
            plt.imshow(model_ouput_frame)


        # full validation of the model
        accuracy, prediction_error = validate_model(loader=valid_dl, model=model)
        valid_error.append(prediction_error)
        accuracy_vec.append(accuracy)
            
#         if epoch > 30 and valid_error and valid_error[-1] < 0.6:  # TODO - different value for different validation data
#             print('Training finished, results good enough...')
#             break
            
            
    time_elapsed = time.time() - start
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')    
    return train_loss, valid_loss, valid_error, accuracy_vec    


loss_fn = nn.L1Loss()
# loss_fn = nn.MSELoss()  # sometimes learn strangely

opt = torch.optim.Adam(unet.parameters(), lr=0.001)
train_loss, valid_loss, valid_error, accuracy_vec = train(model=unet, 
                               train_dl=trainloader, 
                               valid_dl=valloader, 
                               loss_fn=loss_fn, 
                               optimizer=opt, 
                               epochs=450)

if train_loss[-1] > 0.04:
    raise Exception("Training error, reinitialize weights. Network sometimes doesn't learn as it should")

In [None]:
plt.plot(train_loss)
plt.plot(valid_loss)
plt.legend(['train', 'valid'])
plt.xlabel('epoch')
plt.ylabel('loss')
plt.grid()
ax = plt.gca()
ax.set_yscale('log')

In [None]:
plt.plot(valid_error)
plt.xlabel('epoch')
plt.ylabel('valid_error')
plt.grid()

In [None]:
plt.plot(accuracy_vec)
plt.xlabel('epoch')
plt.ylabel('accuracy_vec')
plt.grid()

In [None]:
validate_model(loader=valloader, model=unet)

In [None]:
# PATH = "/home/przemek/Projects/thermo-presence/thermo-presence/data_collection/src/trained_model/unet_v2_cpu1"


# unet_cpu = unet.to(cpu_device)
# torch.save(unet_cpu.state_dict(), PATH)


# unet.load_state_dict(torch.load(PATH))
# unet.eval()

In [None]:
class IrPersonsDataset(torch.utils.data.Dataset):
    def __init__(self, augmented_data: AugmentedBatchesTrainingData, transform=None):
        self.augmented_data = AugmentedBatchesTrainingData
        self.transform = transform
        self._index_to_batch_and_subindex_map = {}
        
        i = 0
        for batch in augmented_data.batches:
            for j in range(len(batch.raw_ir_data)):
                self._index_to_batch_and_subindex_map[i] = (batch, j) 
                i += 1
        
    def __len__(self):
        return len (self._index_to_batch_and_subindex_map)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            raise Exception("Not supported")

        batch, subindex = self._index_to_batch_and_subindex_map[idx]
        frame = batch.normalized_ir_data[subindex][np.newaxis, :, :]
        return frame, len(batch.centre_points[subindex])


In [None]:
VALIDATION_DIRS_2 = VALIDATION_DIRS_1


data_with_people_cound = load_data_for_labeled_batches(labeled_batch_dirs=VALIDATION_DIRS_2)

augmented_data_with_people_count = AugmentedBatchesTrainingData()
augmented_data_with_people_count.add_training_batch(data_with_people_cound, flip_and_rotate=False)

dataset_with_real_people_count = IrPersonsDataset(augmented_data_with_people_count)
loader_with_real_people_count = torch.utils.data.DataLoader(dataset_with_real_people_count, batch_size=1, shuffle=False, worker_init_fn=seed_worker)



In [None]:
def validate_model_with_real_number_of_persons(loader, model, data_plotting_interval=1000):
    correct_count = 0
    tested_frames = 0
    number_of_frames_with_n_persons = {}
    number_of_frames_with_n_persons_predicted_correctly = {}

    persons_error_sum = 0
    
    vec_real_number_of_persons = []
    vec_predicted_number_of_persons = []

    for frame, labels in loader:
        for i in range(len(labels)):
            with torch.no_grad():
                outputs = model(frame.to(device)).to(cpu_device)
            predicted_img = outputs[i].numpy()

            pred_people = np.sum(predicted_img) / sum_of_values_for_one_person
            pred_label = round(pred_people)

            true_label = labels.numpy()[i]

            persons_error_sum += abs(pred_people - true_label)

            number_of_frames_with_n_persons[pred_label] = \
                number_of_frames_with_n_persons.get(pred_label, 0) + 1

            if true_label == pred_label:
                correct_count += 1
                number_of_frames_with_n_persons_predicted_correctly[pred_label] = \
                    number_of_frames_with_n_persons_predicted_correctly.get(pred_label, 0) + 1

            tested_frames += 1
            
            vec_real_number_of_persons.append(true_label)
            vec_predicted_number_of_persons.append(pred_people)
            
            if tested_frames % data_plotting_interval == 0:
                plt.imshow(frame[i, 0, :, :])
                plt.show()
                print(f'true_label={true_label}, pred_people={pred_people}')
                plt.imshow(predicted_img)
                plt.show()
                print('#'*30)


    average_prediction_error = persons_error_sum / tested_frames
    model_accuracy = correct_count / tested_frames
    
    print(f"Number of tested frames: {tested_frames}")
    print(f"Model Accuracy = {model_accuracy}")
    print('Predicted:\n' + '\n'.join([f'   {count} frames with {no} persons' for no, count in number_of_frames_with_n_persons.items()]))
    print('Predicted correctly:\n' + '\n'.join([f'   {count} frames with {no} persons' for no, count in number_of_frames_with_n_persons_predicted_correctly.items()]))
    print(f'average_prediction_error: {average_prediction_error}')
    
    return model_accuracy, average_prediction_error, vec_real_number_of_persons, vec_predicted_number_of_persons


_, _, real_vec, predicted_vec = validate_model_with_real_number_of_persons(loader=loader_with_real_people_count, model=unet)


In [None]:
fig=plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')

plt.grid()
plt.plot(real_vec, linewidth=2.5)
plt.plot(predicted_vec, linewidth=1)

plt.title('number of people')
plt.legend(['real', 'predicted'])
plt.xlabel('time [0.5*s]')
plt.ylabel('people count')
