In [1]:
import torch
import sys
import shutil
import inspect
from collections import OrderedDict,defaultdict
import yaml
import time
from feeders import tools
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from ultralytics import YOLO
import os

In [2]:
def xywh2xyxy(bbox):
    xc,yc,w,h = bbox
    x1 = xc - (w/2)
    y1 = yc - (h/2)
    x2 = xc + (w/2)
    y2 = yc + (h/2)
    return [x1,y1,x2,y2]

def xyxy2xywh(bbox):
    x1,y1,x2,y2 = bbox
    xc = (x1+x2)/2
    yc = (y1+y2)/2
    w = x2-x1
    h = y2-y1
    return [xc,yc,w,h]


In [3]:
class PoseTrackV8(object):
    def __init__(self,weightPath = '/home/k100/Code/weights/yolov8l-pose.pt',
                      device = 'cuda:0'):
        self.model = YOLO(weightPath)
        self.device = device
    def poseTrack(self, imgs,device=None):
        return self.model.track(imgs,
                                stream=False,
                                device=self.device,
                                tracker="bytetrack.yaml",
                                persist=True,
                                verbose=False) 
class KptTrack(object):
    def __init__(self,gg=None):
        self.gg = gg
        self.track_history = defaultdict(lambda: []) 
        self.drop_counting = defaultdict(lambda: 0) 
        self.max_miss = 500
    def tracking(self,trackIDs,kpts):
        diff = list(set(list(set(self.track_history.keys()))).difference(trackIDs))
        for d in diff:
            if self.drop_counting[d] > self.max_miss:
                del self.drop_counting[d]
                del self.track_history[d]
            else:
                self.drop_counting[d]+=1
        trackIDsNew = [] ; kptSeqs = []
        for trackID,kpt in zip(trackIDs,kpts):
            track = self.track_history[trackID]
            track.append(kpt)
            trackIDsNew.append(trackID)
            kptSeqs.append(torch.stack(track))
        return trackIDsNew,kptSeqs

    def sameShaper(self,kptSeqsList):
        if isinstance(kptSeqsList[0],torch.Tensor):
            kptSeqsList = [kptSeq.numpy() for kptSeq in kptSeqsList]
        if len(kptSeqsList) == 1:
            return np.array(kptSeqsList)
        min_len = np.min([len(kptSeq) for kptSeq in kptSeqsList])
        # for kptSeq in kptSeqsList:
            # print(kptSeq.shape,min_len) 
        return np.array([kptSeq[-min_len:] for kptSeq in kptSeqsList])

    def getBBOX_from_kpt(self,kpt,outsizeW=2,outsizeH=1.5):
        kptNp = kpt.copy()#.numpy()
        mask = kptNp[..., 2] < 0.5
        min_xy = kptNp[..., :2].max(axis=(0, 1))

        kptNp[mask, :2] = min_xy
        x1 = np.min(kptNp[:,:,0])
        y1 = np.min(kptNp[:,:,1])
        x2 = np.max(kptNp[:,:,0])
        y2 = np.max(kptNp[:,:,1])
        # print(x1,y1,x2,y2)
        xc,yc,w,h = xyxy2xywh([x1,y1,x2,y2])
        w *= outsizeW
        h *= outsizeH
        x1,y1,x2,y2 = xywh2xyxy([xc,yc,w,h])
        bbox = np.array([x1,y1,x2,y2]).astype(int)
        bbox[bbox<0] = 0
        return bbox
        
