In [15]:
import cv2
import numpy as np
import dlib
import time
import uuid
from facenet_pytorch import MTCNN, InceptionResnetV1
import torch 
import torchvision.transforms as transforms
from PIL import Image

In [16]:
device = torch.device('cpu' if torch.cuda.is_available() else 'cpu')
resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)
MaxDist=100
mtcnn = MTCNN(
    image_size=160, margin=0, min_face_size=20,
    thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
    device=device
)
print('Running on device: {}'.format(device))

Running on device: cpu


In [17]:
database = dict()

In [18]:
trans = transforms.ToTensor()

Mahmud = cv2.imread("4.jpg")
pil_frame = cv2.cvtColor(Mahmud,cv2.COLOR_BGR2RGB)
pil_frame = Image.fromarray(pil_frame)
bbox_array,confidence = mtcnn.detect(pil_frame)
(sX,sY,eX,eY) = bbox_array[0]
h = eY - sY
w = eX - sX
Mahmud = Mahmud[int(sY):int(sY)+int(h), int(sX):int(sX)+int(w),:]
Mahmud = cv2.resize(Mahmud,(160,160))
Mahmud = trans(Mahmud)
Mahmud = torch.unsqueeze(Mahmud,0).to(device)

Mahmud2 = cv2.imread("3.jpg")
pil_frame = cv2.cvtColor(Mahmud2,cv2.COLOR_BGR2RGB)
pil_frame = Image.fromarray(pil_frame)
bbox_array,confidence = mtcnn.detect(pil_frame)
(sX,sY,eX,eY) = bbox_array[0]
h = eY - sY
w = eX - sX
Mahmud2 = Mahmud2[int(sY):int(sY)+int(h), int(sX):int(sX)+int(w),:]
Mahmud2 = cv2.resize(Mahmud2,(160,160))
Mahmud2 = trans(Mahmud2)
Mahmud2 = torch.unsqueeze(Mahmud2,0).to(device)

Monsur = cv2.imread("1.jpeg")
pil_frame = cv2.cvtColor(Monsur,cv2.COLOR_BGR2RGB)
pil_frame = Image.fromarray(pil_frame)
bbox_array,confidence = mtcnn.detect(pil_frame)
(sX,sY,eX,eY) = bbox_array[0]
h = eY - sY
w = eX - sX
Monsur = Monsur[int(sY):int(sY)+int(h), int(sX):int(sX)+int(w),:]
Monsur = cv2.resize(Monsur,(160,160))
Monsur = trans(Monsur)
Monsur = torch.unsqueeze(Monsur,0).to(device)

In [19]:
Mahmud = resnet(Mahmud)
Monsur = resnet(Monsur)
Mahmud2 = resnet(Mahmud2)

database[Mahmud] = "Mahmud"
database[Monsur] = "Monsur"

In [20]:
class Track():
    def __init__(self,bbox,embeddings,confidence,tracker):
        self.box=bbox
        self.embeddings=embeddings
        self.confidence=confidence
        self.recognized=False
        self.name=None
        self.tracker=tracker
        
        
    def start_track(self,frame,drect):
        self.tracker.start_track(frame,drect)
        
    def update_track(self,new_box=None,new_confidence=None,new_label=None,_tracker=None):
        self.box=new_box
        self.tracker=_tracker if _tracker is not None else self.tracker 
        self.centroid=self.get_centroid()
        
    def checkDatabase(self):
        global database,frame
        flag = 0
        (sX,sY,eX,eY) = self.box
        
        if sX<0 or eX>640 or sY<0 or eY>480:
            return
            
        h = eY - sY
        w = eX - sX
        opencv_frame = frame[int(sY):int(sY)+int(h), int(sX):int(sX)+int(w),:]
        opencv_frame = cv2.resize(opencv_frame,(160,160))
        t1 = transforms.ToTensor()
        crop_img = t1(opencv_frame)
        crop_img = torch.unsqueeze(crop_img, 0)
        crop_img = crop_img.permute((0, 1, 2, 3)).to(device)
        self.embeddings = resnet(crop_img)
        mindist = 100
        name = ""
        for key,value in database.items():
            dist = (self.embeddings - key).norm().item()
            if dist<mindist:
                mindist = dist
                name = value
                self.name = name
        if mindist<1:
            flag = 1
            self.name = name
            
        return flag
    def get_centroid(self):
        return (int((self.box[0]+self.box[2])/2),int((self.box[1]+self.box[3])/2))
    

In [21]:
class ObjectCount():
    def __init__(self):
        self.gone_count = dict()
            
    def increment_gone_count(self,tracker):
        self.gone_count[tracker]=self.gone_count.get(tracker,0)+1
    
    def get_gone_count(self,tracker):
        return self.gone_count.get(tracker,0)
    
    def set_gone_count(self,tracker):
        self.gone_count[tracker] = 0

