In [None]:
! nvidia-smi

In [None]:
import numpy as np
import os
from retinanet import model
from retinanet import coco_eval
from retinanet.dataloader import CocoDataset_inOrder, collater, Resizer, AspectRatioBasedSampler, Augmenter, Normalizer
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import collections
import torch


root_path = '/home/deeplab307/Documents/Anaconda/Shiang/CL/'
method = 'w_distillation'
data_split = '15+1'
start_round = 2
batch_size = 2

def get_checkpoint_path(method, now_round, epoch):
    global root_path
    global data_split
    
    path = os.path.join(root_path, 'model', method, 'round{}'.format(now_round), data_split,'voc_retinanet_{}_checkpoint.pt'.format(epoch))
    return path

def readCheckpoint(method, now_round, epoch, retinanet, optimizer = None, scheduler = None):
    print('readcheckpoint at Round{} Epoch{}'.format(now_round, epoch))
    prev_checkpoint = torch.load(get_checkpoint_path(method, now_round, epoch))
    retinanet.load_state_dict(prev_checkpoint['model_state_dict'])
    if optimizer != None:
        optimizer.load_state_dict(prev_checkpoint['optimizer_state_dict'])
    if scheduler != None:
        scheduler.load_state_dict(prev_checkpoint['scheduler_state_dict'])
    

coco_path = '/home/deeplab307/Documents/Anaconda/Shiang/CL/DataSet/VOC2012'

dataset_train = CocoDataset_inOrder(coco_path, set_name='TrainVoc2012', dataset = 'voc',
                                    transform=transforms.Compose([Normalizer(), Augmenter(), Resizer()]), 
                                    start_round=start_round, data_split = data_split)

sampler = AspectRatioBasedSampler(dataset_train, batch_size = batch_size, drop_last=False)
dataloader_train = DataLoader(dataset_train, num_workers=2, collate_fn=collater, batch_sampler=sampler)

retinanet = model.resnet50(num_classes=dataset_train.num_classes(), pretrained=True)

# optimizer = optim.Adam(retinanet.parameters(), lr=1e-5)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
# loss_hist = collections.deque(maxlen=500)

# readCheckpoint(method, start_round, 30, retinanet)#, optimizer, scheduler)
# retinanet = retinanet.cuda()

In [None]:
c = torch.load(get_checkpoint_path('w_distillation', 2, 30))

In [None]:
c['rehearsal_samples']

In [None]:
for img in rehearsal_imgs:
    im = cv2.imread(os.path.join(path, str(img)[:4] + '_' + str(img)[4:] +'.jpg'))
    plt.figure(figsize=(10,10))
    imshow(im)

In [None]:
now_round = 1
path = os.path.join(root_path, 'valResult', 'Voc', method)
path = os.path.join(path, "round{}".format(now_round))
result = json.load(open(os.path.join(path, '{}_bbox_results_{}_for{}epoch_1.json'.format("TestVoc2007", now_round, 50))))

In [None]:
for i in range(1,5):
    thresold = 0.1 * i
    new_result = []
    print(thresold, len(result), end=' ')
    for i,data in enumerate(result):
        if data['score'] >= thresold:
            new_result.append(data)
    print(len(new_result))

In [None]:
from pycocotools.cocoeval import COCOeval
from collections import defaultdict
from tqdm import tqdm
import copy
def evaluate_coco(dataset, model, root_path, method, now_round, epoch, threshold=0.05):
    
    model.eval()
    path = os.path.join(root_path, 'valResult', 'Voc', method)
    path = os.path.join(path, "round{}".format(now_round))
    with torch.no_grad():

        # start collecting results
        results = []
        image_ids = []

#         for index in tqdm(range(len(dataset))):
#             data = dataset[index]
#             scale = data['scale']

#             # run network
#             if torch.cuda.is_available():
#                 scores, labels, boxes = model(data['img'].permute(2, 0, 1).cuda().float().unsqueeze(dim=0))
#             else:
#                 scores, labels, boxes = model(data['img'].permute(2, 0, 1).float().unsqueeze(dim=0))
#             scores = scores.cpu()
#             labels = labels.cpu()
#             boxes  = boxes.cpu()
#             # correct boxes for image scale
#             boxes /= scale

