In [1]:
import torch
from torch.utils.data import Dataset
import numpy as np
import os
import cv2
import time
import random
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from utils import *
from augmentation import Augmentation

In [2]:
class DatasetLaSOT(Dataset):
    def __init__(self, mode, dir_data, size_template, size_search, size_out, max_frame_sep, neg_prob=0.5, extra_context_template=0.5, min_extra_context_search=0.75, max_extra_context_search=1.0, max_shift=0):
        self.mode = mode
        self.dir_data = dir_data
        self.size_template = size_template
        self.size_search = size_search
        self.size_out = size_out
        self.neg_prob = neg_prob
        self.max_frame_sep = max_frame_sep
        self.extra_context_template = extra_context_template
        self.min_extra_context_search = min_extra_context_search
        self.max_extra_context_search = max_extra_context_search
        self.max_shift = max_shift
        # mean/std for ImageNet‐pretrained backbones
        # Adapt these variables to the backbone used
        self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)[None,:,None,None]
        self.std  = np.array([0.229, 0.224, 0.225], dtype=np.float32)[None,:,None,None]

        if self.mode == "train":
            file = os.path.join(dir_data, "training_set.txt")
        elif self.mode == "test":
            file = os.path.join(dir_data, "testing_set.txt")
        else:
            raise Exception("the mode must be either train or test")

        # Get number of videos for the training set
        with open(file, 'r') as file:
            self.video_names = [line.strip() for line in file]

        # Get the category names
        self.categories = sorted([name for name in os.listdir(self.dir_data)
              if os.path.isdir(os.path.join(self.dir_data, name))])

        # Get the location of each frame per video and the bounding boxes (decided to keep as separate dictionaries
        self.dict_frames_per_video = {}
        self.dict_bboxes_per_video = {}
        for video_name in self.video_names:
            category = video_name.split('-')[0]
            # This try is for the moment because I haven't downloaded the whole dataset. REMOVE
            self.dict_frames_per_video[video_name] = sorted([os.path.join(dir_data, category, video_name, "img", frame) for frame in os.listdir(os.path.join(dir_data, category, video_name, "img"))])
            self.dict_bboxes_per_video[video_name] = []
            with open(os.path.join(dir_data, category, video_name, "groundtruth.txt"), "r") as f:
                for line in f:
                    bbox = list(map(int, line.strip().split(",")))
                    self.dict_bboxes_per_video[video_name].append(bbox)


        # Get the number of frames per video
        self.dict_n_frames_per_video = {}
        for key, value in self.dict_frames_per_video.items():
            self.dict_n_frames_per_video[key] = len(value)

        # Total number of frames
        self.total_n_frames = sum(self.dict_n_frames_per_video.values())

        # List of frames
        #self.list_frames = []
        #for key in sorted(self.dict_frames_per_video.keys()):
            #self.list_frames.extend(self.dict_frames_per_video[key])

    def get_data_from_idx(self, idx):
        cumulative = 0
        for key in sorted(self.dict_n_frames_per_video):
            value = self.dict_n_frames_per_video[key]
            if idx < cumulative + value:
                return key, idx-cumulative
            cumulative += value
        raise IndexError("Index out of range")

    def __len__(self):
        return self.total_n_frames
        
    
    def visualize_video(self, video_name, with_bboxes=True, fps=30):
        frame_delay_ms = int(1000/fps)
        for n_frame, frame_path in enumerate(self.dict_frames_per_video[video_name]):
            frame = cv2.imread(frame_path)
            if frame is None:
                print(f"Could not read {frame_path}")
                continue

            if with_bboxes:
                x, y, w, h = self.dict_bboxes_per_video[video_name][n_frame]
                # Draw rectangle: image, top-left, bottom-right, color (BGR), thickness
                cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
                
            cv2.imshow("Video Playback", frame)

            if cv2.waitKey(frame_delay_ms) & 0xFF == ord('q'):
                break

        cv2.destroyAllWindows()


    def get_context_bbox(self, bbox, extra_context):
        """
        Convert a tight bbox (x, y, w, h) into a square context bbox with margin.
        Returns (cx, cy, size).
        """
        x, y, w, h = bbox
        cx = x + w / 2.
        cy = y + h / 2.
        # context padding = (w+h)*extra_context
        pad = (w + h) * extra_context
        # square size
        size = np.sqrt((w + pad) * (h + pad))
        return cx, cy, size

    def crop_and_resize(self,
                    frame,
                    cx, cy,
                    size,
                    out_size,
                    shift_x=0, shift_y=0):
        """
        Crop a square patch of side 'size' centered at (cx, cy) from frame,
        apply a shift in the out_size coordinate system, pad with border
        replication if needed, and resize to (out_size, out_size).
    
        shift_x, shift_y: pixel offsets **in the resized patch**. Can be
        positive or negative, moving the target around in the crop.
        Returns the patch and the scaling factor.
        """
        h, w = frame.shape[:2]
    
        # 1) convert shifts from output coords → original-frame coords
        #    scale = out_size / size  ⇒  size/out_size = 1/scale
        shift_x_orig = shift_x * size / out_size
        shift_y_orig = shift_y * size / out_size
    
        # 2) adjust the true crop-center in the original frame
        cx = cx + shift_x_orig
        cy = cy + shift_y_orig
    
        # 3) now compute the square-window coords as before
        x1 = cx - size/2
        y1 = cy - size/2
        x2 = x1 + size
        y2 = y1 + size
    
        # 4) compute padding amounts for out‑of‑bounds regions
        left   = int(max(0, -np.floor(x1)))
        top    = int(max(0, -np.floor(y1)))
        right  = int(max(0, np.ceil(x2)  - w))
        bottom = int(max(0, np.ceil(y2)  - h))
    
        # 5) pad & crop
        padded = cv2.copyMakeBorder(
            frame,
            top, bottom, left, right,
            borderType=cv2.BORDER_REPLICATE
        )
        x1p, y1p = x1 + left,  y1 + top
        x2p, y2p = x2 + left,  y2 + top
        patch    = padded[int(y1p):int(y2p), int(x1p):int(x2p)]
    
        # 6) resize & return
        patch_resized = cv2.resize(patch, (out_size, out_size))
        scale = out_size / size
        return patch_resized, scale

    def preprocess_pair(self, frame1, frame2, bbox1, bbox2):
        """
        Given two frames and their tight bboxes, compute exemplar & search patches,
        along with their resized bbox coordinates in patch space.
        Returns:
          exemplar_img, search_img, exemplar_box, search_box
        where boxes are (x, y, w, h) in the resized patch coordinate system.
        
        """
        extra_context_search = random.uniform(self.min_extra_context_search, self.max_extra_context_search)
        print("Extra context search: " + str(extra_context_search))
        shift_x = random.randint(-self.max_shift, self.max_shift)
        shift_y = random.randint(-self.max_shift, self.max_shift)
        
        # Frame 1: exemplar
        cx1, cy1, size1 = self.get_context_bbox(bbox1, self.extra_context_template)
        exemplar, scale1 = self.crop_and_resize(frame1, cx1, cy1, size1, self.size_template, 0, 0)
    
        # bbox1 in exemplar coords: centered
        ex_bbox = [ (self.size_template - bbox1[2]*scale1)/2,
                    (self.size_template - bbox1[3]*scale1)/2,
                    bbox1[2]*scale1,
                    bbox1[3]*scale1 ]
    
        # Frame 2: search
        cx2, cy2, size2 = self.get_context_bbox(bbox2, extra_context_search)
        search, scale2 = self.crop_and_resize(frame2, cx2, cy2, size2, self.size_search, shift_x, shift_y)
    
        # bbox2 in search coords before augment: centered
        sr_bbox = [ (self.size_search - bbox2[2]*scale2)/2 - shift_x,
                    (self.size_search - bbox2[3]*scale2)/2 - shift_y,
                    bbox2[2]*scale2,
                    bbox2[3]*scale2 ]
    
        return exemplar, search, ex_bbox, sr_bbox

    def to_tensor(self, img):
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
        img = (img[None].transpose(0,3,1,2) - self.mean) / self.std
        return torch.from_numpy(img[0])

    def get_negative_sample(self, video_name_first_frame):
        # Decide if the negative sample will have an object or pure background
        if random.random() < 0.5:
            print("With object!")
            with_object = True
        else:
            print("Without object!")
            with_object = False
        # Select a video to sample from
        video_name_second_frame = video_name_first_frame
        while video_name_second_frame == video_name_first_frame:
            video_name_second_frame = random.choice(list(self.dict_n_frames_per_video.keys()))
        print(video_name_second_frame)
        # Random frame
        idx_second_frame = random.randint(0, self.dict_n_frames_per_video[video_name_second_frame] -1)
        second_frame = cv2.imread(self.dict_frames_per_video[video_name_second_frame][idx_second_frame])
        # Get the bounding box of the object
        bbox2 = self.dict_bboxes_per_video[video_name_second_frame][idx_second_frame]
        if with_object:
            # We return directly the new frame and Bounding box
            return second_frame, bbox2
        else: # Get a random patch of the image. For that, we are going to move randomly the bbox center and apply the same logic
            w, h = bbox2[2], bbox2[3]
            # Make sure we are within the limits
            max_x = second_frame.shape[1] - w
            max_y = second_frame.shape[0] - h
            x = random.randint(0,max_x)
            y = random.randint(0,max_y)
            return second_frame, [x, y, w, h]

    def make_rect_tent(self, bbox):
        """
        bbox = [xmin, ymin, w, h] in search-patch pixel coords.
        Returns heatmap of shape (H,W), peak=1 at center,
        linearly decaying to 0 at the box edges.
        """
        stride = self.size_search / self.size_out
    
        # 1) cell centers in pixel coords
        coords = np.arange(self.size_out) * stride + stride/2
        xs, ys = np.meshgrid(coords, coords)  # (H,W)
    
        xmin, ymin, w, h = bbox
        cx = xmin + w/2
        cy = ymin + h/2
    
        # 2) normalized distances in [0,1]
        dx = np.abs(xs - cx) / (w/2)
        dy = np.abs(ys - cy) / (h/2)
    
        # 3) clamp and compute tent
        tx = np.clip(1 - dx, 0, 1)
        ty = np.clip(1 - dy, 0, 1)
        heatmap = tx * ty  # (H,W)
        # Force max to 1
        heatmap /= heatmap.max()

        # 4) build (w,h) regression map
        mask = (heatmap > 0).astype(np.float32)      # (H,W)
        reg_wh = np.zeros((self.size_out, self.size_out, 2), dtype=np.float32)
        reg_wh[..., 0] = w/self.size_search * mask  # normalized width
        reg_wh[..., 1] = h/self.size_search * mask  # normalized height

        return heatmap, reg_wh

    def get_positive_sample(self, video_name, idx_first_frame):
        # Obtain the second idx and image
        idx_second_frame = idx_first_frame
        min_frame = max(0, idx_first_frame-self.max_frame_sep)
        max_frame = min(self.dict_n_frames_per_video[video_name], idx_first_frame+self.max_frame_sep)
        while idx_second_frame == idx_first_frame:
            #idx_second_frame = random.choice(range(self.dict_n_frames_per_video[video_name]))
            idx_second_frame = random.choice(range(min_frame, max_frame))
            
        second_frame = cv2.imread(self.dict_frames_per_video[video_name][idx_second_frame])

        # Obtain bounding boxes
        bbox2 = self.dict_bboxes_per_video[video_name][idx_second_frame]
        return second_frame, bbox2

    #def get_output(self, search_img, search_bbox):
        
            
    def __getitem__(self, idx):
        """ Returns the inputs and output for the learning problem. The input
        consists of an reference image tensor and a search image tensor, the
        output is the corresponding label tensor.

        Args:
            idx: (int) The index of a sequence inside the whole dataset, from
                which the function will choose the reference and search frames.

        Returns:
            ref_frame (torch.Tensor): The reference frame with the
                specified size.
            srch_frame (torch.Tensor): The search frame with the
                specified size.
            label (torch.Tensor): The label created with the specified
                function in self.label_fcn.
        """
        # Negative sample
        if random.random() < self.neg_prob:
            is_positive = False
        else: 
            is_positive = True
        
        # Obtain the video and the index inside that video. Then, read the imae
        video_name, idx_first_frame = self.get_data_from_idx(idx)
        first_frame = cv2.imread(self.dict_frames_per_video[video_name][idx_first_frame])

        # Obtain bounding boxes
        bbox1 = self.dict_bboxes_per_video[video_name][idx_first_frame]

        if is_positive: # Positive sample
            print("Positive!")
            second_frame, bbox2 = self.get_positive_sample(video_name, idx_first_frame)
            # If the width or height of the box are 0 it is actually a negative sample!!
            if bbox2[2]==0 or bbox2[3]==0:
                print("The sample is actually negative!! Random crop.")
                is_positive = False
                # Create a random bbox to patch over it
                x = random.randint(0, 0.9*second_frame.shape[1])
                y = random.randint(0, 0.9*second_frame.shape[0])
                w = random.randint(0.1*second_frame.shape[1], second_frame.shape[1]-x-1)
                h = random.randint(0.1*second_frame.shape[0], second_frame.shape[0]-y-1)
                bbox2 = [x, y, w, h]

        else: # Negative sample
            print("Negative!")
            second_frame, bbox2 = self.get_negative_sample(video_name)

        template, search, bbox1_x1y1wh, bbox2_x1y1wh = self.preprocess_pair(first_frame, second_frame, bbox1, bbox2)

        if is_positive:
            heatmap, reg_wh = self.make_rect_tent(bbox2_x1y1wh)
        else:
            reg_wh = np.zeros((self.size_out, self.size_out, 2), dtype=np.float32)
            heatmap = np.zeros((self.size_out, self.size_out), dtype=np.float32)

        #output = {'template': self.to_tensor(template),
        #          'search': self.to_tensor(search),
        #          'heatmap': heatmap,
        #          'reg_wh': reg_wh}
        return template, search, heatmap, reg_wh
        #return output
        #return first_frame, second_frame, template, search, bbox1_x1y1wh, bbox2_x1y1wh, heatmap, reg_wh


