In [8]:
# ==============================
# 0. 기본 설정 & 라이브러리
# ==============================
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image

import os
import glob
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import copy

from torchvision.models import (
    googlenet, GoogLeNet_Weights,
    resnet50, ResNet50_Weights,
    vgg16, VGG16_Weights,
    vit_b_16, ViT_B_16_Weights
)

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# ==============================
# 1. 설정값 & 경로
# ==============================
BASE_DIR = "/content/drive/MyDrive/기학기"
REAL_PATH = BASE_DIR + "/face_real"
FAKE_PATH = BASE_DIR + "/face_fake"
MODEL_SAVE_PATH = "/content/best_googlenet_fft.pth"

IMG_SIZE = 224
EPOCHS = 30          # 필요한 만큼 조절 가능
BATCH_SIZE = 64
NUM_SAMPLES = 10000  # 전체에서 샘플링할 최대 개수
LEARNING_RATE = 1e-4
PATIENCE = 3         # early stopping 기준


# ==============================
# 2. ResizeWithPad
# ==============================
class ResizeWithPad:
    """긴 변 기준으로 리사이즈 후, 짧은 변은 패딩해서 정사각형으로 맞추는 Transform (PIL -> PIL)"""
    def __init__(self, target_size):
        self.target_size = target_size

    def __call__(self, img):
        w, h = img.size
        scale = self.target_size / max(w, h)

        new_w = int(w * scale)
        new_h = int(h * scale)

        resized_img = img.resize((new_w, new_h), Image.LANCZOS)
        canvas = Image.new("RGB", (self.target_size, self.target_size), (0, 0, 0))

        pad_x = (self.target_size - new_w) // 2
        pad_y = (self.target_size - new_h) // 2
        canvas.paste(resized_img, (pad_x, pad_y))

        return canvas


# ==============================
# 3. FFT Magnitude Transform
# ==============================
class FFTMag:
    """
    Tensor(C,H,W)를 받아 채널별로 2D FFT -> 중심 이동(fftshift) -> magnitude -> log1p
    그 다음 전체를 평균 0, 표준편차 1로 정규화해서 반환
    (출력 shape는 입력과 동일: [C, H, W])
    """
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        # x: [C, H, W], float32
        # 2D FFT (채널별)
        x_fft = torch.fft.fft2(x)               # complex tensor
        x_fft_shift = torch.fft.fftshift(x_fft) # 저주파를 중앙으로 이동
        mag = torch.abs(x_fft_shift)            # magnitude
        mag = torch.log1p(mag)                  # log(1 + |F|)

        # 채널 전체 기준으로 표준화
        mean = mag.mean()
        std = mag.std()
        mag = (mag - mean) / (std + 1e-8)

        return mag


# ==============================
# 4. Dataset
# ==============================
class DeepfakeDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img_path = self.paths[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)  # Tensor(C,H,W), 이미 FFT까지 처리된 상태
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return img, label.unsqueeze(0)  # (1,) 형태로 맞춰줌


# ==============================
# 5. Fine-Tuning 모델
# ==============================
def get_model_finetune(name, device, mode="full"):
    """
    name: 'googlenet', 'resnet50', 'vgg16', 'vit' 중 하나
    mode: 'head' -> classifier만 학습, 'full' -> 전체 파라미터 학습
    """
    print(f"\n=== Loading {name} (pretrained, mode={mode}, use_fft=True) ===")

    if name == "googlenet":
        weights = GoogLeNet_Weights.DEFAULT
        # aux_logits=True가 기본이므로 굳이 인자 안 넘김 (ValueError 피하기 위해)
        model = googlenet(weights=weights)
        in_f = model.fc.in_features
        model.fc = nn.Linear(in_f, 1)
        head_params = model.fc.parameters()

    elif name == "resnet50":
        weights = ResNet50_Weights.DEFAULT
        model = resnet50(weights=weights)
        in_f = model.fc.in_features
        model.fc = nn.Linear(in_f, 1)
        head_params = model.fc.parameters()

    elif name == "vgg16":
        weights = VGG16_Weights.DEFAULT
        model = vgg16(weights=weights)
        in_f = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(in_f, 1)
        head_params = model.classifier[6].parameters()

    elif name == "vit":
        weights = ViT_B_16_Weights.DEFAULT
        model = vit_b_16(weights=weights)
        in_f = model.heads.head.in_features
        model.heads.head = nn.Linear(in_f, 1)
        head_params = model.heads.head.parameters()

    else:
        raise ValueError("지원하지 않는 모델 이름입니다.")

    # Freeze 설정
    if mode == "head":
        print("→ Classifier만 학습 (Feature extractor는 Freeze)")
        for p in model.parameters():
            p.requires_grad = False
        for p in head_params:
            p.requires_grad = True
    else:
        print("→ 전체 파라미터 Fine-Tuning")
        for p in model.parameters():
            p.requires_grad = True

    return model.to(device)


