In [90]:
import os
from torch.utils.data import Dataset, DataLoader
import torchvision.io as io
import numpy as np
import torch

from torchvision import transforms
from transformers import AutoImageProcessor, TimesformerForVideoClassification, BitsAndBytesConfig

import accelerate




In [91]:
from huggingface_hub import login

login(token="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")

In [92]:
from huggingface_hub import whoami
print(whoami())

{'type': 'user', 'id': '679be20c3e94b783ac5be8a1', 'name': 'AddHe', 'fullname': 'Ada Henc', 'isPro': False, 'avatarUrl': '/avatars/436fa2e8d09fa03f9aa3b2a5e07695b0.svg', 'orgs': [{'type': 'org', 'id': '679bd7f5c8fcaa69de68bf61', 'name': 'cvproject', 'fullname': 'CV Project ', 'avatarUrl': 'https://www.gravatar.com/avatar/4cc3ce8cb0710b3102c755d0a736f86d?d=retro&size=100', 'roleInOrg': 'write', 'isEnterprise': False}], 'auth': {'type': 'access_token', 'accessToken': {'displayName': 'CV_project', 'role': 'fineGrained', 'createdAt': '2025-02-06T21:12:01.721Z', 'fineGrained': {'canReadGatedRepos': True, 'global': ['inference.serverless.write'], 'scoped': [{'entity': {'_id': '679bd7f5c8fcaa69de68bf61', 'type': 'org', 'name': 'cvproject'}, 'permissions': ['repo.content.read', 'repo.write', 'inference.endpoints.infer.write']}, {'entity': {'_id': '679be20c3e94b783ac5be8a1', 'type': 'user', 'name': 'AddHe'}, 'permissions': ['repo.content.read', 'repo.write']}]}}}}


In [93]:
def get_available_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():  # Dla Mac z Apple Silicon
        return torch.device("mps")
    else:
        return torch.device("cpu")

device = get_available_device()
print(f"Using device: {device}")

Using device: cpu


In [94]:
def sample_frames(video: torch.Tensor, num_frames: int):
    """
    Zakłada, że video ma kształt (C, T, H, W) i wybiera równomiernie rozłożone klatki.
    """
    total_frames = video.shape[1]
    if total_frames < num_frames:
        # Jeżeli wideo ma mniej klatek, można uzupełnić paddingiem (np. powielając ostatnią klatkę) - jesli zdecydujemy sie na krotsze filmiki 
        pad = video[:, -1:, :, :].repeat(1, num_frames - total_frames, 1, 1)
        video = torch.cat([video, pad], dim=1)
        return video
    indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
    return video[:, indices, :, :]