#             if boxes.shape[0] > 0:
#                 # change to (x, y, w, h) (MS COCO standard)
#                 boxes[:, 2] -= boxes[:, 0]
#                 boxes[:, 3] -= boxes[:, 1]

#                 # compute predicted labels and scores
#                 #for box, score, label in zip(boxes[0], scores[0], labels[0]):
#                 for box_id in range(boxes.shape[0]):
#                     score = float(scores[box_id])
#                     label = int(labels[box_id])
#                     box = boxes[box_id, :]

#                     # scores are sorted, so we can break
#                     if score < threshold:
#                         break

#                     # append detection for each positively labeled class
#                     image_result = {
#                         'image_id'    : dataset.image_ids[index],
#                         'category_id' : dataset.label_to_coco_label(label),
#                         'score'       : float(score),
#                         'bbox'        : box.tolist(),
#                     }

#                     # append detection to results
#                     results.append(image_result)

#             # append image to list of processed images
#             image_ids.append(dataset.image_ids[index])

#         if not len(results):
#             return

#         # write output
        
#         path = os.path.join(root_path, 'valResult', 'Voc', method)
#         checkDir(path)
#         path = os.path.join(path, "round{}".format(now_round))
#         checkDir(path)
      
#         json.dump(results, open(os.path.join(path , '{}_bbox_results_{}_for{}epoch.json'.format(dataset.set_name, now_round, epoch)), 'w') ,indent=4)
        
        image_ids = dataset.image_ids
        # load results in COCO evaluation tool
        coco_true = dataset.coco
        
        
        result = json.load(open(os.path.join(path, '{}_bbox_results_{}_for{}epoch.json'.format(dataset.set_name, now_round, epoch))))
        
        new_result = []
        print(threshold, len(result), end=' ')
        for i,data in enumerate(result):
            if data['score'] >= threshold:
                new_result.append(data)
        print(len(new_result))
        
    
        coco_pred = coco_true.loadRes(new_result)
        
        return (coco_true, coco_pred)
        #coco_pred = coco_true.loadRes(os.path.join(path, '{}_bbox_results_{}_for{}epoch_1.json'.format(dataset.set_name, now_round, epoch)))
        #return (coco_true, coco_pred)
        # run COCO evaluation
        coco_eval = COCOeval(coco_true, coco_pred, 'bbox')
        coco_eval.params.imgIds = image_ids
        
        precision_result = defaultdict()
        recall_result = defaultdict()
        
        for class_id in dataset.seen_class_id:
            class_name = dataset.cocoHelper.catIdToName(class_id)[0]
            print('Evaluate {}:'.format(class_name))
            coco_eval.params.catIds = [class_id]

#             coco_eval.params.imgIds = list(set(dataset.cocoHelper.getImgIdFromCats(dataset.seen_class_id)) -  set(dataset.cocoHelper.getImgIdFromCats([2,4,8])))
#             print(len(coco_eval.params.imgIds))
#             coco_eval.params.imgIds = dataset.cocoHelper.getImgIdFromCats(class_id)
            coco_eval.evaluate()
            #return coco_eval
#             for key in coco_eval.ious.keys():
#                 print(coco_eval.ious[key])
#                 break

            coco_eval.accumulate()
            coco_eval.summarize()
            precision_result[class_name] = coco_eval.stats[1]
            recall_result[class_name] = coco_eval.stats[8]

        if len(dataset.seen_class_id) > 1:
            print("Precision:")
            for name, ap in sorted(precision_result.items()):
                print('{:<12} = {:0.2f}'.format(name, ap))

            print("Recall:")
            for name, ap in sorted(recall_result.items()):
                print('{:<12} = {:0.2f}'.format(name, ap))
            
            print("------------------------------------------")
            print('{:<12} = {:0.2f}'.format('MAP', np.mean([v for v in precision_result.values()])))
            print('{:<12} = {:0.2f}'.format('Average Recall', np.mean([v for v in recall_result.values()])))
            print("Precision:")
            for name, ap in sorted(precision_result.items()):
                print('{:0.2f}'.format(ap))
            print("Recall:")
            for name, ap in sorted(recall_result.items()):
                print('{:0.2f}'.format(ap))
        model.train()
        
        return
