Skip to content

Commit

Permalink
Merge pull request #2 from RaivoKoot/frame_ranges
Browse files Browse the repository at this point in the history
Changed NUM_FRAMES in annotations.txt to START and END frame for usin…
  • Loading branch information
RaivoKoot committed Dec 9, 2020
2 parents caecde6 + 3e6c74d commit dcbd456
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 42 deletions.
48 changes: 31 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ For a demo, visit `demo.py`.
### QuickDemo (demo.py)
```python
root = os.path.join(os.getcwd(), 'demo_dataset') # Folder in which all videos lie in a specific structure
annotation_file = os.path.join(root, 'annotations.txt') # A row for each video sample as: (VIDEO_PATH NUM_FRAMES CLASS_INDEX)
annotation_file = os.path.join(root, 'annotations.txt') # A row for each video sample as: (VIDEO_PATH START_FRAME END_FRAME CLASS_INDEX)

""" DEMO 1 WITHOUT IMAGE TRANSFORMS """
dataset = VideoFrameDataset(
Expand Down Expand Up @@ -73,12 +73,13 @@ python >= 3.6
### 2. Custom Dataset
To use any dataset, two conditions must be met.
1) The video data must be supplied as RGB frames, each frame saved as an image file. Each video must have its own folder, in which the frames of
that video lie. The frames of a video inside its folder must be named uniformly as `img_00001.jpg` ... `img_00120.jpg`, if there are 120 frames. The filename template
for frames is then "img_{:05d}.jpg" (python string formatting, specifying 5 digits after the underscore), and must be supplied to the
constructor of VideoFrameDataset as a parameter. Each video folder lies inside a `root` folder of this dataset.
that video lie. The frames of a video inside its folder must be named uniformly with consecutive indices such as `img_00001.jpg` ... `img_00120.jpg`, if there are 120 frames.
Indices can start at zero or any other number and the exact file name template can be chosen freely. The filename template
for frames in this example is "img_{:05d}.jpg" (python string formatting, specifying 5 digits after the underscore), and must be supplied to the
constructor of VideoFrameDataset as a parameter. Each video folder must lie inside some `root` folder.
2) To enumerate all video samples in the dataset and their required metadata, a `.txt` annotation file must be manually created that contains a row for each
video sample in the dataset. The training, validation, and testing datasets must have separate annotation files. Each row must be a space-separated list that contains
`VIDEO_PATH NUM_FRAMES CLASS_INDEX`. The `VIDEO_PATH` of a video sample should be provided without the `root` prefix of this dataset.
video sample or video clip (in case of clips for action recognition for example) in the dataset. The training, validation, and testing datasets must have separate annotation files. Each row must be a space-separated list that contains
`VIDEO_PATH START_FRAME END_FRAME CLASS_INDEX`. The `VIDEO_PATH` of a video sample should be provided without the `root` prefix of this dataset.

This example project demonstrates this using a dummy dataset inside of `demo_dataset/`, which is the `root` dataset folder of this example. The folder
structure looks as follows:
Expand Down Expand Up @@ -108,19 +109,30 @@ demo_dataset
```
The accompanying annotation `.txt` file contains the following rows
The accompanying annotation `.txt` file contains the following rows (PATH, START_FRAME, END_FRAME, LABEL_ID)
```
jumping/0001 17 0
jumping/0002 18 0
running/0001 15 1
running/0002 15 1
jumping/0001 1 17 0
jumping/0002 1 18 0
running/0001 1 15 1
running/0002 1 15 1
```
Another annotations file that uses multiple clips from each video could be
```
jumping/0001 1 8 0
jumping/0001 5 17 0
jumping/0002 1 18 0
running/0001 10 15 1
running/0001 5 10 1
running/0002 1 15 1
```
(END_FRAME is inclusive)

Instantiating a VideoFrameDataset with the `root_path` parameter pointing to `demo_dataset`, the `annotationsfile_path` parameter pointing to the annotation file, and
the `imagefile_template` parameter as "img_{:05d}.jpg", is all that it takes to start using the VideoFrameDataset class.

