## 1. Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
import torchvision.transforms.functional as TF
from pathlib import Path
from PIL import Image
from tqdm.auto import tqdm
import timm
import random

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {DEVICE}')



Using device: cuda


## 2. Configuration

In [2]:
# Data paths
PATH_DATA_TRAIN = r'/kaggle/input/action-video/data/data_train'
PATH_DATA_TEST = r'/kaggle/input/action-video/data/test'

# Model parameters 
NUM_FRAMES = 16
FRAME_STRIDE = 2
IMG_SIZE = 224

# Training parameters
BATCH_SIZE = 16 
EPOCHS = 4  
BASE_LR = 1e-4
HEAD_LR = 5e-4
WEIGHT_DECAY = 0.05
GRAD_ACCUM_STEPS = 4

PRETRAINED_NAME = 'vit_small_patch16_224'
    
print(f"Train data: {PATH_DATA_TRAIN}")
print(f"Test data: {PATH_DATA_TEST}")
print(f"Model: {PRETRAINED_NAME}")
print(f"Frames per video: {NUM_FRAMES}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Epochs: {EPOCHS}")

Train data: /kaggle/input/action-video/data/data_train
Test data: /kaggle/input/action-video/data/test
Model: vit_small_patch16_224
Frames per video: 16
Batch size: 16
Epochs: 4


## 3. Lightweight ViT Model

In [3]:
class LightweightViTForAction(nn.Module):
    """Lightweight ViT for action recognition."""
    
    def __init__(self, num_classes=51, pretrained_name='vit_small_patch16_224'):
        super().__init__()
        
        # Load pretrained ViT (smaller backbone)
        self.vit = timm.create_model(pretrained_name, pretrained=True, num_classes=0)
        
        # Get embedding dimension
        self.embed_dim = self.vit.num_features
        
        # Simple classification head
        self.head = nn.Linear(self.embed_dim, num_classes)
    
    def forward(self, video):
        '''
        Args:
            video: [B, T, C, H, W] - batch of video clips
        Returns:
            logits: [B, num_classes]
        '''
        B, T, C, H, W = video.shape
        
        # Reshape to process all frames
        x = video.view(B * T, C, H, W)
        
        # Extract features with ViT
        features = self.vit(x)  # [B*T, embed_dim]
        
        # Reshape back
        features = features.view(B, T, self.embed_dim)
        
        # Temporal pooling
        pooled = features.mean(dim=1)  # [B, embed_dim]
        
        # Classification
        logits = self.head(pooled)
        
        return logits

print("Lightweight ViT defined")
print(f"  Backbone: {PRETRAINED_NAME}")

Lightweight ViT defined
  Backbone: vit_small_patch16_224


## 4. Data Augmentation

In [4]:
class VideoTransform:
    def __init__(self, image_size=224, is_train=True):
        self.image_size = image_size
        self.is_train = is_train
        self.mean = [0.5, 0.5, 0.5]
        self.std = [0.5, 0.5, 0.5]
    
    def __call__(self, frames):
        if self.is_train:
            h, w = frames.shape[-2:]
            scale = random.uniform(0.8, 1.0)
            new_h, new_w = int(h * scale), int(w * scale)
            frames = TF.resize(frames, [new_h, new_w], interpolation=InterpolationMode.BILINEAR)
            i = random.randint(0, max(0, new_h - self.image_size))
            j = random.randint(0, max(0, new_w - self.image_size))
            frames = TF.crop(frames, i, j, min(self.image_size, new_h), min(self.image_size, new_w))
            frames = TF.resize(frames, [self.image_size, self.image_size], interpolation=InterpolationMode.BILINEAR)
            if random.random() < 0.5:
                frames = TF.hflip(frames)
        else:
            frames = TF.resize(frames, [self.image_size, self.image_size], interpolation=InterpolationMode.BILINEAR)
        normalized = [TF.normalize(frame, self.mean, self.std) for frame in frames]
        return torch.stack(normalized)

print("Augmentation defined")

Augmentation defined


## 5. Dataset Classes

