In [1]:
# BSC THESIS - MACHINE LEARNING FOR UNEXPLODED ORDNANCE (UXO)
# THIS NOTEBOOK IS DEVELOPED BY JONAS KNUDSEN WITH CODE FROM 
# https://towardsdatascience.com/understanding-and-implementing-faster-r-cnn-a-step-by-step-guide-11acfff216b0

In [None]:
# %%%%%%%%%% MODULES %%%%%%%%%% #
import os
import torch
import torchvision
from torchvision import ops
import torchvision.models as models
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchmetrics.classification import BinaryStatScores
import time
import tqdm

In [2]:
# %%%%%%%%%% GLOBAL  %%%%%%%%%% #
# TRANSFER MODEL TO GPU IF AVAILABLE
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu") # 'cpu' works
# Not strictly necessary but can potentially speed up the network and improve memory footprint
torch.backends.cudnn.benchmark = True
#print(f"Using device: {device}")    

dataset_start_number = 1000001
dsn = dataset_start_number

In [3]:
# %%%%%%%%%% DATASET %%%%%%%%%% #

### JONAS ###
class UXO_dataset(Dataset):
    # takes a list of file ids and a directory root
    def __init__(self, file_list, file_root):
        self.file_list = file_list
        self.file_root = file_root
        
        self.folder_name = self.file_root.split('/')[-1]
        
        self.all_images, self.all_BBs = self.get_data()
    
    def __len__(self):
        #return(len(self.file_list))
        return self.all_images.shape[0]
    
    def __getitem__(self, index):
        image = self.all_images[index].float()
        BB = self.all_BBs[index]
        
        return image, BB
    
    
    def get_data(self):
        start_time_global = time.perf_counter()
        all_images = []
        all_BBs = []
        
        if self.folder_name == 'augmentedData' or self.folder_name == 'augmentedData_hard':
            split = 10000
        
        if self.folder_name == 'augmentedData_big' or self.folder_name == 'augmentedData_hard_big':
            split = 25000
        
        count = 1
        
        print_count_number = 5000
        
        for i in self.file_list:
            

            if i >= dsn+0*split and i < dsn + 1*split:
                survey_number = '1'
            if i >= dsn+1*split and i < dsn + 2*split:
                survey_number = '2'
            if i >= dsn+2*split and i < dsn + 3*split:
                survey_number = '3'
            if i >= dsn+3*split and i < dsn + 4*split:
                survey_number = '4'

            image_dir = f'{self.file_root}/survey_{survey_number}/images/image_'+str(i)+'.npy'
            BB_dir = f'{self.file_root}/survey_{survey_number}/labels/BB_'+str(i)+'.npy'

            
            if count % print_count_number == 0:
                if count == print_count_number:
                    start_time = start_time_global
                
                end_time = time.perf_counter()
                # Calculate the elapsed time
                elapsed_time = np.round(end_time - start_time,1)

                # Print the elapsed time
                print('Loaded '+ str(count) + ' images. ' + f'Elapsed time {elapsed_time} seconds' )
                start_time = time.perf_counter()
            
            image = np.load(image_dir)
            BB = np.load(BB_dir)
            
            # change from [row,col,height,width]
            # to the form [x  ,y  ,width,height]
            if BB.shape == (0,):
                BB = np.array([[-1,-1,-1,-1]],dtype=BB.dtype)
            
            BB = BB[:,[1,0,3,2]]
            
            
            image_tensor = torch.from_numpy(image).permute(2,0,1)
            #image_tensor = image_tensor.float()
            
            BB_tensor = torch.from_numpy(BB) 
            BB_tensor = ops.box_convert(BB_tensor,'cxcywh','xyxy')
            
            all_images.append(image_tensor)
            all_BBs.append(BB_tensor)
            count+=1
        
        all_images_stacked = torch.stack(all_images, dim=0)
        all_BBs_padded = pad_sequence(all_BBs,batch_first=True,padding_value=-1)
        
        end_time_global = time.perf_counter()

        # Calculate the elapsed time
        elapsed_time_global = np.round(end_time_global - start_time_global,1)

        # Print the elapsed time
        print('Loaded all '+ str(count-1) + ' images. ' + 'Total time: ' + str(elapsed_time_global)+ ' seconds')
        return all_images_stacked, all_BBs_padded
    
class UXO_dataset_old(Dataset):
    # takes a list of file ids and a directory root
    def __init__(self, file_list, file_root):
        self.file_list = file_list
        self.file_root = file_root
        
        self.all_images, self.all_BBs = self.get_data()
    
    def __len__(self):
        #return(len(self.file_list))
        return self.all_images.shape[0]
    
    def __getitem__(self, index):
        image = self.all_images[index].float()
        BB = self.all_BBs[index]
        
        return image, BB
    
    
    def get_data(self):
        
        all_images = []
        all_BBs = []
        
        for i in self.file_list:
            image_dir = f'{self.file_root}/image_'+str(i)+'.npy'
            BB_dir = f'{self.file_root}/BB_'+str(i)+'.npy'
            
            image = np.load(image_dir)
            BB = np.load(BB_dir)
            print(i)
            # change from [row,col,height,width]
            # to the form [x  ,y  ,width,height]
            if BB.shape == (0,):
                BB = np.array([[-1,-1,-1,-1]],dtype=BB.dtype)
            
            BB = BB[:,[1,0,3,2]]
        
            
            image_tensor = torch.from_numpy(image).permute(2,0,1)
            #image_tensor = image_tensor.float()
            
            BB_tensor = torch.from_numpy(BB) 
            BB_tensor = ops.box_convert(BB_tensor,'cxcywh','xyxy')
            
            all_images.append(image_tensor)
            all_BBs.append(BB_tensor)
        
        all_images_stacked = torch.stack(all_images, dim=0)
        all_BBs_padded = pad_sequence(all_BBs,batch_first=True,padding_value=-1)
        
        return all_images_stacked, all_BBs_padded
#############

In [4]:
# %%%%%%%%%%  MODEL  %%%%%%%%%% #

class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        model = models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)
        model = model.float() # EDIT: Jonas added this
        
        req_layers = list(model.children())[:5] # EDIT: Jonas changed 8 to 5
        self.backbone = nn.Sequential(*req_layers)
        for param in self.backbone.named_parameters():
            param[1].requires_grad = True
        
    def forward(self, img_data):
        return self.backbone(img_data)
    

### JONAS ###    
class FeatureExtractor_random(nn.Module):
    def __init__(self):
        super().__init__()
        model = models.resnet50()
        model = model.float()# EDIT: Jonas added this
        
        req_layers = list(model.children())[:5] # EDIT: Jonas changed 8 to 5
        self.backbone = nn.Sequential(*req_layers)
        for param in self.backbone.named_parameters():
            param[1].requires_grad = True
        
    def forward(self, img_data):
        return self.backbone(img_data)

class FeatureExtractor_hyper(nn.Module):
    def __init__(self):
        super().__init__()
        # input is (b,3,256,256)
        
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(11,11), stride=2, padding=5)
        )
        
        self.block2 = nn.Sequential(
            nn.ReLU(),
            nn.BatchNorm2d(num_features=64), 
            nn.MaxPool2d(kernel_size=(2,2),stride=(2,2)),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(5,5), padding=2)
        )
        
        self.block3 = nn.Sequential(
            nn.ReLU(),
            nn.BatchNorm2d(num_features=64),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3,3), padding=1)
        )
        
        self.block4 = nn.Sequential(
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,2), stride=(2,2)),
            nn.Conv2d(128, 128, 3, padding=1)
        )
        
        self.block5 = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(128, 64, 3, padding=1)
        )
        
    def forward(self, img_data):
        
        img_data = img_data.float()        
        out1 = self.block1(img_data)        
        out2 = self.block2(out1)        
        out3 = self.block3(out2)        
        out4 = self.block4(out3)        
        out5 = self.block5(out4)
                
        # Assembling the Hyper Feature Map
        
        # Reduce the row and col by a factor of 1/2
        hyper_out1 = nn.MaxPool2d(kernel_size=(2,2), stride=(2,2))
        hyper_out1 = hyper_out1(out1)
        
        hyper_out3 = out3
        
        # Increase the row and col by a factor of 2
        hyper_out5 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1 ,output_padding=1).to(device)
        hyper_out5 = hyper_out5(out5)
        
        hyper_out = torch.cat((hyper_out1,out3,hyper_out5),dim=1)

        return hyper_out
    
class FeatureExtractor_hyperV2(nn.Module):
    def __init__(self):
        super().__init__()
        # input is (b,3,256,256)
        
        self.dropout_p = 0.3
        
        self.block1 = nn.Sequential(
            self.conv_bnorm_relu(in_channels=3, out_channels=64),
            nn.MaxPool2d(kernel_size=(2,2), stride=(2,2)),
            nn.Dropout(p=self.dropout_p),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), stride=1, padding=1)
        )
        
        self.block2 = nn.Sequential(
            self.conv_bnorm_relu(in_channels=64, out_channels=128),
            nn.MaxPool2d(kernel_size=(2,2), stride=(2,2)),
            nn.Dropout(p=self.dropout_p),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3,3), stride=1, padding=1)
        )
        
        self.block3 = nn.Sequential(
            self.conv_bnorm_relu(in_channels=128, out_channels=64),
            nn.MaxPool2d(kernel_size=(2,2), stride=(2,2)),
            nn.Dropout(p=self.dropout_p),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), stride=1, padding=1)
        )
        
    def forward(self, img_data):
        
        img_data = img_data.float()        

        out1 = self.block1(img_data)
        
        out2 = self.block2(out1)
        
        out3 = self.block3(out2)
        
        # Assembling the Hyper Feature Map
        
        # Reduce the row and col by a factor of 1/2
        hyper_out1 = nn.MaxPool2d(kernel_size=(2,2), stride=(2,2))
        hyper_out1 = hyper_out1(out1)
        
        hyper_out2 = out2
        
        # Increase the row and col by a factor of 2
        hyper_out3 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1 ,output_padding=1).to(device)
        hyper_out3 = hyper_out3(out3)
        
        hyper_out = torch.cat((hyper_out1,out2,hyper_out3),dim=1)
        
        return hyper_out
    
    def conv_bnorm_relu(self, in_channels, out_channels):
        
        conv_bnorm_relu = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3,3), stride=1, padding=1),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU()
        )
        
        return conv_bnorm_relu
    
class FeatureExtractor_resnet50hyper(nn.Module):
    def __init__(self):
        super().__init__()
        # input is (b,3,256,256)
        
        # Load the pretrained ResNet-50 model
        model = models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)

        # Flatten all modules recursively
        self.all_modules = self.flatten_modules_resnet50hyper(model)
        self.all_modules = self.all_modules[:27]
        # Remove module 11 and 12 since these are only for the residual connection which is not used
        del self.all_modules[11:13]  
        
        self.all_modules_seq = nn.Sequential(*self.all_modules)
        
    def forward(self, img_data):
        
        img_data = img_data.float()        
        
        out3  = nn.Sequential(*self.all_modules[0:4])(img_data) # After 1 convolution
        
        out7  = nn.Sequential(*self.all_modules[4:8])(out3) # After 3 convolutions
        
        out14 = nn.Sequential(*self.all_modules[8:15])(out7) # After 6 convolutions
        
        out21 = nn.Sequential(*self.all_modules[15:22])(out14) # After 9 convolutions
        
        out = torch.cat((out3, out7, out14, out21),dim=1)

        return out
    
    def flatten_modules_resnet50hyper(self,model):
        modules = []
        for m in list(model.children()):
            if isinstance(m, (nn.Sequential, nn.ModuleList, nn.ModuleDict)):
                modules.extend(self.flatten_modules_resnet50hyper(m))
            elif isinstance(m, nn.Module):
                if isinstance(m, models.resnet.Bottleneck):
                    for sub_module in self.flatten_modules_resnet50hyper(m):
                        modules.append(sub_module)
                else:
                    modules.append(m)
        return modules
#############


class ProposalModule(nn.Module):
    def __init__(self, in_features, hidden_dim=512, n_anchors=9, p_dropout=0.3):
        super().__init__()
        self.n_anchors = n_anchors
        self.conv1 = nn.Conv2d(in_features, hidden_dim, kernel_size=3, padding=1) 
        self.dropout = nn.Dropout(p_dropout)
        
        # EDIT: stride can be changed!
        self.conf_head = nn.Conv2d(hidden_dim, n_anchors, kernel_size=1)
        self.reg_head = nn.Conv2d(hidden_dim, n_anchors * 4, kernel_size=1)
        
        
    def forward(self, feature_map, pos_anc_ind=None, neg_anc_ind=None, pos_anc_coords=None):
        # determine mode
        if pos_anc_ind is None or neg_anc_ind is None or pos_anc_coords is None:
            mode = 'eval'
        else:
            mode = 'train'
            
        out = self.conv1(feature_map)
        out = F.relu(self.dropout(out))
        
        reg_offsets_pred = self.reg_head(out) # (B, A*4, hmap, wmap)
        conf_scores_pred = self.conf_head(out) # (B, A, hmap, wmap)
        
        if mode == 'train': 
            # get conf scores 
            conf_scores_pos = conf_scores_pred.flatten()[pos_anc_ind]
            conf_scores_neg = conf_scores_pred.flatten()[neg_anc_ind]
            # get offsets for +ve anchors
            offsets_pos = reg_offsets_pred.contiguous().view(-1, 4)[pos_anc_ind]
            # generate proposals using offsets
            proposals = generate_proposals(pos_anc_coords, offsets_pos)
            # EDIT: added conf_scores_pred and reg_offsets_pred
            return conf_scores_pred,reg_offsets_pred, conf_scores_pos, conf_scores_neg, offsets_pos, proposals 
            
        elif mode == 'eval':
            return conf_scores_pred, reg_offsets_pred