def validation(val_model, dataType, model_round, model_epoch, val_round, years=2012,test_flag=False,custom_ids=[], threshold = 0.05):
    global data_split
    print("-"*100)
    print('Start eval on Round{} Epoch{}!'.format(model_round, model_epoch))

    
    val_model.eval()
    val_model.freeze_bn()
    set_name = "{}Voc{}".format(dataType, years)

    print('Validation data is {} at Round{}'.format(set_name, val_round))
    dataset_val = CocoDataset_inOrder(os.path.join(root_path, 'DataSet', 'VOC{}'.format(years)), set_name=set_name, dataset = 'voc', 
                    transform=transforms.Compose([Normalizer(), Resizer()]), 
                    start_round=val_round, data_split = data_split)
#                     test_flag=test_flag,
#                     custom_ids=custom_ids)
 
    return evaluate_coco(dataset_val, val_model, root_path, method, model_round, model_epoch, threshold)
    del dataset_val


In [None]:
validation(retinanet, 'Test', start_round,60,2, 2007, False, [12], 0.05)

In [None]:
validation(retinanet, 'Test', start_round,50,2, 2007, False, [12], 0.05)

In [None]:
validation(retinanet, 'Test', start_round,40,2, 2007, False, [12], 0.05)

In [None]:
validation(retinanet, 'Test', start_round,50,2, 2007, False, [12], 0.05)

In [None]:
validation(retinanet, 'Test', start_round,40,2, 2007, False, [12], 0.05)

In [None]:
validation(retinanet, 'Test', start_round,40,2, 2007, False, [12], 0.05)

In [None]:
validation(retinanet, 'Test', start_round,40,2, 2007, False, [12], 0.05)

In [None]:
validation(retinanet, 'Test', start_round,50,2, 2007, False, [12], 0.05)

In [None]:
validation(retinanet, 'Test', start_round,50,2, 2007, False, [12], 0.05)

In [None]:
true, pred = validation(retinanet, 'Test', start_round,50, 2, 2007, False, [12], 0.05)

In [None]:
dataset_test = CocoDataset_inOrder(os.path.join(root_path, 'DataSet', 'VOC{}'.format(2007)), set_name="TestVoc2007", dataset = 'voc', 
                transform=transforms.Compose([Normalizer(), Resizer()]), 
                start_round=1, data_split = "15+1")

In [None]:
dataset_test.seen_class_id

In [None]:
dataset_test.cocoHelper.catNameToId('car')

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import os
import cv2
from matplotlib.pyplot import imshow
import random

# Draws a caption above the box in an image
def draw_caption(image, box, caption, color=(255, 0, 255)):
    b = np.array(box).astype(int)
    cv2.putText(image, caption, ( b[0] , b[1] + int((b[3] - b[1]) / 2) ), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 1)

    
target_cat_id = 6
path = '/home/deeplab307/Documents/Anaconda/Shiang/CL/DataSet/VOC2007/images'

non_target = list(set(pred.getImgIds()) & set(true.getImgIds(catIds=[target_cat_id])))
random.shuffle(non_target)
#non12 = true.getImgIds(catIds=[catId])

i = 0
flag = False
flag2 = False
for imgId in non_target:
    print(imgId)
    flag = False
    flag2 = False
    im = cv2.imread(os.path.join(path, os.path.join(path, "%06d.jpg" % (int(imgId)))))
    gd_anns = true.loadAnns(true.getAnnIds(imgIds=[imgId]))
    for ann in pred.loadAnns(pred.getAnnIds(imgIds=[imgId])):
        flag2 = True
        #print(ann)
        box = ann['bbox']

        x1, y1, w, h = box
        x2 = x1 + w
        y2 = y1 + h
        start_p = (int(x1) , int(y1))
        end_p = (int(x2) , int(y2))

        if ann['score'] >= 0.5 or len(gd_anns) == 1:
            if ann['category_id'] == target_cat_id:
                draw_caption(im,(x1, y1, x2, y2),str(round(ann['score'],3)))
                color = (0,255,0)
            else:
                draw_caption(im,(x1, y1, x2, y2),str(round(ann['score'],3)),(0, 255, 255))
                name = dataset_test.cocoHelper.catIdToName(ann['category_id'])[0]
                draw_caption(im,(x1, y1 - 50, x2, y2),name,(0, 255, 0))
                color = (0,0,255)
            cv2.rectangle(im, start_p, end_p, color, thickness=2)
    
    test = 0
    for ann in true.loadAnns(true.getAnnIds(imgIds=[imgId])):
        test += 1
        
        box = ann['bbox']

        x1, y1, w, h = box
        x2 = x1 + w
        y2 = y1 + h
        start_p = (int(x1) , int(y1))
        end_p = (int(x2) , int(y2))

        cv2.rectangle(im, start_p, end_p, (255,0,0), thickness=2)
    print(test)
    plt.figure(figsize=(15,15))
    imshow(im)
    i += 1
    if i == 12:
        break

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import os
import cv2
from matplotlib.pyplot import imshow
import random

