In [None]:
!git clone https://github.com/cydonia999/VGGFace2-pytorch.git

In [None]:
%cd VGGFace2-pytorch

In [None]:
!pip install facenet-pytorch

In [None]:
#import 부분

from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch
import torch.nn as nn
import pandas as pd
import os
import json
import random
import pickle
from functools import partial
from models.resnet import resnet50
from facenet_pytorch import MTCNN
from torch.utils.data import WeightedRandomSampler
#정확도 및 mae 임포트
from torchmetrics import Accuracy, MeanAbsoluteError
import numpy as np

In [None]:
print("PyTorch 버전:", torch.__version__)
print("CUDA 사용 가능:", torch.cuda.is_available())

if torch.cuda.is_available():
    print("현재 CUDA 디바이스 인덱스:", torch.cuda.current_device())
    print("CUDA 디바이스 이름:", torch.cuda.get_device_name(0))
else:
    print("CUDA를 사용할 수 없습니다.")

In [None]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# loss는 그냥 비율 그대로 1:1 로 하고 진행할것 
# MTCNN 전처리 + custom transform 
# augment_tranform 진행 ( 나이대별 성별별... )
# 증강 수준은 그대로 진행도 될 듯함 gpu꼭 쓰게 하고... 
# WeightedRandomSampler적용 추가 버전임

In [None]:
# Data Processing class 만들기 파이프라인... 좀 정리하고 싶은데 일단 적읍시다

class DataProcessing(Dataset) :
    
    #생성자 생성
    def __init__(self, image_dir, label_dir, categories, transform=None, emotion_return=False,
                 mode='train', augment_transform=None, mtcnn=None) :
        
        self.datalist = [] #이제 전체데이터 쓸거라 data로 변수명 바꿈
        self.transform = transform
        self.augment_transform = augment_transform
        self.mtcnn = mtcnn
        self.label_map = { cat :  idx for idx, cat in enumerate(categories)}
        self.emotion_return = emotion_return
        self.mode=mode
        self.age_min = 10
        self.age_max = 60
        
        for category in categories :
            
            json_path = os.path.join(label_dir, f'{self.mode}_{category}.json')
            img_folder = os.path.join(image_dir, category)
            
            with open(json_path, 'r', encoding='euc-kr') as f :
                label_data = json.load(f)
            
            for row in label_data :
                
                filename = row['filename']
                img_path = os.path.join(img_folder, filename)
                
                if os.path.isfile(img_path) :
                    
                    age = row.get('age') 
                    
                    #60대면 skip
                    if age is not None and age >= 60 :
                        continue 
                    
                    #나이 정규화
                    if age is not None : 
                        age_norm = (age - self.age_min) / (self.age_max - self.age_min) 
                    #결측값 (있지않지만 만약을 대비)
                    else :
                        age_norm = 0.0
                        
                    data = {
                        'img_path' : img_path,
                        'category' : category,
                        'age' : age_norm,
                        'raw_age' : age,
                        'gender' : row.get('gender')
                    }
                    
                    self.datalist.append(data)
    
    
    def __len__(self) :
        
        return len(self.datalist)
    
    def __getitem__(self, idx) :
        
        data_item = self.datalist[idx]
        image = Image.open(data_item['img_path']).convert('RGB')
        
        if self.mtcnn is not None :
            
            face_img = self.mtcnn(image)
            
            if face_img is None :
                image = image.resize((224,224))
                face_img = transforms.ToTensor()(image)
                
        else :
            
            face_img = transforms.ToTensor()(image)
            
            
            
            
        age_norm = data_item['age']
        
        augment_flag = False
        if (0 <= age_norm <= 0.18) or (0.6 <= age_norm <= 0.78) or (0.8 <= age_norm <= 0.98):
            augment_flag = True

        if self.mode == 'train':
            if augment_flag and self.augment_transform is not None:
                face_img = self.augment_transform(face_img)
            elif self.transform is not None:
                face_img = self.transform(face_img)
        else:
            if self.transform is not None:
                face_img = self.transform(face_img)

        
        #emotion_label = torch.tensor(self.label_map[sample['category']], dtype=torch.long) 수정1
        age=torch.tensor(data_item['age'], dtype=torch.float32)
        gender = torch.tensor(1 if data_item['gender']=='남' else 0, dtype=torch.long)

        # return image, emotion_label, age, gender 수정1

        if self.emotion_return :
            emotion_label = torch.tensor(self.label_map[data_item['category']], dtype=torch.long)
            return face_img, emotion_label, age, gender

        else :
            return face_img, age, gender
       