class RegionProposalNetwork(nn.Module):
    def __init__(self, img_size, out_size, out_channels):
        super().__init__()
        
        self.img_height, self.img_width = img_size
        self.out_h, self.out_w = out_size
        
        # downsampling scale factor 
        self.width_scale_factor = self.img_width // self.out_w
        self.height_scale_factor = self.img_height // self.out_h 
        
        # ANCHOR BOX DEFINITION
        # scales and ratios for anchor boxes 
        self.anc_scales = [10] #[5,10] #[2, 4, 6] # [5,10]
        self.anc_ratios = [1]#[0.5,1,2]# [0.5, 1, 1.5] # [1,2,0.5]
        self.n_anc_boxes = len(self.anc_scales) * len(self.anc_ratios)
        
        self.n_anc_boxes_total = self.n_anc_boxes*self.out_h*self.out_w
        
        # IoU thresholds for +ve and -ve anchors
        self.pos_thresh = 0.7
        self.neg_thresh = 0.3
        
        # weights for loss
        self.w_conf = 1
        self.w_reg = 5
        
        self.feature_extractor = FeatureExtractor().to(device) # EDIT
        self.proposal_module = ProposalModule(out_channels, n_anchors=self.n_anc_boxes).to(device) # EDIT
        
    def forward(self, images, gt_bboxes): # EDIT: removed gt_classes
        batch_size = images.size(dim=0)
        feature_map = self.feature_extractor(images)
        
        # generate anchors
        anc_pts_x, anc_pts_y = gen_anc_centers(out_size=(self.out_h, self.out_w))
        anc_base = gen_anc_base(anc_pts_x, anc_pts_y, self.anc_scales, self.anc_ratios, (self.out_h, self.out_w))
        anc_boxes_all = anc_base.repeat(batch_size, 1, 1, 1, 1)
        
        # get positive and negative anchors amongst other things
        gt_bboxes_proj = project_bboxes(gt_bboxes, self.width_scale_factor, self.height_scale_factor, mode='p2a')
        
        # EDIT: Removed GT_class_pos right after GT_offsets
        positive_anc_ind, negative_anc_ind, GT_conf_scores, \
        GT_offsets, positive_anc_coords, \
        negative_anc_coords, positive_anc_ind_sep = get_req_anchors(anc_boxes_all, gt_bboxes_proj)
        
        # pass through the proposal module
        conf_scores_pred, reg_offsets_pred, conf_scores_pos, conf_scores_neg, offsets_pos, proposals \
        = self.proposal_module(feature_map, positive_anc_ind, negative_anc_ind, positive_anc_coords)
        
        cls_loss = calc_cls_loss(conf_scores_pos, conf_scores_neg, batch_size)
        reg_loss = calc_bbox_reg_loss(GT_offsets, offsets_pos, batch_size)
        
        # confusion = calc_confusion(conf_scores_pred.flatten(),positive_anc_ind)
        
        total_rpn_loss = self.w_conf * cls_loss + self.w_reg * reg_loss
        
        # EDIT: Removed GT_class_pos right after positive_anc_ind_sep
        return total_rpn_loss, feature_map, proposals, positive_anc_ind_sep #, confusion
    
    def inference(self, images, conf_thresh=0.5, nms_thresh=0.4): # EDIT: CHOOSE TOP 10 PROPOSALS
        with torch.no_grad():
            batch_size = images.size(dim=0)
            feature_map = self.feature_extractor(images)
            
            # generate anchors
            anc_pts_x, anc_pts_y = gen_anc_centers(out_size=(self.out_h, self.out_w))
            anc_base = gen_anc_base(anc_pts_x, anc_pts_y, self.anc_scales, self.anc_ratios, (self.out_h, self.out_w))
            anc_boxes_all = anc_base.repeat(batch_size, 1, 1, 1, 1)
            anc_boxes_flat = anc_boxes_all.reshape(batch_size, -1, 4) # EDIT

            # get conf scores and offsets
            conf_scores_pred, offsets_pred = self.proposal_module(feature_map)
            conf_scores_pred = conf_scores_pred.reshape(batch_size, -1)
            offsets_pred = offsets_pred.reshape(batch_size, -1, 4) # EDIT

            # filter out proposals based on conf threshold and nms threshold for each image
            proposals_final = []
            conf_scores_final = []
            
            for i in range(batch_size):
                conf_scores = torch.sigmoid(conf_scores_pred[i])
                offsets = offsets_pred[i]
                anc_boxes = anc_boxes_flat[i]
                proposals = generate_proposals(anc_boxes.to(device), offsets.to(device))
                # filter based on confidence threshold
                conf_idx = torch.where(conf_scores >= conf_thresh)[0]
                
                conf_scores_pos = conf_scores[conf_idx]
                proposals_pos = proposals[conf_idx]
                # filter based on nms threshold
                nms_idx = ops.nms(proposals_pos, conf_scores_pos, nms_thresh)
                conf_scores_pos = conf_scores_pos[nms_idx]
                proposals_pos = proposals_pos[nms_idx]
                
                # EDIT: The conf_scores_pos are already sorted, so now we choose the top 10 proposals
                # conf_scores_pos = conf_scores_pos[:10]
                # proposals_pos = proposals_pos[:10]
                
                proposals_final.append(proposals_pos)
                conf_scores_final.append(conf_scores_pos)
            
            
            
            
        return proposals_final, conf_scores_final, feature_map
    
class ClassificationModule(nn.Module): # Not used
    def __init__(self, out_channels, n_classes, roi_size, hidden_dim=512, p_dropout=0.3):
        super().__init__()        
        self.roi_size = roi_size
        # hidden network
        self.avg_pool = nn.AvgPool2d(self.roi_size)
        self.fc = nn.Linear(out_channels, hidden_dim)
        self.dropout = nn.Dropout(p_dropout)
        
        # define classification head
        self.cls_head = nn.Linear(hidden_dim, n_classes)
        
    def forward(self, feature_map, proposals_list, gt_classes=None):
        
        if gt_classes is None:
            mode = 'eval'
        else:
            mode = 'train'
        
        # apply roi pooling on proposals followed by avg pooling
        roi_out = ops.roi_pool(feature_map, proposals_list, self.roi_size)
        roi_out = self.avg_pool(roi_out)
        
        # flatten the output
        roi_out = roi_out.squeeze(-1).squeeze(-1)
        
        # pass the output through the hidden network
        out = self.fc(roi_out)
        out = F.relu(self.dropout(out))
        
        # get the classification scores
        cls_scores = self.cls_head(out)
        
        if mode == 'eval':
            return cls_scores
        
        # compute cross entropy loss
        cls_loss = F.cross_entropy(cls_scores, gt_classes.long())
        
        return cls_loss
    
class TwoStageDetector(nn.Module): # Not used
    def __init__(self, img_size, out_size, out_channels, n_classes, roi_size):
        super().__init__() 
        self.rpn = RegionProposalNetwork(img_size, out_size, out_channels)
        self.classifier = ClassificationModule(out_channels, n_classes, roi_size)
        
    def forward(self, images, gt_bboxes, gt_classes):
        total_rpn_loss, feature_map, proposals, \
        positive_anc_ind_sep, GT_class_pos = self.rpn(images, gt_bboxes, gt_classes)
        
        # get separate proposals for each sample
        pos_proposals_list = []
        batch_size = images.size(dim=0)
        for idx in range(batch_size):
            proposal_idxs = torch.where(positive_anc_ind_sep == idx)[0]
            proposals_sep = proposals[proposal_idxs].detach().clone()
            pos_proposals_list.append(proposals_sep)
        
        cls_loss = self.classifier(feature_map, pos_proposals_list, GT_class_pos)
        total_loss = cls_loss + total_rpn_loss
        
        return total_loss
    
    def inference(self, images, conf_thresh=0.5, nms_thresh=0.7):
        batch_size = images.size(dim=0)
        proposals_final, conf_scores_final, feature_map = self.rpn.inference(images, conf_thresh, nms_thresh)
        cls_scores = self.classifier(feature_map, proposals_final)
        
        # convert scores into probability
        cls_probs = F.softmax(cls_scores, dim=-1)
        # get classes with highest probability
        classes_all = torch.argmax(cls_probs, dim=-1)
        
        classes_final = []
        # slice classes to map to their corresponding image
        c = 0
        for i in range(batch_size):
            n_proposals = len(proposals_final[i]) # get the number of proposals for each image
            classes_final.append(classes_all[c: c+n_proposals])
            c += n_proposals
            
        return proposals_final, conf_scores_final, classes_final

# ------------------- Loss Utils ----------------------

def calc_cls_loss(conf_scores_pos, conf_scores_neg, batch_size):
    target_pos = torch.ones_like(conf_scores_pos)
    target_neg = torch.zeros_like(conf_scores_neg)
    
    target = torch.cat((target_pos, target_neg))
    inputs = torch.cat((conf_scores_pos, conf_scores_neg))
     
    loss = F.binary_cross_entropy_with_logits(inputs, target, reduction='sum') #* 1. / batch_size # EDIT
    
    return loss

def calc_bbox_reg_loss(gt_offsets, reg_offsets_pos, batch_size):
    assert gt_offsets.size() == reg_offsets_pos.size()
    loss = F.smooth_l1_loss(reg_offsets_pos, gt_offsets, reduction='sum') #* 1. / batch_size # EDIT
    return loss

### JONAS ###
def calc_confusion(conf_scores_pred,positive_anc_ind):
    n = len(conf_scores_pred)
    test = torch.zeros(n).to(device)
    test_np = positive_anc_ind.detach().cpu().numpy()
    test[test_np] = 1
    test = test.to(device)

    metric = BinaryStatScores(threshold = 0.5).to(device)
    confusion = metric(conf_scores_pred,test)
    
    # [tp, fp, tn, fn, sup] (sup stands for support and equals tp + fn)
    return confusion[:-1]
    
def calc_conf_mat(gt_bboxes_proj,proposals_final,pos_thresh,n_anc_boxes_total,list_or_sum='sum'):
      
    n = len(gt_bboxes_proj)
    
    iou_mat = []
    positive_prop_mask = []
    fn = []
    tp = []
    fp = []
    tn = []
    
    # loop over images
    for i in range(n):
        gt_bboxes_proj_unpadded = unpad_BB(gt_bboxes_proj[i])

        iou_mat.append(ops.box_iou(proposals_final[i].to(device), gt_bboxes_proj_unpadded.to(device)))

        max_iou_per_gt_box, _ = iou_mat[i].max(dim=1, keepdim=True)

        positive_prop_mask.append(torch.logical_and(iou_mat[i] > pos_thresh, max_iou_per_gt_box > 0))


        test = positive_prop_mask[i].sum(dim=0)
        item1 = len(test) - torch.count_nonzero(test)
        fn.append(item1)

        
        item2 = positive_prop_mask[i].sum()
        tp.append(item2)

        item3 = positive_prop_mask[i].shape[0]-tp[i]
        fp.append(item3)
        
        item4 = n_anc_boxes_total - (item1+item2+item3)
        tn.append(item4)
    
    if list_or_sum == 'sum':
        tp = torch.sum(torch.stack(tp))
        fp = torch.sum(torch.stack(fp))
        fn = torch.sum(torch.stack(fn))
        tn = torch.sum(torch.stack(tn))
    elif list_or_sum == 'list':
        tp = torch.stack(tp)
        fp = torch.stack(fp)
        fn = torch.stack(fn)
        tn = torch.stack(tn)
    
    return tp, fp, fn, tn
#############

In [5]:
# %%%%%%%%%%  UTILS  %%%%%%%%%% #
# -------------- Data Utils -------------------

def parse_annotation(annotation_path, image_dir, img_size): # Not used
    '''
    Traverse the xml tree, get the annotations, and resize them to the scaled image size
    '''
    img_h, img_w = img_size

    with open(annotation_path, "r") as f:
        tree = ET.parse(f)

    root = tree.getroot()  
    
    img_paths = []
    gt_boxes_all = []
    gt_classes_all = []
    # get image paths
    for object_ in root.findall('image'):
        img_path = os.path.join(image_dir, object_.get("name"))
        img_paths.append(img_path)
      
        # get raw image size    
        orig_w = int(object_.get("width"))
        orig_h = int(object_.get("height"))
            
        # get bboxes and their labels   
        groundtruth_boxes = []
        groundtruth_classes = []
        for box_ in object_.findall('box'):
            xmin = float(box_.get("xtl"))
            ymin = float(box_.get("ytl"))
            xmax = float(box_.get("xbr"))
            ymax = float(box_.get("ybr"))
        
            # rescale bboxes
            bbox = torch.Tensor([xmin, ymin, xmax, ymax])
            bbox[[0, 2]] = bbox[[0, 2]] * img_w/orig_w
            bbox[[1, 3]] = bbox[[1, 3]] * img_h/orig_h
        
            groundtruth_boxes.append(bbox.tolist())

            # get labels
            label = box_.get("label")
            groundtruth_classes.append(label)

        gt_boxes_all.append(torch.Tensor(groundtruth_boxes))
        gt_classes_all.append(groundtruth_classes)
                
    return gt_boxes_all, gt_classes_all, img_paths

# -------------- Prepocessing utils ----------------

def calc_gt_offsets(pos_anc_coords, gt_bbox_mapping):
    pos_anc_coords = ops.box_convert(pos_anc_coords, in_fmt='xyxy', out_fmt='cxcywh')
    gt_bbox_mapping = ops.box_convert(gt_bbox_mapping, in_fmt='xyxy', out_fmt='cxcywh')

    gt_cx, gt_cy, gt_w, gt_h = gt_bbox_mapping[:, 0], gt_bbox_mapping[:, 1], gt_bbox_mapping[:, 2], gt_bbox_mapping[:, 3]
    anc_cx, anc_cy, anc_w, anc_h = pos_anc_coords[:, 0], pos_anc_coords[:, 1], pos_anc_coords[:, 2], pos_anc_coords[:, 3]

    tx_ = (gt_cx - anc_cx)/anc_w
    ty_ = (gt_cy - anc_cy)/anc_h
    tw_ = torch.log(gt_w / anc_w)
    th_ = torch.log(gt_h / anc_h)

    return torch.stack([tx_, ty_, tw_, th_], dim=-1)

def gen_anc_centers(out_size): # EDIT THIS FOR CHANGING STRIDE!
    out_h, out_w = out_size
    
    anc_pts_x = torch.arange(0, out_w, device=device) + 0.5 # EDIT
    anc_pts_y = torch.arange(0, out_h, device=device) + 0.5 # EDIT
    
    return anc_pts_x, anc_pts_y

def project_bboxes(bboxes, width_scale_factor, height_scale_factor, mode='a2p'):
    assert mode in ['a2p', 'p2a']
    
    batch_size = bboxes.size(dim=0)
    proj_bboxes = bboxes.clone().reshape(batch_size, -1, 4)
    invalid_bbox_mask = (proj_bboxes == -1) # indicating padded bboxes
    proj_bboxes = proj_bboxes.double() # EDIT: Jonas added this
    
    if mode == 'a2p':
        # activation map to pixel image
        proj_bboxes[:, :, [0, 2]] *= width_scale_factor
        proj_bboxes[:, :, [1, 3]] *= height_scale_factor
    else:
        # pixel image to activation map
        proj_bboxes[:, :, [0, 2]] /= width_scale_factor
        proj_bboxes[:, :, [1, 3]] /= height_scale_factor
        
    proj_bboxes = proj_bboxes.int() # EDIT: Jonas added this
    proj_bboxes.masked_fill_(invalid_bbox_mask, -1) # fill padded bboxes back with -1
    proj_bboxes.resize_as_(bboxes)
    
    return proj_bboxes

def generate_proposals(anchors, offsets):
   
    # change format of the anchor boxes from 'xyxy' to 'cxcywh'
    anchors = ops.box_convert(anchors, in_fmt='xyxy', out_fmt='cxcywh')
    
    # apply offsets to anchors to create proposals
    proposals_ = torch.zeros_like(anchors).to(device) # EDIT
    proposals_[:,0] = anchors[:,0] + offsets[:,0]*anchors[:,2]
    proposals_[:,1] = anchors[:,1] + offsets[:,1]*anchors[:,3]
    proposals_[:,2] = anchors[:,2] * torch.exp(offsets[:,2])
    proposals_[:,3] = anchors[:,3] * torch.exp(offsets[:,3])

    # change format of proposals back from 'cxcywh' to 'xyxy'
    proposals = ops.box_convert(proposals_, in_fmt='cxcywh', out_fmt='xyxy') # 

    return proposals

