In [1]:
import random
import gc
import os

from make_a_video_pytorch import PseudoConv3d, SpatioTemporalAttention
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from make_a_video_pytorch import SpaceTimeUnet
from IPython.display import display, HTML
from IPython.display import Image as Img
import matplotlib.animation as animation
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import numpy as np
import torch
import cv2

In [2]:
torch.cuda.empty_cache()
gc.collect()

10

In [3]:
SEP       = os.path.sep
ROOT_PATH = SEP.join(os.getcwd().split(SEP)[:-1])
DATA_PATH = f'{ROOT_PATH}/DataSets/MAV/UCF101'

In [4]:
os.environ["CUDA_DEVICE_ORDER"]    = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f'             DEVICE : {DEVICE}')
print(f'Current cuda DEVICE : {torch.cuda.current_device()}')
print(f'Count of using GPUs : {torch.cuda.device_count()}')

             DEVICE : cuda
Current cuda DEVICE : 0
Count of using GPUs : 2




In [5]:
def plot_sequence_images(images):
    
    dpi  = 72.0
    H, W = images[0].shape[:2]

    fig   = plt.figure(figsize = (H / dpi, W / dpi), dpi = dpi)
    image = plt.figimage(images[0])
    
    def animate(idx):
        
        image.set_array(images[idx])
        return (image, )
    
    anim = animation.FuncAnimation(fig, animate, frames = len(images),
                                   interval = 90, repeat_delay = 1, repeat = True)
    
    display(HTML(anim.to_html5_video()))

In [6]:
video_paths = os.listdir(DATA_PATH)
random_idx  = random.randint(0, len(video_paths))
sample_path = f'{DATA_PATH}/{video_paths[random_idx]}'
sample_path

'/home/jovyan/dove/projects/DataSets/MAV/UCF101/v_SoccerJuggling_g17_c05.avi'

In [7]:
cap = cv2.VideoCapture(sample_path)

In [8]:
frames = []
while cap.isOpened():
    
    ret, frame = cap.read()    
    if ret == False: break
    
    frame = cv2.resize(frame, (128, 128))
    frames.append(frame)

In [9]:
print(len(frames))
plot_sequence_images(frames)

300


<Figure size 128x128 with 0 Axes>

In [10]:
train_paths, test_paths = train_test_split(video_paths, test_size = 0.2, 
                                           shuffle = True, random_state = 99)

torch.cuda.empty_cache()
gc.collect()

166

In [11]:
class UCF101(Dataset):
   
    def __init__(self, video_paths, input_size = 32):
        
        self.videos, self.labels = [], []
        for video_path in video_paths:
            
            label = video_path.split('_')[1]
            
            cap   = cv2.VideoCapture(sample_path)
            
            video = []
            while cap.isOpened():

                ret, frame = cap.read()    
                if ret == False: break

                frame = cv2.resize(frame, (128, 128))
                video.append(frame)
                
            self.labels.append(label)
            self.videos.append(video)
            
    def __len__(self): return len(self.labels)

    
    def __getitem__(self, idx):

        video = np.array(self.videos[idx]).transpose(3, 0, 1, 2)
        video = torch.tensor(video, dtype = torch.float32).unsqueeze(0) / 255
        label = self.labels[idx]
        
        return video, label

In [None]:
train_dataset = UCF101(train_paths)
train_loader  = DataLoader(train_dataset, batch_size = 4, shuffle = True)

In [None]:
sample         = next(iter(train_loader))
videos, labels = sample

video = videos[0].squeeze().detach().cpu().numpy()
video = video.transpose(1, 2, 3, 0)
plot_sequence_images(video)

In [None]:
unet = SpaceTimeUnet(
            dim      = 64,  channels = 3,
            dim_mult =      (1, 2, 4, 8),
            resnet_block_depths   = (1, 1, 1, 2),
            temporal_compression  = (False, False, False, True),
            self_attns            = (False, False, False, True),
            condition_on_timestep = False,
            attn_pos_bias         = False,
            flash_attn            = True
        ).to(DEVICE)