In [3]:
dataset = DatasetLaSOT("train", "/home/rafa/deep_learning/datasets/LaSOT", 127, 255, 127, 10, 0.45, 0.5, 0.75, 1.5, 32)

In [4]:
#dataset.visualize_video("airplane-10")

In [5]:
output = dataset.__getitem__(10000)
print(output['template'].shape)
print(output['search'].shape)
print(output['heatmap'].shape)
print(output['reg_wh'].shape)

Negative!
With object!
bus-10
Extra context search: 1.1480946756504888
torch.Size([3, 127, 127])
torch.Size([3, 255, 255])
(127, 127)
(127, 127, 2)


In [6]:
first_frame, second_frame, template, search, bbox1, bbox2, heatmap, reg_wh = dataset.__getitem__(10000)
print(first_frame.shape)
plt.imshow(cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB))
plt.axis('off')  # Hide axes
plt.show()

Negative!
Without object!
boat-20
Extra context search: 1.3938253609638867


ValueError: not enough values to unpack (expected 8, got 4)

In [None]:
x1, y1, w, h = bbox1
x1, y1, w, h = int(x1), int(y1), int(w), int(h)
cv2.rectangle(template, (x1, y1), (x1+w, y1+h), (0, 255, 0), 2)
plt.imshow(cv2.cvtColor(template.astype(np.uint8), cv2.COLOR_BGR2RGB))
plt.axis('off')  # Hide axes
plt.show()