In [None]:
def get_sampling_weights(dataset):
    """
    데이터셋 내 각 샘플에 대해 클래스 불균형을 보정하는 가중치 부여
    - 10대 남: 3.0
    - 10대 여: 2.0
    - 40대: 1.5
    - 50대 남: 2.0
    - 50대 여: 1.5
    - 그 외: 1.0 (20대 30대)
    """
    weights = []
    for sample in dataset.datalist:
        age = sample.get('raw_age')
        gender = sample.get('gender')

        if age is None:
            weights.append(1.0)
            continue

        age_group = (age // 10) * 10  # 10대, 20대, ...
        if age_group == 10:
            weight = 3.0 if gender == '남' else 2.0
        elif age_group == 40:
            weight = 1.5
        elif age_group == 50:
            weight = 2.0 if gender == '남' else 1.5
        else:
            weight = 1.0
        weights.append(weight)
    return weights

In [None]:
mtcnn = MTCNN(image_size=224, margin=20, min_face_size=20,thresholds=[0.6, 0.7, 0.7], device=device)

In [None]:
#transform
transform = transforms.Compose([
    #transforms.Resize((224,224)), #이미지 사이즈 조정
    transforms.Normalize([0.485, 0.456,0.406], #RGB평균
                         [0.229,0.224,0.225])  #RGB 표준편차
])

In [None]:
augment_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.3, contrast=0.3),
    transforms.RandomRotation(10),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [None]:

categories=['anger','happy','panic','sadness']

base_dir = os.path.dirname(os.path.abspath(__file__))

train_image_dir = os.path.join(base_dir,'Data','img', 'train')
train_label_dir = os.path.join(base_dir,'Data','label', 'train')
val_image_dir = os.path.join(base_dir,'Data','img', 'val')
val_label_dir = os.path.join(base_dir,'Data','label', 'val')

train_data_load=DataProcessing(train_image_dir,train_label_dir,categories,transform=transform,augment_transform=augment_transform, mode='train',mtcnn=mtcnn )
val_data_load = DataProcessing(val_image_dir, val_label_dir,categories, transform=transform, mode='val',mtcnn=mtcnn)

In [None]:
weights = get_sampling_weights(train_data_load)

# 샘플러 정의
sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

In [None]:
train_loader = DataLoader(
    train_data_load,
    batch_size=32,
    sampler=sampler,
    num_workers=2  # 시스템에 따라 조절
)
val_loader= DataLoader(val_data_load, batch_size=32, shuffle=True)

In [None]:
#====================모델 가져오기 2번째 custum v1.2 ==============
model_v2_2_1_4 = resnet50()

model_v2_2_1_4.fc = nn.Sequential(
    nn.Linear(model_v2_2_1_4.fc.in_features,256),
    nn.BatchNorm1d(256),
    nn.ReLU(),
    nn.Dropout(0.4),
    
    nn.Linear(256,128),
    nn.BatchNorm1d(128),
    nn.ReLU(),
    nn.Dropout(0.3),
    
    nn.Linear(128,64),
    nn.BatchNorm1d(64),
    nn.ReLU(),
    nn.Dropout(0.2),
    
    nn.Linear(64,3)
)

In [None]:
#================ 3.가중치 불러오기 ===============
base_dir = os.path.dirname(os.path.abspath(__file__))
weight_path = os.path.join(base_dir, 'resnet50_ft_weight.pkl')

with open(weight_path, 'rb') as f:
    state_dict = pickle.load(f)
    

for key in state_dict:
    if isinstance(state_dict[key], np.ndarray):
        state_dict[key] = torch.from_numpy(state_dict[key])

model_v2_2_1_4.load_state_dict(state_dict, strict=False)


# 5. 디바이스에 올리기
model_v2_2_1_4 = model_v2_2_1_4.to(device)

In [None]:
#================================================
criterion_age = nn.MSELoss()
criterion_gender = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_v2_2_1_4.parameters(), lr=1e-4)
num_epochs = 10 #수정가능

In [None]:
def evaluate(model, data_loader, device, criterion_age, criterion_gender) :
    model.eval()
    total_loss=0
    accuracy = Accuracy(task='binary').to(device)
    mae = MeanAbsoluteError().to(device) 
    
    with torch.no_grad():
        for images, ages, genders in data_loader :
            images = images.to(device)
            ages  = ages.to(device)
            genders = genders.to(device)
            
            outputs = model(images)
            predicted_age = outputs[:,0]
            predicted_gender_logits = outputs[:,1:3]
            
            loss_age = criterion_age(predicted_age, ages)
            loss_gender = criterion_gender(predicted_gender_logits, genders)
            loss = loss_age + loss_gender
            total_loss += loss.item()
            
            pred = torch.argmax(predicted_gender_logits, dim=1)
            accuracy.update(pred, genders)
            mae.update(predicted_age, ages)
    
    avg_loss = total_loss / len(data_loader)        
    return avg_loss, accuracy.compute(), mae.compute()

