In [1]:
import numpy as np
from scipy.optimize import linear_sum_assignment
from PIL import Image
import torchvision.transforms as transforms
import torch

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def compute_scores(img_one, img_two, model):
    transform = transforms.Compose([transforms.PILToTensor(), transforms.Resize((224, 224))])
    img_one = img_one.convert('RGB')
    tensor_one = transform(img_one).to(device)
    emb_one = model(tensor_one.unsqueeze(0)).last_hidden_state[:, 0]
    img_two = img_two.convert('RGB')
    tensor_two = transform(img_two).to(device)
    emb_two = model(tensor_two.unsqueeze(0)).last_hidden_state[:, 0]
    scores = torch.nn.functional.cosine_similarity(emb_one, emb_two)
    return scores.tolist()[0]

In [None]:
def id_f1(preds, true):
    tp, fp, fn = 0, 0, 0
    for i in range(len(preds)):
        for j in range(len(preds[i]['data'])):
            if len(preds[i]['data'][j]['bounding_box']) == 0:
                continue
            if not preds[i]['data'][j]['track_id']:
                fn += 1
            elif preds[i]['data'][j]['track_id'] == true[i]['data'][j]['cb_id']:
                tp += 1
            else:
                fp += 1
    return 2 * tp / (2 * tp + fp + fn)

In [None]:
class KalmanFilter(object):
    """Kalman Filter class keeps track of the estimated state of
    the system and the variance or uncertainty of the estimate.
    Predict and Correct methods implement the functionality
    Reference: https://en.wikipedia.org/wiki/Kalman_filter
    Attributes: None
    """

    def __init__(self):
        """Initialize variable used by Kalman Filter class
        Args:
            None
        Return:
            None
        """
        self.dt = 0.005  # delta time

        self.A = np.array([[1, 0], [0, 1]])  # matrix in observation equations
        self.u = np.zeros((2, 1))  # previous state vector

        # (x,y) tracking object center
        self.b = np.array([[0], [255]])  # vector of observations

        self.P = np.diag((3.0, 3.0))  # covariance matrix
        self.F = np.array([[1.0, self.dt], [0.0, 1.0]])  # state transition mat

        self.Q = np.eye(self.u.shape[0])  # process noise matrix
        self.R = np.eye(self.b.shape[0])  # observation noise matrix
        self.lastResult = np.array([[0], [255]])

    def predict(self):
        """Predict state vector u and variance of uncertainty P (covariance).
            where,
            u: previous state vector
            P: previous covariance matrix
            F: state transition matrix
            Q: process noise matrix
        Equations:
            u'_{k|k-1} = Fu'_{k-1|k-1}
            P_{k|k-1} = FP_{k-1|k-1} F.T + Q
            where,
                F.T is F transpose
        Args:
            None
        Return:
            vector of predicted state estimate
        """
        # Predicted state estimate
        self.u = np.round(np.dot(self.F, self.u))
        # Predicted estimate covariance
        self.P = np.dot(self.F, np.dot(self.P, self.F.T)) + self.Q
        self.lastResult = self.u  # same last predicted result
        return self.u

    def correct(self, b, flag):
        """Correct or update state vector u and variance of uncertainty P (covariance).
        where,
        u: predicted state vector u
        A: matrix in observation equations
        b: vector of observations
        P: predicted covariance matrix
        Q: process noise matrix
        R: observation noise matrix
        Equations:
            C = AP_{k|k-1} A.T + R
            K_{k} = P_{k|k-1} A.T(C.Inv)
            u'_{k|k} = u'_{k|k-1} + K_{k}(b_{k} - Au'_{k|k-1})
            P_{k|k} = P_{k|k-1} - K_{k}(CK.T)
            where,
                A.T is A transpose
                C.Inv is C inverse
        Args:
            b: vector of observations
            flag: if "true" prediction result will be updated else detection
        Return:
            predicted state vector u
        """

        if not flag:  # update using prediction
            self.b = self.lastResult
        else:  # update using detection
            self.b = b
        C = np.dot(self.A, np.dot(self.P, self.A.T)) + self.R
        K = np.dot(self.P, np.dot(self.A.T, np.linalg.inv(C)))

        self.u = np.round(self.u + np.dot(K, (self.b - np.dot(self.A,
                                                              self.u))))
        self.P = self.P - np.dot(K, np.dot(C, K.T))
        self.lastResult = self.u
        return self.u

In [None]:
class Track(object):
    """Track class for every object to be tracked
    Attributes:
        None
    """

    def __init__(self, prediction, trackIdCount):
        """Initialize variables used by Track class
        Args:
            prediction: predicted centroids of object to be tracked
            trackIdCount: identification of each track object
        Return:
            None
        """
        self.track_id = trackIdCount  # identification of each track object
        self.KF = KalmanFilter()  # KF instance to track this object
        self.prediction = np.asarray(prediction)  # predicted centroids (x,y)
        self.skipped_frames = 0  # number of frames skipped undetected
        self.trace = []  # trace path