def gen_anc_base(anc_pts_x, anc_pts_y, anc_scales, anc_ratios, out_size):
    n_anc_boxes = len(anc_scales) * len(anc_ratios)
    anc_base = torch.zeros(1, anc_pts_x.size(dim=0) \
                              , anc_pts_y.size(dim=0), n_anc_boxes, 4,device=device) # shape - [1, Hmap, Wmap, n_anchor_boxes, 4]
    # EDIT
    for ix, xc in enumerate(anc_pts_x):
        for jx, yc in enumerate(anc_pts_y):
            anc_boxes = torch.zeros((n_anc_boxes, 4),device=device) # EDIT
            c = 0
            for i, scale in enumerate(anc_scales):
                for j, ratio in enumerate(anc_ratios):
                    w = scale * ratio
                    h = scale

                    xmin = xc - w / 2
                    ymin = yc - h / 2
                    xmax = xc + w / 2
                    ymax = yc + h / 2
                    anc_boxes[c, :] = torch.Tensor([xmin, ymin, xmax, ymax])
                    c += 1

            anc_base[:, ix, jx, :] = ops.clip_boxes_to_image(anc_boxes, size=out_size)
        
    return anc_base

def get_iou_mat(batch_size, anc_boxes_all, gt_bboxes_all):
    
    
    # flatten anchor boxes
    anc_boxes_flat = anc_boxes_all.reshape(batch_size, -1, 4)
    # get total anchor boxes for a single image
    tot_anc_boxes = anc_boxes_flat.size(dim=1)
    
    # create a placeholder to compute IoUs amongst the boxes
    ious_mat = torch.zeros((batch_size, tot_anc_boxes, gt_bboxes_all.size(dim=1))).to(device) # EDIT

    # compute IoU of the anc boxes with the gt boxes for all the images
    for i in range(batch_size):
        gt_bboxes = gt_bboxes_all[i] # need to change format to (x1,y1,x2,y2)
        #gt_bboxes = ops.box_convert(gt_bboxes,'cxcywh','xyxy') # EDIT: Jonas added this
        
        anc_boxes = anc_boxes_flat[i] 
        ious_mat[i, :] = ops.box_iou(anc_boxes.to(device), gt_bboxes.to(device)) # EDIT
        
    return ious_mat

def get_req_anchors(anc_boxes_all, gt_bboxes_all, pos_thresh=0.7, neg_thresh=0.2):
    # EDIT: Jonas removed gt_classes_all (right after gt_bboxes_all)
    '''
    Prepare necessary data required for training
    
    Input
    ------
    anc_boxes_all - torch.Tensor of shape (B, w_amap, h_amap, n_anchor_boxes, 4)
        all anchor boxes for a batch of images
    gt_bboxes_all - torch.Tensor of shape (B, max_objects, 4)
        padded ground truth boxes for a batch of images
    gt_classes_all - torch.Tensor of shape (B, max_objects)
        padded ground truth classes for a batch of images
        
    EDIT: We do not have gt_classes_all 
        
    Returns
    ---------
    positive_anc_ind -  torch.Tensor of shape (n_pos,)
        flattened positive indices for all the images in the batch
    negative_anc_ind - torch.Tensor of shape (n_pos,)
        flattened positive indices for all the images in the batch
    GT_conf_scores - torch.Tensor of shape (n_pos,), IoU scores of +ve anchors
    GT_offsets -  torch.Tensor of shape (n_pos, 4),
        offsets between +ve anchors and their corresponding ground truth boxes
    GT_class_pos - torch.Tensor of shape (n_pos,)
        mapped classes of +ve anchors
    positive_anc_coords - (n_pos, 4) coords of +ve anchors (for visualization)
    negative_anc_coords - (n_pos, 4) coords of -ve anchors (for visualization)
    positive_anc_ind_sep - list of indices to keep track of +ve anchors
    
    EDIT: We do not have GT_class_pos
    '''
    
    # get the size and shape parameters
    B, w_amap, h_amap, A, _ = anc_boxes_all.shape
    N = gt_bboxes_all.shape[1] # max number of groundtruth bboxes in a batch
    
    # get total number of anchor boxes in a single image
    tot_anc_boxes = A * w_amap * h_amap
    
    # get the iou matrix which contains iou of every anchor box
    # against all the groundtruth bboxes in an image
    iou_mat = get_iou_mat(B, anc_boxes_all, gt_bboxes_all)
    
    # for every groundtruth bbox in an image, find the iou 
    # with the anchor box which it overlaps the most
    max_iou_per_gt_box, _ = iou_mat.max(dim=1, keepdim=True)
    
    # get positive anchor boxes
    
    # condition 1: the anchor box with the max iou for every gt bbox
    positive_anc_mask = torch.logical_and(iou_mat == max_iou_per_gt_box, max_iou_per_gt_box > 0) 
    # condition 2: anchor boxes with iou above a threshold with any of the gt bboxes
    positive_anc_mask = torch.logical_or(positive_anc_mask, iou_mat > pos_thresh)
    
    positive_anc_ind_sep = torch.where(positive_anc_mask)[0] # get separate indices in the batch
    # combine all the batches and get the idxs of the +ve anchor boxes
    positive_anc_mask = positive_anc_mask.flatten(start_dim=0, end_dim=1)
    positive_anc_ind = torch.where(positive_anc_mask)[0]
    
    # for every anchor box, get the iou and the idx of the
    # gt bbox it overlaps with the most
    max_iou_per_anc, max_iou_per_anc_ind = iou_mat.max(dim=-1)
    max_iou_per_anc = max_iou_per_anc.flatten(start_dim=0, end_dim=1)
    
    # get iou scores of the +ve anchor boxes
    GT_conf_scores = max_iou_per_anc[positive_anc_ind]
    
    # get gt classes of the +ve anchor boxes
    
    # EDIT: Jonas commented this section out
    # # expand gt classes to map against every anchor box
    # gt_classes_expand = gt_classes_all.view(B, 1, N).expand(B, tot_anc_boxes, N)
    # # for every anchor box, consider only the class of the gt bbox it overlaps with the most
    # GT_class = torch.gather(gt_classes_expand, -1, max_iou_per_anc_ind.unsqueeze(-1)).squeeze(-1)
    # # combine all the batches and get the mapped classes of the +ve anchor boxes
    # GT_class = GT_class.flatten(start_dim=0, end_dim=1)
    # GT_class_pos = GT_class[positive_anc_ind]
    
    # get gt bbox coordinates of the +ve anchor boxes
    
    # expand all the gt bboxes to map against every anchor box
    gt_bboxes_expand = gt_bboxes_all.view(B, 1, N, 4).expand(B, tot_anc_boxes, N, 4)
    # for every anchor box, consider only the coordinates of the gt bbox it overlaps with the most
    GT_bboxes = torch.gather(gt_bboxes_expand, -2, max_iou_per_anc_ind.reshape(B, tot_anc_boxes, 1, 1).repeat(1, 1, 1, 4))
    # combine all the batches and get the mapped gt bbox coordinates of the +ve anchor boxes
    GT_bboxes = GT_bboxes.flatten(start_dim=0, end_dim=2)
    GT_bboxes_pos = GT_bboxes[positive_anc_ind]
    
    # get coordinates of +ve anc boxes
    anc_boxes_flat = anc_boxes_all.flatten(start_dim=0, end_dim=-2).to(device) # flatten all the anchor boxes # EDIT
    positive_anc_coords = anc_boxes_flat[positive_anc_ind]
    
    # calculate gt offsets
    GT_offsets = calc_gt_offsets(positive_anc_coords, GT_bboxes_pos)
    
    # get -ve anchors
    
    # condition: select the anchor boxes with max iou less than the threshold
    negative_anc_mask = (max_iou_per_anc < neg_thresh)
    negative_anc_ind = torch.where(negative_anc_mask)[0]
    # sample -ve samples to match the +ve samples
    negative_anc_ind = negative_anc_ind[torch.randint(0, negative_anc_ind.shape[0], (positive_anc_ind.shape[0],))]
    negative_anc_coords = anc_boxes_flat[negative_anc_ind]
    
    # EDIT: Jonas removed GT_class_pos (right after GT_offsets)
    return positive_anc_ind, negative_anc_ind, GT_conf_scores, GT_offsets, \
         positive_anc_coords, negative_anc_coords, positive_anc_ind_sep

def unpad_BB(BB_padded_tensor): # Jonas
    
    BB_padded_np = BB_padded_tensor.detach().cpu().numpy()
    
    
    if np.min(BB_padded_np) == -1:
        idx = np.argmin(BB_padded_np,0)[0]
    else:
        # if there is no padding, then return the input
        return BB_padded_tensor
    
        
    BB_unpadded_np = BB_padded_np[0:idx,:]
    
    BB_unpadded_tensor = torch.from_numpy(BB_unpadded_np)
    
    return BB_unpadded_tensor

# # -------------- Visualization utils ----------------

def display_img(img_data, fig, axes):
    for i, img in enumerate(img_data):
        if type(img) == torch.Tensor:
            img = img.permute(1, 2, 0).numpy()
        axes[i].imshow(img.astype('uint8'))
    
    return fig, axes

def display_bbox(bboxes, fig, ax, classes=None, in_format='xyxy', color='y', line_width=3):
    if type(bboxes) == np.ndarray:
        bboxes = torch.from_numpy(bboxes)
    if classes:
        assert len(bboxes) == len(classes)
    # convert boxes to xywh format
    bboxes = ops.box_convert(bboxes, in_fmt=in_format, out_fmt='xywh')
    c = 0
    for box in bboxes:
        x, y, w, h = box.numpy()
        # display bounding box
        rect = patches.Rectangle((x, y), w, h, linewidth=line_width, edgecolor=color, facecolor='none')
        ax.add_patch(rect)
        # display category
        if classes:
            if classes[c] == 'pad':
                continue
            ax.text(x + 5, y + 20, classes[c], bbox=dict(facecolor='yellow', alpha=0.5))
        c += 1
        
    return fig, ax

def display_grid(x_points, y_points, fig, ax, special_point=None):
    # plot grid
    for x in x_points:
        for y in y_points:
            ax.scatter(x, y, color="w", marker='+')
            
    # plot a special point we want to emphasize on the grid
    if special_point:
        x, y = special_point
        ax.scatter(x, y, color="red", marker='+')
        
    return fig, ax


In [34]:
# %%%%%%%%%%  JONAS  %%%%%%%%%% #

def flatten_modules(model):
    # Flatten modules of a pre-designed NN
    modules = []
    for m in list(model.children()):
        if isinstance(m, (nn.Sequential, nn.ModuleList, nn.ModuleDict)):
            modules.extend(flatten_modules(m))
        elif isinstance(m, nn.Module):
            if isinstance(m, models.resnet.Bottleneck):
                for sub_module in flatten_modules(m):
                    modules.append(sub_module)
            else:
                modules.append(m)
    return modules

def init_RPN(name):
    # Initializes Region Proposal Network
    img_size = (256,256)
    out_size = (64,64)
    out_channels = 256
    
    RPN = RegionProposalNetwork(img_size, out_size, out_channels)
    RPN = RPN.to(device)
       
    
    if name.find('hyper') != -1:
        
        if name.find('hyperV2') != -1:
            RPN.feature_extractor = FeatureExtractor_hyperV2().to(device)
            
        elif name.find('resnet50hyper') != -1:
            RPN.feature_extractor = FeatureExtractor_resnet50hyper().to(device)
            
        else:
            RPN.feature_extractor = FeatureExtractor_hyper().to(device)
        
    
    if name.find('random') != -1:
        RPN.feature_extractor = FeatureExtractor_random().to(device)   
    
    return RPN
   
def init_file_root_and_list(name):
    
    # For name = 'example'
    # 'example'          will use the dataset: augmentedData
    # 'example_big'      will use the dataset: augmentedData_big
    # 'example_hard'     will use the dataset: augmentedData_hard
    # 'example_hard_big' will use the dataset: augmentedData_hard_big
    
    file_root = "/scratch/s204219/augmentedData"
    

    hard_logic = name.find('hard')
    
    if hard_logic != -1:
        file_root = file_root + '_hard'
    else:
        file_root = file_root
    #######################################
    big_logic = name.find('big')
    
    if big_logic != -1:
        file_list = np.arange(dsn,dsn+100000)
        file_root = file_root + '_big'
    else:
        file_list = np.arange(dsn,dsn+40000)
        file_root = file_root
    
    return file_root, file_list

def load_data_split(name):
    
    file_root = "/scratch/s204219/augmentedData"
    
    hard_logic = name.find('hard')
    
    if hard_logic != -1:
        file_root = file_root + '_hard'
   
    #######################################
    big_logic = name.find('big')
    
    if big_logic != -1:
        file_root = file_root + '_big'
    
    print_str = file_root.split('/')[-1]
    print('Using dataset: ' + print_str)
    
    train_indices = np.load(f'{file_root}/data_split/Index_train.npy')
    val_indices   = np.load(f'{file_root}/data_split/Index_val.npy')
    test_indices  = np.load(f'{file_root}/data_split/Index_test.npy')

    return train_indices, val_indices, test_indices

def create_train_val_test_split(file_list, val_ratio, test_ratio, name, save_path, save=False):
    
            
    indices = file_list
    dataset_size = len(indices)

    val_split  = int(np.round(val_ratio  * dataset_size))
    test_split = int(np.round(test_ratio * dataset_size))
    train_split= int(dataset_size-(val_split+test_split))

    train_indices = np.array(indices[0:train_split])
    val_indices   = np.array(indices[train_split:train_split+val_split])
    test_indices  = np.array(indices[(train_split+val_split):])
    
    if save == True:
        np.save(f'{save_path}/Index_train_{name}',train_indices)
        np.save(f'{save_path}/Index_val_{name}',val_indices)
        np.save(f'{save_path}/Index_test_{name}',test_indices)

    
    return train_indices, val_indices, test_indices

def train_model_new(name='example', n_epochs=50, dataset_size=40000, val_ratio=1/7., test_ratio=1/7., save=False, cpm=False):
    
    # Trains a given RPN model
    
    save_path = f"/scratch/s204219/RPN_new/{name}"
    if save == True:
        if os.path.exists(save_path) and cpm == False:
            choice = input(f"{name} already exists. Do you want to overwrite it? (y/n)")
            if choice.lower() != 'y':
                return
            print('overwriting')
        if not os.path.exists(save_path):
            os.makedirs(save_path)
    
    batch_size = 50
    
    if experiment == False:
        np.set_printoptions(precision=3)

        file_root, file_list = init_file_root_and_list(name=name)

        torch.manual_seed(1)
        np.random.seed(1)


        # Randomly select numbers from the array
        file_list = np.random.choice(file_list, size=dataset_size, replace=False)    

        # Creating data indices for training and val splits:
        use_old_split = False
        if dataset_size == 40000 or dataset_size == 100000:
            use_old_split = True

        if use_old_split == False:
            print('Creating random data split')
            train_indices, val_indices, _ = create_train_val_test_split(file_list=file_list, val_ratio=val_ratio, \
                                                                        test_ratio=test_ratio, name=name, \
                                                                        save_path=save_path, save=save)
        if use_old_split == True:
            train_indices, val_indices, test_indices = load_data_split(name=name)
            np.save(f'{save_path}/Index_train_{name}',train_indices)
            np.save(f'{save_path}/Index_val_{name}',val_indices)
            np.save(f'{save_path}/Index_test_{name}',test_indices)


    ##############
    # For the experiment: Will more data lead to better performance, along with how much data is needed?
    
    if experiment == True:
        file_root = '/scratch/s204219/augmentedData_hard_big'
        train_indices = np.load('/scratch/s204219/augmentedData_hard_big/data_split/Index_train.npy')
        val_indices   = np.load('/scratch/s204219/augmentedData_hard_big/data_split/Index_val.npy')
        test_indices  = np.load('/scratch/s204219/augmentedData_hard_big/data_split/Index_test.npy')

        n_train = len(train_indices)
        print(n_train)
        name_split = name.split('_')[1]

