# This is a copy of the code that I ran in Kaggle in order to train the model.

### Here I am loading in the modules pytorchvideo and easyfsl to help me train the model.

In [None]:
! pip install pytorchvideo
! pip install easyfsl

### Here I am initializing data processing and loading functions. I am also initializing the ASLDataset Class for later use.

import torchvision 
from torch.utils.data import Dataset
import pytorchvideo

def compile_data(root, requested_split, split_file, num_classes):
    labels = open(split_file, "r")
    labels = json.load(labels)
    data = []
    
    for id in tqdm(labels.keys()):
        if requested_split == "train":
            if labels[id]["subset"] not in ["train", "val"]:
                continue
        else:
            if labels[id]["subset"] != "test":
                continue
                
        path = os.path.join(root, id + ".mp4")
        
        if not os.path.exists(path):
            continue
            
        frames = int(cv2.VideoCapture(path).get(cv2.CAP_PROP_FRAME_COUNT))
        
        if frames < 9: 
            continue
        
        cls_label = labels[id]["action"][0]

        start_frame = labels[id]["action"][1] - 1
        num_frames = labels[id]["action"][2] - labels[id]["action"][1] + 1
        
        if num_frames < 16:
            continue
                
        data.append((id, cls_label, start_frame, num_frames))
    
    
    return data, len(data)
        

def load_frames_from_video(path, start_frame, num_frames):
            cap = cv2.VideoCapture(path) 
            interval = num_frames // 16
            current_frame = start_frame + 1
            cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame); 
            frames = []
            for i in range(16):
                ret, frame = cap.read()
                current_frame += interval
                frame = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                #print(type(frame))
                frames.append(frame)
            result = torch.stack(frames)
            return result

class ASLDataset(Dataset):
    def __init__(self, root, requested_split, split_file, num_classes, transforms):
        self.data, self.length = compile_data(root, requested_split, split_file, num_classes)
        self.root = root
        self.requested_split = requested_split
        self.split_file = split_file
        self.num_classes = num_classes
        self.do = transforms
    def __getitem__(self, index):
        
        vid_id, label, start_frame, num_frames = self.data[index]
        
        path = os.path.join(self.root, vid_id + ".mp4")

        #if num_frames > len(timestamps):
        #    end_frame = len(timestamps) - 1
        #else:
        #    end_frame = start_frame + num_frames - 1
        
        imgs = load_frames_from_video(path, start_frame, num_frames).permute(3, 0, 1, 2)
        #print(imgs.shape)
        data = {"video": imgs}
        data = self.do(data)
        
        return data["video"], label
    def __len__(self):
        return self.length

### Here I am initializing the transforms for the train and the test datasets. They are slightly modified from the official PyTorch given transforms.

In [None]:
# Transforms derived from the PyTorch official X3D training examples: https://pytorch.org/hub/facebookresearch_pytorchvideo_x3d/
from torchvision.transforms import Compose, Lambda, RandomCrop, RandomHorizontalFlip
from torchvision.transforms._transforms_video import (
    CenterCropVideo,
    NormalizeVideo,
    RandomCropVideo,
    RandomHorizontalFlipVideo
)
from pytorchvideo.transforms import (
    ApplyTransformToKey,
    ShortSideScale,
    UniformTemporalSubsample, 
    Normalize, 
    RandomShortSideScale,
)

mean = [0.45, 0.45, 0.45]
std = [0.225, 0.225, 0.225]

transform_params = {
    "side_size":  256,
    "num_frames": 16, 
    "crop_size": 256, 
    
}

train_transform = ApplyTransformToKey(
    key = "video",
    transform = Compose(
        [
             #UniformTemporalSubsample(transform_params["num_frames"]),
            Lambda(lambda x: x/255.0),
            #Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
            RandomShortSideScale(min_size=256, max_size=320),
            RandomCrop(256),
            RandomHorizontalFlip(p=0.5),
        ]
    ),
)

test_transform =  ApplyTransformToKey(
    key="video",
    transform=Compose(
        [
            #UniformTemporalSubsample(transform_params["num_frames"]),
            Lambda(lambda x: x/255.0),
            #NormalizeVideo(mean, std),
            ShortSideScale(size=transform_params["side_size"]),
            CenterCropVideo(
                crop_size=(transform_params["crop_size"], transform_params["crop_size"])
            )
        ]
    ),
)