### 3. Video Frame Sampling Method
When loading a video, only a number of its frames are loaded. They are chosen in the following way:
1. The frame indices [1,N] are divided into NUM_SEGMENTS even segments. From each segment, a random start-index is sampled from which FRAMES_PER_SEGMENT consecutive indices are loaded.
1. The frame index range [START_FRAME, END_FRAME] is divided into NUM_SEGMENTS even segments. From each segment, a random start-index is sampled from which FRAMES_PER_SEGMENT consecutive indices are loaded.
This results in NUM_SEGMENTS*FRAMES_PER_SEGMENT chosen indices, whose frames are loaded as PIL images and put into a list and returned when calling
`dataset[i]`.
![alt text](https://github.com/RaivoKoot/images/blob/main/Sparse_Temporal_Sampling.jpg "Sparse-Temporal-Sampling-Strategy")
Expand All @@ -129,27 +141,29 @@ This results in NUM_SEGMENTS*FRAMES_PER_SEGMENT chosen indices, whose frames are
If you do not want to use sparse temporal sampling and instead want to sample a single N-frame continuous
clip from a video, this is possible. Set `NUM_SEGMENTS=1` and `FRAMES_PER_SEGMENT=N`. Because VideoFrameDataset
will chose a random start index per segment and take `NUM_SEGMENTS` continuous frames from each sampled start
index, this will result in a single N-frame continuous clip per video. An example of this is in `demo.py`.
index, this will result in a single N-frame continuous clip per video that starts at a random index.
An example of this is in `demo.py`.

### 5. Using VideoFrameDataset for training
As demonstrated in `demo.py`, we can use PyTorch's `torch.utils.data.DataLoader` class with VideoFrameDataset to take care of shuffling, batching, and more.
To turn the lists of PIL images returned by VideoFrameDataset into tensors, the transform `video_dataset.ImglistToTensor()` can be supplied
as the `transform` parameter to VideoFrameDataset. This turns a list of N PIL images into a batch of images/frames of shape `N x CHANNELS x HEIGHT x WIDTH`.
We can further chain preprocessing and augmentation functions that act on batches of images onto the end of `ImglistToTensor()`.
We can further chain preprocessing and augmentation functions that act on batches of images onto the end of `ImglistToTensor()`, as seen in `demo.py`

As of `torchvision 0.8.0`, all torchvision transforms can now also operate on batches of images, and they apply deterministic or random transformations
on the batch identically on all images of the batch. Therefore, any torchvision transform can be used here to apply video-uniform preprocessing and augmentation.
on the batch identically on all images of the batch. Because a single video-tensor (FRAMES x CHANNELS x HEIGHT x WIDTH)
has the same shape as an image batch tensor (BATCH x CHANNELS x HEIGHT x WIDTH), any torchvision transform can be used here to apply video-uniform preprocessing and augmentation.

REMEMBER:
Pytorch transforms are applied to individual dataset samples (in this case a video frame PIL list, or a frame tensor after `ImglistToTensor()`) before
Pytorch transforms are applied to individual dataset samples (in this case a list of PIL images of a video, or a video-frame tensor after `ImglistToTensor()`) before
batching. So, any transforms used here must expect its input to be a frame tensor of shape `FRAMES x CHANNELS x HEIGHT x WIDTH` or a list of PIL images if `ImglistToTensor()` is not used.
### 6. Conclusion
A proper code-based explanation on how to use VideoFrameDataset for training is provided in `demo.py`

### 7. Upcoming Features
- [x] Add demo for sampling a single continous-frame clip from videos.
- [ ] Add support for arbitrary labels that are more than just a single integer.
- [ ] Add support for specifying START_FRAME and END_FRAME for a video instead of NUM_FRAMES.
- [x] Add support for specifying START_FRAME and END_FRAME for a video instead of NUM_FRAMES.

### 8. Acknowledgements
We thank the authors of TSN for their [codebase](https://github.com/yjxiong/tsn-pytorch), from which we took VideoFrameDataset and adapted it
Expand Down
2 changes: 1 addition & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def denormalize(video_tensor):
dataset=dataset,
batch_size=2,
shuffle=True,
num_workers=8,
num_workers=4,
pin_memory=True
)

Expand Down
8 changes: 4 additions & 4 deletions demo_dataset/annotations.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
jumping/0001 17 0
jumping/0002 18 0
running/0001 15 1
running/0002 15 1
jumping/0001 1 17 0
jumping/0002 1 18 0
running/0001 1 15 1
running/0002 1 15 1
54 changes: 34 additions & 20 deletions video_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from PIL import Image
from torchvision import transforms
import torch
from collections.abc import Callable

class VideoRecord(object):
"""
Expand All @@ -14,10 +13,11 @@ class VideoRecord(object):
Args:
root_datapath: the system path to the root folder
of the videos.
row: A list with three elements where 1) The first
row: A list with four elements where 1) The first
element is the path to the video sample's frames excluding
the root_datapath prefix 2) The second element is the number
of frames in the video 3) The third element is the label index.
the root_datapath prefix 2) The second element is the starting frame id of the video
3) The third element is the inclusive ending frame id of the video
4) The fourth element is the label index.
"""
def __init__(self, row, root_datapath):
self._data = row
Expand All @@ -30,11 +30,17 @@ def path(self):

@property
def num_frames(self):
return self.end_frame() - self.start_frame() + 1 # +1 because end frame is inclusive

def start_frame(self):
return int(self._data[1])

def end_frame(self):
return int(self._data[2])

@property
def label(self):
return int(self._data[2])
return int(self._data[3])

class VideoFrameDataset(torch.utils.data.Dataset):
r"""
Expand All @@ -46,8 +52,8 @@ class VideoFrameDataset(torch.utils.data.Dataset):
tensors where FRAMES=x if the ``ImglistToTensor()``
transform is used.
More specifically, the frame range [0,N] is divided into NUM_SEGMENTS
segments and FRAMES_PER_SEGMENT frames are taken from each segment.
More specifically, the frame range [START_FRAME, END_FRAME] is divided into NUM_SEGMENTS
segments and FRAMES_PER_SEGMENT consecutive frames are taken from each segment.
Note:
A demonstration of using this class can be seen
Expand All @@ -65,11 +71,11 @@ class VideoFrameDataset(torch.utils.data.Dataset):
inside a ``ROOT_DATA`` folder, each video lies in its own folder,
where each video folder contains the frames of the video as
individual files with a naming convention such as
img_001.jpg ... img_059.jpg. Numbering must start at 1.
img_001.jpg ... img_059.jpg.
For enumeration and annotations, this class expects to receive
the path to a .txt file where each video sample has a row with three
the path to a .txt file where each video sample has a row with four
space separated values:
``VIDEO_FOLDER_PATH NUM_FRAMES LABEL_INDEX``.
``VIDEO_FOLDER_PATH START_FRAME END_FRAME LABEL_INDEX``.
``VIDEO_FOLDER_PATH`` is expected to be the path of a video folder
excluding the ``ROOT_DATA`` prefix. For example, ``ROOT_DATA`` might
be ``home\data\datasetxyz\videos\``, inside of which a ``VIDEO_FOLDER_PATH``
Expand Down Expand Up @@ -138,16 +144,16 @@ def _sample_indices(self, record):
segment are to be loaded from.
"""

average_duration = (record.num_frames - self.frames_per_segment + 1) // self.num_segments
if average_duration > 0:
offsets = np.multiply(list(range(self.num_segments)), average_duration) + np.random.randint(average_duration, size=self.num_segments)
segment_duration = (record.num_frames - self.frames_per_segment + 1) // self.num_segments
if segment_duration > 0:
offsets = np.multiply(list(range(self.num_segments)), segment_duration) + np.random.randint(segment_duration, size=self.num_segments)

# edge cases for when a video only has a tiny number of frames.
elif record.num_frames > self.num_segments:
offsets = np.sort(np.random.randint(record.num_frames - self.frames_per_segment + 1, size=self.num_segments))
# edge cases for when a video has approximately less than (num_frames*frames_per_segment) frames.
# random sampling in that case, which will lead to repeated frames.
else:
offsets = np.zeros((self.num_segments,))
return offsets + 1
offsets = np.sort(np.random.randint(record.num_frames, size=self.num_segments))

return offsets

def _get_val_indices(self, record):
"""
Expand All @@ -163,7 +169,8 @@ def _get_val_indices(self, record):

# edge case for when a video does not have enough frames
else:
offsets = np.zeros((self.num_segments,)) + 1
offsets = np.sort(np.random.randint(record.num_frames, size=self.num_segments))

return offsets

def _get_test_indices(self, record):
Expand All @@ -180,7 +187,7 @@ def _get_test_indices(self, record):

offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)])

return offsets + 1
return offsets

def __getitem__(self, index):
"""
Expand Down Expand Up @@ -218,15 +225,22 @@ def _get(self, record, indices):
2) An integer denoting the video label.
"""

indices = indices + record.start_frame()
images = list()
image_indices = list()
for seg_ind in indices:
frame_index = int(seg_ind)
for i in range(self.frames_per_segment):
seg_img = self._load_image(record.path, frame_index)
images.extend(seg_img)
image_indices.append(frame_index)
if frame_index < record.num_frames:
frame_index += 1

# sort images by index in case of edge cases where segments overlap each other because the overall
# video is too short for num_segments*frames_per_segment indices.
_, images = (list(sorted_list) for sorted_list in zip(*sorted(zip(image_indices, images))))

if self.transform is not None:
images = self.transform(images)

Expand Down

0 comments on commit dcbd456

Please sign in to comment.