# 얼굴 표정 기반 우울증 판별 모델 학습 노트북

이 노트북은 얼굴 표정 이미지를 이용해 우울/비우울을 판별하는 이진 분류 모델을 학습합니다. 
모델은 사전 학습된 **ResNet50**을 기반으로 하며, 훈련 과정에서 다음 기능을 제공합니다:

- **tqdm** 진행률 표시로 배치별 학습 상태 모니터링
- **tensorboardX** 로깅을 통한 loss/accuracy/F1-score 모니터링
- **모델 파라미터 저장**: 각 에폭마다 모델 가중치를 파일로 저장
- **F1-score 및 classification report** 계산
- **Grad-CAM 시각화**: 모델이 어떤 영역을 중요하게 보는지 시각적으로 표시

현재 노트북은 batch size가 32이고 혼합 정밀도 학습(`use_mixed_precision=False`) 설정에 맞추어 작성되었습니다.

데이터 디렉터리 구조는 다음과 같다고 가정합니다:
```
data/
  train/
    train_image/<emotion>/<image files>
    train_label/<json files>
  vali/
    vali_image/<emotion>/<image files>
    vali_label/<json files>
```
라벨 JSON 파일은 다음과 같은 형식이어야 합니다 (키는 file_name, emotion):
```json
[
  {"file_name": "0001.jpg", "emotion": "joy"},
  {"file_name": "0002.jpg", "emotion": "anxiety"}
]
```

In [1]:

import os
import json
import time
import random
from pathlib import Path
from typing import List, Tuple

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import models, transforms

from tqdm.auto import tqdm
from tensorboardX import SummaryWriter
from sklearn.metrics import classification_report, f1_score

# 시드 설정 함수
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# 감정 맵핑
EMOTIONS = {
    'anger': '분노',
    'anxiety': '불안',
    'hurt': '상처',
    'joy': '기쁨',
    'neutral': '중립',
    'sadness': '슬픔',
    'surprise': '당황'
}
DEPRESSION = ['anxiety', 'hurt', 'sadness']
NON_DEP = ['anger', 'joy', 'neutral', 'surprise']
EMO2BIN = {**{e: 1 for e in DEPRESSION}, **{e: 0 for e in NON_DEP}}

