<a href="https://colab.research.google.com/github/KuiMian/ForTest/blob/master/DualStream.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
!pip install torchmetrics pytorch_lightning



In [5]:
from glob import glob
import os
import pandas as pd
from PIL import Image

import torch
from torchvision import transforms as T
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch import nn
from torchvision.models import resnet18
from torchvision.models.video import r3d_18
from torchmetrics import Accuracy
from pytorch_lightning import LightningModule

class FrameImageDataset(torch.utils.data.Dataset):
    def __init__(self,
    root_dir='ucf10',
    split='train',
    transform=None
):
        self.frame_paths = sorted(glob(f'{root_dir}/frames/{split}/*/*/*.jpg'))
        self.df = pd.read_csv(f'{root_dir}/metadata/{split}.csv')
        self.split = split
        self.transform = transform

    def __len__(self):
        return len(self.frame_paths)

    def _get_meta(self, attr, value):
        return self.df.loc[self.df[attr] == value]

    def __getitem__(self, idx):
        frame_path = self.frame_paths[idx]
        video_name = frame_path.split('/')[-2]
        video_meta = self._get_meta('video_name', video_name)
        label = video_meta['label'].item()

        frame = Image.open(frame_path).convert("RGB")

        if self.transform:
            frame = self.transform(frame)
        else:
            frame = T.ToTensor()(frame)

        return frame, label


class FrameVideoDataset(torch.utils.data.Dataset):
    def __init__(self,
    root_dir = 'ucf10',
    split = 'train',
    transform = None,
    stack_frames = True
):

        self.video_paths = sorted(glob(f'{root_dir}/videos/{split}/*/*.avi'))
        self.df = pd.read_csv(f'{root_dir}/metadata/{split}.csv')
        self.split = split
        self.transform = transform
        self.stack_frames = stack_frames

        self.n_sampled_frames = 10

    def __len__(self):
        return len(self.video_paths)

    def _get_meta(self, attr, value):
        return self.df.loc[self.df[attr] == value]

    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        video_name = video_path.split('/')[-1].split('.avi')[0]
        video_meta = self._get_meta('video_name', video_name)
        label = video_meta['label'].item()

        video_frames_dir = self.video_paths[idx].split('.avi')[0].replace('videos', 'frames')
        video_frames = self.load_frames(video_frames_dir)

        if self.transform:
            frames = [self.transform(frame) for frame in video_frames]
        else:
            frames = [T.ToTensor()(frame) for frame in video_frames]

        if self.stack_frames:
            frames = torch.stack(frames).permute(1, 0, 2, 3)


        return frames, label

    def load_frames(self, frames_dir):
        frames = []
        for i in range(1, self.n_sampled_frames + 1):
            frame_file = os.path.join(frames_dir, f"frame_{i}.jpg")
            frame = Image.open(frame_file).convert("RGB")
            frames.append(frame)

        return frames

In [14]:
class DualStreamModel(LightningModule):
    """
    Dual-stream model using same frames as input.
    Temporal stream uses simple frame differences as motion cue.
    """
    def __init__(self, num_classes=10, pretrained=True):
        super().__init__()
        self.num_classes = num_classes

        self.spatial_cnn = resnet18(weights="IMAGENET1K_V1" if pretrained else None)
        self.spatial_cnn.fc = nn.Identity()

        self.temporal_cnn = resnet18(weights="IMAGENET1K_V1" if pretrained else None)
        self.temporal_cnn.fc = nn.Identity()

        self.classifier = nn.Linear(512*2, num_classes)

        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)

    def forward(self, frames):
        B, C, T, H, W = frames.shape

        rgb = frames.permute(0, 2, 1, 3, 4).reshape(B*T, C, H, W)
        spatial_feat = self.spatial_cnn(rgb)
        spatial_feat = spatial_feat.view(B, T, -1).mean(dim=1)

        frame_diff = frames[:, :, 1:, :, :] - frames[:, :, :-1, :, :]
        frame_diff = torch.cat([torch.zeros(B, C, 1, H, W, device=frames.device), frame_diff], dim=2)
        diff = frame_diff.permute(0, 2, 1, 3, 4).reshape(B*T, C, H, W)
        temporal_feat = self.temporal_cnn(diff)
        temporal_feat = temporal_feat.view(B, T, -1).mean(dim=1)

        fused = torch.cat([spatial_feat, temporal_feat], dim=1)
        logits = self.classifier(fused)
        return logits

    def training_step(self, batch, batch_idx):
        frames, labels = batch
        logits = self(frames)
        loss = self.criterion(logits, labels)
        acc = self.accuracy(F.softmax(logits, dim=-1), labels)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        frames, labels = batch
        logits = self(frames)
        loss = self.criterion(logits, labels)
        acc = self.accuracy(F.softmax(logits, dim=-1), labels)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        frames, labels = batch
        logits = self(frames)
        loss = self.criterion(logits, labels)
        acc = self.accuracy(F.softmax(logits, dim=-1), labels)
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=3e-4, weight_decay=1e-4)

model = DualStreamModel()
x = torch.rand(1, 3, 10, 64, 64)  # [batch, channels, number of frames, height, width]
print(f"Output shape of model: {model(x).shape}")

Output shape of model: torch.Size([1, 10])


In [13]:
!unzip -q drive/MyDrive/ucf101_noleakage.zip -d .

replace ./ucf101_noleakage/flows/val/HandstandPushups/v_HandStandPushups_g16_c06/flow_9_10.npy? [y]es, [n]o, [A]ll, [N]one, [r]ename: N


In [15]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger

logger = TensorBoardLogger("tb_logs", name="DualStream")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 30

root_dir = 'ucf101_noleakage'

transform = T.Compose([T.Resize((64, 64)),T.ToTensor()])

framevideostack_dataset_train = FrameVideoDataset(
    root_dir=root_dir, split="train", transform=transform, stack_frames=True
)
framevideostack_dataset_val = FrameVideoDataset(
    root_dir=root_dir, split="val", transform=transform, stack_frames=True
)

framevideostack_dataset_test = FrameVideoDataset(
    root_dir=root_dir, split="test", transform=transform, stack_frames=True
)

train_loader = DataLoader(framevideostack_dataset_train, batch_size=8, shuffle=True, num_workers=4)
val_loader = DataLoader(framevideostack_dataset_val, batch_size=8, shuffle=False, num_workers=4)
test_loader = DataLoader(framevideostack_dataset_test, batch_size=8, shuffle=False, num_workers=4)

model = DualStreamModel()
trainer = Trainer(max_epochs=epochs, accelerator=device.type, logger=logger)
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.test(model, dataloaders=test_loader)

INFO:pytorch_lightning.utilities.rank_zero:💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name         | Type               | Params | Mode 
------------------------------------------------------------
0 | spatial_cnn  | ResNet             | 11.2 M | train
1 | temporal_cnn | ResNet             | 11.2 M | train
2 | classifier   | Linear             | 10.2 K | train
3 | criterion    | CrossEntropyLoss   | 0      | train
4 | accuracy     | MulticlassAccuracy | 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=30` reached.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 1.8741271495819092, 'test_acc': 0.6916666626930237}]