# ==============================
# 6. Train
# ==============================
def train_model(model, train_loader, val_loader, criterion, optimizer, device, epochs, patience):

    best_val_loss = float("inf")
    best_model = None
    no_improve = 0

    for epoch in range(epochs):
        model.train()
        total, correct = 0, 0
        running_loss = 0

        pbar = tqdm(train_loader, desc=f"Train Epoch {epoch+1}/{epochs}")
        for X, y in pbar:
            X, y = X.to(device), y.to(device)

            optimizer.zero_grad()
            out = model(X)

            # GoogLeNet의 경우 출력이 GoogLeNetOutputs(logits, aux_logits) 형태일 수 있음
            if hasattr(out, "logits"):
                out = out.logits
            elif isinstance(out, (tuple, list)):
                out = out[0]

            loss = criterion(out, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * X.size(0)
            pred = (torch.sigmoid(out) > 0.5).float()
            correct += (pred == y).sum().item()
            total += y.size(0)

        train_loss = running_loss / len(train_loader.dataset)
        train_acc = correct / total

        # ---------------------- Validation --------------------------
        model.eval()
        total, correct = 0, 0
        running_val = 0

        with torch.no_grad():
            for X, y in tqdm(val_loader, desc=f"Val Epoch {epoch+1}/{epochs}"):
                X, y = X.to(device), y.to(device)
                out = model(X)

                if hasattr(out, "logits"):
                    out = out.logits
                elif isinstance(out, (tuple, list)):
                    out = out[0]

                loss = criterion(out, y)

                running_val += loss.item() * X.size(0)
                pred = (torch.sigmoid(out) > 0.5).float()
                correct += (pred == y).sum().item()
                total += y.size(0)

        val_loss = running_val / len(val_loader.dataset)
        val_acc = correct / total

        print(f"[Epoch {epoch+1}] Train Loss {train_loss:.4f} Acc {train_acc:.4f} | "
              f"Val Loss {val_loss:.4f} Acc {val_acc:.4f}")

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = copy.deepcopy(model.state_dict())
            no_improve = 0
            print("→ Best model 갱신")
            torch.save(best_model, MODEL_SAVE_PATH)
        else:
            no_improve += 1
            print(f"→ 개선 없음 ({no_improve}/{patience})")
            if no_improve >= patience:
                print("EARLY STOPPING!")
                break

    if best_model is not None:
        model.load_state_dict(best_model)
    return model


# ==============================
# 7. Evaluate
# ==============================
def evaluate_model(model, loader, criterion, device):
    model.eval()
    total, correct = 0, 0
    running_loss = 0

    with torch.no_grad():
        for X, y in tqdm(loader, desc="TEST"):
            X, y = X.to(device), y.to(device)
            out = model(X)

            if hasattr(out, "logits"):
                out = out.logits
            elif isinstance(out, (tuple, list)):
                out = out[0]

            loss = criterion(out, y)

            running_loss += loss.item() * X.size(0)
            pred = (torch.sigmoid(out) > 0.5).float()
            correct += (pred == y).sum().item()
            total += y.size(0)

    print(f"\nTest Loss: {running_loss/len(loader.dataset):.4f}")
    print(f"Test Accuracy: {correct/total*100:.2f}%")


# ==============================
# 8. MAIN
# ==============================
def main():
    print("===== START =====")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("DEVICE:", device)

    # --- 경로에서 이미지 읽기 ---
    real_paths = glob.glob(os.path.join(REAL_PATH, "*"))
    fake_paths = glob.glob(os.path.join(FAKE_PATH, "*"))

    print(f"Real: {len(real_paths)}, Fake: {len(fake_paths)}")

    if len(real_paths) == 0 or len(fake_paths) == 0:
        print("❌ Real 또는 Fake 폴더에 이미지가 없습니다. 경로 다시 확인하세요.")
        return

    all_paths = real_paths + fake_paths
    all_labels = [0]*len(real_paths) + [1]*len(fake_paths)

    # --- NUM_SAMPLES 만큼 stratified 샘플링 (train_size 사용) ---
    use_num = min(NUM_SAMPLES, len(all_paths))
    print(f"\n[샘플링] 최대 {NUM_SAMPLES}개 중 {use_num}개 사용")

    if use_num == len(all_paths):
        sample_paths = all_paths
        sample_labels = all_labels
        print("→ 전체 데이터 사용")
    else:
        sample_paths, _, sample_labels, _ = train_test_split(
            all_paths, all_labels,
            train_size=use_num,
            stratify=all_labels,
            random_state=42
        )
        print(f"→ 샘플링 완료: {len(sample_paths)}개")

    # --- Train / Val / Test 분할 ---
    train_paths, temp_paths, train_labels, temp_labels = train_test_split(
        sample_paths, sample_labels, test_size=0.3, stratify=sample_labels, random_state=42
    )
    val_paths, test_paths, val_labels, test_labels = train_test_split(
        temp_paths, temp_labels, test_size=1/3, stratify=temp_labels, random_state=42
    )

    print(f"\nTrain {len(train_paths)}  Val {len(val_paths)}  Test {len(test_paths)}")

    # --- Transform 정의 (FFT 포함) ---
    train_tf = transforms.Compose([
        ResizeWithPad(IMG_SIZE),     # PIL -> PIL
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),       # PIL -> Tensor
        FFTMag()                     # Tensor -> FFT Magnitude Tensor
    ])

    test_tf = transforms.Compose([
        ResizeWithPad(IMG_SIZE),
        transforms.ToTensor(),
        FFTMag()
    ])

    # --- DataLoader (디스크에서 바로 읽기) ---
    train_loader = DataLoader(
        DeepfakeDataset(train_paths, train_labels, train_tf),
        batch_size=BATCH_SIZE, shuffle=True,
        num_workers=2, pin_memory=True
    )
    val_loader = DataLoader(
        DeepfakeDataset(val_paths, val_labels, test_tf),
        batch_size=BATCH_SIZE, shuffle=False,
        num_workers=2, pin_memory=True
    )
    test_loader = DataLoader(
        DeepfakeDataset(test_paths, test_labels, test_tf),
        batch_size=BATCH_SIZE, shuffle=False,
        num_workers=2, pin_memory=True
    )

    # --- 모델/옵티마 정의 ---
    model = get_model_finetune("resnet50", device, mode="full")
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=LEARNING_RATE
    )

    # --- 학습 ---
    model = train_model(model, train_loader, val_loader, criterion, optimizer, device, EPOCHS, PATIENCE)

    # --- 평가 ---
    evaluate_model(model, test_loader, criterion, device)

    print("===== DONE =====")