In [5]:
class VideoDataset(Dataset):
    def __init__(self, root, num_frames=16, frame_stride=2, image_size=224, is_train=True):
        self.root = Path(root)
        self.num_frames = num_frames
        self.frame_stride = frame_stride
        self.transform = VideoTransform(image_size, is_train)
        self.to_tensor = transforms.ToTensor()
        self.classes = sorted([d.name for d in self.root.iterdir() if d.is_dir()])
        self.class_to_idx = {name: idx for idx, name in enumerate(self.classes)}
        self.samples = []
        for cls in self.classes:
            cls_dir = self.root / cls
            for video_dir in sorted([d for d in cls_dir.iterdir() if d.is_dir()]):
                frame_paths = sorted([p for p in video_dir.iterdir() if p.suffix.lower() in {'.jpg', '.jpeg', '.png'}])
                if frame_paths:
                    self.samples.append((frame_paths, self.class_to_idx[cls]))
    
    def __len__(self):
        return len(self.samples)
    
    def _select_indices(self, total):
        if total <= 0:
            raise ValueError("No frames")
        if total == 1:
            return torch.zeros(self.num_frames, dtype=torch.long)
        steps = max(self.num_frames * self.frame_stride, self.num_frames)
        grid = torch.linspace(0, total - 1, steps=steps)
        idxs = grid[::self.frame_stride].long()
        if idxs.numel() < self.num_frames:
            pad = idxs.new_full((self.num_frames - idxs.numel(),), idxs[-1].item())
            idxs = torch.cat([idxs, pad], dim=0)
        return idxs[:self.num_frames]
    
    def __getitem__(self, idx):
        frame_paths, label = self.samples[idx]
        total = len(frame_paths)
        idxs = self._select_indices(total)
        frames = []
        for i in idxs:
            path = frame_paths[int(i.item())]
            with Image.open(path) as img:
                img = img.convert("RGB")
                frames.append(self.to_tensor(img))
        video = torch.stack(frames)
        video = self.transform(video)
        return video, label


class TestDataset(Dataset):
    def __init__(self, root, num_frames=16, frame_stride=2, image_size=224):
        self.root = Path(root)
        self.num_frames = num_frames
        self.frame_stride = frame_stride
        self.transform = VideoTransform(image_size, is_train=False)
        self.to_tensor = transforms.ToTensor()
        self.video_dirs = sorted([d for d in self.root.iterdir() if d.is_dir()], key=lambda x: int(x.name))
        self.video_ids = [int(d.name) for d in self.video_dirs]
    
    def __len__(self):
        return len(self.video_dirs)
    
    def _select_indices(self, total):
        if total <= 0:
            raise ValueError("No frames")
        if total == 1:
            return torch.zeros(self.num_frames, dtype=torch.long)
        steps = max(self.num_frames * self.frame_stride, self.num_frames)
        grid = torch.linspace(0, total - 1, steps=steps)
        idxs = grid[::self.frame_stride].long()
        if idxs.numel() < self.num_frames:
            pad = idxs.new_full((self.num_frames - idxs.numel(),), idxs[-1].item())
            idxs = torch.cat([idxs, pad], dim=0)
        return idxs[:self.num_frames]
    
    def __getitem__(self, idx):
        video_dir = self.video_dirs[idx]
        video_id = self.video_ids[idx]
        frame_paths = sorted([p for p in video_dir.iterdir() if p.suffix.lower() in {'.jpg', '.jpeg', '.png'}])
        total = len(frame_paths)
        idxs = self._select_indices(total)
        frames = []
        for i in idxs:
            path = frame_paths[int(i.item())]
            with Image.open(path) as img:
                img = img.convert("RGB")
                frames.append(self.to_tensor(img))
        video = torch.stack(frames)
        video = self.transform(video)
        return video, video_id

print("Dataset classes defined")

Dataset classes defined


## 6. Training

