In [None]:
!pip install -U aifactory

In [None]:
!pip install -U torch==2.7.1 torchvision==0.22.1 --index-url https://download.pytorch.org/whl/cu126
!pip install -U opencv-python==4.10.0.82 numpy==1.26.4 scikit-learn==1.3.2 scipy==1.11.4
!pip install -U dlib
!pip install -U timm

In [None]:
from pathlib import Path
import multiprocessing
import csv

import cv2
import dlib
import numpy as np
from PIL import Image

import torch
from torchvision.transforms import v2
from models import DinoDiscriminator2

# Model
checkpoint = Path("./checkpoints/dino2/attempt1/epoch-004.pt")
img_size = 518
threshold = 0.4
transform = v2.Compose(
    [
        v2.Resize(img_size, v2.InterpolationMode.BICUBIC),
        v2.CenterCrop(img_size),
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# Processing
test_dataset_path = Path("./data")
files = [p for p in sorted(test_dataset_path.iterdir()) if p.is_file()]
print(files)
output_csv_path = Path("submission.csv")
IMAGE_EXTS = {".jpg", ".jpeg", ".png"}
VIDEO_EXTS = {".avi", ".mp4"}

num_workers = min(max(1, multiprocessing.cpu_count() - 1), 8)
print(f"Using {num_workers} worker processes for preprocessing.")

def get_boundingbox(face, width, height):
    x1, y1, x2, y2 = face.left(), face.top(), face.right(), face.bottom()
    size_bb = int(max(x2 - x1, y2 - y1) * 1.3)
    center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
    x1 = max(int(center_x - size_bb // 2), 0)
    y1 = max(int(center_y - size_bb // 2), 0)
    size_bb = min(width - x1, size_bb)
    size_bb = min(height - y1, size_bb)
    return x1, y1, size_bb

def detect_and_crop_face_optimized(image: Image.Image, resize_for_detection=640):
    if image.mode != 'RGB': image = image.convert('RGB')
    original_np = np.array(image)
    original_h, original_w, _ = original_np.shape
    if original_w > resize_for_detection:
        scale = resize_for_detection / float(original_w)
        resized_h = int(original_h * scale)
        resized_np = cv2.resize(original_np, (resize_for_detection, resized_h), interpolation=cv2.INTER_AREA)
    else:
        scale = 1.0
        resized_np = original_np
    
    face_detector = dlib.get_frontal_face_detector()
    faces = face_detector(resized_np, 1)

    if not faces: return None
    face = max(faces, key=lambda rect: rect.width() * rect.height())
    scaled_face_rect = dlib.rectangle(
        left=int(face.left() / scale), top=int(face.top() / scale),
        right=int(face.right() / scale), bottom=int(face.bottom() / scale)
    )
    x, y, size = get_boundingbox(scaled_face_rect, original_w, original_h)
    cropped_np = original_np[y:y + size, x:x + size]
    face_img = Image.fromarray(cropped_np)
    return face_img

def process_single_file(file_path):
    """파일 경로를 입력받아 전처리된 얼굴 이미지 리스트와 파일 이름을 반환"""
    print(f"processing {file_path.name}")

    face_images = []
    ext = file_path.suffix.lower()
    num_frames_to_extract = 30

    try:
        if ext in IMAGE_EXTS:
            image = Image.open(file_path)
            face_img = detect_and_crop_face_optimized(image)
            if face_img:
                face_images.append(face_img)
                
        elif ext in VIDEO_EXTS:
            cap = cv2.VideoCapture(str(file_path))
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            if total_frames > 0:
                frame_indices = np.linspace(0, total_frames - 1, num_frames_to_extract, dtype=int)
                for idx in frame_indices:
                    cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
                    ret, frame = cap.read()
                    if not ret: continue
                    image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                    face_img = detect_and_crop_face_optimized(image)
                    if face_img:
                        face_images.append(face_img)
            cap.release()
    except Exception as e:
        return file_path.name, [], str(e)

    return file_path.name, face_images, None

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device - {device}")
    model = DinoDiscriminator2(pretrained=False).to(device)
    print(f"Loading weights from {checkpoint}...")
    model.load(checkpoint, map_location=device)
    model.eval()
    print("Model successfully loaded.")

    results_to_write = {}
    with multiprocessing.Pool(processes=num_workers) as pool:
        for filename, face_images, error in pool.imap_unordered(process_single_file, files):
            if error:
                print(f"Error processing {filename}: {error}")
            
            if not face_images:
                results_to_write[filename] = 0
                continue

            tensors = torch.stack([transform(img) for img in face_images]).to(device)
                
            with torch.no_grad():
                logits = model(tensors) # [N, 1]
            probs = torch.sigmoid(logits.view(-1)).mean()
            preds = (probs > threshold).int().item()
            results_to_write[filename] = preds

            print(f"{filename} - {preds}")
                
    print("Writing results to CSV...")
    with open(output_csv_path, mode="w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["filename", "label"])
        for p in files:
            filename = p.name
            label = results_to_write.get(filename, 0)
            writer.writerow([filename, label])

    print("Inference completed.")

  from .autonotebook import tqdm as notebook_tqdm


[WindowsPath('data/sample_image_1.png'), WindowsPath('data/sample_image_2.png'), WindowsPath('data/sample_image_3.png'), WindowsPath('data/sample_image_4.png'), WindowsPath('data/sample_image_5.png'), WindowsPath('data/sample_image_6.png'), WindowsPath('data/sample_image_7.png')]
Using 8 worker processes for preprocessing.
Using device - cuda
Loading weights from checkpoints\dino\attempt1\epoch-001.pt...
Model successfully loaded.


In [4]:
import aifactory.score as aif
import time
t = time.time()

aif.submit(
    model_name="dino2_eval-4_threshold-0.4",
    key="4fb7354a-6bcb-4442-b77d-e8f93ec1e1e3",
)

print(time.time() - t)

file : task
jupyter notebook
제출 완료
28.774463176727295