class Tracker(object):
    """Tracker class that updates track vectors of object tracked
    Attributes:
        None
    """

    def __init__(self, dist_thresh, max_frames_to_skip, max_trace_length,
                 trackIdCount, model):
        """Initialize variable used by Tracker class
        Args:
            dist_thresh: distance threshold. When exceeds the threshold,
                         track will be deleted and new track is created
            max_frames_to_skip: maximum allowed frames to be skipped for
                                the track object undetected
            max_trace_lenght: trace path history length
            trackIdCount: identification of each track object
        Return:
            None
        """
        self.dist_thresh = dist_thresh
        self.max_frames_to_skip = max_frames_to_skip
        self.max_trace_length = max_trace_length
        self.tracks = []
        self.trackIdCount = trackIdCount
        self.model = model

    def Update(self, detections):
        """Update tracks vector using following steps:
            - Create tracks if no tracks vector found
            - Calculate cost using sum of square distance
              between predicted vs detected centroids
            - Using Hungarian Algorithm assign the correct
              detected measurements to predicted tracks
              https://en.wikipedia.org/wiki/Hungarian_algorithm
            - Identify tracks with no assignment, if any
            - If tracks are not detected for long time, remove them
            - Now look for un_assigned detects
            - Start new tracks
            - Update KalmanFilter state, lastResults and tracks trace
        Args:
            detections: detected centroids of object to be tracked
        Return:
            None
        """

        # Create tracks if no tracks vector found
        if (len(self.tracks) == 0):
            for i in range(len(detections)):
                track = Track(detections[i]['center'], self.trackIdCount)
                self.trackIdCount += 1
                self.tracks.append(track)

        # Calculate cost using sum of square distance between
        # predicted vs detected centroids
        N = len(self.tracks)
        M = len(detections)
        cost = np.zeros(shape=(N, M))   # Cost matrix
        for i in range(len(self.tracks)):
            for j in range(len(detections)):
                try:
                    bbox = detections[j]['bbox']
                    img_one = Image.open(detections[j]['img'])
                    w, h = bbox[2] - bbox[0], bbox[3] - bbox[1]
                    img_w, img_h = img_one.size
                    img_one = img_one.crop((
                        int(self.tracks[i].prediction[0] - w / 2),
                        int(self.tracks[i].prediction[1] - h / 2),
                        int(self.tracks[i].prediction[0] + w / 2),
                        int(self.tracks[i].prediction[1] + h / 2),
                    ))
                    img_two = Image.open(detections[j]['img'])
                    img_two = img_two.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
                    lam = 0.3
                    cos_sim = compute_scores(img_one, img_two, self.model)
                    diff = self.tracks[i].prediction - detections[j]['center']
                    diff[0] = diff[0] / img_w
                    diff[1] = diff[1] / img_h
                    distance = np.sqrt(diff[0]**2 + diff[1]**2)
                    cost[i][j] = lam * distance + (1 - lam) * (1 - cos_sim)
                except:
                    pass

        # Let's average the squared ERROR
        # print(cost)
        # cost = (0.5) * cost
        # Using Hungarian Algorithm assign the correct detected measurements
        # to predicted tracks
        assignment = []
        for _ in range(N):
            assignment.append(-1)
        row_ind, col_ind = linear_sum_assignment(cost)
        for i in range(len(row_ind)):
            assignment[row_ind[i]] = col_ind[i]

        # Identify tracks with no assignment, if any
        un_assigned_tracks = []
        for i in range(len(assignment)):
            if (assignment[i] != -1):
                # check for cost distance threshold.
                # If cost is very high then un_assign (delete) the track
                if (cost[i][assignment[i]] > self.dist_thresh):
                    assignment[i] = -1
                    un_assigned_tracks.append(i)
                pass
            else:
                self.tracks[i].skipped_frames += 1

        # If tracks are not detected for long time, remove them
        del_tracks = []
        for i in range(len(self.tracks)):
            if (self.tracks[i].skipped_frames > self.max_frames_to_skip):
                del_tracks.append(i)
        if len(del_tracks) > 0:  # only when skipped frame exceeds max
            for id in del_tracks:
                if id < len(self.tracks):
                    del self.tracks[id]
                    del assignment[id]
                else:
                    print("ERROR: id is greater than length of tracks")

        # Now look for un_assigned detects
        un_assigned_detects = []
        for i in range(len(detections)):
                if i not in assignment:
                    un_assigned_detects.append(i)

        # Start new tracks
        if(len(un_assigned_detects) != 0):
            for i in range(len(un_assigned_detects)):
                track = Track(detections[un_assigned_detects[i]]['center'],
                              self.trackIdCount)
                self.trackIdCount += 1
                self.tracks.append(track)

        # Update KalmanFilter state, lastResults and tracks trace
        for i in range(len(assignment)):
            self.tracks[i].KF.predict()

            if(assignment[i] != -1):
                self.tracks[i].skipped_frames = 0
                self.tracks[i].prediction = self.tracks[i].KF.correct(
                                            detections[assignment[i]]['center'], 1)
            else:
                self.tracks[i].prediction = self.tracks[i].KF.correct(
                                            np.array([[0], [0]]), 0)

            if(len(self.tracks[i].trace) > self.max_trace_length):
                for j in range(len(self.tracks[i].trace) -
                               self.max_trace_length):
                    del self.tracks[i].trace[j]

            self.tracks[i].trace.append(self.tracks[i].prediction)
            self.tracks[i].KF.lastResult = self.tracks[i].prediction

