## üîµ STEP 1 ‚Äî ÌôòÍ≤Ω ÏÑ§Ï†ï Î∞è ÎùºÏù¥Î∏åÎü¨Î¶¨

In [None]:
!pip install opencv-python-headless
!pip install scikit-learn
!pip install transformers

In [None]:
import os
from glob import glob

import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from transformers import TimesformerConfig, TimesformerModel
from sklearn.metrics import f1_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NUM_FRAMES  = 16
IMG_SIZE    = 224
NUM_CLASSES = 2
BATCH_SIZE  = 2
EPOCHS      = 15
LR          = 1e-4

TRAIN_DIR = "/workspace/data/train"
VAL_DIR   = "/workspace/data/val"
SAVE_PATH = "/workspace/model/best_model.pth"

print("Device:", device)
print("Train dir:", TRAIN_DIR)
print("Val dir:", VAL_DIR)

## üîµ STEP 2 ‚Äî Ï†ÑÏ≤òÎ¶¨ Ï†ïÏùò

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([
        transforms.ColorJitter(brightness=0.2, contrast=0.2,
                               saturation=0.2, hue=0.02)
    ], p=0.38),
    transforms.RandomApply(
        [transforms.GaussianBlur(3, sigma=(0.1, 1.0))], p=0.15
    ),
    transforms.RandomApply(
        [transforms.RandomAdjustSharpness(1.3)], p=0.16
    ),
    transforms.RandomGrayscale(p=0.05),
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.9, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

## üîµ STEP 3 ‚Äî Dataset & DataLoader

In [None]:
class VideoDataset(Dataset):
    def __init__(self, root, num_frames, transform):
        self.num_frames = num_frames
        self.transform = transform
        self.files = []
        self.class_to_idx = {"real": 0, "fake": 1}

        for cls in ["real", "fake"]:
            cls_dir = os.path.join(root, cls)
            if not os.path.isdir(cls_dir):
                continue

            for vid_dir in os.listdir(cls_dir):
                full = os.path.join(cls_dir, vid_dir)
                if os.path.isdir(full):
                    self.files.append((full, self.class_to_idx[cls]))

        print(f"[VideoDataset] root={root}, samples={len(self.files)}")

    def __len__(self):
        return len(self.files)

    def load_frames(self, folder):
        # Ìè¥Îçî ÏïàÏùò jpg Ï†ïÎ†¨Ìï¥ÏÑú ÏÇ¨Ïö©
        jpgs = sorted(glob(os.path.join(folder, "*.jpg")))

        if len(jpgs) == 0:
            # ÎπÑÏñ¥ ÏûàÏúºÎ©¥ Í∑∏ÎÉ• Í≤ÄÏùÄ ÌôîÎ©¥ Ìå®Îî© (ÏóêÎü¨ Î∞©ÏßÄÏö©)
            dummy = Image.new("RGB", (IMG_SIZE, IMG_SIZE))
            imgs = [self.transform(dummy) for _ in range(self.num_frames)]
            return torch.stack(imgs)

        # ÌîÑÎ†àÏûÑ Î∂ÄÏ°±ÌïòÎ©¥ ÎßàÏßÄÎßâ ÌîÑÎ†àÏûÑ Î∞òÎ≥µÌï¥ÏÑú Ìå®Îî©
        if len(jpgs) < self.num_frames:
            jpgs = jpgs + [jpgs[-1]] * (self.num_frames - len(jpgs))

        # ÎÑàÎ¨¥ ÎßéÏúºÎ©¥ ÏïûÏóêÏÑúÎ∂ÄÌÑ∞ 16Ïû•Îßå ÏÇ¨Ïö© (ÌïÑÏöîÌïòÎ©¥ Ïä¨ÎùºÏù¥Îî©/ÎûúÎç§ÏúºÎ°ú Î∞îÍøÄ Ïàò ÏûàÏùå)
        jpgs = jpgs[:self.num_frames]

        frames = []
        for p in jpgs:
            img = Image.open(p).convert("RGB")
            frames.append(self.transform(img))

        # [T, 3, H, W]
        return torch.stack(frames)

    def __getitem__(self, idx):
        folder, label = self.files[idx]
        frames = self.load_frames(folder)      # [T, 3, 224, 224]
        label = torch.tensor(label, dtype=torch.long)
        return frames, label


