In [2]:
import pytorch_lightning
from pytorchvideo.data import LabeledVideoDataset, UniformClipSampler, RandomClipSampler
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Lambda, RandomCrop, RandomHorizontalFlip
from transforms import (
    ApplyTransformToKey,
    UniformTemporalSubsample,
    Normalize,
    RandomShortSideScale,
)

In [3]:
import csv 
def read_csv_to_list_of_tuples(filename: str):
    data = []

    with open(filename, 'r') as file:
        reader = csv.reader(file)
        for row in reader:
            # Assuming the first column contains file paths and the second column contains labels
            filepath = row[0]
            label = row[1]
            data.append((filepath, {"label": int(label)}))

    return data

In [4]:
class DataModule(pytorch_lightning.LightningDataModule):
    def __init__(self, clip_duration=60, batch_size=8, num_workers=8) -> None:
        super().__init__()
        self.CLIP_DURATION = clip_duration
        self.BATCH_SIZE = batch_size
        self.NUM_WORKERS = num_workers
        self.transform = Compose(
            [
                ApplyTransformToKey(
                    key="video",
                    transform=Compose(
                        [
                            UniformTemporalSubsample(8),
                            Lambda(lambda x: x / 255.0),
                            Normalize((0.45, 0.45, 0.45),
                                      (0.225, 0.225, 0.225)),
                            RandomShortSideScale(min_size=256, max_size=320),
                            RandomCrop(244),
                            RandomHorizontalFlip(p=0.5),
                        ]
                    ),
                ),
            ]
        )

    def train_dataloader(self):
        train_dataset = LabeledVideoDataset(
            labeled_video_paths=read_csv_to_list_of_tuples("train.csv"),
            clip_sampler=UniformClipSampler(self.CLIP_DURATION),
            decode_audio=False,
            transform=self.transform
        )

        return DataLoader(dataset=train_dataset, batch_size=self.BATCH_SIZE)

    def val_dataloader(self):
        val_dataset = LabeledVideoDataset(
            labeled_video_paths=read_csv_to_list_of_tuples("val.csv"),
            clip_sampler=UniformClipSampler(self.CLIP_DURATION),
            decode_audio=False,
            transform=self.transform
        )

        return DataLoader(dataset=val_dataset, batch_size=self.BATCH_SIZE)

In [5]:
import timm
model_name = 'inception_v4'
model = timm.create_model(model_name, pretrained=False)

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [7]:
class VideoClassificationLightningModule(pytorch_lightning.LightningModule):
  def __init__(self):
    super().__init__()
    self.model = model

  def forward(self, x):
    return self.model(x)

  def training_step(self, batch, batch_idx):
    y_hat = self.model(batch["video"])
    loss = F.cross_entropy(y_hat, torch.Tensor(batch["label"]))
    self.log("train_loss", loss.item())

    return loss
  
  def validation_step(self, batch, batch_idx):
      y_hat = self.model(batch["video"])
      loss = F.cross_entropy(y_hat, torch.tensor(batch["label"]))
      self.log("val_loss", loss)
      return loss

  def configure_optimizers(self):
      return torch.optim.Adam(self.parameters(), lr=1e-1)

In [None]:
classification_module = VideoClassificationLightningModule()
data_module = DataModule()
trainer = pytorch_lightning.Trainer(max_epochs=10)
trainer.fit(classification_module, data_module)