In [None]:
x1, y1, w, h = bbox2
x1, y1, w, h = int(x1), int(y1), int(w), int(h)
cv2.rectangle(search, (x1, y1), (x1+w, y1+h), (0, 255, 0), 2)
plt.imshow(cv2.cvtColor(search.astype(np.uint8), cv2.COLOR_BGR2RGB))
plt.axis('off')  # Hide axes
plt.show()

In [None]:
print("cx: ", int((x1+w/2)/dataset.size_search*dataset.size_out))
print("cy: ", int((y1+h/2)/dataset.size_search*dataset.size_out))
print(np.max(heatmap))
cy, cx = np.unravel_index(heatmap.argmax(), heatmap.shape)
w, h = reg_wh[cy, cx]
#cx = round((x1+w/2)/dataset.size_search*dataset.size_out)
#cy = round((y1+h/2)/dataset.size_search*dataset.size_out)
print("cx: ", cx)
print("cy: ", cy)
print("w: ", w)
print("h: ", h)
#cv2.rectangle(heatmap, (int(cx-w*dataset.size_out/2), int(cy-w*dataset.size_out/2)), (int(cx+w*dataset.size_out/2), int(cy+w*dataset.size_out/2)), (0, 255, 0), 2)
plt.figure(figsize=(5,5))
plt.imshow(heatmap, cmap='jet', interpolation='bilinear')
plt.colorbar(label='Heatmap value')
plt.title('Rectangular Gaussian Heatmap')
plt.xlabel('Output X')
plt.ylabel('Output Y')