In [22]:
def delete(trackers,rects):
    t = None
    minDist = -100
    for tracker in trackers:
        centX,centY = tracker.get_centroid() 
        for rect in rects:
            (sX,sY,eX,eY)=rect
            centX_r=int((sX+eX)/2)
            centY_r=int((sY+eY)/2)
            manDist=abs(centX-centX_r)+abs(centY-centY_r)
            
            if manDist>minDist:
                minDist = manDist
                t = tracker
    trackers.remove(t)
    return trackers

In [23]:
cap = cv2.VideoCapture(0)
val = "Not recognized"
frame_skip = 5
frame_count = 0
trackers = []
fps = cap.get(cv2.CAP_PROP_FPS)
object_counter = ObjectCount()
out = cv2.VideoWriter('output_1hr_vid.mp4',cv2.VideoWriter_fourcc('M','J','P','G'), fps, (640,480))
while True:
    
    ret,frame = cap.read()
    
    if not ret:
        print("Video ended")
        break
    pil_frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
    pil_frame = Image.fromarray(pil_frame)

    if frame_count % frame_skip ==0:
        bbox_array,confidence = mtcnn.detect(pil_frame)
        
        if bbox_array is not None:
            bbox_confidence_list = list(zip(bbox_array,confidence))
        
            for (rect,conf) in bbox_confidence_list:
                (sX,sY,eX,eY)=rect
                centX_r=int((sX+eX)/2)
                centY_r=int((sY+eY)/2)
                flag=0
                for i in trackers:

                    centX,centY = i.get_centroid() 


                    manDist=abs(centX-centX_r)+abs(centY-centY_r) #Calculating manhattan distance
                    if manDist<MaxDist:
                        flag=1
                        break

                if flag!=1:
#                     if sX<0 or eX>640 or sY<0 or eY>480:
#                         continue
#                     h = eY - sY
#                     w = eX - sX
#                     opencv_frame = frame[int(sY):int(sY)+int(h), int(sX):int(sX)+int(w),:]
#                     opencv_frame = cv2.resize(opencv_frame,(160,160))
#                     t1 = transforms.ToTensor()
#                     crop_img = t1(opencv_frame)
#                     crop_img = torch.unsqueeze(crop_img, 0)
#                     crop_img = crop_img.permute((0, 1, 2, 3)).to(device)
                    t = dlib.correlation_tracker()
                    drect = dlib.rectangle(int(sX), int(sY), int(eX), int(eY))
                    box=[int(sX),int(sY),int(eX),int(eY)]
                    embeddings = None
                    tracker = Track(box,embeddings,conf,t)
                    tracker.start_track(frame, drect)
                    trackers.append(tracker)
                    object_counter.set_gone_count(tracker)
#                     tracker.checkDatabase()
            if cv2.waitKey(40) & 0xFF == ord('r'):
                    for tracker in trackers:
                        flag2 = tracker.checkDatabase()
                        if flag2 == 1:
                            val = "Recognized"
                            
            if len(trackers)>len(bbox_array):
                trackers = delete(trackers,bbox_array)
    
    for t in trackers:
#         trackers,bool_flag = delete(trackers,t)
#         if bool_flag:
#             continue
        
        sc = t.tracker.update(frame)
        pos = t.tracker.get_position()
        startX = int(pos.left())
        startY = int(pos.top())
        endX = int(pos.right())
        endY = int(pos.bottom())

        box=[startX,startY,endX,endY]
        t.update_track(box)

        centX,centY = t.get_centroid()
        if endX>640 or startX<0: # If an object crosses the thresh line and the tracker value(sc) is below 12 remove it
            object_counter.increment_gone_count(t)

        # If gone count >= 5 update the object dictionary and remove the tracker
        if object_counter.get_gone_count(t)>=5:
            trackers.remove(t)
            continue

#         cv2.circle(frame,(centX,centY),10 , (0,0,255), -1)
        cv2.rectangle(frame, (startX, startY), (endX, endY),(0, 255, 0), 2)
        cv2.putText(frame, t.name, (centX,centY),cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
        
    if val == "Not recognized":
        cv2.putText(frame, val, (50,450),cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
    else:
        cv2.putText(frame, val, (50,450),cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
    cv2.imshow("Frame",frame)
    out.write(frame)
    frame_count+=1
    if cv2.waitKey(20) & 0xFF == ord('q'):
        break
        
cap.release()
out.release()
cv2.destroyAllWindows()

In [None]:
euclidena_dist = torch.cdist(Mahmud,Mahmud2)**2
print(euclidena_dist.item())