# Copy of the code run in Kaggle to train the model using few-shot classification.

### Paths are based on Kaggle input/working directories

#### Loading in the modules pytorchvideo and easyfsl to help train the model.

In [None]:
! pip install pytorchvideo
! pip install easyfsl

### Initializing data processing and loading functions, including custom dataset class.

In [None]:
import torchvision 
from torch.utils.data import Dataset
import pytorchvideo

def compile_data(root, requested_split, split_file, num_classes):
    labels = open(split_file, "r")
    labels = json.load(labels)
    data = []
    
    for id in tqdm(labels.keys()):
        if requested_split == "train":
            if labels[id]["subset"] not in ["train", "val"]:
                continue
        else:
            if labels[id]["subset"] != "test":
                continue
                
        path = os.path.join(root, id + ".mp4")
        
        if not os.path.exists(path):
            continue
            
        frames = int(cv2.VideoCapture(path).get(cv2.CAP_PROP_FRAME_COUNT))
        
        if frames < 9: 
            continue
        
        cls_label = labels[id]["action"][0]

        start_frame = labels[id]["action"][1] - 1
        num_frames = labels[id]["action"][2] - labels[id]["action"][1] + 1
        
        if num_frames < 16:
            continue
                
        data.append((id, cls_label, start_frame, num_frames))
    
    
    return data, len(data)
        

def load_frames_from_video(path, start_frame, num_frames):
            cap = cv2.VideoCapture(path) 
            interval = num_frames // 16
            current_frame = start_frame + 1
            cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame); 
            frames = []
            for i in range(16):
                ret, frame = cap.read()
                current_frame += interval
                frame = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                #print(type(frame))
                frames.append(frame)
            result = torch.stack(frames)
            return result

class ASLDataset(Dataset):
    def __init__(self, root, requested_split, split_file, num_classes, transforms):
        self.data, self.length = compile_data(root, requested_split, split_file, num_classes)
        self.root = root
        self.requested_split = requested_split
        self.split_file = split_file
        self.num_classes = num_classes
        self.do = transforms
    def __getitem__(self, index):
        
        vid_id, label, start_frame, num_frames = self.data[index]
        
        path = os.path.join(self.root, vid_id + ".mp4")

        #if num_frames > len(timestamps):
        #    end_frame = len(timestamps) - 1
        #else:
        #    end_frame = start_frame + num_frames - 1
        
        imgs = load_frames_from_video(path, start_frame, num_frames).permute(3, 0, 1, 2)
        #print(imgs.shape)
        data = {"video": imgs}
        data = self.do(data)
        
        return data["video"], label
    def __len__(self):
        return self.length

### Initializing the transforms for the train and the test datasets. 
##### Slightly modified from the official PyTorch given transforms.

In [None]:
# Transforms derived from the PyTorch official X3D training examples: https://pytorch.org/hub/facebookresearch_pytorchvideo_x3d/
from torchvision.transforms import Compose, Lambda, RandomCrop, RandomHorizontalFlip
from torchvision.transforms._transforms_video import (
    CenterCropVideo,
    NormalizeVideo,
    RandomCropVideo,
    RandomHorizontalFlipVideo
)
from pytorchvideo.transforms import (
    ApplyTransformToKey,
    ShortSideScale,
    UniformTemporalSubsample, 
    Normalize, 
    RandomShortSideScale,
)

mean = [0.45, 0.45, 0.45]
std = [0.225, 0.225, 0.225]

transform_params = {
    "side_size":  256,
    "num_frames": 16, 
    "crop_size": 256, 
    
}

train_transform = ApplyTransformToKey(
    key = "video",
    transform = Compose(
        [
             #UniformTemporalSubsample(transform_params["num_frames"]),
            Lambda(lambda x: x/255.0),
            #Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
            RandomShortSideScale(min_size=256, max_size=320),
            RandomCrop(256),
            RandomHorizontalFlip(p=0.5),
        ]
    ),
)

test_transform =  ApplyTransformToKey(
    key="video",
    transform=Compose(
        [
            #UniformTemporalSubsample(transform_params["num_frames"]),
            Lambda(lambda x: x/255.0),
            #NormalizeVideo(mean, std),
            ShortSideScale(size=transform_params["side_size"]),
            CenterCropVideo(
                crop_size=(transform_params["crop_size"], transform_params["crop_size"])
            )
        ]
    ),
)

### Importing and defining the model

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from easyfsl.data_tools import TaskSampler
import os
import cv2
import random
import json

model_name = 'x3d_m'
backbone = torch.hub.load('facebookresearch/pytorchvideo:main', model_name, pretrained=True)


backbone.blocks[5].proj = nn.Identity()
backbone.blocks[5].activation = nn.Identity()
backbone.blocks[5].output_pool = nn.Flatten()

class Model(nn.Module):
    def __init__(self, backbone):
        super(Model, self).__init__()
        self.backbone = backbone
        
    def forward(self, support_imgs, support_labels, query_imgs):
        
        support_vector = self.backbone(support_imgs)
        query_vector = self.backbone(query_imgs)
        
        num_uniques = len(torch.unique(support_labels))
        
        prototype_vector = [support_vector[torch.nonzero(support_labels == label)].mean(0) for label in range(num_uniques)]
        prototype_vector = torch.cat(prototype_vector)
        
        scores = torch.cdist(query_vector, prototype_vector) * -1
        
        return scores
    
device = "cuda"    
model = Model(backbone)
model = model.to(device)

### Initializing the dataset and dataloader

In [None]:
n_train_episodes = 10000
n_ways = 2 
n_shots = 2 
n_queries = 2 

root = "../input/wlasl-processed/videos"
split_file = "../input/wlasl-processed/nslt_100.json"
num_classes = 100

train_dataset = ASLDataset(root, "train", split_file, num_classes, test_transform)

train_dataset.labels = [train_dataset.data[i][1] for i in range(train_dataset.length)]

train_sampler = TaskSampler(
    train_dataset, n_way= n_ways, n_shot= n_shots, n_query= n_queries, n_tasks= n_train_episodes
)
train_loader = DataLoader(
    train_dataset,
    batch_sampler=train_sampler,
    num_workers=2,
    pin_memory=False,
    collate_fn=train_sampler.episodic_collate_fn,
)

### Training the model

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model.train()

for step_idx, (support_imgs, support_labels, query_imgs, query_labels, class_ids) in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()
        scores = model(support_imgs.cuda(), support_labels.cuda(), query_imgs.cuda())
        
        loss = criterion(scores, query_labels.cuda())
        loss.backward()
        optimizer.step()
        
        if step_idx % 100 == 0:
            print(f"Step: {step_idx}, Loss: {loss}")
            torch.save(model.state_dict(), f"/kaggle/working/FewShotLearning Step {step_idx}")