# 실행
if __name__ == "__main__":
    main()

Mounted at /content/drive
===== START =====
DEVICE: cuda
Real: 34681, Fake: 33854

[샘플링] 최대 10000개 중 10000개 사용
→ 샘플링 완료: 10000개

Train 7000  Val 2000  Test 1000

=== Loading resnet50 (pretrained, mode=full, use_fft=True) ===
→ 전체 파라미터 Fine-Tuning


Train Epoch 1/30: 100%|██████████| 110/110 [01:53<00:00,  1.03s/it]
Val Epoch 1/30: 100%|██████████| 32/32 [00:28<00:00,  1.12it/s]


[Epoch 1] Train Loss 0.5685 Acc 0.6933 | Val Loss 0.5059 Acc 0.7460
→ Best model 갱신


Train Epoch 2/30: 100%|██████████| 110/110 [01:51<00:00,  1.01s/it]
Val Epoch 2/30: 100%|██████████| 32/32 [00:28<00:00,  1.11it/s]


[Epoch 2] Train Loss 0.4417 Acc 0.7833 | Val Loss 0.4885 Acc 0.7505
→ Best model 갱신


Train Epoch 3/30: 100%|██████████| 110/110 [01:23<00:00,  1.31it/s]
Val Epoch 3/30: 100%|██████████| 32/32 [00:16<00:00,  1.93it/s]


[Epoch 3] Train Loss 0.3475 Acc 0.8394 | Val Loss 0.5227 Acc 0.7550
→ 개선 없음 (1/3)


Train Epoch 4/30: 100%|██████████| 110/110 [01:11<00:00,  1.54it/s]
Val Epoch 4/30: 100%|██████████| 32/32 [00:16<00:00,  1.96it/s]


[Epoch 4] Train Loss 0.2689 Acc 0.8870 | Val Loss 0.5447 Acc 0.7525
→ 개선 없음 (2/3)


Train Epoch 5/30: 100%|██████████| 110/110 [01:13<00:00,  1.50it/s]
Val Epoch 5/30: 100%|██████████| 32/32 [00:16<00:00,  1.91it/s]


[Epoch 5] Train Loss 0.1940 Acc 0.9186 | Val Loss 0.6614 Acc 0.7545
→ 개선 없음 (3/3)
EARLY STOPPING!


TEST: 100%|██████████| 16/16 [00:07<00:00,  2.02it/s]


Test Loss: 0.4829
Test Accuracy: 75.40%
===== DONE =====