# ==========================
# Dataloader
# ==========================
train_dataset = VideoDataset(TRAIN_DIR, NUM_FRAMES, train_transform)
val_dataset   = VideoDataset(VAL_DIR,   NUM_FRAMES, val_transform)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=8,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=8,
    pin_memory=True
)

## üîµ STEP 4 ‚Äî TimeSformer Î™®Îç∏ Ï†ïÏùò

In [None]:
class TimeSformerWrapper(nn.Module):
    def __init__(self, num_classes=2, num_frames=16, img_size=224):
        super().__init__()

        config = TimesformerConfig(
            num_frames=num_frames,
            num_labels=num_classes,
            image_size=img_size,

            patch_size=16,
            attention_type="divided_space_time",

            num_hidden_layers=8,
            hidden_size=768,
            num_attention_heads=12,
            intermediate_size=3072,

            dropout=0.1,
            attention_dropout=0.1,
        )

        self.backbone = TimesformerModel(config)
        self.cls_head = nn.Sequential(
            nn.LayerNorm(config.hidden_size),
            nn.Dropout(0.2),
            nn.Linear(config.hidden_size, num_classes),
        )

        for m in self.cls_head:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        # x: [B, T, 3, H, W]
        out = self.backbone(pixel_values=x)
        # [B, num_tokens, hidden] ‚Üí ÌèâÍ∑† ÌíÄÎßÅ
        pooled = out.last_hidden_state.mean(dim=1)
        return self.cls_head(pooled)


model = TimeSformerWrapper(
    num_classes=NUM_CLASSES,
    num_frames=NUM_FRAMES,
    img_size=IMG_SIZE
).to(device)

print(model.__class__.__name__, "initialized.")

## üîµ STEP 5 ‚Äî Optimizer / Scheduler / Loss

In [None]:
optimizer = torch.optim.AdamW(
    model.parameters(), lr=LR, weight_decay=5e-5
)
criterion = nn.CrossEntropyLoss(label_smoothing=0.02)

## üîµ STEP 6 ‚Äî Train / Validation / Loop

In [None]:
def train_one_epoch():
    model.train()
    total_loss = 0.0

    for frames, labels in tqdm(train_loader, desc="Training"):
        # frames: [B, T, 3, 224, 224]
        frames = frames.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(frames)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)


def validate():
    model.eval()
    total_loss = 0.0
    preds, gts = [], []

    with torch.no_grad():
        for frames, labels in tqdm(val_loader, desc="Validating"):
            frames = frames.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = model(frames)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            pred = outputs.argmax(1)
            preds.extend(pred.cpu().numpy().tolist())
            gts.extend(labels.cpu().numpy().tolist())

    acc = np.mean(np.array(preds) == np.array(gts))
    f1 = f1_score(gts, preds, average="macro")
    return total_loss / len(val_loader), acc, f1


# ==========================
# Train Loop
# ==========================
best_f1 = 0.0

for epoch in range(EPOCHS):
    print(f"\n===== EPOCH {epoch+1}/{EPOCHS} =====")

    train_loss = train_one_epoch()
    val_loss, val_acc, val_f1 = validate()

    print(
        f"[Epoch {epoch+1:02d}] "
        f"Train Loss: {train_loss:.4f} | "
        f"Val Loss: {val_loss:.4f} | "
        f"Val Acc: {val_acc*100:.2f}% | "
        f"Val F1: {val_f1:.4f}"
    )

    if val_f1 > best_f1:
        best_f1 = val_f1
        os.makedirs(os.path.dirname(SAVE_PATH), exist_ok=True)
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "num_frames": NUM_FRAMES,
                "img_size": IMG_SIZE,
                "num_classes": NUM_CLASSES,
            },
            SAVE_PATH,
        )
        print(f"üî• Best Model Saved: {SAVE_PATH} (F1={best_f1:.4f})")

print("\nTraining finished.")
print(f"Best F1: {best_f1:.4f}")