In [12]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image

import os
import glob
import copy
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

from tqdm import tqdm

from torchvision.models import (
    googlenet, GoogLeNet_Weights,
    resnet50, ResNet50_Weights,
    vgg16, VGG16_Weights,
    vit_b_16, ViT_B_16_Weights
)

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

BASE_DIR = "/content/drive/MyDrive/기학기"
REAL_PATH = BASE_DIR + "/face_real"
FAKE_PATH = BASE_DIR + "/face_fake"
MODEL_SAVE_PATH = "/content/best_googlenet_fft.pth"

IMG_SIZE = 224
EPOCHS = 30
BATCH_SIZE = 64
NUM_SAMPLES = 10000
LEARNING_RATE = 1e-4
PATIENCE = 3

# 2. ResizeWithPad
class ResizeWithPad:
    """긴 변 기준으로 리사이즈 후, 짧은 변은 패딩해서 정사각형으로 맞추는 Transform (PIL -> PIL)"""
    def __init__(self, target_size):
        self.target_size = target_size

    def __call__(self, img):
        w, h = img.size
        scale = self.target_size / max(w, h)

        new_w = int(w * scale)
        new_h = int(h * scale)

        resized_img = img.resize((new_w, new_h), Image.LANCZOS)
        canvas = Image.new("RGB", (self.target_size, self.target_size), (0, 0, 0))

        pad_x = (self.target_size - new_w) // 2
        pad_y = (self.target_size - new_h) // 2
        canvas.paste(resized_img, (pad_x, pad_y))

        return canvas


# 3. FFT Magnitude Transform
class FFTMag:

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        x_fft = torch.fft.fft2(x)
        x_fft_shift = torch.fft.fftshift(x_fft)
        mag = torch.abs(x_fft_shift)
        mag = torch.log1p(mag)

        mean = mag.mean()
        std = mag.std()
        mag = (mag - mean) / (std + 1e-8)

        return mag