# Draws a caption above the box in an image
def draw_caption(image, box, caption, color=(255, 0, 255)):
    b = np.array(box).astype(int)
    cv2.putText(image, caption, ( b[0] , b[1] + int((b[3] - b[1]) / 2) ), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 1)

    
target_cat_id = 6
path = '/home/deeplab307/Documents/Anaconda/Shiang/CL/DataSet/VOC2007/images'

non_target = list(set(pred.getImgIds()) & set(true.getImgIds(catIds=[6])))
random.shuffle(non_target)
#non12 = true.getImgIds(catIds=[catId])

i = 0
flag = False
flag2 = False
for imgId in non_target:
    print(imgId)
    flag = False
    flag2 = False
    im = cv2.imread(os.path.join(path, os.path.join(path, "%06d.jpg" % (int(imgId)))))
    gd_anns = true.loadAnns(true.getAnnIds(imgIds=[imgId]))
    for ann in pred.loadAnns(pred.getAnnIds(imgIds=[imgId])):
        flag2 = True
        #print(ann)
        box = ann['bbox']

        x1, y1, w, h = box
        x2 = x1 + w
        y2 = y1 + h
        start_p = (int(x1) , int(y1))
        end_p = (int(x2) , int(y2))

        if ann['score'] >= 0.5 or len(gd_anns) == 1:
            if ann['category_id'] == catId:
                draw_caption(im,(x1, y1, x2, y2),str(round(ann['score'],3)))
                color = (0,255,0)
            else:
                draw_caption(im,(x1, y1, x2, y2),str(round(ann['score'],3)),(0, 255, 255))
                name = dataset_test.cocoHelper.catIdToName(ann['category_id'])[0]
                draw_caption(im,(x1, y1 - 50, x2, y2),name,(0, 255, 0))
                color = (0,0,255)
            cv2.rectangle(im, start_p, end_p, color, thickness=2)
    
    test = 0
    for ann in true.loadAnns(true.getAnnIds(imgIds=[imgId])):
        test += 1
        
        box = ann['bbox']

        x1, y1, w, h = box
        x2 = x1 + w
        y2 = y1 + h
        start_p = (int(x1) , int(y1))
        end_p = (int(x2) , int(y2))

        cv2.rectangle(im, start_p, end_p, (255,0,0), thickness=2)
    print(test)
    plt.figure(figsize=(15,15))
    imshow(im)
    i += 1
    if i == 12:
        break

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import os
import cv2
from matplotlib.pyplot import imshow
import random

# Draws a caption above the box in an image
def draw_caption(image, box, caption, color=(255, 0, 255)):
    b = np.array(box).astype(int)
    cv2.putText(image, caption, ( b[0] , b[1] + int((b[3] - b[1]) / 2) ), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 1)

    
target_cat_id = 2
path = '/home/deeplab307/Documents/Anaconda/Shiang/CL/DataSet/VOC2007/images'

non_target = list(set(pred.getImgIds()) & set(true.getImgIds(catIds=[target_cat_id])))
random.shuffle(non_target)
#non12 = true.getImgIds(catIds=[catId])