In [95]:
class AdDetectionDataset(Dataset):
    def __init__(self, root: str, split: str, transform=None, max_frames=100):
        """
        Inicjalizuje dataset, ładuje ścieżki do plików wideo i przypisuje etykiety.

        :param root: Główny katalog, w którym znajdują się dane (np. "dataset").
        :param split: Określa, którą część danych ładować: 'train', 'validate', 'test'.
        :param transform: Możliwość dodania transformacji do wideo.
        :param max_frames: Maksymalna liczba klatek w sekwencji.
        """
        self.root_dir = os.path.join(root, split)
        self.transform = transform
        self.max_frames = max_frames
        self.video_paths = []
        self.labels = []
        self.idx_to_label = {0: 'content', 1: 'commercial'}

        # Zbieranie plików wideo z folderów 'content' i 'commercial'
        for i, label in self.idx_to_label.items():
            label_dir_path = os.path.join(self.root_dir, label)
            videos_in_dir = os.listdir(label_dir_path)
            self.video_paths.extend([os.path.join(label_dir_path, video) for video in videos_in_dir])
            self.labels.extend([i] * len(videos_in_dir))

    def map_idx_to_label(self, idx):
        """Mapowanie indeksu na etykietę (np. 0 -> 'content', 1 -> 'commercial')."""
        return self.idx_to_label[idx]

    def __len__(self):
        """Zwraca liczbę elementów w zbiorze danych (liczba filmów)."""
        return len(self.video_paths)

    def __getitem__(self, idx):
        """
        Wczytuje wideo, stosuje padding i zwraca tensor.

        :param idx: Indeks pliku w zbiorze danych.
        :returns: Para (wideo, etykieta).
        """
            
        video_path = self.video_paths[idx]
        label = self.labels[idx]

        # Wczytanie pliku wideo
        # video = self.load_video(video_path)
        # video, _, _ = io.read_video(video_path, pts_unit="sec")  # Zwraca tensor (T, H, W, C)

        video, _, _ = io.read_video(video_path, pts_unit="sec")
        video = video.permute(3, 0, 1, 2).float() / 255.0  # (T, H, W, C) -> (C, T, H, W)

        # Wybieramy np. 16 klatek
        video = sample_frames(video, num_frames=16)
        
        if self.transform:
            video = self.transform(video)


        return video, label
        # print(video)
        # print(video.shape)

        # num_frames = video.shape[0]  # Liczba klatek w wideo
        # frames = list(video.numpy())  # Konwersja na listę, aby można było dodać padding



        # # Dodanie paddingu, jeśli mniej niż max_frames
        # if num_frames < self.max_frames:
        #     pad = [np.zeros_like(frames[0])] * (self.max_frames - num_frames)
        #     frames.extend(pad)
        #     mask = [1] * num_frames + [0] * (self.max_frames - num_frames)
        # else:
        #     frames = frames[:self.max_frames]  # Odcinamy nadmiar
        #     mask = [1] * self.max_frames

        # # Konwersja do tensora PyTorch
        # frames_tensor = torch.tensor(frames, dtype=torch.float32).permute(3, 0, 1, 2)  # (C, T, H, W)
        # mask_tensor = torch.tensor(mask, dtype=torch.float32)  # (T,)

        # if self.transform:
        #     frames_tensor = self.transform(frames_tensor)

        # return frames_tensor, label, mask_tensor



In [96]:
# Ścieżka do folderu z danymi
root_dir = "dataset"

# Tworzenie instancji datasetu dla treningu
train_dataset = AdDetectionDataset(root=root_dir, split="train")



#Sprawdzenie
print(len(train_dataset))

# Uzyskiwanie jednego elementu
video, label= train_dataset[0]
print(f"Label: {train_dataset.map_idx_to_label(label)}")


20
Label: content


In [97]:
batch_size = 4  
num_workers = 0  # Liczba wątków do wczytywania danych - dla Windows 0 bo jest problem z wielowatkowscia wynikajaca z działania multiprocessingu

train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=num_workers
)

#Załadowanie modelu
# MODEL_NAME = "vit_base_patch16_224"
# model = create_model(MODEL_NAME, pretrained=True, num_classes=400).to(device)
# model.eval()

# processor = AutoImageProcessor.from_pretrained("facebook/timesformer-base-finetuned-k600")
# model = TimesformerForVideoClassification.from_pretrained("facebook/timesformer-base-finetuned-k600")

In [98]:
# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_compute_dtype=torch.bfloat16  
# )

device_map="auto"   #"sequential"
# model = TimesformerForVideoClassification.from_pretrained(
#     "facebook/timesformer-base-finetuned-k600"
# )
model = TimesformerForVideoClassification.from_pretrained("facebook/timesformer-base-finetuned-k400", force_download=True)

model.to(device)

OSError: Unable to load weights from pytorch checkpoint file for 'C:\Users\User\.cache\huggingface\hub\models--facebook--timesformer-base-finetuned-k400\snapshots\8aaf40ea7d3d282dcb0a5dea01a198320d15d6c0\pytorch_model.bin' at 'C:\Users\User\.cache\huggingface\hub\models--facebook--timesformer-base-finetuned-k400\snapshots\8aaf40ea7d3d282dcb0a5dea01a198320d15d6c0\pytorch_model.bin'. If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True.

In [20]:
for batch_videos, batch_labels, batch_masks in train_loader:
    batch_videos, batch_labels, batch_masks = batch_videos.to(device), batch_labels.to(device), batch_masks.to(device)
    print(f"Batch videos shape: {batch_videos.shape}")  # (B, C, T, H, W)
    print(f"Batch labels: {batch_labels}")
    print(f"Batch masks shape: {batch_masks.shape}")  # (B, T)
    break

Batch videos shape: torch.Size([4, 3, 100, 360, 640])
Batch labels: tensor([1, 1, 0, 1])
Batch masks shape: torch.Size([4, 100])


ValueError: too many values to unpack (expected 4)