#         number_dict = {'one_hard_big': 1, 'two_hard_big': 2, 'three_hard_big': 3, 'four_hard_big': 4,\
#                        'five_hard_big': 5, 'six_hard_big': 6, 'seven_hard_big': 7, 'eight_hard_big': 8}
        number_dict = {'one': 1, 'two': 2, 'three': 3, 'four': 4,\
                       'five': 5, 'six': 6, 'seven': 7, 'eight': 8}
        
        value = number_dict[name_split]
        
        # since data_split is 8:1:1
        n_train = int(np.round((value*n_train)/10))
        print(n_train)
        train_indices=train_indices[:n_train]

        np.save(f'{save_path}/Index_train_{name}',train_indices)
        np.save(f'{save_path}/Index_val_{name}',val_indices)
        np.save(f'{save_path}/Index_test_{name}',test_indices)
    
    ##############
    
    dataset_train    = UXO_dataset(file_list=train_indices, file_root=file_root)
    dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=False)
    n_train = len(dataset_train)
    
    dataset_val      = UXO_dataset(file_list=val_indices, file_root=file_root)
    dataloader_val   = DataLoader(dataset_val, batch_size=batch_size, shuffle=False)
    n_val = len(dataset_val)
    
    
  
    RPN = init_RPN(name)       
    
    learning_rate = 1e-3
    optimizer = optim.Adam(RPN.parameters(), lr=learning_rate)
    
    # LOAD LAST CHECKPOINT IF choose_previous_model == True
    choose_previous_model = cpm
    previous_model_file_root = save_path
    
    if choose_previous_model == True:
        # Choose which model to start from
        previous_model = f'Last_RPN_{name}'#f"model_weights_at_epoch_{int(epoch_number)}"
        checkpoint = torch.load(f"{previous_model_file_root}/{previous_model}")
        RPN.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        e = checkpoint['epoch'] + 1
        loss = checkpoint['loss']
        print(f"Using previous model with name: \n{previous_model}\n")
        
        loss_list = list(np.load(f'{save_path}/Loss_train_{name}.npy'))
        loss_list_val = list(np.load(f'{save_path}/Loss_val_{name}.npy'))
        
        best_loss = np.min(loss_list_val)
        choose_previous_model_factor = 2
    else:
        loss_list = []
        loss_list_val = []
        choose_previous_model_factor = 1
        e = 0        

    for e in range(e,e+n_epochs):
        RPN.train()
        print('Epoch '+str(e+1)+': ')
        total_loss = 0
        count = 0
        for img_batch, gt_bboxes_batch in dataloader_train:
            img_batch = img_batch.to(device)
            gt_bboxes_batch = gt_bboxes_batch.to(device)
            
            # forward pass
            loss,_,_,_ = RPN(img_batch, gt_bboxes_batch)

            # backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            


            count += 1

        avg_loss = total_loss/n_train
        print('Avg. train loss: ' + str(np.round(avg_loss,2)))

        loss_list.append(avg_loss)

        total_loss_val = 0
        
        
        # val
        with torch.no_grad():
            
            torch.manual_seed(1)
            np.random.seed(1)
            
            RPN.eval()
            for img_batch, gt_bboxes_batch in dataloader_val:
                img_batch = img_batch.to(device)
                gt_bboxes_batch = gt_bboxes_batch.to(device)

                # forward pass
                loss_val,_,_,_ = RPN(img_batch, gt_bboxes_batch)

                
                total_loss_val += loss_val.item()

            #print('val loss: ' + str(total_loss_val))


            
            avg_loss_val = total_loss_val/n_val
            if e == 0:
                best_loss = avg_loss_val

            if avg_loss_val <= best_loss:
                best_loss = avg_loss_val

            print('Avg. val loss: ' + str(np.round(avg_loss_val,2)))

            loss_list_val.append(avg_loss_val)

        if e > 0 and save == True:                          
            if avg_loss_val <= best_loss:
                print('saving')
                torch.save({'epoch': e,
                                'model_state_dict': RPN.state_dict(),
                                'optimizer_state_dict':optimizer.state_dict(),
                                'loss': total_loss},
                                f'{save_path}/Best_RPN_{name}') 

            if e == choose_previous_model_factor*n_epochs-1:
                print('saving last model')
                torch.save({'epoch': e,
                                'model_state_dict': RPN.state_dict(),
                                'optimizer_state_dict':optimizer.state_dict(),
                                'loss': total_loss},
                                f'{save_path}/Last_RPN_{name}') 
    
    
    if save == True:
        np.save(f'{save_path}/Loss_train_{name}',loss_list)
        np.save(f'{save_path}/Loss_val_{name}',loss_list_val)

    return 

def train_model(name='example', n_epochs=50, dataset_size=40000, val_ratio=1/7., test_ratio=1/7., save=False, cpm=False):
    save_path = f"/scratch/s204219/RPN/{name}"
    if save == True:
        if os.path.exists(save_path) and cpm == False:
            choice = input(f"{name} already exists. Do you want to overwrite it? (y/n)")
            if choice.lower() != 'y':
                return
            print('overwriting')
        if not os.path.exists(save_path):
            os.makedirs(save_path)
    
    batch_size = 50

    np.set_printoptions(precision=3)
    
    file_root = "/scratch/s204219/augmentedData"
    file_list = np.arange(dsn,dsn+40000) 
    np.random.seed(1)
    
    # Randomly select numbers from the array
    file_list = np.random.choice(file_list, size=dataset_size, replace=False)
    
    dataset = UXO_dataset(file_list=file_list, file_root=file_root)
    
    

    # Creating data indices for training and val splits:
    dataset_size = len(dataset)
    
    use_old_split = False
    if dataset_size == 40000:
        use_old_split = True
    
    train_indices, val_indices, _ = create_train_val_test_split(file_list=file_list, val_ratio=val_ratio, \
                                                                test_ratio=test_ratio, name=name, \
                                                                save_path=save_path, save=save, \
                                                                use_old_split=use_old_split)
    ##############
        
#     train_indices = np.load('/scratch/s204219/RPN/big/Index_train_big.npy')
#     val_indices   = np.load('/scratch/s204219/RPN/big/Index_val_big.npy')
#     test_indices  = np.load('/scratch/s204219/RPN/big/Index_test_big.npy')
    
#     n_train = len(train_indices)
#     n_train = int(np.round(n_train/5)) # if name == 'one'
#     n_train = int(np.round((3*n_train)/5)) # if name == 'three'

#     train_indices=train_indices[:n_train]
    
#     np.save(f'{save_path}/Index_train_{name}',train_indices)
#     np.save(f'{save_path}/Index_val_{name}',val_indices)
#     np.save(f'{save_path}/Index_test_{name}',test_indices)
    
    ##############
    
    
    n_train = len(train_indices)
    n_val = len(val_indices)
    print('Number of training data: ' + str(n_train))
    print('Number of validation data: ' + str(n_val))

    # Creating PT data samplers and loaders:
    train_sampler = SubsetRandomSampler(np.argsort(train_indices)) # train_indices - dsn   if dataset_size = 40000 
    valid_sampler = SubsetRandomSampler(np.argsort(val_indices)) # val_indices - dsn   if dataset_size = 40000 

    dataloader_train = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 
                                               sampler=train_sampler)
    dataloader_val = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                    sampler=valid_sampler)
    
    #img_size = (256,256)
    #out_size = (64,64)
    #out_channels = 256

    #RPN = RegionProposalNetwork_hyper(img_size, out_size, out_channels)
    
    RPN = init_RPN(name)       
    
    learning_rate = 1e-3
    optimizer = optim.Adam(RPN.parameters(), lr=learning_rate)
    
    # LOAD LAST CHECKPOINT IF choose_previous_model == True
    
    choose_previous_model = cpm
    previous_model_file_root = save_path
    if choose_previous_model == True:
        # Choose which model to start from
        previous_model = f'Last_RPN_{name}'#f"model_weights_at_epoch_{int(epoch_number)}"
        checkpoint = torch.load(f"{previous_model_file_root}/{previous_model}")
        RPN.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        e = checkpoint['epoch'] + 1
        loss = checkpoint['loss']
        print(f"Using previous model with name: \n{previous_model}\n")
    else:
        e = 0 

    

    if choose_previous_model == True:
        loss_list = list(np.load(f'{save_path}/Loss_train_{name}.npy'))
        loss_list_val = list(np.load(f'{save_path}/Loss_val_{name}.npy'))
        
        best_loss = np.min(loss_list_val)
        choose_previous_model_factor = 2
        
    
    if choose_previous_model == False:
        loss_list = []
        loss_list_val = []
        choose_previous_model_factor = 1

    for e in range(e,e+n_epochs):
        RPN.train()
        print('Epoch '+str(e+1)+': ')
        total_loss = 0
        count = 0
        for img_batch, gt_bboxes_batch in dataloader_train:
            img_batch = img_batch.to(device)
            gt_bboxes_batch = gt_bboxes_batch.to(device)
            
            # forward pass
            loss,_,_,_ = RPN(img_batch, gt_bboxes_batch)

            # backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            


            count += 1

        avg_loss = total_loss/n_train
        print('Avg. train loss: ' + str(np.round(avg_loss,2)))

        loss_list.append(avg_loss)

        total_loss_val = 0
        
        # val
        with torch.no_grad():
            RPN.eval()
            for img_batch, gt_bboxes_batch in dataloader_val:
                img_batch = img_batch.to(device)
                gt_bboxes_batch = gt_bboxes_batch.to(device)

                # forward pass
                loss_val,_,_,_ = RPN(img_batch, gt_bboxes_batch)

                
                total_loss_val += loss_val.item()

            #print('val loss: ' + str(total_loss_val))


            
            avg_loss_val = total_loss_val/n_val
            if e == 0:
                best_loss = avg_loss_val

            if avg_loss_val <= best_loss:
                best_loss = avg_loss_val

            print('Avg. val loss: ' + str(np.round(avg_loss_val,2)))

            loss_list_val.append(avg_loss_val)

        if e > 0 and save == True:                          
            if avg_loss_val <= best_loss:
                print('saving')
                torch.save({'epoch': e,
                                'model_state_dict': RPN.state_dict(),
                                'optimizer_state_dict':optimizer.state_dict(),
                                'loss': total_loss},
                                f'{save_path}/Best_RPN_{name}') 

            if e == choose_previous_model_factor*n_epochs-1:
                print('saving last model')
                torch.save({'epoch': e,
                                'model_state_dict': RPN.state_dict(),
                                'optimizer_state_dict':optimizer.state_dict(),
                                'loss': total_loss},
                                f'{save_path}/Last_RPN_{name}') 
    
    
    if save == True:
        np.save(f'{save_path}/Loss_train_{name}',loss_list)
        np.save(f'{save_path}/Loss_val_{name}',loss_list_val)

    return 

def test_model(name='temp',Best_or_Last='Best'): 
        
    # Tests a given RPN model
    
    previous_model_path = f"/scratch/s204219/RPN_new/{name}"
    #previous_model_path = f"/scratch/s204219/RPN/{name}"
    previous_model = f'{previous_model_path}/{Best_or_Last}_RPN_{name}'
    batch_size = 50   
    
    file_root_test, _ = init_file_root_and_list(name=name)
    file_list_test = np.load(f'{previous_model_path}/Index_test_{name}.npy')
    
    torch.manual_seed(seed=1)
    np.random.seed(seed=1)
    
    dataset_test    = UXO_dataset(file_list=file_list_test, file_root=file_root_test)
    dataloader_test = DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=False)  
    
    n_batches = len(dataloader_test) 
    
    RPN = init_RPN(name)
    
    checkpoint = torch.load(previous_model)
    RPN.load_state_dict(checkpoint['model_state_dict'])
    
    total_loss_test = 0
    k = 1
    print('Beginning test')
    with torch.no_grad():
        
        torch.manual_seed(seed=1)
        np.random.seed(seed=1)
    
        RPN.eval()
        for img_batch, gt_bboxes_batch in dataloader_test:
            img_batch = img_batch.to(device)
            gt_bboxes_batch = gt_bboxes_batch.to(device)

            # forward pass
            loss_test,_,_,_ = RPN(img_batch, gt_bboxes_batch)

            total_loss_test += loss_test.item()
            
            
            if k % 10 == 0:
                print(f'Batch {k}/{n_batches}')
            
            k+=1
            
        
    avg_loss_test = total_loss_test/len(dataset_test)
    print('Avg. test loss: ' + str(avg_loss_test))
    
    np.save(f'{previous_model_path}/Loss_test_{Best_or_Last}_{name}',avg_loss_test)

    
    return 

def tune_model(name='temp',Best_or_Last='Best',train_or_val_or_test = 'val',conf_thresh_arr=[0.6,0.7],nms_thresh_arr=[0.1,0.2],save=False):
    
    # Tunes a given RPN model. Rather, it computes the confusion matrice
    
    # Confusion [conf_thresh,nms_thresh,2,2]
    # Calculates the best values for confidence- and NMS thresholds from the validation dataset
    
    # previous_model_path = f"/scratch/s204219/RPN/{name}"
    previous_model_path = f"/scratch/s204219/RPN_new/{name}"
    previous_model = f'{previous_model_path}/{Best_or_Last}_RPN_{name}' 
    
    if save == True:
        if train_or_val_or_test == 'val':
            save_path = f'{previous_model_path}/Confusion_{Best_or_Last}_{name}'
            save_path_conf = f'{previous_model_path}/conf_thresh_arr_{Best_or_Last}_{name}'
            save_path_nms = f'{previous_model_path}/nms_thresh_arr_{Best_or_Last}_{name}'
        if train_or_val_or_test == 'test':
            save_path = f'{previous_model_path}/Confusion_test_{Best_or_Last}_{name}'
            save_path_conf = f'{previous_model_path}/conf_thresh_arr_test_{Best_or_Last}_{name}'
            save_path_nms = f'{previous_model_path}/nms_thresh_arr_test_{Best_or_Last}_{name}'
