In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, models
from torchvision.io import read_video
from torchvision.datasets import VideoDataset
from torch.utils.data import random_split

In [None]:
batch_size = 4
learning_rate = 1e-3
num_epochs = 10
num_classes = 1

In [None]:
def video_loader(path):
    
    video, _, _ = read_video(path)
    video = video[::5]
    video = torch.stack([transform(frame) for frame in video])
    
    return video

In [None]:
transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

In [None]:
class CustomVideoDataset(VideoDataset):
    
    def __init__(self, video_paths, labels, transform=None):
        self.video_paths = video_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        video = video_loader(self.video_paths[idx])
        label = self.labels[idx]
        
        return video, label


In [None]:
dataset = CustomVideoDataset(video_paths, labels, transform=transform)

In [None]:
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

In [None]:
class VideoClassifier(nn.Module):
    
    def __init__(self, num_classes):
        super(VideoClassifier, self).__init__()
            self.resnet = models.resnet50(pretrained=True)
            self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)
            self.lstm = nn.LSTM(input_size=num_classes, hidden_size=128, num_layers=1, batch_first=True)
            self.fc = nn.Linear(128, num_classes)

        def forward(self, x):
            # x.size() = (batch_size, num_frames, C, H, W)
            batch_size, num_frames, C, H, W = x.size()
            x = x.view(batch_size * num_frames, C, H, W)
            x = self.resnet(x)
            x = x.view(batch_size, num_frames, -1)
            x, _ = self.lstm(x)
            x = self.fc(x[:, -1, :])
            
            return x