In [None]:
import pandas as pd
from typing import Dict
import json
import urllib
from torchvision.transforms import Compose, Lambda, Resize
from torchvision.transforms._transforms_video import (
    CenterCropVideo,
    NormalizeVideo,
)
from pytorchvideo.data.encoded_video import EncodedVideo
from pytorchvideo.transforms import (
    ApplyTransformToKey,
    ShortSideScale,
    UniformTemporalSubsample,
    UniformCropVideo
) 
from torch.utils.data import Dataset, DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.optimization import AdamW, get_cosine_schedule_with_warmup

import random
import os
import numpy as np

In [None]:
side_size = 256
mean = [0.45, 0.45, 0.45]
std = [0.225, 0.225, 0.225]
crop_size = 256
num_frames = 32
sampling_rate = 1
frames_per_second = 30
slowfast_alpha = 4
num_clips = 10
num_crops = 3

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
model = torch.hub.load('facebookresearch/pytorchvideo', 'slowfast_r50', pretrained=True)
model.blocks[6].proj = nn.Linear(2304, 3)
model.load_state_dict(torch.load("./slowfast_base.pt"))
model.to(device)

In [None]:
class PackPathway(torch.nn.Module):
    """
    Transform for converting video frames as a list of tensors. 
    """
    def __init__(self):
        super().__init__()
        
    def forward(self, frames: torch.Tensor):
        fast_pathway = frames
        # Perform temporal sampling from the fast pathway.
        slow_pathway = torch.index_select(
            frames,
            1,
            torch.linspace(
                0, frames.shape[1] - 1, frames.shape[1] // slowfast_alpha
            ).long(),
        )
        frame_list = [slow_pathway, fast_pathway]
        return frame_list

In [None]:
class CustomDataset(Dataset):
    def __init__(self, file,device,transform=None, train=True):
        super().__init__()
        self.file = file
        self.len = len(self.file)
        self.device = device
        self.transform = transform
        self.train = train
        self.datalayer = PackPathway()
    
    def __getitem__(self, idx):
        if self.train :
            path = self.file[idx][0]
            label = self.file[idx][1]
            video = EncodedVideo.from_path(path)
            video_data = video.get_clip(start_sec=0, end_sec=int(video.duration))
            try:
                video_data = self.transform(video_data)
                inputs = video_data["video"]
                inputs = [i.to(device) for i in inputs]
            except:
                inputs = [torch.zeros((3, 8, 256, 256)), torch.zeros((3, 32, 256, 256))]
                inputs[0][0][0][0][0] = 100
                inputs = [i.to(device) for i in inputs]
            
            return inputs, label
        else :
            path = self.file[idx][0]
            label = self.file[idx][1]
            video = EncodedVideo.from_path(path)
            video_data = video.get_clip(start_sec=0, end_sec=int(video.duration))
            try:
                video_data = self.transform(video_data)
                inputs = video_data["video"]
                inputs = [i.to(device) for i in inputs]
            except:
                inputs = [torch.zeros((3, 8, 256, 256)), torch.zeros((3, 32, 256, 256))]
                inputs[0][0][0][0][0] = 100
                inputs = [i.to(device) for i in inputs]

            return inputs, label
            

    def __len__(self):
        return self.len

In [None]:
transform =  ApplyTransformToKey(
    key="video",
    transform=Compose(
        [
            UniformTemporalSubsample(num_frames),
            Lambda(lambda x: x/255.0),
            NormalizeVideo(mean, std),
            Resize((side_size, side_size)),
            PackPathway()
        ]
    ),
)

In [None]:
from glob import glob
assult_path = glob("./dataset/assult/*.mp4")
train_assult_path = assult_path[:-213]
test_assult_path = assult_path[-213:]

swoon_path = glob("./dataset/swoon/*.mp4")
train_swoon_path = swoon_path[:int(0.8*len(swoon_path))]
test_swoon_path = swoon_path[int(0.8*len(swoon_path)):]

normal_path = glob("./dataset/normal/*.mp4")
train_normal_path = normal_path[:-213]
test_normal_path = normal_path[-213:]

In [None]:
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(
        nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(
        nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters,
                    lr=5e-5, correct_bias=False)
criterion = nn.CrossEntropyLoss()

In [None]:
def test_loop(dataloader, model, loss_fn):
    print("Valdidation step")
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in tqdm(dataloader, total=len(dataloader)):
            pred = model(X)
            y = y.to(device)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [None]:
from tqdm import tqdm
import random

test_data = list()

for path in test_normal_path:
    test_data.append((path, 0))

for path in test_assult_path:
    test_data.append((path, 1))

for path in test_swoon_path:
    test_data.append((path, 2))

test_dataset = CustomDataset(test_data,device,transform,train=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

for epoch in range(1,21) :
    # 0 : normal, 1 : assult, 2 : swoon

    train_data = list()

    for path in random.sample(train_normal_path, 851):
        train_data.append((path, 0))

    for path in random.sample(train_assult_path, 851):
        train_data.append((path, 1))

    for path in train_swoon_path:
        train_data.append((path, 2))
    
    
    train_dataset = CustomDataset(train_data,device,transform,train=True)
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    
    total_loss = 0
    model.train()
    print("------------TRAIN------------")
    for i, d in tqdm(enumerate(train_loader), total=len(train_loader)): 
        flag = 1
        for j in range(len(d[0][0])):
            if d[0][0][j][0][0][0][0] == 100:
                flag = 0
                break
        if flag:
            data, label = d
            label = label.to(device)
            optimizer.zero_grad()

            output = model(data)
            loss = criterion(output,label)

            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            
            
        
    print("EPOCH:", epoch)
    print("train_loss:{:.6f}".format(total_loss/len(train_loader)))
    
    if epoch % 5 == 0:
        test_loop(test_loader, model, criterion)

        print("Saving model")
        path = "./slowfast_base.pt"
        torch.save(model.state_dict(), path)