# 4. Dataset
class DeepfakeDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img_path = self.paths[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return img, label.unsqueeze(0)


# 5. Fine-Tuning 모델
def get_model_finetune(name, device, mode="full"):
    """
    name: 'googlenet', 'resnet50', 'vgg16', 'vit' 중 하나
    mode: 'head' -> classifier만 학습, 'full' -> 전체 파라미터 학습
    """
    print(f"\n=== Loading {name} (pretrained, mode={mode}, use_fft=True) ===")

    if name == "googlenet":
        weights = GoogLeNet_Weights.DEFAULT
        model = googlenet(weights=weights)
        in_f = model.fc.in_features
        model.fc = nn.Linear(in_f, 1)
        head_params = model.fc.parameters()

    elif name == "resnet50":
        weights = ResNet50_Weights.DEFAULT
        model = resnet50(weights=weights)
        in_f = model.fc.in_features
        model.fc = nn.Linear(in_f, 1)
        head_params = model.fc.parameters()

    elif name == "vgg16":
        weights = VGG16_Weights.DEFAULT
        model = vgg16(weights=weights)
        in_f = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(in_f, 1)
        head_params = model.classifier[6].parameters()

    elif name == "vit":
        weights = ViT_B_16_Weights.DEFAULT
        model = vit_b_16(weights=weights)
        in_f = model.heads.head.in_features
        model.heads.head = nn.Linear(in_f, 1)
        head_params = model.heads.head.parameters()

    else:
        raise ValueError("지원하지 않는 모델 이름입니다.")

    # Freeze 설정
    if mode == "head":
        print("→ Classifier만 학습 (Feature extractor는 Freeze)")
        for p in model.parameters():
            p.requires_grad = False
        for p in head_params:
            p.requires_grad = True
    else:
        print("→ 전체 파라미터 Fine-Tuning")
        for p in model.parameters():
            p.requires_grad = True

    return model.to(device)


# 6. Train (히스토리 저장)
def train_model(model, train_loader, val_loader, criterion, optimizer, device, epochs, patience):

    best_val_loss = float("inf")
    best_model = None
    no_improve = 0

    history = {
        "train_loss": [],
        "val_loss": [],
        "train_acc": [],
        "val_acc": [],
    }

    for epoch in range(epochs):
        model.train()
        total, correct = 0, 0
        running_loss = 0

        pbar = tqdm(train_loader, desc=f"Train Epoch {epoch+1}/{epochs}")
        for X, y in pbar:
            X, y = X.to(device), y.to(device)

            optimizer.zero_grad()
            out = model(X)

            if hasattr(out, "logits"):
                out = out.logits
            elif isinstance(out, (tuple, list)):
                out = out[0]

            loss = criterion(out, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * X.size(0)
            pred = (torch.sigmoid(out) > 0.5).float()
            correct += (pred == y).sum().item()
            total += y.size(0)

        train_loss = running_loss / len(train_loader.dataset)
        train_acc = correct / total

        # ---------------------- Validation --------------------------
        model.eval()
        total, correct = 0, 0
        running_val = 0

        with torch.no_grad():
            for X, y in tqdm(val_loader, desc=f"Val Epoch {epoch+1}/{epochs}"):
                X, y = X.to(device), y.to(device)
                out = model(X)

                if hasattr(out, "logits"):
                    out = out.logits
                elif isinstance(out, (tuple, list)):
                    out = out[0]

                loss = criterion(out, y)

                running_val += loss.item() * X.size(0)
                pred = (torch.sigmoid(out) > 0.5).float()
                correct += (pred == y).sum().item()
                total += y.size(0)

        val_loss = running_val / len(val_loader.dataset)
        val_acc = correct / total

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["train_acc"].append(train_acc)
        history["val_acc"].append(val_acc)

        print(f"[Epoch {epoch+1}] Train Loss {train_loss:.4f} Acc {train_acc:.4f} | "
              f"Val Loss {val_loss:.4f} Acc {val_acc:.4f}")

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = copy.deepcopy(model.state_dict())
            no_improve = 0
            print("→ Best model 갱신")
            torch.save(best_model, MODEL_SAVE_PATH)
        else:
            no_improve += 1
            print(f"→ 개선 없음 ({no_improve}/{patience})")
            if no_improve >= patience:
                print("EARLY STOPPING!")
                break

    if best_model is not None:
        model.load_state_dict(best_model)
    return model, history


# 7. 학습 곡선 시각화
def plot_history(history, save_prefix="/content/train_history"):
    epochs = range(1, len(history["train_loss"]) + 1)

    # Loss
    plt.figure()
    plt.plot(epochs, history["train_loss"], label="Train Loss")
    plt.plot(epochs, history["val_loss"], label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Train / Val Loss")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_prefix + "_loss.png")
    plt.close()

    # Accuracy
    plt.figure()
    plt.plot(epochs, history["train_acc"], label="Train Acc")
    plt.plot(epochs, history["val_acc"], label="Val Acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Train / Val Accuracy")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_prefix + "_acc.png")
    plt.close()

    print(f"▶ 학습 곡선 이미지 저장 완료: {save_prefix}_loss.png, {save_prefix}_acc.png")


# 8. Evaluate + Confusion Matrix
def evaluate_model(model, loader, criterion, device):
    model.eval()
    total, correct = 0, 0
    running_loss = 0

    all_labels = []
    all_preds = []

    with torch.no_grad():
        for X, y in tqdm(loader, desc="TEST"):
            X, y = X.to(device), y.to(device)
            out = model(X)

            if hasattr(out, "logits"):
                out = out.logits
            elif isinstance(out, (tuple, list)):
                out = out[0]

            loss = criterion(out, y)

            running_loss += loss.item() * X.size(0)
            probs = torch.sigmoid(out)
            pred = (probs > 0.5).float()
            correct += (pred == y).sum().item()
            total += y.size(0)

            all_labels.append(y.cpu().numpy())
            all_preds.append(pred.cpu().numpy())

    test_loss = running_loss / len(loader.dataset)
    test_acc = correct / total

    print(f"\nTest Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_acc*100:.2f}%")

    y_true = np.concatenate(all_labels, axis=0)
    y_pred = np.concatenate(all_preds, axis=0)

    cm = confusion_matrix(y_true, y_pred)
    plt.figure()
    plt.imshow(cm, interpolation="nearest", cmap="Blues")
    plt.title("Confusion Matrix")
    plt.colorbar()
    tick_marks = np.arange(2)
    plt.xticks(tick_marks, ["Real(0)", "Fake(1)"])
    plt.yticks(tick_marks, ["Real(0)", "Fake(1)"])

    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], 'd'),
                     ha="center", va="center",
                     color="white" if cm[i, j] > thresh else "black")

    plt.ylabel("True label")
    plt.xlabel("Predicted label")
    plt.tight_layout()
    cm_path = "/content/confusion_matrix.png"
    plt.savefig(cm_path)
    plt.close()
    print(f"▶ Confusion Matrix 이미지 저장 완료: {cm_path}")


