In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import transforms
from model_functions import *
from scf import *
from data_loader_scf import *
from model_types import efficientnet_1, custom_model1, efficientnet_2
from torchsummary import summary
from torchvision import models
from copy import deepcopy

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
effnet_model = efficientnet_1.Model(in_channels=1)
effnet_model.unfreeze_params()

state_dict_path = r'C:\AAA\FYP\mimo-radar-drone-detection-fyp\final code\saved_models_type\efficientnet_clipped\from_epoch_40\efficientnet_unfreezed_40_10_epoch_82.pt'
state_dict = torch.load(state_dict_path)

effnet_model.load_state_dict(state_dict['Model state'])
effnet_model = effnet_model.to(device)

In [None]:
train_dir = '../../Data/Dataset/Train/scf/'
test_dir = '../../Data/Dataset/Test/scf/'

train_data = data_loader(train_dir, shuffle=True, batch_size=128, binary=True)
val_data = data_loader(test_dir, shuffle=False, batch_size=128, binary=True)

In [None]:
accuracy = evaluate_model_binary(effnet_model, val_data, device)

In [None]:
from data_loader_scf_visualization import *

validation_data = data_loader_visualize(test_dir, shuffle=False, batch_size=1, binary=True)

from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
import numpy as np

y_pred = []
y_true = []

effnet_model.eval()
effnet_model = effnet_model.to(device)

df = {i:{'true':[], 'false':[]} for i in range(6)}

# with torch.no_grad():
#     for inputs, labels, label_name, class_label in validation_data:
#         if 'reference' in label_name[0] or 'nlos' in label_name[0]:
#             inputs = inputs.to(dtype=torch.float)
#             inputs, labels = inputs.to(device), labels.to(device).reshape(-1,1)
#             output = effnet_model(inputs)
#             predicted = (output > 0).int().reshape(-1).data.cpu().numpy()
#             y_pred.extend(predicted)

#             labels = labels.reshape(-1).data.cpu().numpy()
#             y_true.extend(labels)
            
#             if predicted[0] == labels[0]:
#                 df[int(class_label)]['true'].append(label_name[0])
#             else:
#                 df[int(class_label)]['false'].append(label_name[0])


with torch.no_grad():
    for inputs, labels, label_name, class_label in validation_data:
        inputs = inputs.to(dtype=torch.float)
        inputs, labels = inputs.to(device), labels.to(device).reshape(-1,1)
        output = effnet_model(inputs)
        predicted = (output > 0).int().reshape(-1).data.cpu().numpy()
        y_pred.extend(predicted)

        labels = labels.reshape(-1).data.cpu().numpy()
        y_true.extend(labels)
        
        if predicted[0] == labels[0]:
            df[int(class_label)]['true'].append(label_name[0])
        else:
            df[int(class_label)]['false'].append(label_name[0])

                
classes = ('reference', 'drone')

cf_matrix = confusion_matrix(y_true, y_pred)
df_cm = pd.DataFrame(cf_matrix, index = [i for i in classes],
                     columns = [i for i in classes])
plt.figure(figsize = (12,7))
sn.heatmap(df_cm, annot=True, cmap='Blues')

In [None]:
df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index = [i for i in classes],
                     columns = [i for i in classes])
plt.figure(figsize = (12,7))
sn.heatmap(df_cm, annot=True, cmap='Blues', annot_kws={"fontsize": 20})

In [None]:
index = 5
print("true", len(df[index]['true']))
print('false', len(df[index]['false']))
print('ratio', len(df[index]['true'])/(len(df[index]['true'])+len(df[index]['false'])))

df[index]

In [None]:
# for i in (1,4,5):
#     df[i]['true'] = [data for data in df[i]['true'] if 'nlos' in data]
#     df[i]['false'] = [data for data in df[i]['false'] if 'nlos' in data]

In [None]:
# index = 5
# print("true", len(df[index]['true']))
# print('false', len(df[index]['false']))
# print('ratio', len(df[index]['true'])/(len(df[index]['true'])+len(df[index]['false'])))

In [None]:
# from data_loader_scf_visualization import *

# test_dir = '../../Data/Dataset/Test/scf/visualize/'

# val_data_visualize = data_loader_visualize(test_dir, shuffle=False, batch_size=1, binary=True)

# for data, labels, label_name, class_name in val_data_visualize:
#     print(label_name)
#     plt.matshow(torch.clamp(data[0][1], max=500))
#     plt.colorbar()
#     plt.show()