#         if os.path.exists(f'{save_path}.npy'):
#             choice = input(f"Confusion has already been calculated for {name}. Do you want to overwrite it? (y/n)")
#             if choice.lower() != 'y':
#                 return

    
    file_root, _ = init_file_root_and_list(name=name)
    file_list = np.load(f'{previous_model_path}/Index_{train_or_val_or_test}_{name}.npy')
    
    n = len(file_list)
    batch_size = 50
    dataset_test = UXO_dataset(file_list=file_list, file_root=file_root)
    dataloader_test = DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=False, num_workers=2)  
    
    n_batches = int(np.ceil(len(dataset_test)/batch_size))
    
    img_size = (256,256)
    out_size = (64,64)
    out_channels = 256
    
    height_scale_factor= img_size[0]/out_size[0]
    width_scale_factor = img_size[1]/out_size[1]
    
    RPN = init_RPN(name)
    
    
    checkpoint = torch.load(previous_model)
    RPN.load_state_dict(checkpoint['model_state_dict'])
    
    n_anc_boxes_total = RPN.n_anc_boxes_total
    pos_thresh = 0.3
    
    
    
    confusion = torch.zeros((len(conf_thresh_arr),len(nms_thresh_arr),2,2),device=device)
    count_1 = 1
    number_of_combinations = len(conf_thresh_arr)*len(nms_thresh_arr)
    for j,nms_thresh in enumerate(nms_thresh_arr):
        
        for i,conf_thresh in enumerate(conf_thresh_arr):
            
            print(f'Beginning conf- and nms threshold combination number {count_1}/{number_of_combinations}')
            tp, fp, fn, tn = 0, 0, 0, 0
            count = 1
            

            for img_batch, gt_bboxes_batch in dataloader_test:


                img_batch = img_batch.to(device)
                gt_bboxes_batch = gt_bboxes_batch.to(device)



                gt_bboxes_proj = project_bboxes(gt_bboxes_batch, width_scale_factor, height_scale_factor, mode='p2a')

                proposals_final, conf_scores_final, feature_map = RPN.inference(img_batch,conf_thresh=conf_thresh,nms_thresh=nms_thresh)

                tp_batch,fp_batch,fn_batch,tn_batch = \
                calc_conf_mat(gt_bboxes_proj,proposals_final,pos_thresh,n_anc_boxes_total,list_or_sum='sum')

                tp+=tp_batch
                fp+=fp_batch
                fn+=fn_batch
                tn+=tn_batch

                if count % 50 == 0:
                    print(f'Batch {count}/{n_batches} completed')

                count += 1
            
            print('Done')
            count_1+=1

            # Averaging
            tp, fp, fn, tn = tp/n, fp/n, fn/n, tn/n

            confusion[i,j,:,:] = torch.tensor([[tp,fp],[fn,tn]],device=device)
    
    
    confusion_np = confusion.detach().cpu().numpy()
    
    if save == True:
        
        conf_thresh_arr_np = np.array(conf_thresh_arr) 
        nms_thresh_arr_np  = np.array(nms_thresh_arr)

        np.save(f'{save_path}',confusion_np)
        np.save(f'{save_path_conf}',conf_thresh_arr_np)
        np.save(f'{save_path_nms}',nms_thresh_arr_np)
        
        if train_or_val_or_test == 'val':
            print('Done tuning')
        if train_or_val_or_test == 'test':
            print('Done testing')
            
            
        
        return
    
    if train_or_val_or_test == 'val':
        print('Done tuning')
    if train_or_val_or_test == 'test':
        print('Done testing')
    
    return confusion_np

def find_conf_and_nms_thresh(name='baseline', Best_or_Last='Best'):
    
    # Finds optimal confidence- and nms thresholds for a given RPN model
    print_option = False
    previous_model_path = f"/scratch/s204219/RPN_new/{name}"
    save_path = f'{previous_model_path}/Confusion_{Best_or_Last}_{name}'

    confusion = np.load(f'{save_path}.npy')

    conf_thresh_arr = np.load(f'{previous_model_path}/conf_thresh_arr_{Best_or_Last}_{name}.npy')
    nms_thresh_arr = np.load(f'{previous_model_path}/nms_thresh_arr_{Best_or_Last}_{name}.npy')

    #fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    number_of_proposals = []
    conf_thresh_optimal = []

    for i, nms_thresh in enumerate(nms_thresh_arr):

        tp = confusion[:,i,0,0]
        #axes[0].plot(conf_thresh_arr,tp,marker='.',label = f'NMS thresh = {np.round(nms_thresh,2)}')
        fp = confusion[:,i,0,1]

        number_of_proposals.append(tp+fp)

        fn = confusion[:,i,1,0]
        #axes[1].plot(conf_thresh_arr,fn,marker='.',label = f'NMS thresh = {np.round(nms_thresh,2)}')

        tp_mask = np.logical_and(tp>=0.75,tp<1.0)
        fn_mask = fn<=0.25
        mask = np.logical_and(tp_mask,fn_mask)
        #print(conf_thresh_arr[mask])
        possible_conf_thresh = conf_thresh_arr[mask]
        if len(possible_conf_thresh) != 0:
            best_conf_thresh = possible_conf_thresh[-1]
            #print(best_conf_thresh)

            best_conf_thresh_idx = np.where(best_conf_thresh == conf_thresh_arr)[0][0]
            if print_option == True:
                print(best_conf_thresh)
            conf_thresh_optimal.append(best_conf_thresh)

            number_of_proposals[i] = number_of_proposals[i][best_conf_thresh_idx]

        if len(possible_conf_thresh) == 0:
            number_of_proposals[i] = float('inf')
            conf_thresh_optimal.append(float('inf'))
        if print_option == True:
            print(number_of_proposals[i])
            print('---------------')

    best_nms_thresh_idx = np.argmin(number_of_proposals)
    best_nms_thresh = nms_thresh_arr[best_nms_thresh_idx]
    best_conf_thresh = conf_thresh_optimal[best_nms_thresh_idx]

    if print_option == True:
        print('Optimal conf thresh: ' + str(np.round(best_conf_thresh,2)))
        print('Optimal nms thresh: ' + str(np.round(best_nms_thresh,2)))
    
    return best_conf_thresh, best_nms_thresh

def test_model_confusion(name='temp',Best_or_Last='Best',conf_thresh_arr=[0.6,0.7],nms_thresh_arr=[0.1,0.2]):
    
    save=True
    train_or_val_or_test='test'
    
    confusion = tune_model(name=f'{name}',Best_or_Last=f'{Best_or_Last}',\
                           train_or_val_or_test=f'{train_or_val_or_test}',\
                           conf_thresh_arr=conf_thresh_arr,\
                           nms_thresh_arr=nms_thresh_arr,save=save)
    
    
    
    # confusion = confusion[0,0,:,:]
    
    return confusion #tp, fp, fn, tn

def print_results(names=['baseline','hyper']):
    
    Best_or_Last = 'Best'
    
    test_loss_all = []
    
    best_conf_thresh = []
    best_nms_thresh  = []
    
    no_prop_all = []
    recall_all = []
    precision_all = []
    
    for name in names:
        
        previous_model_path = f"/scratch/s204219/RPN_new/{name}"
        save_path = f'{previous_model_path}/Confusion_test_{Best_or_Last}_{name}'

        test_loss = np.load(f'{previous_model_path}/Loss_test_Best_{name}.npy')
        
        confusion = np.load(f'{save_path}.npy')
        conf_thresh_arr = np.load(f'{previous_model_path}/conf_thresh_arr_test_{Best_or_Last}_{name}.npy')
        nms_thresh_arr = np.load(f'{previous_model_path}/nms_thresh_arr_test_{Best_or_Last}_{name}.npy')
        
        
        conf, nms = find_conf_and_nms_thresh(name=f'{name}',Best_or_Last='Best')
                
        i = np.where(nms == nms_thresh_arr)[0][0]
        j = np.where(conf == conf_thresh_arr)[0][0]
    
        tp = confusion[j,i,0,0]
        fp = confusion[j,i,0,1]
        fn = confusion[j,i,1,0]
        tn = confusion[j,i,1,1]
        
        no_prop   = tp+fp
        recall    = tp/(tp+fn)
        precision = tp/(tp+fp)
        
        test_loss_all.append(test_loss)
    
        best_conf_thresh.append(conf)
        best_nms_thresh.append(nms)

        no_prop_all.append(no_prop)
        recall_all.append(recall)
        precision_all.append(precision)
    
    mat = np.zeros((6,len(names)))
    mat[0,:] = np.array(test_loss_all)
    mat[1,:] = np.array(best_conf_thresh)
    mat[2,:] = np.array(best_nms_thresh)
    mat[3,:] = np.array(no_prop_all)
    mat[4,:] = np.array(recall_all)
    mat[5,:] = np.array(precision_all)
    
    
    print(np.round(mat,2))
    
    return mat
        

# PLOTTING

def plot_pos_neg_ABs(): 
    # %% PlOT POSITIVE AND NEGATIVE AB'S 
    img_size = (256,256)
    out_size = (64,64)
    out_channels = 256

    height_scale_factor= img_size[0]/out_size[0]
    width_scale_factor = img_size[1]/out_size[1]


    file_root = "/scratch/s204219/augmentedData"
    file_list = np.arange(dsn,dsn+400)

    dataset = UXO_dataset(file_list=file_list, file_root=file_root)
    batch_size = 16

    # ONLY WORKS IF num_workers = 0 
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=0) 

    img_batch, gt_bboxes_batch = next(iter(dataloader))


    img_data_all = img_batch[11:13].to(device)
    gt_bboxes_all = gt_bboxes_batch[11:13].to(device)

    out_h,out_w = 64,64

    anc_pts_x, anc_pts_y = gen_anc_centers(out_size=(out_h, out_w))
    anc_scales = [8,10]#[2, 4, 6]
    anc_ratios = [1]
    n_anc_boxes = len(anc_scales) * len(anc_ratios) # number of anchor boxes for each anchor point

    anc_base = gen_anc_base(anc_pts_x, anc_pts_y, anc_scales, anc_ratios, (out_h, out_w))

    anc_boxes_all = anc_base.repeat(img_data_all.size(dim=0), 1, 1, 1, 1)



    pos_thresh = 0.7
    neg_thresh = 0.3

    # project gt bboxes onto the feature map
    gt_bboxes_proj = project_bboxes(gt_bboxes_all, width_scale_factor, height_scale_factor, mode='p2a')
    positive_anc_ind, negative_anc_ind, GT_conf_scores, \
    GT_offsets, positive_anc_coords, \
    negative_anc_coords, positive_anc_ind_sep = get_req_anchors(anc_boxes_all, gt_bboxes_proj, pos_thresh, neg_thresh)

    # project anchor coords to the image space
    pos_anc_proj = project_bboxes(positive_anc_coords, width_scale_factor, height_scale_factor, mode='a2p')
    neg_anc_proj = project_bboxes(negative_anc_coords, width_scale_factor, height_scale_factor, mode='a2p')

    # grab +ve and -ve anchors for each image separately

    anc_idx_1 = torch.where(positive_anc_ind_sep == 0)[0]
    anc_idx_2 = torch.where(positive_anc_ind_sep == 1)[0]

    pos_anc_1 = pos_anc_proj[anc_idx_1]
    pos_anc_2 = pos_anc_proj[anc_idx_2]

    neg_anc_1 = neg_anc_proj[anc_idx_1]
    neg_anc_2 = neg_anc_proj[anc_idx_2]

    nrows, ncols = (1, 2)
    fig, axes = plt.subplots(nrows, ncols, figsize=(16, 8))

    fig, axes = display_img(img_data_all.detach().cpu(), fig, axes)

    # plot groundtruth bboxes
    fig, _ = display_bbox(gt_bboxes_all[0].detach().cpu(), fig, axes[0])
    fig, _ = display_bbox(gt_bboxes_all[1].detach().cpu(), fig, axes[1])

    # plot positive anchor boxes
    fig, _ = display_bbox(pos_anc_1.detach().cpu(), fig, axes[0], color='g')
    fig, _ = display_bbox(pos_anc_2.detach().cpu(), fig, axes[1], color='g')

    # plot negative anchor boxes
    fig, _ = display_bbox(neg_anc_1.detach().cpu(), fig, axes[0], color='r')
    fig, _ = display_bbox(neg_anc_2.detach().cpu(), fig, axes[1], color='r')
    
    return

def plot_proposals(name='temp',Best_or_Last='Best',train_or_val_or_test='test',conf_thresh=0.6,nms_thresh=0.1,idx=0,number=4,no_proposals=10):
    
    previous_model_path = f"/scratch/s204219/RPN_new/{name}"
    previous_model = f'{previous_model_path}/{Best_or_Last}_RPN_{name}'    
    
    index = np.load(f'{previous_model_path}/Index_{train_or_val_or_test}_{name}.npy')
    
    #file_root  = "/scratch/s204219/augmentedData"   
    file_root,_ = init_file_root_and_list(name=name)
    
    file_list  = index[idx:idx+number*2]  
    # np.array([index[4],index[8]])
    
    dataset = UXO_dataset(file_list=file_list, file_root=file_root)
    
    RPN = init_RPN(name)
    
    width_scale_factor  = RPN.width_scale_factor
    height_scale_factor = RPN.height_scale_factor
    
    checkpoint = torch.load(previous_model)
    RPN.load_state_dict(checkpoint['model_state_dict'])
    
    n_anc_boxes_total = RPN.n_anc_boxes_total
    
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    # e = checkpoint['epoch']
    # loss = checkpoint['loss']
    
     
    
    img_batch,bb_batch = dataset[:]
    batch_size = len(bb_batch)

    gt_bboxes_proj = project_bboxes(bb_batch, width_scale_factor, height_scale_factor, mode='p2a')
    
    
    proposals_final, conf_scores_final, feature_map = RPN.inference(img_batch.to(device),conf_thresh=conf_thresh,nms_thresh=nms_thresh)
    for j in range(number):
        proposals_final[j*2] = proposals_final[j*2][:no_proposals,:]
        proposals_final[j*2+1] = proposals_final[j*2+1][:no_proposals,:]
        
        
    
    pos_thresh = 0.3
    tp,fp,fn,tn = calc_conf_mat(gt_bboxes_proj,proposals_final,pos_thresh,n_anc_boxes_total,list_or_sum='list')
    # print('Image 1 tp:' + str(tp[i*2].detach().cpu().numpy()) + '| Image 2 tp:' + str(tp[i*2+1].detach().cpu().numpy()))    
    
    for i in range(number):
                
        if proposals_final[i*2].shape == torch.Size([0, 4]):
            proposals_final[i*2] = torch.tensor([[0,0,0,0]],device=device)
        if proposals_final[i*2+1].shape == torch.Size([0, 4]):
            proposals_final[i*2+1] = torch.tensor([[0,0,0,0]],device=device)
            
        proposals_proj = project_bboxes(proposals_final[i*2], width_scale_factor, height_scale_factor, mode='a2p').detach().cpu(), \
                       project_bboxes(proposals_final[i*2+1], width_scale_factor, height_scale_factor, mode='a2p').detach().cpu()
        nrows, ncols = (1, 2)
        fig, axes = plt.subplots(nrows, ncols, figsize=(10, 5))

        fig, axes = display_img(img_batch[i*2:i*2+2], fig, axes)

        # plot proposals
        
        fig, _ = display_bbox(proposals_proj[0][:no_proposals,:], fig, axes[0])
        fig, _ = display_bbox(proposals_proj[1][:no_proposals,:], fig, axes[1])
        
        #fig, _ = display_bbox(proposals_proj[0], fig, axes[0])
        #fig, _ = display_bbox(proposals_proj[1], fig, axes[1])
        
        # add titles to the figures
        # f'tp={tp[j]}, ' + f'fp={fp[j]}'
        axes[0].set_title(f'tp={tp[i*2]}, ' + f'fp={fp[i*2]}')#, ' + f'fn={fn[i*2]}, ' + f'tn={tn[i*2]}')
        axes[1].set_title(f'tp={tp[i*2+1]}, ' + f'fp={fp[i*2+1]}')#, ' + f'fn={fn[i*2+1]}, ' + f'tn={tn[i*2+1]}' )
        
        plt.setp(axes, xticks=[], yticks=[])
    
    

    
    return 

