In [1]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Resize
import matplotlib.pyplot as plt
import os

import glob as glob
import torchvision
from torchvision.io import read_video as read_video
from torchvision.io import write_video as write_video
import torch.nn as nn
import torch.nn.functional as F

In [2]:
root_path = r"C:\Users\Rafael Wersom\Documents\CS196\RWF-2000 Dataset_short\*"

In [3]:
class CustomImageDataset(Dataset):
    def __init__(self, train=True):
        self.n_frames_per_video = 150
        
        self.paths = []
        folders = glob.glob(root_path)
        for folder in folders:
            if train == ("train" in folder):
                sub_folders = glob.glob(folder + "\*")
                for sub_folder in sub_folders:
                    files = glob.glob(sub_folder + "\*.avi")
                    for file in files:
                        self.paths.append(file)
        self.data = []
        for path in self.paths:
            self.data.append(read_video(path))
            
        self.resize = Resize((224, 224))

    def __len__(self):
        return len(self.paths)*self.n_frames_per_video

    def __getitem__(self, idx):
        path_idx = idx // self.n_frames_per_video
        frame_idx = idx % self.n_frames_per_video
        
        path = self.paths[path_idx]
        
        video_data = self.data[path_idx][0]
        frame_data = video_data[frame_idx]
        
        label = not 'NonFight' in path
        return frame_data, label

In [6]:
class CustomC3DNet(nn.Module):
    def __init__(self,
                 sample_size,
                 sample_duration,
                 num_classes=1):

        super(CustomC3DNet, self).__init__()
        self.group1 = nn.Sequential(
            Conv3d(3, 4, 3, padding=1),
            nn.BatchNorm3d(4),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 2, 2)))
        self.group2 = nn.Sequential(
            nn.Conv3d(4, 8, 3, padding=1),
            nn.BatchNorm3d(8),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)))
        self.group3 = nn.Sequential(
            nn.Conv3d(8, 16, 3, padding=1),
            nn.BatchNorm3d(16),
            nn.ReLU(),
            nn.Conv3d(16, 16, 3, padding=1),
            nn.BatchNorm3d(16),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)))
        self.group4 = nn.Sequential(
            nn.Conv3d(16, 32, 3, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.Conv3d(32, 32, 3, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)))
        self.group5 = nn.Sequential(
            nn.Conv3d(32, 64, 3, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.Conv3d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1)))
        self.group6 = nn.Sequential(
            nn.Conv3d(64, 128, 3, padding=1),
            nn.BatchNorm3d(128),
            nn.ReLU(),
            nn.Conv3d(128, 128, 3, padding=1),
            nn.BatchNorm3d(128),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1)))

        last_duration = int(math.floor(sample_duration / 16))
        last_size = int(math.ceil(sample_size / 32))
        self.fc1 = nn.Sequential(
            nn.Linear((246 * last_duration * last_size * last_size) , 4096),
            nn.ReLU(),
            nn.Dropout(0.5))
        self.fc2 = nn.Sequential(
            nn.Linear(2048, 2048),
            nn.ReLU(),
            nn.Dropout(0.5))
        self.fc = nn.Sequential(
            nn.Linear(2048, num_classes))         

        

    def forward(self, x):
        
        out = self.group1(x)
        out = self.group2(out)
        out = self.group3(out)
        out = self.group4(out)
        out = self.group5(out)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.fc2(out)
        out = self.fc(out)
        return out