i = 0
flag = False
flag2 = False
for imgId in non_target:
    print(imgId)
    flag = False
    flag2 = False
    im = cv2.imread(os.path.join(path, os.path.join(path, "%06d.jpg" % (int(imgId)))))
    gd_anns = true.loadAnns(true.getAnnIds(imgIds=[imgId]))
    for ann in pred.loadAnns(pred.getAnnIds(imgIds=[imgId])):
        flag2 = True
        #print(ann)
        box = ann['bbox']

        x1, y1, w, h = box
        x2 = x1 + w
        y2 = y1 + h
        start_p = (int(x1) , int(y1))
        end_p = (int(x2) , int(y2))

        if ann['score'] >= 0.5 or len(gd_anns) == 1:
            if ann['category_id'] == 6:
                draw_caption(im,(x1, y1, x2, y2),str(round(ann['score'],3)))
                color = (0,255,0)
            else:
                draw_caption(im,(x1, y1, x2, y2),str(round(ann['score'],3)),(0, 255, 255))
                name = dataset_test.cocoHelper.catIdToName(ann['category_id'])[0]
                draw_caption(im,(x1, y1 - 50, x2, y2),name,(0, 255, 0))
                color = (0,0,255)
            cv2.rectangle(im, start_p, end_p, color, thickness=2)
    
    test = 0
    for ann in true.loadAnns(true.getAnnIds(imgIds=[imgId])):
        test += 1
        
        box = ann['bbox']

        x1, y1, w, h = box
        x2 = x1 + w
        y2 = y1 + h
        start_p = (int(x1) , int(y1))
        end_p = (int(x2) , int(y2))

        cv2.rectangle(im, start_p, end_p, (255,0,0), thickness=2)
    print(test)
    plt.figure(figsize=(15,15))
    imshow(im)
    i += 1
    if i == 12:
        break

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import os
import cv2
from matplotlib.pyplot import imshow
import random

# Draws a caption above the box in an image
def draw_caption(image, box, caption, color=(255, 0, 255)):
    b = np.array(box).astype(int)
    cv2.putText(image, caption, ( b[0] , b[1] + int((b[3] - b[1]) / 2) ), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 1)

    
target_cat_id = 6
path = '/home/deeplab307/Documents/Anaconda/Shiang/CL/DataSet/VOC2007/images'

non_target = list(set(true.getImgIds()) - set(true.getImgIds(catIds=[target_cat_id])))
random.shuffle(non_target)
#non12 = true.getImgIds(catIds=[catId])

i = 0
flag = False
flag2 = False
for imgId in non_target:
    flag = False
    flag2 = False
    for ann in pred.loadAnns(pred.getAnnIds(imgIds=[imgId])):
        if ann['category_id'] == target_cat_id:
            im = cv2.imread(os.path.join(path, os.path.join(path, "%06d.jpg" % (int(imgId)))))
            flag = True
            break
        
    if flag:
        for ann in pred.loadAnns(pred.getAnnIds(imgIds=[imgId])):
            flag2 = True
            #print(ann)
            box = ann['bbox']

            x1, y1, w, h = box
            x2 = x1 + w
            y2 = y1 + h
            start_p = (int(x1) , int(y1))
            end_p = (int(x2) , int(y2))
 
            if ann['category_id'] == target_cat_id:
                draw_caption(im,(x1, y1, x2, y2),str(round(ann['score'],3)))
                color = (0,255,0)
            else:
                draw_caption(im,(x1, y1, x2, y2),str(round(ann['score'],3)),(0, 255, 255))
                color = (0,0,255)
            cv2.rectangle(im, start_p, end_p, color, thickness=2)
    if flag2:
        for ann in true.loadAnns(true.getAnnIds(imgIds=[imgId])):
#             if ann['category_id'] == catId:
            box = ann['bbox']

            x1, y1, w, h = box
            x2 = x1 + w
            y2 = y1 + h
            start_p = (int(x1) , int(y1))
            end_p = (int(x2) , int(y2))
            
            cv2.rectangle(im, start_p, end_p, (255,0,0), thickness=2)
        plt.figure(figsize=(15,15))
        imshow(im)
        i += 1
    if i == 12:
        break

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import os
import cv2
from matplotlib.pyplot import imshow
import random

# Draws a caption above the box in an image
def draw_caption(image, box, caption, color):
    b = np.array(box).astype(int)
    #cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 0), 2)
    cv2.putText(image, caption, (b[0] + int((b[2] - b[0]) /  4), b[1] + int((b[3] - b[1]) / 2) ), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 1)

    
catId = 6
path = '/home/deeplab307/Documents/Anaconda/Shiang/CL/DataSet/VOC2007/images'

non12 = list(set(true.getImgIds()) - set(true.getImgIds(catIds=[catId])))

#non12 = true.getImgIds(catIds=[catId])