def plot_train_val_test_proposals(name='temp',Best_or_Last='Best',conf_thresh=0.6,nms_thresh=0.2,idx=[0,0,0],number=2,no_proposals=10):
       
    previous_model_path = f"/scratch/s204219/RPN_new/{name}"
    #previous_model_path = f"/scratch/s204219/RPN/{name}"
    previous_model = f'{previous_model_path}/{Best_or_Last}_RPN_{name}'    
    
    index_train = np.load(f'{previous_model_path}/Index_train_{name}.npy')
    index_val   = np.load(f'{previous_model_path}/Index_val_{name}.npy')
    index_test  = np.load(f'{previous_model_path}/Index_test_{name}.npy')
 
    file_root,_ = init_file_root_and_list(name=name)
    
    file_list_train  = index_train[idx[0]:idx[0]+number] 
    file_list_val    = index_val[idx[1]:idx[1]+number] 
    file_list_test   = index_test[idx[2]:idx[2]+number] 

    dataset_train = UXO_dataset(file_list=file_list_train, file_root=file_root)
    dataset_val   = UXO_dataset(file_list=file_list_val, file_root=file_root)
    dataset_test  = UXO_dataset(file_list=file_list_test, file_root=file_root)

    RPN = init_RPN(name)

    checkpoint = torch.load(previous_model)
    RPN.load_state_dict(checkpoint['model_state_dict'])

    n_anc_boxes_total   = RPN.n_anc_boxes_total
    height_scale_factor = RPN.height_scale_factor
    width_scale_factor  = RPN.width_scale_factor

    img_height = RPN.img_height
    img_width  = RPN.img_width


    fig, axes = plt.subplots(number,3, figsize=(8, 8))

    for i in range(3):
        if i == 0:
            dataset = dataset_train
        if i == 1:
            dataset = dataset_val
        if i == 2:
            dataset = dataset_test

        img_batch,bb_batch = dataset[:] # [:] is very important

        gt_bboxes_proj = project_bboxes(bb_batch, width_scale_factor, height_scale_factor, mode='p2a')

        proposals_final, conf_scores_final, feature_map = RPN.inference(img_batch.to(device),conf_thresh=conf_thresh,nms_thresh=nms_thresh)
        for j in range(number):
            proposals_final[j] = proposals_final[j][:no_proposals,:]
        
        pos_thresh = 0.3    
        tp,fp,fn,tn = calc_conf_mat(gt_bboxes_proj,proposals_final,pos_thresh,n_anc_boxes_total,list_or_sum='list')

        for j in range(number):
            if proposals_final[j].shape == torch.Size([0, 4]):
                proposals_final[j] = torch.tensor([[0,0,0,0]],device=device)
            
            
            proposals_proj = project_bboxes(proposals_final[j], width_scale_factor, height_scale_factor, mode='a2p').detach().cpu()

            img = img_batch[j,:,:,:].permute(1, 2, 0).numpy()
            axes[j,i].imshow(img.astype('uint8'))

            # plot proposals
            fig, _ = display_bbox(proposals_proj[:no_proposals,:], fig, axes[j,i])
            fig, _ = display_bbox(bb_batch[j], fig, axes[j,i], color='r')

            # add titles to the figures
            if j == 0 and i == 0:
                axes[j,i].set_title('Train\n' +f'tp={tp[j]}, ' + f'fp={fp[j]}')
            if j == 0 and i == 1:
                axes[j,i].set_title('Validation\n' +f'tp={tp[j]}, ' + f'fp={fp[j]}')
            if j == 0 and i == 2:
                axes[j,i].set_title('Test\n' +f'tp={tp[j]}, ' + f'fp={fp[j]}')
            if j > 0:
                axes[j,i].set_title(f'tp={tp[j]}, ' + f'fp={fp[j]}')

    plt.setp(axes, xticks=[], yticks=[])
    plt.suptitle(f'Model: {name}\n{no_proposals} best proposals',fontsize = 20,y=1.05)

    
    return

def plot_train_val_loss(name='overfitted',n_epochs='all', move=0):
    
    #location = -1, 5
    
       
    
    path = f"/scratch/s204219/RPN_new/{name}"
    # path = f"/scratch/s204219/RPN/{name}"
    train_loss = np.load(f'{path}/Loss_train_{name}.npy') #
    val_loss   = np.load(f'{path}/Loss_val_{name}.npy')   #
    
    
    if n_epochs == 'all' or n_epochs > len(train_loss):
        n_epochs = len(train_loss)
    
    train_loss = train_loss[:n_epochs]
    val_loss   = val_loss[:n_epochs]
    
    if np.max(train_loss) >= np.max(val_loss):
        loss = train_loss
    else:
        loss = val_loss
        
    
    
    
    min_loss_epoch = np.argmin(loss) + 1 # add 1 to make it 1-indexed
    min_loss = loss[min_loss_epoch-1]
    
    max_loss = np.max(loss)
    
    location = [-3,2*(max_loss-min_loss)/5]
    
    if min_loss_epoch + location[0] <= 0.5:
        location[0]+=2
    
    if min_loss_epoch + location[0] >= n_epochs-20:
        location[0]-=n_epochs/8
        
    location[0]+=move
    
    epochs = range(1,n_epochs+1) #
    
    plt.figure(figsize=(8,3))
    
    plt.plot(epochs, train_loss, label = 'Training loss')
    plt.plot(epochs, val_loss, label = 'Validation loss')
    
    min_val_loss_epoch = np.argmin(val_loss) + 1 # add 1 to make it 1-indexed
    min_val_loss = np.min(val_loss)
    
    plt.plot(min_val_loss_epoch, min_val_loss, 'ro',label = 'Min. val. loss')#, label = 'Minimum validation loss')
    # 'ro' specifies the red dot marker
    
    arrow_props = dict(facecolor='red', edgecolor='red',
                       arrowstyle="-|>, head_width=0.5, head_length=1.5")
    
    
    plt.annotate(f'Epoch: {min_val_loss_epoch}\nLoss: {min_val_loss:.2f}', 
                 xy=(min_val_loss_epoch, min_val_loss), xycoords='data',
                 xytext=(min_val_loss_epoch+location[0], min_val_loss+location[1]), textcoords='data',
                 arrowprops=arrow_props,
                 bbox=dict(boxstyle="round", fc="white", ec="red", lw=1.5, alpha=0.9))
    
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend(loc='best')#'upper left')
    
    if n_epochs == 20:
        plt.xticks(range(0,n_epochs+1,2))
    elif n_epochs == 100:
        plt.xticks(range(0,n_epochs+1,10))
    
    #plt.yticks([10,20,40,60,80,100])
    plt.ylim(0,150)
    
    
    plt.title(f'Model: {name}\nTrain- and validation loss')
    
    
    return

def plot_confusion(name='temp', Best_or_Last='Best', nms_thresh='all'):
    previous_model_path = f"/scratch/s204219/RPN_new/{name}"
    save_path = f'{previous_model_path}/Confusion_{Best_or_Last}_{name}'
    
    confusion = np.load(f'{save_path}.npy')

    conf_thresh_arr = np.load(f'{previous_model_path}/conf_thresh_arr_{Best_or_Last}_{name}.npy')
    nms_thresh_arr = np.load(f'{previous_model_path}/nms_thresh_arr_{Best_or_Last}_{name}.npy')
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 8))
    
    
    if nms_thresh != 'all':
        i = np.argmin(abs(nms_thresh_arr-nms_thresh))
        nms_thresh = nms_thresh_arr[i]
        
        tp = confusion[:,i,0,0]
        axes[0,0].plot(conf_thresh_arr,tp,marker='.',label = f'NMS thresh = {np.round(nms_thresh,2)}',color=f'C{i}')
        
        tn = confusion[:,i,1,1]
        axes[1,1].plot(conf_thresh_arr,tn,marker='.',label = f'NMS thresh = {np.round(nms_thresh,2)}',color=f'C{i}')
        
        fn = confusion[:,i,1,0]
        axes[1,0].plot(conf_thresh_arr,fn,marker='.',label = f'NMS thresh = {np.round(nms_thresh,2)}',color=f'C{i}')
        
        fp = confusion[:,i,0,1]
        axes[0,1].plot(conf_thresh_arr,fp,marker='.',label = f'NMS thresh = {np.round(nms_thresh,2)}',color=f'C{i}')
    
    else:
        
        for i, nms_thresh in enumerate(nms_thresh_arr):

            tp = confusion[:,i,0,0]
            axes[0,0].plot(conf_thresh_arr,tp,marker='.',label = f'NMS thresh = {np.round(nms_thresh,2)}')

            tn = confusion[:,i,1,1]
            axes[1,1].plot(conf_thresh_arr,tn,marker='.',label = f'NMS thresh = {np.round(nms_thresh,2)}')

            fn = confusion[:,i,1,0]
            axes[1,0].plot(conf_thresh_arr,fn,marker='.',label = f'NMS thresh = {np.round(nms_thresh,2)}')

            fp = confusion[:,i,0,1]
            axes[0,1].plot(conf_thresh_arr,fp,marker='.',label = f'NMS thresh = {np.round(nms_thresh,2)}')
        
        
    axes[0,0].legend(loc='best')
    axes[0,0].set_xlabel('Confidence threshold ' + r'$ (\alpha)$',fontsize = 13)
    axes[0,0].set_ylabel('TP',fontsize = 15)
    axes[0,0].plot(conf_thresh_arr,[0.75]*len(conf_thresh_arr),color='black')
    
    #axes[1,1].legend(loc='best')
    axes[1,1].set_xlabel('Confidence threshold ' + r'$ (\alpha)$',fontsize = 13)
    axes[1,1].set_ylabel('TN',fontsize = 15)

    #axes[1,0].legend(loc='best')
    axes[1,0].set_xlabel('Confidence threshold ' + r'$ (\alpha)$',fontsize = 13)
    axes[1,0].set_ylabel('FN',fontsize = 15)
    axes[1,0].plot(conf_thresh_arr,[0.25]*len(conf_thresh_arr),color='black')

    #axes[0,1].legend(loc='best')
    axes[0,1].set_xlabel('Confidence threshold ' + r'$ (\alpha)$',fontsize = 13)
    axes[0,1].set_ylabel('FP',fontsize = 15)
#     plt.title(name)
    
    #axes[0,0].xticks(conf_thresh_arr)
    #axes[0,0].set_title('True positives')
    #axes[1,1].set_title('True negatives')
    #axes[1,0].set_title('False negatives')
    #axes[0,1].set_title('False positives')
    plt.suptitle(f'Model: {name}\nAverage confusion',fontsize = 20)
    
    plt.subplots_adjust(wspace=0.15, hspace=0.25)
    #fig.supxlabel('Confidence threshold')
    xticks_arr = np.arange(0.0,1.1,0.1) 
    plt.setp(axes, xticks=xticks_arr)
    
    plt.imshow
    
    
    return

def plot_tp_and_fn(name='temp', Best_or_Last='Best'):
    previous_model_path = f"/scratch/s204219/RPN_new/{name}"
    save_path = f'{previous_model_path}/Confusion_{Best_or_Last}_{name}'
    
    confusion = np.load(f'{save_path}.npy')

    conf_thresh_arr = np.load(f'{previous_model_path}/conf_thresh_arr_{Best_or_Last}_{name}.npy')
    nms_thresh_arr = np.load(f'{previous_model_path}/nms_thresh_arr_{Best_or_Last}_{name}.npy')
    
    
    conf, nms = find_conf_and_nms_thresh(name=f'{name}', Best_or_Last=f'{Best_or_Last}')
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    
    for i, nms_thresh in enumerate(nms_thresh_arr):
        
        
        
        tp = confusion[:,i,0,0]
        axes[0].plot(conf_thresh_arr,tp,marker='.',label = f'NMS thresh = {np.round(nms_thresh,2)}')
        
        fn = confusion[:,i,1,0]
        axes[1].plot(conf_thresh_arr,fn,marker='.')#,label = f'NMS thresh = {np.round(nms_thresh,2)}')
    
    i = np.where(nms == nms_thresh_arr)[0][0]
    j = np.where(conf == conf_thresh_arr)[0][0]
    
    tp = confusion[j,i,0,0]
    axes[0].scatter(conf,tp,marker='*',label = 'Optimal model',color='magenta',s=100,zorder=10)

    fp = confusion[j,i,0,1]
    
    fn = confusion[j,i,1,0]
    axes[1].scatter(conf,fn,marker='*',label = 'Conf. thresh '+r'$\alpha = $'+str(np.round(conf,2))+\
                    '\n'+'NMS thresh '+r'$\beta = $'+str(np.round(nms,2))+'\n'+\
                    'No. proposals '+r'$P =$'+str(np.round(tp+fp,2)) +\
                    '\n'+r'$TP =$'+str(np.round(tp,2)) + '\n' \
                    r'$FN =$' + str(np.round(fn,2)),\
                    color='magenta',s=100,zorder=10)    
    
        
    axes[0].legend(loc='best')
    axes[0].set_xlabel('Confidence threshold ' + r'$ (\alpha)$',fontsize = 13)
    axes[0].set_ylabel('TP',fontsize = 15)
    axes[0].plot(conf_thresh_arr,[0.75]*len(conf_thresh_arr),color='black')
    axes[0].plot(conf_thresh_arr,[1.0]*len(conf_thresh_arr),color='black')
    
    rect0 = patches.Rectangle((0.0, 0.75), 1.0, 0.25, alpha=0.5, facecolor='darkgray')

    # Add the rectangle patch to the axis
    axes[0].add_patch(rect0)
    
    
    axes[1].legend(loc='best')
    axes[1].set_xlabel('Confidence threshold ' + r'$ (\alpha)$',fontsize = 13)
    axes[1].set_ylabel('FN',fontsize = 15)
    axes[1].plot(conf_thresh_arr,[0.25]*len(conf_thresh_arr),color='black')
    axes[1].plot(conf_thresh_arr,[0.0]*len(conf_thresh_arr),color='black')
    
    rect1 = patches.Rectangle((0.0, 0.0), 1.0, 0.25, alpha=0.5, facecolor='darkgray')

    # Add the rectangle patch to the axis
    axes[1].add_patch(rect1)
    fig.subplots_adjust(top=0.8)
    
    plt.suptitle(f'Model: {name}\nAverage TP and FN for validation data',fontsize = 20)
    
    plt.subplots_adjust(wspace=0.15, hspace=2)
    #fig.supxlabel('Confidence threshold')
    xticks_arr = np.arange(0.0,1.1,0.1) 
    plt.setp(axes, xticks=xticks_arr)
    
    
    
    
    plt.imshow
    
    
    return

def plot_fn(name='temp', Best_or_Last='Best', nms_thresh='all'):
    previous_model_path = f"/scratch/s204219/RPN_new/{name}"
    save_path = f'{previous_model_path}/Confusion_{Best_or_Last}_{name}'
    
    confusion = np.load(f'{save_path}.npy')

    conf_thresh_arr = np.load(f'{previous_model_path}/conf_thresh_arr_{Best_or_Last}_{name}.npy')
    nms_thresh_arr = np.load(f'{previous_model_path}/nms_thresh_arr_{Best_or_Last}_{name}.npy')
    
    fig, ax = plt.subplots(1,1,figsize=(10,5))
    
    
    if nms_thresh != 'all':
        i = np.argmin(abs(nms_thresh_arr-nms_thresh))
        nms_thresh = nms_thresh_arr[i]
        
        
        fn = confusion[:,i,1,0]
        plt.plot(conf_thresh_arr,fn,marker='.',label = f'NMS thresh = {np.round(nms_thresh,2)}',color=f'C{i}')
        
    
    else:
        
        for i, nms_thresh in enumerate(nms_thresh_arr):

            fn = confusion[:,i,1,0]
            plt.plot(conf_thresh_arr,fn,marker='.',label = f'NMS thresh = {np.round(nms_thresh,2)}')
        
    

    plt.legend(loc='best')
    plt.xlabel('Confidence threshold ' + r'$ (\alpha)$',fontsize = 13)
    plt.ylabel('FN',fontsize = 15)
    plt.plot(conf_thresh_arr,[0.25]*len(conf_thresh_arr),color='black')
    plt.plot(conf_thresh_arr,[0.0]*len(conf_thresh_arr),color='black')

    plt.title(f'Model: {name}\nAverage FN',fontsize = 20)
    
    rect = patches.Rectangle((0.0, 0.0), 1.0, 0.25, alpha=0.5, facecolor='darkgray')

    # Add the rectangle patch to the axis
    ax.add_patch(rect)
    
    #plt.xlim((0.0, 1.0))  
    
    xticks_arr = np.arange(0.0,1.1,0.1) 
    plt.xticks(xticks_arr)
    
    plt.imshow
    
    
    return

