In [1]:
import torch
import os
import cv2
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision.models.video import mvit_v1_b, MViT_V1_B_Weights
from PIL import Image

In [2]:
import torch.nn.functional as F

class VideoDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, transform=None, max_frames=16):
        self.data_dir = data_dir
        self.classes = sorted(os.listdir(data_dir))
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        self.video_paths = self._get_video_paths()
        self.transform = transform
        self.max_frames = max_frames

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        video_path, label = self.video_paths[idx]
        frames = self.load_frames(video_path)
        if self.max_frames is not None:
            frames = self._process_frames(frames)
        if self.transform:
            frames = [self.transform(frame) for frame in frames]
        video_tensor = torch.stack(frames, dim=0)
        return video_tensor.permute(1,0,2,3), label

    def _get_video_paths(self):
        video_paths = []
        for class_name in self.classes:
            class_dir = os.path.join(self.data_dir, class_name)
            for video_name in os.listdir(class_dir):
                video_path = os.path.join(class_dir, video_name)
                label = self.class_to_idx[class_name]
                video_paths.append((video_path, label))
        return video_paths

    def load_frames(self, video_path):
        frames = []
        cap = cv2.VideoCapture(video_path)
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)  # Convert frame to RGB
            pil_image = Image.fromarray(frame)  # Convert numpy array to PIL Image
            frames.append(pil_image)
        cap.release()
        return frames

    def _process_frames(self, frames):
        if len(frames) > self.max_frames:
            # Trim frames if more than max_frames
            frames = frames[:self.max_frames]
        elif len(frames) < self.max_frames:
            # Pad frames if less than max_frames
            num_to_pad = self.max_frames - len(frames)
            pad_width = [(0, num_to_pad)] + [(0, 0)] * (len(frames[0].shape) - 1)
            frames.extend([F.pad(frame, pad_width, value=0) for frame in frames[-1:]])
        return frames


In [3]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize frames to match MViT input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize frames
])
# Define paths to your data directory
batch_size=3
data_dir = '../data_1'

# Load your custom video dataset
train_dataset = VideoDataset(data_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Rest of the code remains the same

In [4]:
a = next(iter(train_loader))

In [5]:
a[0].shape

torch.Size([3, 3, 16, 224, 224])

In [6]:
model = mvit_v1_b(pretrained=True, weights=MViT_V1_B_Weights)
model



MViT(
  (conv_proj): Conv3d(3, 96, kernel_size=(3, 7, 7), stride=(2, 4, 4), padding=(1, 3, 3))
  (pos_encoding): PositionalEncoding()
  (blocks): ModuleList(
    (0): MultiscaleBlock(
      (norm1): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
      (norm2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
      (attn): MultiscaleAttention(
        (qkv): Linear(in_features=96, out_features=288, bias=True)
        (project): Sequential(
          (0): Linear(in_features=96, out_features=96, bias=True)
        )
        (pool_k): Pool(
          (pool): Conv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 8, 8), padding=(1, 1, 1), groups=96, bias=False)
          (norm_act): Sequential(
            (0): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          )
        )
        (pool_v): Pool(
          (pool): Conv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 8, 8), padding=(1, 1, 1), groups=96, bias=False)
          (norm_act): Sequential(
            (0): LayerNorm((96,

In [7]:
model.head = nn.Sequential(
    nn.Dropout(p=0.5, inplace=True),
    nn.Linear(in_features=768, out_features=2),  # Change 400 to 2
    nn.Softmax(dim=1)  # Apply softmax activation
)

In [9]:
model

MViT(
  (conv_proj): Conv3d(3, 96, kernel_size=(3, 7, 7), stride=(2, 4, 4), padding=(1, 3, 3))
  (pos_encoding): PositionalEncoding()
  (blocks): ModuleList(
    (0): MultiscaleBlock(
      (norm1): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
      (norm2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
      (attn): MultiscaleAttention(
        (qkv): Linear(in_features=96, out_features=288, bias=True)
        (project): Sequential(
          (0): Linear(in_features=96, out_features=96, bias=True)
        )
        (pool_k): Pool(
          (pool): Conv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 8, 8), padding=(1, 1, 1), groups=96, bias=False)
          (norm_act): Sequential(
            (0): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          )
        )
        (pool_v): Pool(
          (pool): Conv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 8, 8), padding=(1, 1, 1), groups=96, bias=False)
          (norm_act): Sequential(
            (0): LayerNorm((96,

In [10]:
model.to("cuda")

MViT(
  (conv_proj): Conv3d(3, 96, kernel_size=(3, 7, 7), stride=(2, 4, 4), padding=(1, 3, 3))
  (pos_encoding): PositionalEncoding()
  (blocks): ModuleList(
    (0): MultiscaleBlock(
      (norm1): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
      (norm2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
      (attn): MultiscaleAttention(
        (qkv): Linear(in_features=96, out_features=288, bias=True)
        (project): Sequential(
          (0): Linear(in_features=96, out_features=96, bias=True)
        )
        (pool_k): Pool(
          (pool): Conv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 8, 8), padding=(1, 1, 1), groups=96, bias=False)
          (norm_act): Sequential(
            (0): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          )
        )
        (pool_v): Pool(
          (pool): Conv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 8, 8), padding=(1, 1, 1), groups=96, bias=False)
          (norm_act): Sequential(
            (0): LayerNorm((96,

In [11]:
criterion = nn.CrossEntropyLoss(torch.tensor([1.0, 1.6]).float().to("cuda"))
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [12]:
val_dir = "../test"
val_dataset = VideoDataset(val_dir, transform=transform)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [13]:
from tqdm import tqdm

# Define number of epochs
num_epochs = 10


In [None]:

# Training loop
for epoch in range(num_epochs):
    # Training
    model.train()
    train_loss = 0.0
    correct_train = 0
    total_train = 0
    for videos, labels in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs} - Training'):
        videos, labels = videos.to("cuda"), labels.to("cuda")
        optimizer.zero_grad()
        outputs = model(videos)
        loss = criterion(outputs.to(), labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        correct_train += predicted.eq(labels).sum().item()
        total_train += labels.size(0)
    train_accuracy = 100 * correct_train / total_train
    train_loss /= len(train_loader)
    
    # Validation
    model.eval()
    val_loss = 0.0
    correct_val = 0
    total_val = 0
    with torch.no_grad():
        for videos, labels in tqdm(val_loader, desc=f'Epoch {epoch + 1}/{num_epochs} - Validation'):
            videos, labels = videos.to("cuda"), labels.to("cuda")
            outputs = model(videos)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            correct_val += predicted.eq(labels).sum().item()
            total_val += labels.size(0)
    val_accuracy = 100 * correct_val / total_val
    val_loss /= len(val_loader)
    
    # Print epoch statistics
    print(f'Epoch {epoch + 1}/{num_epochs}, '
          f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, '
          f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')


In [None]:
torch.save({
            'model_state_dict': model.state_dict(),
            'model_architecture': model
            }, 'mvit_b_recall.pth')