In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image

import os
import glob
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import time
import copy

from torchvision.models import (
    resnet50, ResNet50_Weights,
    googlenet, GoogLeNet_Weights,
    vgg16, VGG16_Weights,
    vit_b_16, ViT_B_16_Weights
)

BASE_DIR = "/content/drive/MyDrive/기학기"
REAL_PATH = BASE_DIR + "/face_real"
FAKE_PATH = BASE_DIR + "/face_fake"
MODEL_SAVE_PATH = "/content/best_resnet50.pth"

IMG_SIZE = 224
EPOCHS = 30
BATCH_SIZE = 64
NUM_SAMPLES = 10000
LEARNING_RATE = 1e-4
PATIENCE = 3


# 2. ResizeWithPad
class ResizeWithPad:
    def __init__(self, target_size):
        self.target_size = target_size

    def __call__(self, img):
        w, h = img.size
        scale = self.target_size / max(w, h)

        new_w = int(w * scale)
        new_h = int(h * scale)

        resized_img = img.resize((new_w, new_h), Image.LANCZOS)
        canvas = Image.new("RGB", (self.target_size, self.target_size), (0, 0, 0))

        pad_x = (self.target_size - new_w) // 2
        pad_y = (self.target_size - new_h) // 2
        canvas.paste(resized_img, (pad_x, pad_y))

        return canvas