def plot_tp(name='temp', Best_or_Last='Best', nms_thresh='all'):
    previous_model_path = f"/scratch/s204219/RPN_new/{name}"
    save_path = f'{previous_model_path}/Confusion_{Best_or_Last}_{name}'
    
    confusion = np.load(f'{save_path}.npy')

    conf_thresh_arr = np.load(f'{previous_model_path}/conf_thresh_arr_{Best_or_Last}_{name}.npy')
    nms_thresh_arr = np.load(f'{previous_model_path}/nms_thresh_arr_{Best_or_Last}_{name}.npy')
    
    fig, ax = plt.subplots(1,1,figsize=(10,5))
    
    
    if nms_thresh != 'all':
        i = np.argmin(abs(nms_thresh_arr-nms_thresh))
        nms_thresh = nms_thresh_arr[i]
        
        
        tp = confusion[:,i,0,0]
        plt.plot(conf_thresh_arr,tp,marker='.',label = f'NMS thresh = {np.round(nms_thresh,2)}',color=f'C{i}')
        
    
    else:
        
        for i, nms_thresh in enumerate(nms_thresh_arr):

            tp = confusion[:,i,0,0]
            plt.plot(conf_thresh_arr,tp,marker='.',label = f'NMS thresh = {np.round(nms_thresh,2)}')
        
    

    plt.legend(loc='best')
    plt.xlabel('Confidence threshold ' + r'$ (\alpha)$',fontsize = 13)
    plt.ylabel('TP',fontsize = 15)
    plt.plot(conf_thresh_arr,[0.75]*len(conf_thresh_arr),color='black')
    plt.plot(conf_thresh_arr,[1.0]*len(conf_thresh_arr),color='black')
    
    rect = patches.Rectangle((0.0, 0.75), 1.0, 0.25, alpha=0.5, facecolor='darkgray')

    # Add the rectangle patch to the axis
    ax.add_patch(rect)
    
    plt.title(f'Model: {name}\nAverage TP',fontsize = 20)
    
    #plt.xlim((0.0, 1.0))  
    
    xticks_arr = np.arange(0.0,1.1,0.1) 
    plt.xticks(xticks_arr)
    
    plt.imshow
    
    
    return

def plot_ROC(name='temp',Best_or_Last='Best'):
    
    
    previous_model_path = f"/scratch/s204219/RPN_new/{name}"
    save_path = f'{previous_model_path}/Confusion_{Best_or_Last}_{name}'
    
    confusion = np.load(f'{save_path}.npy')

    conf_thresh_arr = np.load(f'{previous_model_path}/conf_thresh_arr_{Best_or_Last}_{name}.npy')
    nms_thresh_arr = np.load(f'{previous_model_path}/nms_thresh_arr_{Best_or_Last}_{name}.npy')
    plt.figure()
    for i,nms_thresh in enumerate(nms_thresh_arr):
    
        tp = confusion[:,i,0,0]
        tn = confusion[:,i,1,1]
        fn = confusion[:,i,1,0]
        fp = confusion[:,i,0,1]

        tpr = tp/(tp+fn)

        fpr = fp/(fp+tn)
        
        plt.plot(fpr,tpr,marker='.',label = f'NMS thresh = {np.round(nms_thresh,2)}')
    
    plt.legend(loc='best')
    plt.xlabel('fpr')
    plt.ylabel('tpr')
    plt.title(f'Model: {name}\nReceiver operating characteristic (ROC) curve')
    plt.show()
    
    return

def plot_precision_recall(name='temp',Best_or_Last='Best'):
    
    
    previous_model_path = f"/scratch/s204219/RPN_new/{name}"
    save_path = f'{previous_model_path}/Confusion_{Best_or_Last}_{name}'
    
    confusion = np.load(f'{save_path}.npy')

    conf_thresh_arr = np.load(f'{previous_model_path}/conf_thresh_arr_{Best_or_Last}_{name}.npy')
    nms_thresh_arr = np.load(f'{previous_model_path}/nms_thresh_arr_{Best_or_Last}_{name}.npy')
    
    F1 = []
    
    for i,nms_thresh in enumerate(nms_thresh_arr):
    
        tp = confusion[:,i,0,0]
        tn = confusion[:,i,1,1]
        fn = confusion[:,i,1,0]
        fp = confusion[:,i,0,1]
        
        # Precision
        pre = tp/(tp+fp)
        
        # Recall
        rec = tp/(tp+fn)
        
        F1.append(2*tp/(2*tp + fp + fn)) # 1 /( 1/rec + 1/pre )
        
        plt.plot(rec,pre,marker='.',label = f'NMS thresh = {np.round(nms_thresh,2)}')
    
    plt.legend(loc='best')
    plt.xlabel('Recall  '+ r'$\dfrac{tp}{tp+fn}$')
    plt.ylabel('Precision  '+ r'$\dfrac{tp}{tp+fp}$')
    plt.title(f'Model: {name}\nPrecision-Recall curve')
    plt.show()
    
    return F1

def plot_recall_vs_number_of_propsals(name='temp',nms_thresh='all'):
    
    Best_or_Last='Best'
    
    if isinstance(name,str):
        one_model_only = True
        names = [name]
        nms_thresh = [nms_thresh]
    
    if isinstance(name,list):
        one_model_only = False
        names = name
    
    nms_thresh = nms_thresh.copy()
    
    best_conf_thresh = []
    best_nms_thresh  = []
    nms_thresh_was_best = []
    
    for j,name in enumerate(names):
        conf, nms = find_conf_and_nms_thresh(name=f'{name}',Best_or_Last=f'{Best_or_Last}')
        best_conf_thresh.append(conf)
        best_nms_thresh.append(nms)
    
        if nms_thresh[j] == 'best':
            nms_thresh[j] = best_nms_thresh[j]
            nms_thresh_was_best.append(True)
        else:
            nms_thresh_was_best.append(False)
    
    
    
    
    
    fig, ax = plt.subplots(1,1,figsize=(8,3))
    max_no_prop = -1
    for j,name in enumerate(names):
        previous_model_path = f"/scratch/s204219/RPN_new/{name}"
        save_path = f'{previous_model_path}/Confusion_test_{Best_or_Last}_{name}'

        confusion = np.load(f'{save_path}.npy')

        conf_thresh_arr = np.load(f'{previous_model_path}/conf_thresh_arr_test_{Best_or_Last}_{name}.npy')
        nms_thresh_arr = np.load(f'{previous_model_path}/nms_thresh_arr_test_{Best_or_Last}_{name}.npy')
        
        
        
        
        
        if nms_thresh[j] != 'all':
            i = np.argmin(abs(nms_thresh_arr-nms_thresh[j]))
            nms_thresh_temp = nms_thresh_arr[i]

            tp = confusion[:,i,0,0]
            tn = confusion[:,i,1,1]
            fn = confusion[:,i,1,0]
            fp = confusion[:,i,0,1]

            # Recall
            rec = tp/(tp+fn)

            # Number of proposals
            no_prop = tp+fp
            
            if np.max(no_prop) > max_no_prop:
                max_no_prop = np.max(no_prop)
            
            
            if one_model_only == True:
                plt.plot(no_prop,rec,marker='.',label = f'{name}, NMS thresh = {np.round(nms_thresh_temp,2)}',color=f'C{i}')
                k=0
                va='bottom'
                ha='right'
                for x, y, conf_thresh in zip(no_prop, rec, conf_thresh_arr):
                    k+=1
                    if k % 3 == 0:
                        plt.text(x, y, f'{conf_thresh:.2f}', ha=ha, va=va)
                        if va == 'bottom':
                            va = 'top'
                        else:
                            va = 'bottom'
            else:
                plt.plot(no_prop,rec,marker='.',label = f'{name}, NMS thresh = {np.round(nms_thresh_temp,2)}')
                
            if nms_thresh_was_best[j] == True:
        
                y = np.where(best_conf_thresh[j] == conf_thresh_arr)[0][0]
                tp = confusion[y,i,0,0]
                fn = confusion[y,i,1,0]
                fp = confusion[y,i,0,1]

                rec = tp/(tp+fn)
                no_prop = tp+fp

                #ax.scatter(no_prop,rec,marker='*',label = 'Optimal model',color='magenta',s=100,zorder=10)
                ax.scatter(no_prop,rec,marker='*',s=200,zorder=10)
            
                

        if nms_thresh[j] == 'all':  

            for i,nms_thresh_temp in enumerate(nms_thresh_arr):

                tp = confusion[:,i,0,0]
                tn = confusion[:,i,1,1]
                fn = confusion[:,i,1,0]
                fp = confusion[:,i,0,1]

                # Recall
                rec = tp/(tp+fn)

                # Number of proposals
                no_prop = tp+fp
    
                if np.max(no_prop) >= max_no_prop:
                    max_no_prop = np.max(no_prop)                
                print(max_no_prop)
                
                plt.plot(no_prop,rec,marker='.',label = f'NMS thresh = {np.round(nms_thresh_temp,2)}') 
              
            
    
    
    
    rect0 = patches.Rectangle((0.0, 0.75), max_no_prop, 0.25, alpha=0.5, facecolor='darkgray')
    
    plt.ylim(-0.05,1.05)
    plt.yticks([0,1/4,1/2,3/4,1])
    # Add the rectangle patch to the axis
    ax.add_patch(rect0)
    
    plt.legend(loc='best')
    plt.xlabel('Number of proposals  '+ r'$TP+FP$')
    plt.ylabel('Recall  '+ r'$\dfrac{TP}{TP+FN}$')
    
    if one_model_only == True:
        plt.title(f'Model: {name}\nRecall vs. number of propsals')
    else:
        title_string = str(', '.join(names)) 
        plt.title(f'Models: {title_string}\nRecall vs. number of propsals')
    
    plt.plot([0.0,max_no_prop],[0.75,0.75],color='black')
    plt.plot([0.0,max_no_prop],[1.0,1.0],color='black')

    
    
    plt.show() 

    return

def plot_recall_vs_conf_thresh(name='temp',nms_thresh='all'):
    
    Best_or_Last='Best'
    
    if isinstance(name,str):
        one_model_only = True
        names = [name]
        nms_thresh = [nms_thresh]
    
    if isinstance(name,list):
        one_model_only = False
        names = name
    
    nms_thresh = nms_thresh.copy()
    
    best_conf_thresh = []
    best_nms_thresh  = []
    nms_thresh_was_best_1 = []

    for j,name in enumerate(names):
        conf, nms = find_conf_and_nms_thresh(name=f'{name}',Best_or_Last=f'{Best_or_Last}')
        best_conf_thresh.append(conf)
        best_nms_thresh.append(nms)
    
        if nms_thresh[j] == 'best':
            nms_thresh[j] = best_nms_thresh[j]
            nms_thresh_was_best_1.append(True)
        else:
            nms_thresh_was_best_1.append(False)
            
    
    fig, ax = plt.subplots(1,1,figsize=(8,3))
    for j,name in enumerate(names):
        previous_model_path = f"/scratch/s204219/RPN_new/{name}"
        save_path = f'{previous_model_path}/Confusion_test_{Best_or_Last}_{name}'

        confusion = np.load(f'{save_path}.npy')

        conf_thresh_arr = np.load(f'{previous_model_path}/conf_thresh_arr_test_{Best_or_Last}_{name}.npy')
        nms_thresh_arr = np.load(f'{previous_model_path}/nms_thresh_arr_test_{Best_or_Last}_{name}.npy')

        if nms_thresh[j] != 'all':
            i = np.argmin(abs(nms_thresh_arr-nms_thresh[j]))
            nms_thresh_temp = nms_thresh_arr[i]

            tp = confusion[:,i,0,0]
            tn = confusion[:,i,1,1]
            fn = confusion[:,i,1,0]
            fp = confusion[:,i,0,1]

            # Recall
            rec = tp/(tp+fn)            
            
            no_prop = tp+fp
            
            
            if one_model_only == True:
                plt.plot(conf_thresh_arr,rec,marker='.',label = f'{name}, NMS thresh = {np.round(nms_thresh_temp,2)}',color=f'C{i}')
                k=0
                va='bottom'
                ha='right'
                for x, y, conf_thresh in zip(no_prop, rec, no_prop):
                    k+=1
                    if k % 3 == 0:
                        plt.text(x, y, f'{no_prop:.2f}', ha=ha, va=va)
                        if va == 'bottom':
                            va = 'top'
                        else:
                            va = 'bottom'
            else:
                plt.plot(conf_thresh_arr,rec,marker='.',label = f'{name}, NMS thresh = {np.round(nms_thresh_temp,2)}')
                
            if nms_thresh_was_best_1[j] == True:
                y = np.where(best_conf_thresh[j] == conf_thresh_arr)[0][0]
                tp = confusion[y,i,0,0]
                fn = confusion[y,i,1,0]
                fp = confusion[y,i,0,1]
                
                rec = tp/(tp+fn)
                no_prop = tp+fp

                #ax.scatter(no_prop,rec,marker='*',label = 'Optimal model',color='magenta',s=100,zorder=10)
                ax.scatter(best_conf_thresh[j],rec,marker='*',s=200,zorder=10)
            

        if nms_thresh == 'all':  

            for i,nms_thresh_temp in enumerate(nms_thresh_arr):

                tp = confusion[:,i,0,0]
                tn = confusion[:,i,1,1]
                fn = confusion[:,i,1,0]
                fp = confusion[:,i,0,1]

                # Recall
                rec = tp/(tp+fn)

                # Number of proposals
                no_prop = tp+fp
                
                
                plt.plot(conf_thresh_arr,rec,marker='.',label = f'NMS thresh = {np.round(nms_thresh_temp,2)}')
            
        
    
    rect0 = patches.Rectangle((0.0, 0.75), 1.0, 0.25, alpha=0.5, facecolor='darkgray')
    
    plt.ylim(-0.05,1.05)
    plt.yticks([0,1/4,1/2,3/4,1])
    # Add the rectangle patch to the axis
    ax.add_patch(rect0)
    
    plt.legend(loc='best')
    plt.xlabel('Confidence threshold  '+ r'$(\alpha)$')
    plt.ylabel('Recall  '+ r'$\dfrac{TP}{TP+FN}$')
    plt.xticks(np.linspace(0.0,1.0,11))
    
    if one_model_only == True:
        plt.title(f'Model: {name}\nRecall vs. confidence threshold')
    else:
        title_string = str(', '.join(names)) 
        plt.title(f'Models: {title_string}\nRecall vs. confidence threshold')
    plt.plot([0.0,1.0],[0.75,0.75],color='black')
    plt.plot([0.0,1.0],[1.0,1.0],color='black')
    plt.show() 

    return