class MotionRecognizeSkateFormerRGB(object):
    def __init__(self,ckptPath = './work_dir/ntu/cs/SkateFormerRGB_j/runs-last_model_Epoch90_acc98.pt',
                      modelType = 'model.SkateFormer.SkateFormerRGB_',
                      device = 'cuda:0',
                      clsNamePath = './nturgbd120_cls.txt'
                ):
        
        self.model_args = {'num_classes': 120, 'num_people': 2, 'num_points': 17,
                      'kernel_size': 7, 'num_heads': 32, 'attn_drop': 0.5,
                      'head_drop': 0.0, 'rel': True, 'drop_path': 0.2, 
                      'type_1_size': [8, 2], 'type_2_size': [8, 17],
                      'type_3_size': [8, 2], 'type_4_size': [8, 17],
                      'mlp_ratio': 4.0, 'index_t': True}
        
        self.ckptPath = ckptPath
        self.modelType = modelType 
        self.device = device
        self.load_model()
        self.maxPersonNum = 2
        self.kptNum = 17
        self.maxFrameNum = 100
        self.dim = 3
        ## (batch, 100, 102)
        self.p_interval = [0.95]
        self.window_size = 64
        self.thres = 64
        self.clsNamePath = clsNamePath
        self.load_clsNameTxt()
    def load_clsNameTxt(self):
        if os.path.exists(self.clsNamePath):
            with open(self.clsNamePath,'r') as f:
                self.clsList = [line[:-1] for line in f.readlines()]
        else:
            print(f'clsNamePath ({self.clsNamePath}) is not exists so skip')
            self.clsList = None
    def cropPreprocess(self,crop):
        cropH,cropW,_ = crop.shape
        assert cropH==256 and cropW==128
        crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB).astype(float)/255
        return torch.from_numpy(np.transpose(crop, (2, 0, 1)))#.unsqueeze(0)

    def import_class(self,import_str):
        mod_str, _sep, class_str = import_str.rpartition('.')
        __import__(mod_str)
        try:
            return getattr(sys.modules[mod_str], class_str)
        except AttributeError:
            raise ImportError('Class %s cannot be found (%s)' % (class_str, traceback.format_exception(*sys.exc_info())))

    def load_model(self):
        print('[modelType] : ', self.modelType)
        Model = self.import_class(self.modelType)
        self.model = Model(**self.model_args)
        weights = torch.load(self.ckptPath)
        weights = OrderedDict([[k.split('module.')[-1], v.to(self.device)] for k, v in weights.items()])
        keys = list(weights.keys())
        self.model.load_state_dict(weights)
        self.model.to(self.device)

    def normalization(self,data,frameShape=(1080,1920)):
        height,width = frameShape
        data[...,0]/=width
        data[...,1]/=height
        return data

    
    def kptsPreprocess(self,kptSeqs,img):
        imgSacle = img.shape[:-1]
        kpts = [kptSeq[-1,:,:] for kptSeq in kptSeqs]
        bboxs = [KptTracker.getBBOX_from_kpt(kpt[:,:,:].numpy()) for kpt in kpts]
        # print('bbox len : ', len(bboxs))
        crops = torch.stack([self.cropPreprocess(cv2.resize(img[int(y1):int(y2),int(x1):int(x2)],(128,256))) 
                             for x1,y1,x2,y2 in bboxs],dim=0)
        
        kptsN = [self.normalization(kptSeq.reshape(kptSeq.shape[1],51*kptSeq.shape[0]),frameShape=imgSacle) for kptSeq in kptSeqs]
        kptsNN = np.zeros((len(kptsN), self.maxFrameNum, self.kptNum*self.maxPersonNum*self.dim), dtype=np.float32)
        for dIdx,kptN in enumerate(kptsN):
            if kptN.shape[-2] > self.maxFrameNum:
                kptN = kptN[-self.maxFrameNum,:]
            kptsNN[dIdx, :kptN.shape[0],:kptN.shape[-1]] = kptN
        return kptsNN,crops,bboxs

    def inference(self,Inputdata,Inputcrops):
        inputdata = Inputdata.reshape((len(Inputcrops), self.maxFrameNum, 2, 17, 3)).transpose(0, 4, 1, 3, 2)
        t_indexs = np.empty((len(inputdata),64))#,dtype=float
        datas = np.empty((len(inputdata),3,64,17,2))#,dtype=float
        t_indexs = np.empty((len(inputdata),64))#,dtype=float
        valid_frame_nums = np.sum(np.sum(np.sum(inputdata, axis=-1), axis=-1) != 0, axis=-1)[...,0]

        for i in range(len(inputdata)):
            data,t_index = tools.valid_crop_uniform(inputdata[i], 
                                                    valid_frame_nums[i], 
                                                    self.p_interval, 
                                                    self.window_size, 
                                                    self.thres)
            datas[i] = data
            t_indexs[i] = t_index
        datas = torch.from_numpy(datas).float().to(self.device)
        t_indexs = torch.from_numpy(t_indexs).float().to(self.device)
        Inputcrops = Inputcrops.float().to(self.device)

        with torch.no_grad():
            preds = self.model(datas,t_indexs,Inputcrops)
        preds = torch.nn.functional.softmax(preds, dim=0)
        scores = torch.max(preds,dim=1).values.cpu()
        prs = torch.argmax(preds,dim=1).cpu()

        return prs,scores
        
    def drawBoundingBox(self,img, bboxs, actIs, actSs,fontScale=1,thickness=2):
        for box,actI,actS in zip(bboxs,actIs,actSs):
            x1,y1,x2,y2 = box[:4]
            actI = self.clsList[actI] if self.clsList is not None else actI
            actS = int(actS*100)
            label = f'{actI}({actS}%)'
            cv2.rectangle(img, (x1,y1), (x2,y2), (0,255,0), 6)
            fontFace = cv2.FONT_HERSHEY_COMPLEX
            labelSize = cv2.getTextSize(label, fontFace, fontScale, thickness)
            _x1 = x1 # bottomleft x of text
            _y1 = y1 # bottomleft y of text
            _x2 = x1+labelSize[0][0] # topright x of text
            _y2 = y1-labelSize[0][1] # topright y of text
            cv2.rectangle(img, (_x1,_y1), (_x2,_y2), (0,255,0), cv2.FILLED) # text background
            cv2.putText(img, label, (x1,y1), fontFace, fontScale, (0,0,0), thickness)
        return img
        
        
    def drawRecongizeResult(self, img, bboxs, actIs, actSs,fontScale=1,thickness=2):
        return self.drawBoundingBox(img.copy(),bboxs,actIs, actSs,fontScale=1,thickness=2)
        