# 9. MAIN
def main():
    print("===== START =====")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("DEVICE:", device)

    real_paths = glob.glob(os.path.join(REAL_PATH, "*"))
    fake_paths = glob.glob(os.path.join(FAKE_PATH, "*"))

    print(f"Real: {len(real_paths)}, Fake: {len(fake_paths)}")

    if len(real_paths) == 0 or len(fake_paths) == 0:
        print("❌ Real 또는 Fake 폴더에 이미지가 없습니다. 경로 다시 확인하세요.")
        return

    all_paths = real_paths + fake_paths
    all_labels = [0]*len(real_paths) + [1]*len(fake_paths)

    use_num = min(NUM_SAMPLES, len(all_paths))
    print(f"\n[샘플링] 최대 {NUM_SAMPLES}개 중 {use_num}개 사용")

    if use_num == len(all_paths):
        sample_paths = all_paths
        sample_labels = all_labels
        print("→ 전체 데이터 사용")
    else:
        sample_paths, _, sample_labels, _ = train_test_split(
            all_paths, all_labels,
            train_size=use_num,
            stratify=all_labels,
            random_state=42
        )
        print(f"→ 샘플링 완료: {len(sample_paths)}개")

    train_paths, temp_paths, train_labels, temp_labels = train_test_split(
        sample_paths, sample_labels, test_size=0.3, stratify=sample_labels, random_state=42
    )
    val_paths, test_paths, val_labels, test_labels = train_test_split(
        temp_paths, temp_labels, test_size=1/3, stratify=temp_labels, random_state=42
    )

    print(f"\nTrain {len(train_paths)}  Val {len(val_paths)}  Test {len(test_paths)}")

    train_tf = transforms.Compose([
        ResizeWithPad(IMG_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        FFTMag()
    ])

    test_tf = transforms.Compose([
        ResizeWithPad(IMG_SIZE),
        transforms.ToTensor(),
        FFTMag()
    ])

    train_loader = DataLoader(
        DeepfakeDataset(train_paths, train_labels, train_tf),
        batch_size=BATCH_SIZE, shuffle=True,
        num_workers=2, pin_memory=True
    )
    val_loader = DataLoader(
        DeepfakeDataset(val_paths, val_labels, test_tf),
        batch_size=BATCH_SIZE, shuffle=False,
        num_workers=2, pin_memory=True
    )
    test_loader = DataLoader(
        DeepfakeDataset(test_paths, test_labels, test_tf),
        batch_size=BATCH_SIZE, shuffle=False,
        num_workers=2, pin_memory=True
    )

    # --- 모델/옵티마 정의 ---
    model = get_model_finetune("vit", device, mode="full")
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=LEARNING_RATE
    )

    # --- 학습 ---
    model, history = train_model(model, train_loader, val_loader, criterion, optimizer, device, EPOCHS, PATIENCE)

    # --- 학습 곡선 시각화 ---
    plot_history(history, save_prefix="/content/train_history")

    # --- 평가 + Confusion Matrix ---
    evaluate_model(model, test_loader, criterion, device)

    print("===== DONE =====")


# 실행
if __name__ == "__main__":
    main()

Mounted at /content/drive
===== START =====
DEVICE: cuda
Real: 34681, Fake: 33854

[샘플링] 최대 10000개 중 10000개 사용
→ 샘플링 완료: 10000개

Train 7000  Val 2000  Test 1000

=== Loading vit (pretrained, mode=full, use_fft=True) ===
Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth


100%|██████████| 330M/330M [00:05<00:00, 66.7MB/s]


→ 전체 파라미터 Fine-Tuning


Train Epoch 1/30: 100%|██████████| 110/110 [08:19<00:00,  4.54s/it]
Val Epoch 1/30: 100%|██████████| 32/32 [00:23<00:00,  1.37it/s]


[Epoch 1] Train Loss 0.7127 Acc 0.5186 | Val Loss 0.7014 Acc 0.5060
→ Best model 갱신


Train Epoch 2/30: 100%|██████████| 110/110 [04:01<00:00,  2.20s/it]
Val Epoch 2/30: 100%|██████████| 32/32 [00:23<00:00,  1.37it/s]


[Epoch 2] Train Loss 0.6873 Acc 0.5519 | Val Loss 0.6817 Acc 0.5795
→ Best model 갱신


Train Epoch 3/30: 100%|██████████| 110/110 [03:57<00:00,  2.16s/it]
Val Epoch 3/30: 100%|██████████| 32/32 [00:23<00:00,  1.36it/s]


[Epoch 3] Train Loss 0.6657 Acc 0.5931 | Val Loss 0.6566 Acc 0.5980
→ Best model 갱신


Train Epoch 4/30: 100%|██████████| 110/110 [03:59<00:00,  2.17s/it]
Val Epoch 4/30: 100%|██████████| 32/32 [00:23<00:00,  1.39it/s]


[Epoch 4] Train Loss 0.6698 Acc 0.5864 | Val Loss 0.6990 Acc 0.5165
→ 개선 없음 (1/3)


Train Epoch 5/30: 100%|██████████| 110/110 [03:58<00:00,  2.17s/it]
Val Epoch 5/30: 100%|██████████| 32/32 [00:23<00:00,  1.36it/s]


[Epoch 5] Train Loss 0.6735 Acc 0.5749 | Val Loss 0.6893 Acc 0.5120
→ 개선 없음 (2/3)


Train Epoch 6/30: 100%|██████████| 110/110 [03:58<00:00,  2.17s/it]
Val Epoch 6/30: 100%|██████████| 32/32 [00:24<00:00,  1.33it/s]


[Epoch 6] Train Loss 0.6554 Acc 0.6051 | Val Loss 0.7283 Acc 0.5965
→ 개선 없음 (3/3)
EARLY STOPPING!
▶ 학습 곡선 이미지 저장 완료: /content/train_history_loss.png, /content/train_history_acc.png


TEST: 100%|██████████| 16/16 [00:12<00:00,  1.31it/s]



Test Loss: 0.6584
Test Accuracy: 60.60%
▶ Confusion Matrix 이미지 저장 완료: /content/confusion_matrix.png
===== DONE =====