def plot_experiment():
    
    names = ['hyperV2_one_hard_big','hyperV2_two_hard_big','hyperV2_three_hard_big','hyperV2_four_hard_big',\
             'hyperV2_five_hard_big','hyperV2_six_hard_big','hyperV2_seven_hard_big','hyperV2_eight_hard_big']
    
    test_loss_all = []
    for name in names:
        previous_model_path = f"/scratch/s204219/RPN_new/{name}"
        test_loss = np.load(f'{previous_model_path}/Loss_test_Best_{name}.npy')
        test_loss_all.append(test_loss)
    
    
    data = [i*(1/8)*80000 for i in range(1,9)]
    plt.figure(figsize=(8,3))
    plt.plot(data,test_loss_all,'o-',label='HyperV2')
    plt.title('HyperV2\n Test loss Experiment')
    plt.xlabel('Train data size')
    plt.ylabel('Test loss')
    plt.legend(loc='best')
    
    return

def plot_proposals_non_augmented(name='temp',Best_or_Last='Best',conf_thresh=0.6,nms_thresh=0.1,number=2,no_proposals=10):
    
    previous_model_path = f"/scratch/s204219/RPN_new/{name}"
    previous_model = f'{previous_model_path}/{Best_or_Last}_RPN_{name}'    
    
    all_images = []

    for i in range(1,5):
        image = np.load(f'/scratch/s204219/blue_ATM/im{i}.npy')
        image = image[:,:,:3]
        image_tensor = torch.from_numpy(image).permute(2,0,1)
        image_tensor = image_tensor.float()
        all_images.append(image_tensor)


    img_batch = torch.stack(all_images, dim=0)
#     index = np.load(f'{previous_model_path}/Index_{train_or_val_or_test}_{name}.npy')
    
#     #file_root  = "/scratch/s204219/augmentedData"   
#     file_root,_ = init_file_root_and_list(name=name)
    
#     file_list  = index[idx:idx+number*2]  
#     # np.array([index[4],index[8]])
    
#     dataset = UXO_dataset(file_list=file_list, file_root=file_root)
    
    RPN = init_RPN(name)
    
    width_scale_factor  = RPN.width_scale_factor
    height_scale_factor = RPN.height_scale_factor
    
    checkpoint = torch.load(previous_model)
    RPN.load_state_dict(checkpoint['model_state_dict'])
    
    n_anc_boxes_total = RPN.n_anc_boxes_total
    
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    # e = checkpoint['epoch']
    # loss = checkpoint['loss']
    
     
    
    # img_batch,bb_batch = dataset[:]
    batch_size = len(img_batch)

    #gt_bboxes_proj = project_bboxes(bb_batch, width_scale_factor, height_scale_factor, mode='p2a')
    
    
    proposals_final, conf_scores_final, feature_map = RPN.inference(img_batch.to(device),conf_thresh=conf_thresh,nms_thresh=nms_thresh)
    for j in range(number):
        proposals_final[j*2] = proposals_final[j*2][:no_proposals,:]
        proposals_final[j*2+1] = proposals_final[j*2+1][:no_proposals,:]
        
        
    
    pos_thresh = 0.3
    # tp,fp,fn,tn = calc_conf_mat(gt_bboxes_proj,proposals_final,pos_thresh,n_anc_boxes_total,list_or_sum='list')
    # print('Image 1 tp:' + str(tp[i*2].detach().cpu().numpy()) + '| Image 2 tp:' + str(tp[i*2+1].detach().cpu().numpy()))    
    
    for i in range(number):
                
        if proposals_final[i*2].shape == torch.Size([0, 4]):
            proposals_final[i*2] = torch.tensor([[0,0,0,0]],device=device)
        if proposals_final[i*2+1].shape == torch.Size([0, 4]):
            proposals_final[i*2+1] = torch.tensor([[0,0,0,0]],device=device)
            
        proposals_proj = project_bboxes(proposals_final[i*2], width_scale_factor, height_scale_factor, mode='a2p').detach().cpu(), \
                       project_bboxes(proposals_final[i*2+1], width_scale_factor, height_scale_factor, mode='a2p').detach().cpu()
        nrows, ncols = (1, 2)
        fig, axes = plt.subplots(nrows, ncols, figsize=(10, 5))

        fig, axes = display_img(img_batch[i*2:i*2+2], fig, axes)

        # plot proposals
        
        fig, _ = display_bbox(proposals_proj[0][:no_proposals,:], fig, axes[0])
        fig, _ = display_bbox(proposals_proj[1][:no_proposals,:], fig, axes[1])
        
        #fig, _ = display_bbox(proposals_proj[0], fig, axes[0])
        #fig, _ = display_bbox(proposals_proj[1], fig, axes[1])
        
        # add titles to the figures
        # f'tp={tp[j]}, ' + f'fp={fp[j]}'
        axes[0].set_title('Survey '+str(i*2+1))#, ' + f'fn={fn[i*2]}, ' + f'tn={tn[i*2]}')
        axes[1].set_title('Survey '+str(i*2+2))#, ' + f'fn={fn[i*2+1]}, ' + f'tn={tn[i*2+1]}' )
        
        plt.setp(axes, xticks=[], yticks=[])
    
    

    
    return 

def plot_proposals_non_preprocessed(name='temp',Best_or_Last='Best',conf_thresh=0.6,nms_thresh=0.1,number=2,no_proposals=10):
    
    previous_model_path = f"/scratch/s204219/RPN_new/{name}"
    previous_model = f'{previous_model_path}/{Best_or_Last}_RPN_{name}'    
    
    all_images = []

    for i in range(1,5):
        image = np.load(f'/scratch/s204219/blue_ATM/org_im{i}.npy')
        image_tensor = torch.from_numpy(image).permute(2,0,1)
        image_tensor = image_tensor.float()
        all_images.append(image_tensor)


    img_batch = torch.stack(all_images, dim=0)
#     index = np.load(f'{previous_model_path}/Index_{train_or_val_or_test}_{name}.npy')
    
#     #file_root  = "/scratch/s204219/augmentedData"   
#     file_root,_ = init_file_root_and_list(name=name)
    
#     file_list  = index[idx:idx+number*2]  
#     # np.array([index[4],index[8]])
    
#     dataset = UXO_dataset(file_list=file_list, file_root=file_root)
    
    RPN = init_RPN(name)
    
    width_scale_factor  = RPN.width_scale_factor
    height_scale_factor = RPN.height_scale_factor
    
    checkpoint = torch.load(previous_model)
    RPN.load_state_dict(checkpoint['model_state_dict'])
    
    n_anc_boxes_total = RPN.n_anc_boxes_total
    
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    # e = checkpoint['epoch']
    # loss = checkpoint['loss']
    
     
    
    # img_batch,bb_batch = dataset[:]
    batch_size = len(img_batch)

    #gt_bboxes_proj = project_bboxes(bb_batch, width_scale_factor, height_scale_factor, mode='p2a')
    
    
    proposals_final, conf_scores_final, feature_map = RPN.inference(img_batch.to(device),conf_thresh=conf_thresh,nms_thresh=nms_thresh)
    for j in range(number):
        proposals_final[j*2] = proposals_final[j*2][:no_proposals,:]
        proposals_final[j*2+1] = proposals_final[j*2+1][:no_proposals,:]
        
        
    
    pos_thresh = 0.3
    # tp,fp,fn,tn = calc_conf_mat(gt_bboxes_proj,proposals_final,pos_thresh,n_anc_boxes_total,list_or_sum='list')
    # print('Image 1 tp:' + str(tp[i*2].detach().cpu().numpy()) + '| Image 2 tp:' + str(tp[i*2+1].detach().cpu().numpy()))    
    
    for i in range(number):
                
        if proposals_final[i*2].shape == torch.Size([0, 4]):
            proposals_final[i*2] = torch.tensor([[0,0,0,0]],device=device)
        if proposals_final[i*2+1].shape == torch.Size([0, 4]):
            proposals_final[i*2+1] = torch.tensor([[0,0,0,0]],device=device)
            
        proposals_proj = project_bboxes(proposals_final[i*2], width_scale_factor, height_scale_factor, mode='a2p').detach().cpu(), \
                       project_bboxes(proposals_final[i*2+1], width_scale_factor, height_scale_factor, mode='a2p').detach().cpu()
        nrows, ncols = (1, 2)
        fig, axes = plt.subplots(nrows, ncols, figsize=(10, 5))

        fig, axes = display_img(img_batch[i*2:i*2+2], fig, axes)

        # plot proposals
        
        fig, _ = display_bbox(proposals_proj[0][:no_proposals,:], fig, axes[0])
        fig, _ = display_bbox(proposals_proj[1][:no_proposals,:], fig, axes[1])
        
        #fig, _ = display_bbox(proposals_proj[0], fig, axes[0])
        #fig, _ = display_bbox(proposals_proj[1], fig, axes[1])
        
        # add titles to the figures
        # f'tp={tp[j]}, ' + f'fp={fp[j]}'
        axes[0].set_title('Survey '+str(i*2+1))#, ' + f'fn={fn[i*2]}, ' + f'tn={tn[i*2]}')
        axes[1].set_title('Survey '+str(i*2+2))#, ' + f'fn={fn[i*2+1]}, ' + f'tn={tn[i*2+1]}' )
        
        plt.setp(axes, xticks=[], yticks=[])
    
    

    
    return 



In [3]:
# PIPELINE 

# For name = 'example'
# 'example'          will use the dataset: augmentedData
# 'example_big'      will use the dataset: augmentedData_big
# 'example_hard'     will use the dataset: augmentedData_hard
# 'example_hard_big' will use the dataset: augmentedData_hard_big

name = 'hyperV2_one_hard_big'
experiment = True
dummy_pipeline = False
cpm = False # choose_previous_model

if dummy_pipeline == True:
    conf_thresh_arr = [0.5,1]
    nms_thresh_arr  = [0.1,0.2]
    dataset_size = 1000
    n_epochs = 2
    
if dummy_pipeline == False: 
    conf_thresh_arr = np.linspace(0.0,1.0,21)
    nms_thresh_arr  = np.linspace(0.05,0.2,4)
    dataset_size = 40000
    n_epochs = 100

if experiment == True:
    dataset_size = 100000


train_model_new(name=f'{name}',n_epochs=n_epochs,dataset_size=dataset_size,val_ratio=1/7.,test_ratio=1/7.,save=True,cpm=cpm)
print('')
test_model(name=f'{name}',Best_or_Last='Best')
print('')

tune_model(name=f'{name}',Best_or_Last='Best',train_or_val_or_test = 'val',\
           conf_thresh_arr=conf_thresh_arr,nms_thresh_arr=nms_thresh_arr, save=True)

test_model_confusion(name=f'{name}',Best_or_Last='Best',conf_thresh_arr=conf_thresh_arr,nms_thresh_arr=nms_thresh_arr)

In [1]:
# Plot proposals

name = 'baseline'
Best_or_Last = 'Best'
train_or_val_or_test = 'test'


conf,nms = find_conf_and_nms_thresh(name=f'{name}', Best_or_Last=f'{Best_or_Last}')

if False:
    plot_train_val_test_proposals(name=f'{name}',Best_or_Last=f'{Best_or_Last}',\
                                  conf_thresh=conf,nms_thresh=nms,idx=[9,6,2],\
                                  number=3,no_proposals=5)
if False:
    plot_proposals(name=f'{name}',Best_or_Last=f'{Best_or_Last}',train_or_val_or_test=f'{train_or_val_or_test}'\
                   ,conf_thresh=conf,nms_thresh=nms,idx=0,number=10,no_proposals=5)
    
if False:
    conf,nms = find_conf_and_nms_thresh(name='baseline', Best_or_Last='Best')

    plot_proposals_non_augmented(name='baseline',Best_or_Last='Best',conf_thresh=conf,nms_thresh=nms,number=2,no_proposals=5)

    plot_proposals_non_preprocessed(name='baseline',Best_or_Last='Best',conf_thresh=conf,nms_thresh=nms,number=2,no_proposals=5)


In [4]:
# Compare models

#names = ['baseline','hyperV2'] #['baseline','random','hyper'] #['baseline','random','hyper','hyperV2','resnet50hyper']
names = ['baseline','random','hyper','hyperV2','resnet50hyper']

nms_thresh = ['best']*len(names)
nms_thresh = [0.15]*len(names)

print_results(names)

plot_recall_vs_number_of_propsals(name=names,nms_thresh=nms_thresh)

plot_recall_vs_conf_thresh(name=names,nms_thresh=nms_thresh)


In [5]:
# Print data
file_root = "/scratch/s204219/augmentedData"
file_list = np.arange(dsn+1000+0*10000,dsn+2000+0*10000) 


dataset = UXO_dataset(file_list=file_list, file_root=file_root)

for i in range(800,1000):
    img = dataset[i][0]
    img = img.permute(1, 2, 0).numpy()
    plt.figure()
    plt.imshow(img.astype('uint8'))
    plt.xticks([])
    plt.yticks([])

In [6]:
# Estimating varepsilon

name = 'hyperV2_one_hard_big'
Best_or_Last = 'Best'
train_or_val_or_test = 'test'

conf_thresh, nms_thresh = find_conf_and_nms_thresh(name='baseline', Best_or_Last='Best')

batch_size = 50


previous_model_path = f"/scratch/s204219/RPN_new/{name}"
previous_model = f'{previous_model_path}/{Best_or_Last}_RPN_{name}'    

index = np.load(f'{previous_model_path}/Index_{train_or_val_or_test}_{name}.npy')

#file_root  = "/scratch/s204219/augmentedData"   
file_root,_ = init_file_root_and_list(name=name)

file_list  = index


dataset = UXO_dataset(file_list=file_list, file_root=file_root)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

RPN = init_RPN(name)

width_scale_factor  = RPN.width_scale_factor
height_scale_factor = RPN.height_scale_factor

checkpoint = torch.load(previous_model)
RPN.load_state_dict(checkpoint['model_state_dict'])

n_anc_boxes_total = RPN.n_anc_boxes_total


e = checkpoint['epoch']
loss = checkpoint['loss']

print(e)

pos_thresh = 0.3


tp_list = np.zeros((len(dataset)))

print('inference')
idx = 0
for img_batch, gt_bboxes_batch in dataloader:
    gt_bboxes_proj = project_bboxes(gt_bboxes_batch, width_scale_factor, height_scale_factor, mode='p2a')
    
    img_batch = img_batch.to(device)
    gt_bboxes_batch = gt_bboxes_batch.to(device)

    proposals_final, conf_scores_final, feature_map = RPN.inference(img_batch,conf_thresh=conf_thresh,nms_thresh=nms_thresh)
    
    tp,fp,fn,tn = calc_conf_mat(gt_bboxes_proj,proposals_final,pos_thresh,n_anc_boxes_total,list_or_sum='list')
    
    tp_list[idx:idx+len(tp)] = tp.detach().cpu().numpy()
    
    idx+=len(tp)
    
    print(idx)
    
# for j in range(number):
#     proposals_final[j*2] = proposals_final[j*2][:no_proposals,:]
#     proposals_final[j*2+1] = proposals_final[j*2+1][:no_proposals,:]





# print('Image 1 tp:' + str(tp[i*2].detach().cpu().numpy()) + '| Image 2 tp:' + str(tp[i*2+1].detach().cpu().numpy()))    

eps = len(tp_list)/np.sum(tp_list==2)

print(eps)