# 3. Dataset
class DeepfakeDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img_path = self.paths[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return img, label.unsqueeze(0)


# 4. Fine-Tuning 모델
def get_model_finetune(name, device, mode="full"):
    print(f"\n=== Loading {name} (pretrained, mode={mode}) ===")

    if name == "resnet50":
        weights = ResNet50_Weights.DEFAULT
        model = resnet50(weights=weights)
        in_f = model.fc.in_features
        model.fc = nn.Linear(in_f, 1)
        head_params = model.fc.parameters()

    elif name == "vgg16":
        weights = VGG16_Weights.DEFAULT
        model = vgg16(weights=weights)
        in_f = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(in_f, 1)
        head_params = model.classifier[6].parameters()

    elif name == "googlenet":
        weights = GoogLeNet_Weights.DEFAULT
        model = googlenet(weights=weights, aux_logits=True)
        in_f = model.fc.in_features
        model.fc = nn.Linear(in_f, 1)
        head_params = model.fc.parameters()

    elif name == "vit":
        weights = ViT_B_16_Weights.DEFAULT
        model = vit_b_16(weights=weights)
        in_f = model.heads.head.in_features
        model.heads.head = nn.Linear(in_f, 1)
        head_params = model.heads.head.parameters()
    else:
        raise ValueError("지원하지 않는 모델")

    # Freeze 설정
    if mode == "head":
        print("→ Classifier만 학습")
        for p in model.parameters():
            p.requires_grad = False
        for p in head_params:
            p.requires_grad = True
    else:
        print("→ 전체 파라미터 fine-tuning")
        for p in model.parameters():
            p.requires_grad = True

    return model.to(device)


# 5. Train
def train_model(model, train_loader, val_loader, criterion, optimizer, device, epochs, patience):

    best_val_loss = float("inf")
    best_model = None
    no_improve = 0

    for epoch in range(epochs):
        model.train()
        total, correct = 0, 0
        running_loss = 0

        pbar = tqdm(train_loader, desc=f"Train Epoch {epoch+1}/{epochs}")
        for X, y in pbar:
            X, y = X.to(device), y.to(device)

            optimizer.zero_grad()
            out = model(X)

            if isinstance(out, (tuple, list)):
                out = out[0]
            elif hasattr(out, "logits"):
                out = out.logits

            loss = criterion(out, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * X.size(0)
            pred = (torch.sigmoid(out) > 0.5).float()
            correct += (pred == y).sum().item()
            total += y.size(0)

        train_loss = running_loss / len(train_loader.dataset)
        train_acc = correct / total

        # Validation
        model.eval()
        total, correct = 0, 0
        running_val = 0

        with torch.no_grad():
            for X, y in tqdm(val_loader, desc=f"Val Epoch {epoch+1}/{epochs}"):
                X, y = X.to(device), y.to(device)
                out = model(X)

                if isinstance(out, (tuple, list)):
                    out = out[0]
                elif hasattr(out, "logits"):
                    out = out.logits

                loss = criterion(out, y)

                running_val += loss.item() * X.size(0)
                pred = (torch.sigmoid(out) > 0.5).float()
                correct += (pred == y).sum().item()
                total += y.size(0)

        val_loss = running_val / len(val_loader.dataset)
        val_acc = correct / total

        print(f"[Epoch {epoch+1}] Train Loss {train_loss:.4f} Acc {train_acc:.4f} | "
              f"Val Loss {val_loss:.4f} Acc {val_acc:.4f}")

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = copy.deepcopy(model.state_dict())
            no_improve = 0
            print("→ Best model 갱신")
            torch.save(best_model, MODEL_SAVE_PATH)
        else:
            no_improve += 1
            print(f"→ 개선 없음 ({no_improve}/{patience})")
            if no_improve >= patience:
                print("EARLY STOPPING!")
                break

    if best_model is not None:
        model.load_state_dict(best_model)
    return model


# 6. Evaluate
def evaluate_model(model, loader, criterion, device):
    model.eval()
    total, correct = 0, 0
    running_loss = 0

    with torch.no_grad():
        for X, y in tqdm(loader, desc="TEST"):
            X, y = X.to(device), y.to(device)
            out = model(X)

            if isinstance(out, (tuple, list)):
                out = out[0]
            elif hasattr(out, "logits"):
                out = out.logits

            loss = criterion(out, y)

            running_loss += loss.item() * X.size(0)
            pred = (torch.sigmoid(out) > 0.5).float()
            correct += (pred == y).sum().item()
            total += y.size(0)

    print(f"\nTest Loss: {running_loss/len(loader.dataset):.4f}")
    print(f"Test Accuracy: {correct/total*100:.2f}%")


# 7. MAIN
def main():
    print("===== START =====")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("DEVICE:", device)

    # --- 경로에서 이미지 읽기 ---
    real_paths = glob.glob(os.path.join(REAL_PATH, "*"))
    fake_paths = glob.glob(os.path.join(FAKE_PATH, "*"))

    print(f"Real: {len(real_paths)}, Fake: {len(fake_paths)}")

    if len(real_paths) == 0 or len(fake_paths) == 0:
        print("❌ Real 또는 Fake 폴더에 이미지가 없습니다. 경로 다시 확인하세요.")
        return

    all_paths = real_paths + fake_paths
    all_labels = [0]*len(real_paths) + [1]*len(fake_paths)

    # --- NUM_SAMPLES 만큼 stratified 샘플링 (train_size 사용) ---
    use_num = min(NUM_SAMPLES, len(all_paths))
    print(f"\n[샘플링] 최대 {NUM_SAMPLES}개 중 {use_num}개 사용")

    if use_num == len(all_paths):
        sample_paths = all_paths
        sample_labels = all_labels
        print("→ 전체 데이터 사용")
    else:
        sample_paths, _, sample_labels, _ = train_test_split(
            all_paths, all_labels,
            train_size=use_num,
            stratify=all_labels,
            random_state=42
        )
        print(f"→ 샘플링 완료: {len(sample_paths)}개")

    # --- Train / Val / Test 분할 ---
    train_paths, temp_paths, train_labels, temp_labels = train_test_split(
        sample_paths, sample_labels, test_size=0.3, stratify=sample_labels, random_state=42
    )
    val_paths, test_paths, val_labels, test_labels = train_test_split(
        temp_paths, temp_labels, test_size=1/3, stratify=temp_labels, random_state=42
    )

    print(f"\nTrain {len(train_paths)}  Val {len(val_paths)}  Test {len(test_paths)}")

    # --- Transform 정의 ---
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    train_tf = transforms.Compose([
        ResizeWithPad(IMG_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    test_tf = transforms.Compose([
        ResizeWithPad(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    # --- DataLoader (디스크에서 바로 읽기) ---
    train_loader = DataLoader(
        DeepfakeDataset(train_paths, train_labels, train_tf),
        batch_size=BATCH_SIZE, shuffle=True,
        num_workers=2, pin_memory=True
    )
    val_loader = DataLoader(
        DeepfakeDataset(val_paths, val_labels, test_tf),
        batch_size=BATCH_SIZE, shuffle=False,
        num_workers=2, pin_memory=True
    )
    test_loader = DataLoader(
        DeepfakeDataset(test_paths, test_labels, test_tf),
        batch_size=BATCH_SIZE, shuffle=False,
        num_workers=2, pin_memory=True
    )

    # --- 모델/옵티마 정의 ---
    model = get_model_finetune("resnet50", device, mode="head")
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=LEARNING_RATE
    )

    # --- 학습 ---
    model = train_model(model, train_loader, val_loader, criterion, optimizer, device, EPOCHS, PATIENCE)

    # --- 평가 ---
    evaluate_model(model, test_loader, criterion, device)

    print("===== DONE =====")


# 실행
if __name__ == "__main__":
    main()

===== START =====
DEVICE: cuda
Real: 34681, Fake: 33854

[샘플링] 최대 10000개 중 10000개 사용
→ 샘플링 완료: 10000개

Train 7000  Val 2000  Test 1000

=== Loading resnet50 (pretrained, mode=head) ===
→ Classifier만 학습


Train Epoch 1/30: 100%|██████████| 110/110 [01:01<00:00,  1.78it/s]
Val Epoch 1/30: 100%|██████████| 32/32 [00:17<00:00,  1.86it/s]


[Epoch 1] Train Loss 0.6845 Acc 0.5674 | Val Loss 0.6747 Acc 0.6295
→ Best model 갱신


Train Epoch 2/30: 100%|██████████| 110/110 [05:12<00:00,  2.84s/it]
Val Epoch 2/30: 100%|██████████| 32/32 [00:10<00:00,  3.20it/s]


[Epoch 2] Train Loss 0.6662 Acc 0.6401 | Val Loss 0.6616 Acc 0.6630
→ Best model 갱신


Train Epoch 3/30: 100%|██████████| 110/110 [00:37<00:00,  2.90it/s]
Val Epoch 3/30: 100%|██████████| 32/32 [00:11<00:00,  2.86it/s]


[Epoch 3] Train Loss 0.6527 Acc 0.6616 | Val Loss 0.6506 Acc 0.6745
→ Best model 갱신


Train Epoch 4/30: 100%|██████████| 110/110 [00:38<00:00,  2.84it/s]
Val Epoch 4/30: 100%|██████████| 32/32 [00:10<00:00,  2.94it/s]


[Epoch 4] Train Loss 0.6405 Acc 0.6716 | Val Loss 0.6426 Acc 0.6875
→ Best model 갱신


Train Epoch 5/30: 100%|██████████| 110/110 [00:41<00:00,  2.62it/s]
Val Epoch 5/30: 100%|██████████| 32/32 [00:10<00:00,  2.96it/s]


[Epoch 5] Train Loss 0.6323 Acc 0.6803 | Val Loss 0.6335 Acc 0.6915
→ Best model 갱신


Train Epoch 6/30: 100%|██████████| 110/110 [00:37<00:00,  2.90it/s]
Val Epoch 6/30: 100%|██████████| 32/32 [00:10<00:00,  2.94it/s]


[Epoch 6] Train Loss 0.6234 Acc 0.6827 | Val Loss 0.6243 Acc 0.7005
→ Best model 갱신


Train Epoch 7/30: 100%|██████████| 110/110 [00:39<00:00,  2.75it/s]
Val Epoch 7/30: 100%|██████████| 32/32 [00:10<00:00,  2.94it/s]


[Epoch 7] Train Loss 0.6163 Acc 0.6917 | Val Loss 0.6183 Acc 0.7050
→ Best model 갱신


Train Epoch 8/30: 100%|██████████| 110/110 [00:38<00:00,  2.88it/s]
Val Epoch 8/30: 100%|██████████| 32/32 [00:10<00:00,  3.16it/s]


[Epoch 8] Train Loss 0.6100 Acc 0.6959 | Val Loss 0.6156 Acc 0.7025
→ Best model 갱신


Train Epoch 9/30: 100%|██████████| 110/110 [00:38<00:00,  2.88it/s]
Val Epoch 9/30: 100%|██████████| 32/32 [00:10<00:00,  2.93it/s]


[Epoch 9] Train Loss 0.6036 Acc 0.7033 | Val Loss 0.6083 Acc 0.7085
→ Best model 갱신


Train Epoch 10/30: 100%|██████████| 110/110 [00:38<00:00,  2.89it/s]
Val Epoch 10/30: 100%|██████████| 32/32 [00:10<00:00,  3.07it/s]


[Epoch 10] Train Loss 0.5984 Acc 0.7044 | Val Loss 0.6049 Acc 0.7110
→ Best model 갱신


Train Epoch 11/30: 100%|██████████| 110/110 [00:37<00:00,  2.94it/s]
Val Epoch 11/30: 100%|██████████| 32/32 [00:11<00:00,  2.89it/s]


[Epoch 11] Train Loss 0.5945 Acc 0.7037 | Val Loss 0.6002 Acc 0.7060
→ Best model 갱신


Train Epoch 12/30: 100%|██████████| 110/110 [00:38<00:00,  2.84it/s]
Val Epoch 12/30: 100%|██████████| 32/32 [00:11<00:00,  2.91it/s]


[Epoch 12] Train Loss 0.5884 Acc 0.7110 | Val Loss 0.5968 Acc 0.7080
→ Best model 갱신


Train Epoch 13/30: 100%|██████████| 110/110 [00:37<00:00,  2.94it/s]
Val Epoch 13/30: 100%|██████████| 32/32 [00:10<00:00,  2.98it/s]


[Epoch 13] Train Loss 0.5831 Acc 0.7214 | Val Loss 0.5919 Acc 0.7160
→ Best model 갱신


Train Epoch 14/30: 100%|██████████| 110/110 [00:38<00:00,  2.85it/s]
Val Epoch 14/30: 100%|██████████| 32/32 [00:11<00:00,  2.88it/s]


[Epoch 14] Train Loss 0.5807 Acc 0.7133 | Val Loss 0.5901 Acc 0.7165
→ Best model 갱신


Train Epoch 15/30: 100%|██████████| 110/110 [00:38<00:00,  2.87it/s]
Val Epoch 15/30: 100%|██████████| 32/32 [00:10<00:00,  3.20it/s]


[Epoch 15] Train Loss 0.5768 Acc 0.7206 | Val Loss 0.5849 Acc 0.7175
→ Best model 갱신


Train Epoch 16/30: 100%|██████████| 110/110 [00:38<00:00,  2.85it/s]
Val Epoch 16/30: 100%|██████████| 32/32 [00:11<00:00,  2.87it/s]


[Epoch 16] Train Loss 0.5734 Acc 0.7207 | Val Loss 0.5817 Acc 0.7210
→ Best model 갱신


Train Epoch 17/30: 100%|██████████| 110/110 [00:38<00:00,  2.85it/s]
Val Epoch 17/30: 100%|██████████| 32/32 [00:10<00:00,  2.95it/s]


[Epoch 17] Train Loss 0.5706 Acc 0.7201 | Val Loss 0.5788 Acc 0.7180
→ Best model 갱신


Train Epoch 18/30: 100%|██████████| 110/110 [00:37<00:00,  2.93it/s]
Val Epoch 18/30: 100%|██████████| 32/32 [00:11<00:00,  2.79it/s]


[Epoch 18] Train Loss 0.5649 Acc 0.7310 | Val Loss 0.5831 Acc 0.7175
→ 개선 없음 (1/3)


Train Epoch 19/30: 100%|██████████| 110/110 [00:38<00:00,  2.89it/s]
Val Epoch 19/30: 100%|██████████| 32/32 [00:11<00:00,  2.90it/s]


[Epoch 19] Train Loss 0.5624 Acc 0.7334 | Val Loss 0.5785 Acc 0.7225
→ Best model 갱신


Train Epoch 20/30: 100%|██████████| 110/110 [00:37<00:00,  2.92it/s]
Val Epoch 20/30: 100%|██████████| 32/32 [00:10<00:00,  3.08it/s]


[Epoch 20] Train Loss 0.5607 Acc 0.7323 | Val Loss 0.5722 Acc 0.7205
→ Best model 갱신


Train Epoch 21/30: 100%|██████████| 110/110 [00:38<00:00,  2.87it/s]
Val Epoch 21/30: 100%|██████████| 32/32 [00:10<00:00,  2.92it/s]


[Epoch 21] Train Loss 0.5597 Acc 0.7307 | Val Loss 0.5763 Acc 0.7225
→ 개선 없음 (1/3)


Train Epoch 22/30: 100%|██████████| 110/110 [00:37<00:00,  2.94it/s]
Val Epoch 22/30: 100%|██████████| 32/32 [00:10<00:00,  3.05it/s]


[Epoch 22] Train Loss 0.5539 Acc 0.7383 | Val Loss 0.5707 Acc 0.7215
→ Best model 갱신


Train Epoch 23/30: 100%|██████████| 110/110 [00:38<00:00,  2.88it/s]
Val Epoch 23/30: 100%|██████████| 32/32 [00:11<00:00,  2.89it/s]


[Epoch 23] Train Loss 0.5529 Acc 0.7344 | Val Loss 0.5665 Acc 0.7270
→ Best model 갱신


Train Epoch 24/30: 100%|██████████| 110/110 [00:38<00:00,  2.86it/s]
Val Epoch 24/30: 100%|██████████| 32/32 [00:09<00:00,  3.24it/s]


[Epoch 24] Train Loss 0.5507 Acc 0.7347 | Val Loss 0.5667 Acc 0.7255
→ 개선 없음 (1/3)


Train Epoch 25/30: 100%|██████████| 110/110 [00:38<00:00,  2.84it/s]
Val Epoch 25/30: 100%|██████████| 32/32 [00:11<00:00,  2.88it/s]


[Epoch 25] Train Loss 0.5480 Acc 0.7421 | Val Loss 0.5625 Acc 0.7315
→ Best model 갱신


Train Epoch 26/30: 100%|██████████| 110/110 [00:38<00:00,  2.88it/s]
Val Epoch 26/30: 100%|██████████| 32/32 [00:10<00:00,  2.95it/s]


[Epoch 26] Train Loss 0.5480 Acc 0.7354 | Val Loss 0.5665 Acc 0.7290
→ 개선 없음 (1/3)


Train Epoch 27/30: 100%|██████████| 110/110 [00:37<00:00,  2.93it/s]
Val Epoch 27/30: 100%|██████████| 32/32 [00:11<00:00,  2.89it/s]


[Epoch 27] Train Loss 0.5458 Acc 0.7351 | Val Loss 0.5589 Acc 0.7330
→ Best model 갱신


Train Epoch 28/30: 100%|██████████| 110/110 [00:38<00:00,  2.82it/s]
Val Epoch 28/30: 100%|██████████| 32/32 [00:10<00:00,  2.91it/s]


[Epoch 28] Train Loss 0.5444 Acc 0.7393 | Val Loss 0.5588 Acc 0.7295
→ Best model 갱신


Train Epoch 29/30: 100%|██████████| 110/110 [00:37<00:00,  2.95it/s]
Val Epoch 29/30: 100%|██████████| 32/32 [00:10<00:00,  2.93it/s]


[Epoch 29] Train Loss 0.5410 Acc 0.7460 | Val Loss 0.5586 Acc 0.7300
→ Best model 갱신


Train Epoch 30/30: 100%|██████████| 110/110 [00:38<00:00,  2.87it/s]
Val Epoch 30/30: 100%|██████████| 32/32 [00:10<00:00,  2.92it/s]


[Epoch 30] Train Loss 0.5366 Acc 0.7471 | Val Loss 0.5577 Acc 0.7300
→ Best model 갱신


TEST: 100%|██████████| 16/16 [00:05<00:00,  2.92it/s]


Test Loss: 0.5708
Test Accuracy: 71.00%
===== DONE =====