In [6]:
def train_one_epoch(model, loader, optimizer, scaler, device, grad_accum_steps=1):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    optimizer.zero_grad()
    progress = tqdm(loader, desc="Train", leave=False)
    for batch_idx, (videos, labels) in enumerate(progress):
        videos = videos.to(device)
        labels = labels.to(device)
        with torch.amp.autocast(device_type='cuda', enabled=(device.type == 'cuda')):
            logits = model(videos)
            loss = F.cross_entropy(logits, labels)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        loss_value = loss.item()
        loss = loss / grad_accum_steps
        scaler.scale(loss).backward()
        should_step = ((batch_idx + 1) % grad_accum_steps == 0) or (batch_idx + 1 == len(loader))
        if should_step:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        batch_size = videos.size(0)
        total_loss += loss_value * batch_size
        progress.set_postfix(loss=f"{loss_value:.4f}", acc=f"{correct / max(total, 1):.4f}")
    avg_loss = total_loss / max(total, 1)
    avg_acc = correct / max(total, 1)
    return avg_loss, avg_acc

print("Training functions defined")

Training functions defined


In [7]:
print("Loading training dataset...")
train_dataset = VideoDataset(PATH_DATA_TRAIN, num_frames=NUM_FRAMES, frame_stride=FRAME_STRIDE, image_size=IMG_SIZE, is_train=True)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
print(f"Train samples: {len(train_dataset)}")
print(f"Classes: {len(train_dataset.classes)}")
print(f"Class names: {train_dataset.classes[:10]}...")

Loading training dataset...
Train samples: 6254
Classes: 51
Class names: ['brush_hair', 'cartwheel', 'catch', 'chew', 'clap', 'climb', 'climb_stairs', 'dive', 'draw_sword', 'dribble']...


In [8]:
print("Creating lightweight model...")
model = LightweightViTForAction(num_classes=len(train_dataset.classes), pretrained_name=PRETRAINED_NAME).to(DEVICE)
num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")
print(f"Model size: {num_params * 4 / 1024 / 1024:.2f} MB")  # Approximate

Creating lightweight model...


model.safetensors:   0%|          | 0.00/88.2M [00:00<?, ?B/s]

Model parameters: 21,685,299
Model size: 82.72 MB


In [9]:
backbone_params = []
head_params = []
for name, param in model.named_parameters():
    if not param.requires_grad:
        continue
    if 'head' in name:
        head_params.append(param)
    else:
        backbone_params.append(param)

optimizer = torch.optim.AdamW([
    {"params": backbone_params, "lr": BASE_LR},
    {"params": head_params, "lr": HEAD_LR},
], weight_decay=WEIGHT_DECAY)

scaler = torch.amp.GradScaler(enabled=torch.cuda.is_available())
print(f"Optimizer: AdamW | Base LR: {BASE_LR} | Head LR: {HEAD_LR}")

Optimizer: AdamW | Base LR: 0.0001 | Head LR: 0.0005


In [10]:
best_acc = 0.0
checkpoint_path = Path('./lightweight_vit_best.pt')

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, scaler, DEVICE, GRAD_ACCUM_STEPS)
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    if train_acc > best_acc:
        best_acc = train_acc
        torch.save({'model': model.state_dict(), 'classes': train_dataset.classes, 'acc': best_acc}, checkpoint_path)
        print(f" Best model saved (acc: {best_acc:.4f})")

print("\n" + "="*40)
print(f"Training completed! Best accuracy: {best_acc:.4f}")
print(f"Model saved to: {checkpoint_path}")


Epoch 1/4


Train:   0%|          | 0/391 [00:00<?, ?it/s]

  Train Loss: 1.8192 | Train Acc: 0.5222
 Best model saved (acc: 0.5222)

Epoch 2/4


Train:   0%|          | 0/391 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a1df3f5cae0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
     Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7a1df3f5cae0>
   Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
      self._shutdown_workers()^
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
^^    ^if w.is_alive():^
^ ^^  ^^ ^ ^ ^^ ^^^^^^^^^^

  Train Loss: 0.7302 | Train Acc: 0.7909
 Best model saved (acc: 0.7909)

Epoch 3/4


Train:   0%|          | 0/391 [00:00<?, ?it/s]

  Train Loss: 0.4546 | Train Acc: 0.8671
 Best model saved (acc: 0.8671)