# 데이터셋 정의
class EmotionDataset(Dataset):
    # 커스텀 데이터셋: 이미지 경로와 레이블을 로드
    def __init__(self, img_root: Path, label_root: Path, transforms=None):
        self.img_root = Path(img_root)
        self.label_root = Path(label_root)
        self.transforms = transforms
        self.samples: List[Tuple[str, int]] = []
        self._load()

    def _load(self):
        json_files = sorted(list(self.label_root.glob('*.json')))
        for jf in json_files:
            with open(jf, 'r', encoding='utf-8') as f:
                try:
                    items = json.load(f)
                except Exception as e:
                    print(f"[WARN] JSON parsing error in {jf}: {e}")
                    items = []
            for it in items:
                fname = it.get('file_name') or it.get('image') or it.get('img') or ''
                emo = it.get('emotion') or it.get('label') or ''
                if not fname or emo not in EMO2BIN:
                    continue
                path = self.img_root / emo / fname
                if path.exists():
                    self.samples.append((str(path), EMO2BIN[emo]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert('RGB')
        if self.transforms:
            img = self.transforms(img)
        return img, label

# 전처리 정의
def make_transforms():
    train_t = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    val_t = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return train_t, val_t

# 클래스 가중치 및 샘플러
def compute_class_weights(dataset: Dataset):
    labels = [y for _, y in dataset.samples]
    n0 = sum(1 for y in labels if y == 0)
    n1 = sum(1 for y in labels if y == 1)
    total = max(1, n0 + n1)
    w0 = total / (2 * max(1, n0))
    w1 = total / (2 * max(1, n1))
    return torch.tensor([w0, w1], dtype=torch.float32)

def make_weighted_sampler(dataset: Dataset):
    labels = [y for _, y in dataset.samples]
    class_count = np.bincount(labels, minlength=2)
    class_weights = 1. / np.maximum(class_count, 1)
    sample_weights = [class_weights[y] for y in labels]
    return WeightedRandomSampler(weights=torch.DoubleTensor(sample_weights), num_samples=len(sample_weights), replacement=True)

# 모델 초기화
def build_model(device):
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 2)
    return model.to(device)

# Grad-CAM 관련
def overlay_cam_on_image(img_tensor, cam_tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img = (img_tensor.cpu() * std + mean).clamp(0, 1)
    cam = cam_tensor.cpu().repeat(3, 1, 1)
    heat = (0.5 * img + 0.5 * cam).clamp(0, 1)
    return img, heat

class GradCAM:
    def __init__(self, model: nn.Module, target_layer: nn.Module):
        self.model = model
        self.target_layer = target_layer
        self.activations = None
        self.gradients = None
        self.hook_handles = []
        self._register()

    def _register(self):
        self.hook_handles.append(self.target_layer.register_forward_hook(self._forward_hook))
        self.hook_handles.append(self.target_layer.register_full_backward_hook(self._backward_hook))

    def _forward_hook(self, module, inp, out):
        self.activations = out.detach()

    def _backward_hook(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def __call__(self, input_tensor, class_idx=None):
        self.model.zero_grad(set_to_none=True)
        logits = self.model(input_tensor)
        if class_idx is None:
            class_idx = logits.argmax(dim=1)
        loss = logits[torch.arange(logits.size(0)), class_idx].sum()
        loss.backward()
        weights = self.gradients.mean(dim=(2, 3), keepdim=True)
        cam = (weights * self.activations).sum(dim=1, keepdim=True)
        cam = torch.relu(cam)
        cam_min = cam.amin(dim=(2, 3), keepdim=True)
        cam_max = cam.amax(dim=(2, 3), keepdim=True).clamp(min=1e-6)
        cam = (cam - cam_min) / (cam_max - cam_min)
        return cam

    def remove(self):
        for h in self.hook_handles:
            h.remove()


  from .autonotebook import tqdm as notebook_tqdm


In [None]:

def train_and_eval(base_dir='./data',
                   batch_size=32,
                   num_workers=4,
                   epochs=10,
                   lr=1e-4,
                   use_mixed_precision=False,
                   log_dir='./runs',
                   save_dir='./checkpoints'):
    seed_everything(42)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    Path(save_dir).mkdir(parents=True, exist_ok=True)
    train_img = Path(base_dir) / 'train' / 'train_image'
    train_lbl = Path(base_dir) / 'train' / 'train_label'
    val_img = Path(base_dir) / 'vali' / 'vali_image'
    val_lbl = Path(base_dir) / 'vali' / 'vali_label'
    ttr, vtr = make_transforms()
    train_ds = EmotionDataset(train_img, train_lbl, transforms=ttr)
    val_ds = EmotionDataset(val_img, val_lbl, transforms=vtr)
    sampler = make_weighted_sampler(train_ds)
    class_weights = compute_class_weights(train_ds).to(device)
    train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    model = build_model(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scaler = torch.cuda.amp.GradScaler(enabled=use_mixed_precision)
    run_tag = time.strftime('%Y%m%d-%H%M%S')
    writer = SummaryWriter(log_dir=f"{log_dir}/resnet50_{run_tag}")
    global_step = 0
    best_f1 = -1.0
    best_path = None
    for epoch in range(1, epochs + 1):
        model.train()
        pbar = tqdm(train_loader, desc=f'Epoch {epoch}/{epochs} [Train]', leave=False)
        running_loss = 0.0
        correct = 0
        total = 0
        for imgs, ys in pbar:
            imgs = imgs.to(device, non_blocking=True)
            ys = ys.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=use_mixed_precision):
                logits = model(imgs)
                loss = criterion(logits, ys)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            running_loss += loss.item() * imgs.size(0)
            preds = logits.argmax(1)
            correct += (preds == ys).sum().item()
            total += imgs.size(0)
            writer.add_scalar('loss/train', loss.item(), global_step)
            global_step += 1
            pbar.set_postfix(loss=f"{loss.item():.4f}", acc=f"{(correct/total):.3f}")
        train_epoch_loss = running_loss / total if total > 0 else 0.0
        train_epoch_acc = correct / total if total > 0 else 0.0
        writer.add_scalar('epoch/train_loss', train_epoch_loss, epoch)
        writer.add_scalar('epoch/train_acc', train_epoch_acc, epoch)
        # Validation
        model.eval()
        v_loss = 0.0
        v_correct = 0
        v_total = 0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            pbar_v = tqdm(val_loader, desc=f'Epoch {epoch}/{epochs} [Val]', leave=False)
            for imgs, ys in pbar_v:
                imgs = imgs.to(device, non_blocking=True)
                ys = ys.to(device, non_blocking=True)
                with torch.cuda.amp.autocast(enabled=use_mixed_precision):
                    logits = model(imgs)
                    loss = criterion(logits, ys)
                v_loss += loss.item() * imgs.size(0)
                preds = logits.argmax(1)
                v_correct += (preds == ys).sum().item()
                v_total += imgs.size(0)
                all_preds.extend(preds.detach().cpu().tolist())
                all_labels.extend(ys.detach().cpu().tolist())
        val_epoch_loss = v_loss / max(1, v_total)
        val_epoch_acc = v_correct / max(1, v_total)
        f1 = f1_score(all_labels, all_preds, average='macro') if v_total > 0 else 0.0
        writer.add_scalar('epoch/val_loss', val_epoch_loss, epoch)
        writer.add_scalar('epoch/val_acc', val_epoch_acc, epoch)
        writer.add_scalar('epoch/val_f1_macro', f1, epoch)
        print(f'[Epoch {epoch}] train_loss={train_epoch_loss:.4f} acc={train_epoch_acc:.4f} | val_loss={val_epoch_loss:.4f} acc={val_epoch_acc:.4f} f1_macro={f1:.4f}')
        ckpt_path = Path(save_dir) / f'resnet50_epoch{epoch:02d}_f1{f1:.4f}.pt'
        torch.save(model.state_dict(), ckpt_path)
        if f1 > best_f1:
            best_f1 = f1
            best_path = ckpt_path
    print('=== Validation Classification Report (last epoch) ===')
    print(classification_report(all_labels, all_preds, target_names=['non_depression(0)', 'depression(1)']))
    print(f'Best F1 (macro): {best_f1:.4f}Best checkpoint: {best_path}')
    if len(val_loader) > 0:
        import matplotlib.pyplot as plt
        model.eval()
        sample_imgs, sample_labels = next(iter(val_loader))
        sample_imgs = sample_imgs.to(device)
        with torch.no_grad():
            _ = model(sample_imgs[:1])
        cam = GradCAM(model, model.layer4)
        cams = cam(sample_imgs[:1])
        cams_up = torch.nn.functional.interpolate(cams, size=(224, 224), mode='bilinear', align_corners=False)
        img_denorm, img_heat = overlay_cam_on_image(sample_imgs[0].detach(), cams_up[0].detach())
        fig, axes = plt.subplots(1, 3, figsize=(10, 4))
        axes[0].imshow(img_denorm.permute(1, 2, 0))
        axes[0].set_title('Input'); axes[0].axis('off')
        axes[1].imshow(cams_up[0, 0].cpu(), cmap='jet')
        axes[1].set_title('Grad-CAM'); axes[1].axis('off')
        axes[2].imshow(img_heat.permute(1, 2, 0))
        axes[2].set_title('Overlay'); axes[2].axis('off')
        plt.tight_layout(); plt.show()
    writer.close()


In [None]:

# 모델 학습 실행
train_and_eval(base_dir='./data',
               batch_size=32,
               num_workers=4,
               epochs=10,
               lr=1e-4,
               use_mixed_precision=False,
               log_dir='./runs',
               save_dir='./checkpoints')
