In [None]:
# Import required modules ...

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoader
from torch.optim.lr_scheduler import StepLR
import torchvision
from torchvision import get_video_backend
from torchvision.models.video import r3d_18 
from torchvision import transforms
import os
from tqdm.auto import tqdm
import numpy as np
import time
import av
import random
print(f"PyAV version -- {av.__version__}")

SEED = 491
torch.manual_seed(SEED)

from collections import OrderedDict
import warnings
warnings.filterwarnings('ignore')

#run_av_diagnostics()

In [None]:
# Datasets and Dataloaders for model training ..

val_split = 0.05
num_frames = 16 # 16
clip_steps = 50
num_workers = 8
pin_memory = True
import torchvision.transforms as transforms

train_tfms = transforms.Compose([
    transforms.Lambda(lambda x: x.float() / 255.0),  # uint8 -> float conversion
    transforms.Resize((128, 171)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop((112, 112))
])

test_tfms = transforms.Compose([
    transforms.Lambda(lambda x: x.float() / 255.0),  # uint8 -> float conversion
    transforms.Resize((128, 171)),
    transforms.CenterCrop((112, 112))
])


hmdb51_train = torchvision.datasets.HMDB51('/workspace/hi/data', '/workspace/hi/data2', num_frames,
                                                step_between_clips = clip_steps, fold=1, train=True,
                                                transform=train_tfms, num_workers=num_workers)


hmdb51_test = torchvision.datasets.HMDB51('/workspace/hi/data', '/workspace/hi/data2', num_frames,
                                                step_between_clips = clip_steps, fold=1, train=False,
                                                transform=test_tfms, num_workers=num_workers)
      
total_train_samples = len(hmdb51_train)
total_val_samples = round(val_split * total_train_samples)

print(f"number of train samples {total_train_samples}")
print(f"number of validation samples {total_val_samples}")
print(f"number of test samples {len(hmdb51_test)}")


In [None]:
video, audio, label = hmdb51_train.__getitem__(111)
print(video.shape)
print(label)

In [None]:
bs = 4
lr = 1e-2
gamma = 0.7
total_epochs = 10
config = {}
num_workers = 0

kwargs = {'num_workers':num_workers, 'pin_memory':True} if torch.cuda.is_available() else {'num_workers':num_workers}
#kwargs = {'num_workers':num_workers}
#kwargs = {}

hmdb51_train_v1, hmdb51_val_v1 = random_split(hmdb51_train, [total_train_samples - total_val_samples,
                                                                       total_val_samples])

#hmdb51_train_v1.video_clips.compute_clips(16, 1, frame_rate=30)
#hmdb51_val_v1.video_clips.compute_clips(16, 1, frame_rate=30)
#hmdb51_test.video_clips.compute_clips(16, 1, frame_rate=30)

#train_sampler = RandomClipSampler(hmdb51_train_v1.video_clips, 5)
#test_sampler = UniformClipSampler(hmdb51_test.video_clips, 5)
  
train_loader = DataLoader(hmdb51_train_v1, batch_size=bs, shuffle=True, **kwargs)
val_loader   = DataLoader(hmdb51_val_v1, batch_size=bs, shuffle=True, **kwargs)
test_loader  = DataLoader(hmdb51_test, batch_size=bs, shuffle=False, **kwargs)


In [None]:
for video, _, label in train_loader:
    print(video.shape)
    print(label.shape)
    break

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader, random_split

def extract_representative_frame(video, clip_idx=None):
    """
    video: tensor, shape = (T, total_channels, H, W) 
           여기서 total_channels는 (num_clips * 3)
    clip_idx: 사용할 클립 인덱스. None이면 중앙 클립 사용.
    """
    T, total_channels, H, W = video.shape  # T=16, total_channels=240, H=112, W=112
    num_clips = total_channels // 3  # 각 클립은 3채널 이미지 => num_clips = 240/3 = 80
    if clip_idx is None:
        clip_idx = num_clips // 2  # 중앙 클립 선택 (예, 40)
    
    # video의 shape을 (T, num_clips, 3, H, W)로 재구성
    video_reshaped = video.view(T, num_clips, 3, H, W)  # (16, 80, 3, 112, 112)
    
    # 선택한 클립 추출: (T, 3, H, W)
    selected_clip = video_reshaped[:, clip_idx, :, :, :]
    
    # 대표 프레임으로 중앙 프레임 선택: (3, H, W)
    center_idx = T // 2
    frame = selected_clip[center_idx]
    return frame


# --- 모델 준비 ---
# resnet18은 기본적으로 3채널 이미지를 입력받음.
resnet18 = models.resnet18(pretrained=True)
num_classes = 51  # HMDB51의 클래스 수
resnet18.fc = nn.Linear(resnet18.fc.in_features, num_classes)
resnet18 = resnet18.to("cuda" if torch.cuda.is_available() else "cpu")

# --- DataLoader 등 설정 ---
# (여기서는 hmdb51_train, total_train_samples, total_val_samples 등이 이미 정의되었다고 가정)
bs = 8
num_workers = 0
kwargs = {'num_workers': num_workers, 'pin_memory': True} if torch.cuda.is_available() else {'num_workers': num_workers}

# 학습/검증 데이터셋 분할 (예)
hmdb51_train_v1, hmdb51_val_v1 = random_split(
    hmdb51_train, [total_train_samples - total_val_samples, total_val_samples]
)

train_loader = DataLoader(hmdb51_train_v1, batch_size=bs, shuffle=True, **kwargs)
val_loader   = DataLoader(hmdb51_val_v1, batch_size=bs, shuffle=True, **kwargs)
test_loader  = DataLoader(hmdb51_test, batch_size=bs, shuffle=False, **kwargs)

# --- 학습 루프 예시 ---
optimizer = torch.optim.SGD(resnet18.parameters(), lr=1e-2, momentum=0.9)
criterion = nn.CrossEntropyLoss()
device = "cuda"
from tqdm import tqdm

for i in range(10):
    # tqdm progress bar와 함께 DataLoader 순회
    for batch in tqdm(train_loader, desc="Training batches"):
        video_batch, audio_batch, labels = batch  # video_batch: (B, 16, 240, 112, 112)
        batch_frames = []
        for video in video_batch:
            # 각 video에서 대표 프레임 추출 (shape: (3, 112, 112))
            frame = extract_representative_frame(video)
            batch_frames.append(frame)
        # 대표 프레임들을 배치 텐서로 결합
        frames_tensor = torch.stack(batch_frames).to("cuda" if torch.cuda.is_available() else "cpu")
        labels = labels.to("cuda" if torch.cuda.is_available() else "cpu")
        
        # 모델의 forward 및 학습 단계 (예시)
        optimizer.zero_grad()
        outputs = resnet18(frames_tensor)  # (B, num_classes)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # 검증, 테스트 루프도 유사하게 진행하면 됩니다.
    print(f"{i+1} epoch: train_loss: {loss.item():.4f}")