Epoch 4/4


Train:   0%|          | 0/391 [00:00<?, ?it/s]

  Train Loss: 0.3063 | Train Acc: 0.9133
 Best model saved (acc: 0.9133)

Training completed! Best accuracy: 0.9133
Model saved to: lightweight_vit_best.pt


## 6. Inference on Test Set

In [11]:
print("INFERENCE ON TEST SET")

checkpoint_path = Path('./lightweight_vit_best.pt')
print(f"Loading checkpoint from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
classes = checkpoint['classes']

model = LightweightViTForAction(num_classes=len(classes), pretrained_name=PRETRAINED_NAME).to(DEVICE)
model.load_state_dict(checkpoint['model'])
model.eval()
print(f"Model loaded (trained acc: {checkpoint['acc']:.4f})")

print("\nLoading test dataset...")
test_dataset = TestDataset(PATH_DATA_TEST, num_frames=NUM_FRAMES, frame_stride=FRAME_STRIDE, image_size=IMG_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
print(f"Test samples: {len(test_dataset)}")

INFERENCE ON TEST SET
Loading checkpoint from lightweight_vit_best.pt...
Model loaded (trained acc: 0.9133)

Loading test dataset...
Test samples: 510


In [12]:
print("\nRunning inference...")
predictions = []
with torch.no_grad():
    for videos, video_ids in tqdm(test_loader, desc="Inference"):
        videos = videos.to(DEVICE)
        logits = model(videos)
        preds = logits.argmax(dim=1)
        for video_id, pred_idx in zip(video_ids.cpu().numpy(), preds.cpu().numpy()):
            pred_class = classes[pred_idx]
            predictions.append((video_id, pred_class))

predictions.sort(key=lambda x: x[0])
print(f"\nTotal predictions: {len(predictions)}")


Running inference...


Inference:   0%|          | 0/32 [00:00<?, ?it/s]


Total predictions: 510


## 7. Evaluate on Test Set (with Ground Truth Labels)


In [13]:
!gdown "1Xv2CWOqdBj3kt0rkNJKRsodSIEd3-wX_" -O test_labels.csv

Downloading...
From: https://drive.google.com/uc?id=1Xv2CWOqdBj3kt0rkNJKRsodSIEd3-wX_
To: /kaggle/working/test_labels.csv
100%|██████████████████████████████████████| 5.71k/5.71k [00:00<00:00, 21.7MB/s]


In [17]:
import pandas as pd
from sklearn.metrics import accuracy_score, classification_report
# Load ground truth
gt_df = pd.read_csv("test_labels.csv")
test_labels = dict(zip(gt_df['id'].astype(str), gt_df['class']))
# Match predictions with ground truth
y_pred = []
y_true = []
for video_id, pred_class in predictions:
    video_id_str = str(video_id)
    if video_id_str in test_labels:
        y_pred.append(pred_class)
        y_true.append(test_labels[video_id_str])
# Calculate accuracy
accuracy = accuracy_score(y_true, y_pred)
print("=" * 50)
print("TEST SET EVALUATION")
print("=" * 50)
print(f"Total: {len(y_true)} | Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
print()
print(classification_report(y_true, y_pred, zero_division=0))

TEST SET EVALUATION
Total: 510 | Accuracy: 0.6392 (63.92%)

                precision    recall  f1-score   support

    brush_hair       0.80      0.80      0.80        10
     cartwheel       0.43      0.30      0.35        10
         catch       0.78      0.70      0.74        10
          chew       0.82      0.90      0.86        10
          clap       0.83      1.00      0.91        10
         climb       0.82      0.90      0.86        10
  climb_stairs       0.73      0.80      0.76        10
          dive       0.86      0.60      0.71        10
    draw_sword       1.00      0.80      0.89        10
       dribble       0.82      0.90      0.86        10
         drink       0.60      0.60      0.60        10
           eat       0.67      0.20      0.31        10
    fall_floor       0.46      0.60      0.52        10
       fencing       1.00      0.50      0.67        10
     flic_flac       0.55      0.60      0.57        10
          golf       0.88      0.70      0.