i = 0
flag = False
flag2 = False
random.shuffle(non12)
for imgId in non12:
    flag = False
    flag2 = False
    for ann in pred.loadAnns(pred.getAnnIds(imgIds=[imgId])):
        if ann['category_id'] == catId and flag == False:
            im = cv2.imread(os.path.join(path, str(imgId)[:4] + '_' + str(imgId)[4:] +'.jpg'))
            flag = True
            break
        
        
    if flag:
        for ann in pred.loadAnns(pred.getAnnIds(imgIds=[imgId])):
            flag2 = True
            #print(ann)
            box = ann['bbox']

            x1, y1, w, h = box
            x2 = x1 + w
            y2 = y1 + h
            start_p = (int(x1) , int(y1))
            end_p = (int(x2) , int(y2))
 
            if ann['category_id'] == catId:
                draw_caption(im,(x1, y1, x2, y2),str(round(ann['score'],3)), (0,255,0))
                color = (0,255,0)
            else:
                draw_caption(im,(x1, y1, x2, y2),str(round(ann['score'],3)), (0,0,255))
                color = (0,0,255)
            cv2.rectangle(im, start_p, end_p, color, thickness=2)
    if flag2:
        for ann in true.loadAnns(true.getAnnIds(imgIds=[imgId])):
#             if ann['category_id'] == catId:
            box = ann['bbox']

            x1, y1, w, h = box
            x2 = x1 + w
            y2 = y1 + h
            start_p = (int(x1) , int(y1))
            end_p = (int(x2) , int(y2))
            
            cv2.rectangle(im, start_p, end_p, (255,0,0), thickness=2)
        plt.figure(figsize=(15,15))
        imshow(im)
        i+= 1
    if i == 9:
        break

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import os
import cv2
from matplotlib.pyplot import imshow
import random

# Draws a caption above the box in an image
def draw_caption(image, box, caption, color):
    b = np.array(box).astype(int)
    #cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 0), 2)
    cv2.putText(image, caption, (b[0] + int((b[2] - b[0]) /  4), b[1] + int((b[3] - b[1]) / 2) ), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 1)

    
catId = 6
path = '/home/deeplab307/Documents/Anaconda/Shiang/CL/DataSet/VOC2007/images'

# non12 = list(set(true.getImgIds()) - set(true.getImgIds(catIds=[catId])))

#non12 = true.getImgIds(catIds=[catId])

i = 0
flag = False
flag2 = False
# random.shuffle(non12)
non12 = visual_img
for imgId in non12:
    flag = False
    flag2 = False
    for ann in pred.loadAnns(pred.getAnnIds(imgIds=[imgId])):
        if ann['category_id'] == catId and flag == False:
            im = cv2.imread(os.path.join(path, str(imgId)[:4] + '_' + str(imgId)[4:] +'.jpg'))
            flag = True
            break
        
        
    if flag:
        for ann in pred.loadAnns(pred.getAnnIds(imgIds=[imgId])):
            flag2 = True
            #print(ann)
            box = ann['bbox']

            x1, y1, w, h = box
            x2 = x1 + w
            y2 = y1 + h
            start_p = (int(x1) , int(y1))
            end_p = (int(x2) , int(y2))
 
            if ann['category_id'] == catId:
                draw_caption(im,(x1, y1, x2, y2),str(round(ann['score'],3)), (0,255,0))
                color = (0,255,0)
            else:
                draw_caption(im,(x1, y1, x2, y2),str(round(ann['score'],3)), (0,0,255))
                color = (0,0,255)
            cv2.rectangle(im, start_p, end_p, color, thickness=2)
    if flag2:
        for ann in true.loadAnns(true.getAnnIds(imgIds=[imgId])):
#             if ann['category_id'] == catId:
            box = ann['bbox']

            x1, y1, w, h = box
            x2 = x1 + w
            y2 = y1 + h
            start_p = (int(x1) , int(y1))
            end_p = (int(x2) , int(y2))
            
            cv2.rectangle(im, start_p, end_p, (255,0,0), thickness=2)
        plt.figure(figsize=(15,15))
        imshow(im)
        i+= 1
    if i == 9:
        break

In [None]:
import torch

In [None]:
test = torch.ones(4,20)
test[0,0] = 0
test[1,0] = 0

test[3,10] = 0
print(test)
test = torch.mean(test, 0)


In [None]:
torch.zeros(test.shape).shape

In [None]:
torch.stack([test,test]).mean(dim=0)