In [None]:
country_balls_amount = 5
track_data = [{'frame_id': 1, 'data': [{'cb_id': 0, 'bounding_box': [953, -106, 1065, -4], 'x': 997, 'y': 0, 'track_id': None}, {'cb_id': 1, 'bounding_box': [-71, 381, 80, 513], 'x': 0, 'y': 493, 'track_id': None}, {'cb_id': 2, 'bounding_box': [-41, -100, 70, -19], 'x': 0, 'y': 0, 'track_id': None}, {'cb_id': 3, 'bounding_box': [20, 703, 141, 786], 'x': 92, 'y': 799, 'track_id': None}, {'cb_id': 4, 'bounding_box': [905, 720, 1031, 801], 'x': 965, 'y': 800, 'track_id': None}]}, {'frame_id': 2, 'data': [{'cb_id': 0, 'bounding_box': [957, -101, 1053, 18], 'x': 997, 'y': 4, 'track_id': None}, {'cb_id': 1, 'bounding_box': [-55, 389, 77, 468], 'x': 23, 'y': 480, 'track_id': None}, {'cb_id': 2, 'bounding_box': [], 'x': 19, 'y': 52, 'track_id': None}, {'cb_id': 3, 'bounding_box': [31, 670, 140, 778], 'x': 88, 'y': 787, 'track_id': None}, {'cb_id': 4, 'bounding_box': [898, 608, 984, 726], 'x': 944, 'y': 722, 'track_id': None}]}, {'frame_id': 3, 'data': [{'cb_id': 0, 'bounding_box': [931, -88,
1037, 19], 'x': 997, 'y': 8, 'track_id': None}, {'cb_id': 1, 'bounding_box': [-17, 376, 94, 462], 'x': 47, 'y': 467, 'track_id': None}, {'cb_id': 2, 'bounding_box': [], 'x': 38, 'y': 103, 'track_id': None}, {'cb_id': 3, 'bounding_box': [], 'x': 84, 'y': 774, 'track_id': None}, {'cb_id': 4, 'bounding_box': [878, 569, 976, 638], 'x': 923, 'y': 651, 'track_id': None}]}, {'frame_id': 4, 'data': [{'cb_id': 0, 'bounding_box': [927, -105, 1056, 14], 'x': 997, 'y': 12, 'track_id': None}, {'cb_id': 1, 'bounding_box': [11, 363, 140, 451], 'x': 70, 'y': 456, 'track_id': None}, {'cb_id': 2, 'bounding_box': [12, 40, 134, 142], 'x': 57, 'y': 152, 'track_id': None}, {'cb_id': 3, 'bounding_box': [27, 664, 135, 777], 'x': 80, 'y': 761, 'track_id': None}, {'cb_id': 4, 'bounding_box':
[835, 484, 981, 605], 'x': 902, 'y': 588, 'track_id': None}]}, {'frame_id': 5, 'data': [{'cb_id': 0, 'bounding_box': [942, -84, 1075, 21], 'x': 997, 'y': 16, 'track_id': None}, {'cb_id': 1, 'bounding_box': [38, 335, 167, 428], 'x': 94, 'y': 445, 'track_id': None}, {'cb_id': 2, 'bounding_box': [], 'x': 76, 'y': 199, 'track_id': None}, {'cb_id': 3, 'bounding_box': [29, 628, 123, 733], 'x': 76, 'y': 747, 'track_id': None}, {'cb_id': 4, 'bounding_box': [807, 441, 948, 535], 'x': 881, 'y': 533, 'track_id': None}]}, {'frame_id': 6, 'data': [{'cb_id': 0, 'bounding_box': [932, -65, 1037, 20], 'x': 997, 'y': 20, 'track_id': None}, {'cb_id': 1, 'bounding_box': [40, 326, 185, 447], 'x': 117, 'y': 435, 'track_id': None}, {'cb_id': 2, 'bounding_box': [], 'x': 96, 'y': 245, 'track_id': None}, {'cb_id': 3, 'bounding_box': [24, 626, 144, 747], 'x': 72, 'y': 733, 'track_id': None}, {'cb_id':
4, 'bounding_box': [801, 384, 919, 500], 'x': 860, 'y': 483, 'track_id': None}]}, {'frame_id': 7, 'data': [{'cb_id': 0, 'bounding_box': [938, -75, 1069, 36], 'x': 997, 'y': 24, 'track_id': None}, {'cb_id': 1, 'bounding_box': [99, 314, 197, 416], 'x': 141, 'y': 427, 'track_id': None}, {'cb_id': 2, 'bounding_box': [40, 195, 192, 268], 'x': 115, 'y': 288, 'track_id': None}, {'cb_id': 3, 'bounding_box': [16, 613, 145, 704], 'x': 69, 'y': 718, 'track_id': None}, {'cb_id': 4, 'bounding_box': [788, 324, 908, 426], 'x': 839, 'y': 441, 'track_id': None}]}, {'frame_id': 8, 'data': [{'cb_id': 0, 'bounding_box': [925, -76, 1076, 43], 'x': 997, 'y': 29, 'track_id': None}, {'cb_id': 1, 'bounding_box': [102, 306, 242, 439], 'x': 164, 'y': 419, 'track_id': None}, {'cb_id': 2, 'bounding_box': [78, 237, 195, 320], 'x': 134, 'y': 329, 'track_id': None}, {'cb_id': 3, 'bounding_box': [16, 598, 143, 717], 'x': 65, 'y': 703, 'track_id': None}, {'cb_id': 4, 'bounding_box': [778, 310, 893, 396], 'x': 818, 'y': 404, 'track_id': None}]}, {'frame_id': 9, 'data': [{'cb_id': 0, 'bounding_box': [946, -79, 1045, 52], 'x': 997, 'y': 33, 'track_id': None}, {'cb_id': 1, 'bounding_box': [134, 294, 256, 416], 'x': 188, 'y': 412, 'track_id': None}, {'cb_id': 2, 'bounding_box': [81, 281, 208, 355], 'x': 153, 'y': 369, 'track_id': None}, {'cb_id': 3, 'bounding_box': [-10, 587, 133, 687], 'x': 61, 'y': 687, 'track_id': None}, {'cb_id': 4, 'bounding_box': [724, 262, 849, 390], 'x': 797, 'y': 374, 'track_id': None}]}, {'frame_id': 10, 'data': [{'cb_id': 0, 'bounding_box': [944, -52, 1054, 41], 'x': 997, 'y': 37, 'track_id': None}, {'cb_id': 1, 'bounding_box': [144, 318, 251, 422], 'x': 211, 'y': 406, 'track_id': None}, {'cb_id': 2, 'bounding_box': [], 'x': 173, 'y': 406, 'track_id': None}, {'cb_id': 3, 'bounding_box': [-6, 558, 124, 672], 'x': 57, 'y': 671, 'track_id': None}, {'cb_id': 4, 'bounding_box': [733, 253, 826, 362], 'x': 776, 'y': 349, 'track_id': None}]}, {'frame_id': 11, 'data': [{'cb_id': 0, 'bounding_box': [941, -67, 1038, 40], 'x': 997, 'y': 41, 'track_id': None}, {'cb_id': 1, 'bounding_box': [179, 292, 293, 410], 'x': 235, 'y': 402, 'track_id': None}, {'cb_id': 2, 'bounding_box': [117, 322, 258, 426], 'x': 192, 'y': 441, 'track_id': None}, {'cb_id': 3, 'bounding_box': [-25, 551, 124, 651], 'x': 53, 'y': 654, 'track_id': None}, {'cb_id': 4, 'bounding_box': [690, 220, 797, 340], 'x': 755, 'y': 329, 'track_id': None}]}, {'frame_id': 12, 'data': [{'cb_id': 0, 'bounding_box': [], 'x': 997, 'y': 45, 'track_id': None}, {'cb_id': 1, 'bounding_box': [217, 282, 307, 405], 'x': 258, 'y': 399, 'track_id': None}, {'cb_id': 2, 'bounding_box': [170, 360, 291, 471], 'x': 211, 'y': 473, 'track_id': None}, {'cb_id': 3, 'bounding_box': [-26, 539, 128, 643], 'x': 49, 'y': 637, 'track_id': None}, {'cb_id': 4, 'bounding_box': [], 'x': 734, 'y': 315, 'track_id': None}]}, {'frame_id': 13, 'data': [{'cb_id': 0, 'bounding_box': [919, -35, 1074, 52], 'x': 998, 'y': 50, 'track_id': None}, {'cb_id': 1, 'bounding_box': [234, 293, 329, 416], 'x': 282, 'y': 397, 'track_id': None}, {'cb_id': 2, 'bounding_box':
[], 'x': 230, 'y': 503, 'track_id': None}, {'cb_id': 3, 'bounding_box': [-18, 517, 96, 624], 'x': 46, 'y': 619, 'track_id': None}, {'cb_id': 4, 'bounding_box': [645, 190, 786, 285], 'x': 713, 'y': 305, 'track_id': None}]}, {'frame_id': 14, 'data': [{'cb_id': 0, 'bounding_box': [936, -33, 1050, 58], 'x': 998, 'y': 54, 'track_id': None}, {'cb_id': 1, 'bounding_box': [235, 301, 385, 401], 'x': 305, 'y': 396, 'track_id': None}, {'cb_id': 2, 'bounding_box': [208, 411, 319, 528], 'x': 249, 'y': 530, 'track_id': None}, {'cb_id': 3, 'bounding_box': [], 'x': 42, 'y': 600, 'track_id': None}, {'cb_id': 4, 'bounding_box': [635, 205, 732, 293], 'x': 692, 'y': 300, 'track_id': None}]}, {'frame_id': 15, 'data': [{'cb_id': 0, 'bounding_box': [950, -46, 1055, 58], 'x': 998, 'y': 58,
'track_id': None}, {'cb_id': 1, 'bounding_box': [266, 285, 376, 398], 'x': 329, 'y': 397, 'track_id': None}, {'cb_id': 2, 'bounding_box': [211, 436, 323, 552], 'x': 269, 'y': 555, 'track_id': None}, {'cb_id': 3, 'bounding_box': [-8, 477, 115, 569], 'x': 38, 'y': 581, 'track_id': None}, {'cb_id': 4, 'bounding_box': [617, 181, 712, 300], 'x': 671, 'y': 299, 'track_id': None}]}, {'frame_id': 16, 'data': [{'cb_id': 0, 'bounding_box': [926, -29, 1047, 57], 'x': 998, 'y': 62, 'track_id': None}, {'cb_id': 1, 'bounding_box': [298, 298, 418, 411], 'x': 352,
'y': 400, 'track_id': None}, {'cb_id': 2, 'bounding_box': [248, 490, 336, 575], 'x': 288, 'y': 577, 'track_id': None}, {'cb_id': 3, 'bounding_box': [-46, 481, 112, 576], 'x': 34, 'y': 562, 'track_id': None}, {'cb_id': 4, 'bounding_box': [603, 189, 721, 316], 'x': 650, 'y': 301, 'track_id': None}]}, {'frame_id': 17, 'data': [{'cb_id': 0, 'bounding_box': [], 'x': 998, 'y': 66, 'track_id': None}, {'cb_id': 1, 'bounding_box': [333, 309, 443, 396], 'x': 376, 'y': 404, 'track_id': None}, {'cb_id': 2, 'bounding_box': [250, 498, 353, 610], 'x': 307, 'y': 596, 'track_id': None}, {'cb_id': 3, 'bounding_box': [-35, 430, 90, 526], 'x': 30, 'y': 542, 'track_id': None}, {'cb_id': 4, 'bounding_box': [578, 218, 691, 320], 'x': 629, 'y': 308, 'track_id': None}]}, {'frame_id': 18, 'data': [{'cb_id': 0, 'bounding_box': [951, -47, 1065, 56], 'x': 998, 'y': 71, 'track_id': None}, {'cb_id': 1, 'bounding_box': [357, 311, 478, 423], 'x': 400, 'y': 409, 'track_id': None}, {'cb_id': 2, 'bounding_box': [278, 529, 402, 614], 'x': 326, 'y': 612, 'track_id': None}, {'cb_id': 3, 'bounding_box': [-23, 401, 105, 509], 'x': 26, 'y': 521, 'track_id': None}, {'cb_id': 4, 'bounding_box': [568, 225, 671, 304], 'x': 608, 'y': 317, 'track_id': None}]}, {'frame_id': 19, 'data': [{'cb_id': 0, 'bounding_box': [926, -34, 1045, 63], 'x': 998, 'y': 75, 'track_id': None}, {'cb_id': 1, 'bounding_box': [], 'x': 423, 'y': 416, 'track_id': None}, {'cb_id': 2, 'bounding_box': [301, 524, 420, 624], 'x': 346, 'y': 626, 'track_id': None}, {'cb_id': 3, 'bounding_box': [-56, 408, 70,
510], 'x': 23, 'y': 500, 'track_id': None}, {'cb_id': 4, 'bounding_box': [], 'x': 587, 'y': 330, 'track_id': None}]}, {'frame_id': 20, 'data': [{'cb_id': 0, 'bounding_box': [942, -25, 1051, 79], 'x': 998, 'y': 79, 'track_id': None}, {'cb_id': 1, 'bounding_box': [368, 330, 500, 437], 'x': 447, 'y': 425, 'track_id': None}, {'cb_id':
2, 'bounding_box': [309, 543, 439, 617], 'x': 365, 'y': 636, 'track_id': None}, {'cb_id': 3, 'bounding_box': [-32, 365, 99, 487], 'x': 19, 'y': 478, 'track_id': None}, {'cb_id': 4, 'bounding_box': [], 'x': 566, 'y': 345, 'track_id': None}]}, {'frame_id': 21, 'data': [{'cb_id': 0, 'bounding_box': [923, -3, 1077, 85], 'x': 998, 'y':
83, 'track_id': None}, {'cb_id': 1, 'bounding_box': [424, 355, 531, 433], 'x': 470, 'y': 435, 'track_id': None}, {'cb_id': 2, 'bounding_box': [326, 529, 427, 654], 'x': 384, 'y': 643, 'track_id': None}, {'cb_id': 3, 'bounding_box': [-36, 373, 83, 453], 'x': 15, 'y': 455, 'track_id': None}, {'cb_id': 4, 'bounding_box': [505, 269, 601, 360], 'x': 545, 'y': 363, 'track_id': None}]}, {'frame_id': 22, 'data': [{'cb_id': 0, 'bounding_box': [932,
-1, 1072, 105], 'x': 998, 'y': 88, 'track_id': None}, {'cb_id': 1, 'bounding_box': [417, 331, 547, 450], 'x': 494, 'y': 448, 'track_id': None}, {'cb_id': 2, 'bounding_box': [359, 535, 467, 642], 'x': 403, 'y': 647, 'track_id': None}, {'cb_id': 3, 'bounding_box': [-44, 336, 66, 433], 'x': 11, 'y': 432, 'track_id': None}, {'cb_id': 4, 'bounding_box': [476, 285, 588, 372], 'x': 524, 'y': 383, 'track_id': None}]}, {'frame_id': 23, 'data': [{'cb_id': 0, 'bounding_box': [922, -10, 1057, 74], 'x': 998, 'y': 92, 'track_id': None}, {'cb_id': 1, 'bounding_box': [448, 365, 575, 450], 'x': 517, 'y': 462, 'track_id': None}, {'cb_id': 2, 'bounding_box': [381, 528, 494, 657], 'x': 422, 'y': 648, 'track_id': None}, {'cb_id': 3, 'bounding_box': [-61, 311, 68, 389], 'x': 7, 'y': 409,
'track_id': None}, {'cb_id': 4, 'bounding_box': [448, 317, 580, 425], 'x': 503, 'y': 405, 'track_id': None}]},
{'frame_id': 24, 'data': [{'cb_id': 0, 'bounding_box': [955, 1, 1043, 105], 'x': 998, 'y': 96, 'track_id': None}, {'cb_id': 1, 'bounding_box': [499, 373, 596, 488], 'x': 541, 'y': 478, 'track_id': None}, {'cb_id': 2, 'bounding_box': [399, 535, 500, 646], 'x': 442, 'y': 645, 'track_id': None}, {'cb_id': 3, 'bounding_box': [-38, 304, 74, 395], 'x': 3, 'y': 384, 'track_id': None}, {'cb_id': 4, 'bounding_box': [], 'x': 482, 'y': 429, 'track_id': None}]}, {'frame_id': 25, 'data': [{'cb_id': 0, 'bounding_box': [919, -13, 1040, 88], 'x': 999, 'y': 100, 'track_id': None}, {'cb_id': 1, 'bounding_box': [], 'x': 564, 'y': 496, 'track_id': None}, {'cb_id': 2, 'bounding_box': [388, 555, 507, 637], 'x': 461, 'y': 639, 'track_id': None}, {'cb_id': 3, 'bounding_box': [], 'x': 0, 'y': 360, 'track_id': None}, {'cb_id': 4, 'bounding_box': [414, 361, 511, 469], 'x': 461, 'y': 453, 'track_id': None}]}, {'frame_id': 26, 'data': [{'cb_id': 0, 'bounding_box': [930, 19, 1045, 104], 'x': 999, 'y': 105, 'track_id': None}, {'cb_id': 1, 'bounding_box': [548, 425, 633, 510], 'x': 588, 'y': 516, 'track_id': None}, {'cb_id': 2, 'bounding_box': [420, 509, 553, 633], 'x': 480, 'y': 629, 'track_id': None}, {'cb_id': 4, 'bounding_box': [364, 379, 495, 491], 'x': 440, 'y': 479, 'track_id': None}]}, {'frame_id': 27, 'data': [{'cb_id': 0, 'bounding_box': [934, 10, 1078, 112], 'x': 999, 'y': 109, 'track_id': None}, {'cb_id': 1, 'bounding_box': [554, 458, 680, 537], 'x': 611, 'y': 539, 'track_id': None}, {'cb_id': 2, 'bounding_box': [436, 511, 540, 616], 'x': 499, 'y': 616, 'track_id': None}, {'cb_id': 4, 'bounding_box': [353, 408, 471, 498], 'x': 419, 'y': 506, 'track_id': None}]}, {'frame_id': 28, 'data': [{'cb_id': 0, 'bounding_box': [937, 32, 1050, 132], 'x': 999, 'y': 113, 'track_id': None}, {'cb_id': 1, 'bounding_box': [570, 478, 677, 549], 'x': 635, 'y': 563, 'track_id': None}, {'cb_id':
2, 'bounding_box': [464, 485, 561, 581], 'x': 519, 'y': 599, 'track_id': None}, {'cb_id': 4, 'bounding_box': [325, 431, 449, 553], 'x': 398, 'y': 533, 'track_id': None}]}, {'frame_id': 29, 'data': [{'cb_id': 0, 'bounding_box': [950, 3, 1041, 136], 'x': 999, 'y': 117, 'track_id': None}, {'cb_id': 1, 'bounding_box': [591, 496, 732, 583], 'x': 658, 'y': 590, 'track_id': None}, {'cb_id': 2, 'bounding_box': [481, 469, 600, 590], 'x': 538, 'y': 578, 'track_id': None}, {'cb_id': 4, 'bounding_box': [301, 442, 436, 550], 'x': 377, 'y': 560, 'track_id': None}]}, {'frame_id': 30, 'data': [{'cb_id': 0, 'bounding_box': [920, 27, 1053, 129], 'x': 999, 'y': 122, 'track_id': None}, {'cb_id': 1, 'bounding_box': [608, 508, 732, 609], 'x': 682, 'y': 619, 'track_id': None}, {'cb_id': 2, 'bounding_box': [512, 436, 637, 561], 'x': 557, 'y': 553, 'track_id': None}, {'cb_id': 4, 'bounding_box': [313, 467, 427, 573], 'x': 356, 'y': 587, 'track_id': None}]}, {'frame_id': 31, 'data': [{'cb_id': 0, 'bounding_box': [940, 40, 1050, 119], 'x': 999, 'y': 126, 'track_id': None}, {'cb_id': 1, 'bounding_box': [656, 546, 757, 656], 'x': 705, 'y': 650, 'track_id': None}, {'cb_id': 2, 'bounding_box': [517, 440, 638, 528], 'x': 576, 'y': 524, 'track_id': None}, {'cb_id': 4, 'bounding_box': [], 'x': 335, 'y': 613, 'track_id': None}]}, {'frame_id': 32, 'data': [{'cb_id': 0, 'bounding_box': [936, 16, 1065, 133], 'x': 999, 'y': 130, 'track_id': None}, {'cb_id':
1, 'bounding_box': [683, 587, 778, 700], 'x': 729, 'y': 684, 'track_id': None}, {'cb_id': 2, 'bounding_box': [524, 408, 650, 511], 'x': 595, 'y': 491, 'track_id': None}, {'cb_id': 4, 'bounding_box': [246, 526, 375, 658], 'x': 314, 'y': 639, 'track_id': None}]}, {'frame_id': 33, 'data': [{'cb_id': 0, 'bounding_box': [942, 41, 1040,
136], 'x': 999, 'y': 135, 'track_id': None}, {'cb_id': 1, 'bounding_box': [712, 608, 804, 732], 'x': 752, 'y':
720, 'track_id': None}, {'cb_id': 2, 'bounding_box': [538, 371, 681, 463], 'x': 615, 'y': 454, 'track_id': None}, {'cb_id': 4, 'bounding_box': [226, 544, 333, 649], 'x': 293, 'y': 664, 'track_id': None}]}, {'frame_id': 34, 'data': [{'cb_id': 0, 'bounding_box': [954, 48, 1075, 143], 'x': 999, 'y': 139, 'track_id': None}, {'cb_id': 1, 'bounding_box': [718, 640, 818, 770], 'x': 776, 'y': 758, 'track_id': None}, {'cb_id': 2, 'bounding_box': [579, 315, 707, 432], 'x': 634, 'y': 413, 'track_id': None}, {'cb_id': 4, 'bounding_box': [216, 596, 321, 689], 'x': 272, 'y': 687, 'track_id': None}]}, {'frame_id': 35, 'data': [{'cb_id': 0, 'bounding_box': [923, 35, 1071, 134], 'x': 999, 'y': 143, 'track_id': None}, {'cb_id': 1, 'bounding_box': [752, 711, 854, 785], 'x': 800, 'y': 800, 'track_id': None}, {'cb_id': 2, 'bounding_box': [601, 270, 732, 374], 'x': 653, 'y': 368, 'track_id': None}, {'cb_id': 4, 'bounding_box': [183, 622, 305, 719], 'x': 251, 'y': 709, 'track_id': None}]}, {'frame_id': 36,
'data': [{'cb_id': 0, 'bounding_box': [931, 65, 1071, 145], 'x': 1000, 'y': 148, 'track_id': None}, {'cb_id': 2, 'bounding_box': [626, 206, 712, 316], 'x': 672, 'y': 318, 'track_id': None}, {'cb_id': 4, 'bounding_box': [167, 644, 274, 726], 'x': 230, 'y': 729, 'track_id': None}]}, {'frame_id': 37, 'data': [{'cb_id': 2, 'bounding_box': [623, 155, 771, 244], 'x': 692, 'y': 263, 'track_id': None}, {'cb_id': 4, 'bounding_box': [141, 641, 261, 756], 'x': 209, 'y': 747, 'track_id': None}]}, {'frame_id': 38, 'data': [{'cb_id': 2, 'bounding_box': [634, 115, 766, 197], 'x': 711, 'y': 204, 'track_id': None}, {'cb_id': 4, 'bounding_box': [136, 663, 250, 772], 'x': 188, 'y': 762, 'track_id': None}]}, {'frame_id': 39, 'data': [{'cb_id': 2, 'bounding_box': [678, 30, 796, 155], 'x': 730, 'y': 141, 'track_id': None}, {'cb_id': 4, 'bounding_box': [103, 690, 218, 767], 'x': 167, 'y': 775, 'track_id': None}]}, {'frame_id': 40, 'data': [{'cb_id': 2, 'bounding_box': [689, -23, 806, 76], 'x': 749, 'y': 73, 'track_id': None}, {'cb_id': 4, 'bounding_box': [101, 695, 221, 795], 'x': 146, 'y': 785, 'track_id': None}]}, {'frame_id': 41, 'data': [{'cb_id': 2, 'bounding_box': [693, -92, 817, 11], 'x': 769, 'y': 0, 'track_id': None}, {'cb_id': 4, 'bounding_box': [73, 674, 190, 786], 'x': 125, 'y': 791, 'track_id': None}]}, {'frame_id': 42,
'data': [{'cb_id': 4, 'bounding_box': [41, 681, 181, 775], 'x': 104, 'y': 794, 'track_id': None}]}, {'frame_id': 43, 'data': [{'cb_id': 4, 'bounding_box': [], 'x': 83, 'y': 793, 'track_id': None}]}, {'frame_id': 44, 'data': [{'cb_id': 4, 'bounding_box': [-14, 678, 128, 770], 'x': 62, 'y': 787, 'track_id': None}]}, {'frame_id': 45,
'data': [{'cb_id': 4, 'bounding_box': [-3, 669, 99, 785], 'x': 41, 'y': 778, 'track_id': None}]}, {'frame_id':
46, 'data': [{'cb_id': 4, 'bounding_box': [], 'x': 20, 'y': 763, 'track_id': None}]}, {'frame_id': 47, 'data':
[{'cb_id': 4, 'bounding_box': [-44, 657, 51, 742], 'x': 0, 'y': 744, 'track_id': None}]}]

In [None]:
import asyncio
import glob
from transformers import AutoModel

imgs = glob.glob('/content/drive/MyDrive/imgs/*')
country_balls = [{'cb_id': x, 'img': imgs[x % len(imgs)]} for x in range(country_balls_amount)]
frames = glob.glob('/content/drive/MyDrive/screens/cb5/*')
model_ckpt = "nateraw/vit-base-beans"
model = AutoModel.from_pretrained(model_ckpt)
model = model.to(device)
print('Started')


def tracker_strong(el):
    tracker = Tracker(160, 30, 5, 0, model)
    detections = [
        {
            'img': frames[el['frame_id'] - 1],
            'bbox': obj['bounding_box'],
            'center': (obj['x'], int(obj['y'] - obj['bounding_box'][1] / 2))
        } for obj in el['data'] if len(obj['bounding_box']) > 0
    ]
    tracker.Update(detections)
    for i in range(len(tracker.tracks)):
        el['data'][i]['track_id'] = int(tracker.tracks[i].track_id)
    return el



preds = []
for el in track_data:
    el = tracker_strong(el)
    preds.append(el)
print(id_f1(preds, track_data))


Some weights of ViTModel were not initialized from the model checkpoint at nateraw/vit-base-beans and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Started
0.7286821705426356