class VideoCaptureSave(object):
    def __init__(self):
        self.img_list = []
        self.fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # MP4 codec
        self.fps = 30  # Default frames per second
        self.frame_width = None
        self.frame_height = None
        self.video_writer = None

    def init_parameter(self, frame_width, frame_height, output_file, fps=30):
        """Initialize parameters and video capture."""
        self.fps = fps
        
        # Get frame dimensions
        self.frame_width = int(frame_width)
        self.frame_height = int(frame_height)
        self.output_file = output_file
        # Initialize video writer
        self.video_writer = cv2.VideoWriter(output_file, self.fourcc, self.fps, (self.frame_width, self.frame_height))

    def addImg2list(self,img):
        #self.img_list.append(frame)
        self.img_list.append(cv2.resize(img,(self.frame_width,self.frame_height)))

    def saveImgList2mp4(self):
        """Save the images in the list to an MP4 file."""
        print(f'😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄 {len(self.img_list)}')
        for img in self.img_list:
            self.video_writer.write(cv2.resize(img,(self.frame_width,self.frame_height)))

        # Release the video writer
        self.video_writer.release()
        print(f"Video saved successfully.  ---> {self.output_file}")

In [4]:
MotionRecognizeSkateFormerRGBer = MotionRecognizeSkateFormerRGB()
KptTracker = KptTrack()
PoseTracker = PoseTrackV8()
VideoCaptureSaver = VideoCaptureSave()

[modelType] :  model.SkateFormer.SkateFormerRGB_


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [5]:
videoPath = './samples/ntu_sample.avi'
# videoPath = './samples/vlc-record-2024-04-24-15h03m46s-0x645A1775_20230509183809_20230509185044.avi-.mp4'
cap = cv2.VideoCapture(videoPath)

outs = []
count = 0
while True:
    ret,img = cap.read()
    if not ret:
        break
    out = PoseTracker.poseTrack(img)[0]
    outs.append(out)
    count +=1

In [6]:
VideoCaptureSaver.init_parameter(960, 540, 'out.mp4')
for out in tqdm(outs):
    img = out.orig_img
    imgH,imgW,_ = img.shape
    trackIDs = out.boxes.id
    trackIDs = trackIDs.cpu().numpy() if trackIDs is not None else np.array([])
    keypoints = out.keypoints.data.cpu() if len(trackIDs) else torch.tensor([])
    # keypoints[...,0]*=(1920/imgW)
    # keypoints[...,1]*=(1080/imgH)
    trackIDs,kptSeqs = KptTracker.tracking(trackIDs,keypoints)
    kptsFeatures,crops,bboxes = MotionRecognizeSkateFormerRGBer.kptsPreprocess([kptSeq.unsqueeze(0) 
                                                                         for kptSeq in kptSeqs],img)
    preds,scores = MotionRecognizeSkateFormerRGBer.inference(kptsFeatures,crops)
    img_plot = MotionRecognizeSkateFormerRGBer.drawRecongizeResult(img,bboxes,preds,scores,
                                                                   fontScale=1*(imgW/1920),thickness=int(2*(imgW/1920)))
    VideoCaptureSaver.addImg2list(img_plot)
VideoCaptureSaver.saveImgList2mp4()

100%|███████████████████████████████████████████████████████████████████████████████████████████████| 72/72 [00:02<00:00, 33.45it/s]


😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄😄 72
Video saved successfully.  ---> out.mp4


In [7]:
result = PoseTracker.poseTrack('sample.png')[0]
h,w,_ = result.orig_img.shape
result.orig_img = np.ones((h,w,3),dtype='uint8')*255
cv2.imwrite('sample_pose.png',result.plot(labels=False,boxes=False))

True