# Create a Rectangle patch
rect = patches.Rectangle((cx-w*dataset.size_out/2, cy-h*dataset.size_out/2), w*dataset.size_out, h*dataset.size_out, 
                         linewidth=2, edgecolor='white', facecolor='none')

# Add the rectangle to the current axes
plt.gca().add_patch(rect)

plt.show()

In [None]:
print(mask.max())

In [None]:
plt.imshow(cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB))
plt.axis('off')  # Hide axes
plt.show()

In [None]:
bbox1_x1y1x2y2 = x1y1wh_x1y1x2y2(*bbox1)
first_frame_cropped = dataset.crop_roi(first_frame, bbox1_x1y1x2y2, dataset.size_template)

In [None]:
plt.imshow(cv2.cvtColor(first_frame_cropped, cv2.COLOR_BGR2RGB))
plt.axis('off')  # Hide axes
plt.show()

In [None]:
x1, y1, x2, y2 = bbox1
cv2.rectangle(first_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
plt.imshow(cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB))
plt.axis('off')  # Hide axes
plt.show()

In [None]:
cropped = first_frame[y:y+h, x:x+w]
cropped = cv2.resize(cropped, (dataset.size_template, dataset.size_template))
plt.imshow(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
plt.axis('off')  # Hide axes
plt.show()

In [None]:
cropped.shape