In [None]:
#=============모델 저장을 위한 빈 리스트 생성=============
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
train_maes = []
val_maes = []

In [None]:
#================Early Stopping======================
class EarlyStopping:
    def __init__(self, patience=5, verbose=False):
        self.patience = patience      # 개선 없을 때 참을 에폭 수
        self.verbose = verbose        # 멈출 때 출력 여부
        self.counter = 0              # 개선 없을 때 카운트
        self.best_loss = np.Inf       # 최저 검증 손실 저장
        self.early_stop = False       # 멈춤 여부
        self.best_model_state = None  # 최적 모델 가중치 저장

    def __call__(self, val_loss, model):
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.best_model_state = model.state_dict()
            self.counter = 0
            if self.verbose:
                print(f'Validation loss improved to {val_loss:.4f}. Saving model.')
        else:
            self.counter += 1
            if self.verbose:
                print(f'No improvement for {self.counter} epochs.')
            if self.counter >= self.patience:
                if self.verbose:
                    print('Early stopping triggered.')
                self.early_stop = True

In [None]:
#===================== 학습 ===========================
early_stopping = EarlyStopping(patience=5, verbose=True)

for epoch in range(num_epochs):
    model_v2_2_1_4.train()
    epoch_loss = 0

    for images, ages, genders in train_loader:
        
        images = images.to(device)
        ages = ages.to(device)
        genders = genders.to(device)

        outputs = model_v2_2_1_4(images)
        predicted_age = outputs[:, 0]
        predicted_gender_logits = outputs[:, 1:3]

        loss_age = criterion_age(predicted_age, ages)
        loss_gender = criterion_gender(predicted_gender_logits, genders)
        loss = loss_age + loss_gender

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    #avg_loss = epoch_loss / len(train_loader)

    # === 평가 ===
    train_loss, train_acc, train_mae = evaluate(model_v2_2_1_4, train_loader, device, criterion_age, criterion_gender)
    val_loss, val_acc, val_mae = evaluate(model_v2_2_1_4, val_loader, device, criterion_age, criterion_gender)

    # === 저장 ===
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_acc.item())
    val_accuracies.append(val_acc.item())
    train_maes.append(train_mae.item())
    val_maes.append(val_mae.item())

    # ==== 출력 =====
    print(f'Epoch [{epoch+1}/{num_epochs}]')
    print(f'Train Loss : {train_loss : .4f}, Train Gender Accuracy : {train_acc : .4f}, Train AGE MAE : {train_mae : .4f}')
    print(f'Validation Loss : {val_loss : .4f}, Validation Gender Accuracy : {val_acc : .4f}, Validation AGE MAE : {val_mae : .4f}')
    
    early_stopping(val_loss, model_v2_2_1_4)

    if early_stopping.early_stop:
        print("Early stopping triggered")
        break

# 가장 좋은 가중치로 복원
model_v2_2_1_4.load_state_dict(early_stopping.best_model_state)

In [None]:
base_dir = os.path.dirname(os.path.abspath(__file__))
pth_save_path = os.path.join(base_dir, 'pth_pkl', 'model_raw_weights_v2_2_1_4.pth')

try:
    torch.save(model_v2_2_1_4.state_dict(), pth_save_path)
    print(f'모델 저장 완료 → {pth_save_path}')
except Exception as e:
    print(f'모델 저장 실패: {e}')

In [None]:
history = {
    'train_losses': train_losses,
    'val_losses': val_losses,
    'train_accuracies': train_accuracies,
    'val_accuracies': val_accuracies,
    'train_maes': train_maes,
    'val_maes': val_maes
}

In [None]:
base_dir = os.path.dirname(os.path.abspath(__file__))
pkl_save_path = os.path.join(base_dir, 'pth_pkl', 'model_raw_v2_2_1_4.pkl')

try:
    with open(pkl_save_path, "wb") as f:
        pickle.dump(history, f)
    print(f'학습 기록이 성공적으로 저장되었습니다 : {pkl_save_path}')
except Exception as e:
    print(f'학습 기록 저장 실패: {e}')