In [None]:
import os
import time
import random
import optuna
import timm
import torch
import albumentations as A
import pandas as pd
import numpy as np
import torch.nn as nn
from albumentations.pytorch import ToTensorV2
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, Subset
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import StratifiedKFold

import cv2
import json
import gc 
from datetime import datetime

# 시드 고정
def set_seed(SEED=42):
    os.environ['PYTHONHASHSEED'] = str(SEED)
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.benchmark = True

train_img_path = "../data/train/"
train_csv_path = "../data/train.csv"

test_img_path = "../data/test/"
sample_path = "../data/sample_submission.csv"



# 데이터셋 클래스 정의
class ImageDataset(Dataset):
    def __init__(self, data, path, transform=None):
        """
        Args:
            data: DataFrame 또는 CSV 파일 경로
            path: 이미지 파일들이 있는 디렉토리 경로
            transform: 이미지 변환 함수
        """
        if isinstance(data, str):
            self.df = pd.read_csv(data).values
        else:
            self.df = data.values
            
        self.path = path
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        name, target = self.df[idx]
        img = np.array(Image.open(os.path.join(self.path, name)))
        
        # RGB가 아닌 이미지 처리
        if len(img.shape) == 2:
            img = np.stack([img] * 3, axis=-1)
        elif img.shape[2] == 4:
            img = img[:, :, :3]
            
        if self.transform:
            img = self.transform(image=img)['image']
        return img, target

def get_transforms(img_size):
    trn_transform = A.Compose([
        A.OneOf([
        A.GridDistortion(num_steps=5, distort_limit=0.3, interpolation=1, border_mode=4, value=None, mask_value=None, always_apply=True, p=1),
        A.ElasticTransform(always_apply=True, p=1, alpha=1.0, sigma=50.0, alpha_affine=50.0, interpolation=0, border_mode=1, value=(0, 0, 0), mask_value=None, approximate=False),
        A.OpticalDistortion(always_apply=True, p=1, distort_limit=(-0.3, -0.1)),
        A.OpticalDistortion(always_apply=True, p=1, distort_limit=(0.1, 0.3)),
        ], p=0.85),
        A.SomeOf([
            A.RandomBrightnessContrast(brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), p=1),
            A.HueSaturationValue(hue_shift_limit=15, sat_shift_limit=25, val_shift_limit=20, p=1),
            A.MultiplicativeNoise(p=1, multiplier=(1, 1.5), per_channel=True),
            A.Equalize(p=1, mode='cv', by_channels=True),
        ], n=2, p=0.85),
        A.OneOf([
            A.Rotate(limit=(10, 30), border_mode=cv2.BORDER_CONSTANT, p=1),
            A.Rotate(limit=(150, 170), border_mode=cv2.BORDER_CONSTANT, p=1),
            A.Rotate(limit=(190, 210), border_mode=cv2.BORDER_CONSTANT, p=1),
            A.Rotate(limit=(330, 350), border_mode=cv2.BORDER_CONSTANT, p=1),
        ], p=1),
        A.CoarseDropout(p=0.5, max_holes=40, max_height=15, max_width=15, min_holes=8, min_height=8, min_width=8),
        A.Equalize(p=0.5, mode='cv', by_channels=True),
        A.OneOf([
            A.Blur(blur_limit=(3, 4), p=1),
            A.MotionBlur(blur_limit=(3, 5), p=1),
            A.Downscale(scale_min=0.455, scale_max=0.5, interpolation=2, p=1),
        ], p=0.5),
        A.GaussNoise(var_limit=(100, 800), per_channel=True, p=0.5),
        A.RandomRotate90(),
        A.CLAHE(p=0.5),
        A.MotionBlur(p=0.2),
        A.RandomBrightnessContrast(p=0.2),
        A.Resize(height=img_size, width=img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

    tst_transform = A.Compose([
        A.Resize(height=img_size, width=img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    return trn_transform, tst_transform




def calculate_class_metrics(y_true, y_pred, num_classes=17):
    """각 클래스별 accuracy와 f1 score를 계산"""
    class_metrics = {}
    
    for class_idx in range(num_classes):
        # 해당 클래스에 대한 이진 레이블 생성
        y_true_binary = (y_true == class_idx)
        y_pred_binary = (y_pred == class_idx)
        
        # 클래스별 metrics 계산
        class_acc = accuracy_score(y_true_binary, y_pred_binary)
        class_f1 = f1_score(y_true_binary, y_pred_binary, average='binary')
        support = np.sum(y_true_binary)
        
        class_metrics[f'class_{class_idx}'] = {
            'accuracy': class_acc,
            'f1_score': class_f1,
            'support': support
        }
        
    # 전체 metrics
    overall_acc = accuracy_score(y_true, y_pred)
    overall_f1 = f1_score(y_true, y_pred, average='macro')
    
    class_metrics['overall'] = {
        'accuracy': overall_acc,
        'f1_score': overall_f1
    }
    
    return class_metrics    
        
        
        
def train_one_epoch(loader, model, optimizer, loss_fn, device):
    model.train()
    train_loss = 0
    preds_list = []
    targets_list = []

    pbar = tqdm(loader, leave=False)
    for image, targets in pbar:
        image = image.to(device)
        targets = targets.to(device)

        model.zero_grad(set_to_none=True)
        preds = model(image)
        loss = loss_fn(preds, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        preds_list.extend(preds.argmax(dim=1).detach().cpu().numpy())
        targets_list.extend(targets.detach().cpu().numpy())
        pbar.set_description(f"Loss: {loss.item():.4f}")

    train_loss /= len(loader)
    train_acc = accuracy_score(targets_list, preds_list)
    train_f1 = f1_score(targets_list, preds_list, average='macro')

    return {
        "loss": train_loss,
        "acc": train_acc,
        "f1": train_f1,
    }

def valid_one_epoch(loader, model, loss_fn, device):
    model.eval()
    valid_loss = 0
    preds_list = []
    targets_list = []

    with torch.no_grad():
        for image, targets in loader:
            image = image.to(device)
            targets = targets.to(device)

            preds = model(image)
            loss = loss_fn(preds, targets)

            valid_loss += loss.item()
            preds_list.extend(preds.argmax(dim=1).detach().cpu().numpy())
            targets_list.extend(targets.detach().cpu().numpy())

    valid_loss /= len(loader)
    metrics = calculate_class_metrics(np.array(targets_list), np.array(preds_list))

    return {
        "loss": valid_loss,
        "metrics": metrics
    }




def evaluate_best_params(params, device):
    """베스트 파라미터로 k-fold 검증 수행"""
    print("\nEvaluating best parameters with k-fold validation...")
    
    if params['model_name'] == "swinv2_tiny_window8_256":
        img_size = 256
    elif params['model_name'] == "tf_efficientnet_b3.ns_jft_in1k":
        img_size = 300
        
    trn_transform, _ = get_transforms(img_size)
    
    # -----------------------------------------------------------------------------
    # K-fold 검증 준비
    n_splits = 5
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    
    # 전체 데이터셋 로드
    full_dataset = ImageDataset(train_csv_path, train_img_path, transform=trn_transform)
    
    # 데이터와 라벨 분리
    data = pd.read_csv(train_csv_path)
    X = np.arange(len(data))
    y = data['target'].values
    
    # 각 fold의 결과를 저장할 리스트
    fold_results = []
    
    # K-fold 교차 검증
    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y), 1):
        print(f"\n{'='*20} Fold {fold} {'='*20}")
        
        # 데이터셋 분할
        train_dataset = Subset(full_dataset, train_idx)
        val_dataset = Subset(full_dataset, val_idx)
        
        train_loader = DataLoader(
            train_dataset,
            batch_size=params['batch_size'],
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=params['batch_size'],
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )
        
        # 모델 설정
        model = timm.create_model(
            params['model_name'],
            pretrained=True,
            num_classes=17
        ).to(device)
        
        optimizer_class = getattr(torch.optim, params['optimizer'])
        optimizer = optimizer_class(model.parameters(), 
                                  lr=params['lr'], 
                                  weight_decay=params['weight_decay'])
        
        loss_fn = nn.CrossEntropyLoss()
        
        # -----------------------------------------------------------------------------
        best_val_f1 = 0
        best_epoch_metrics = None
        patience = 5
        patience_counter = 0
        EPOCH = 30
        
        # 학습
        for epoch in range(EPOCH):
            train_ret = train_one_epoch(train_loader, model, optimizer, loss_fn, device)
            val_ret = valid_one_epoch(val_loader, model, loss_fn, device)
            
            current_val_f1 = val_ret['metrics']['overall']['f1_score']
            
            if current_val_f1 > best_val_f1:
                best_val_f1 = current_val_f1
                best_epoch_metrics = val_ret['metrics']
                patience_counter = 0
            else:
                patience_counter += 1
            
            if not epoch % 5:
                print(f"Epoch {epoch+1}: Val F1 = {current_val_f1:.4f}")
            
            if patience_counter >= patience:
                print("Early stopping!")
                break
            

                    
        
        # 현재 fold의 best 결과 출력
        print(f"\nBest Results for Fold {fold}:")
        print("\nClass-wise Metrics:")
        for class_idx in range(17):
            metrics = best_epoch_metrics[f'class_{class_idx}']
            print(f"\nClass {class_idx}:")
            print(f"Accuracy: {metrics['accuracy']:.4f}")
            print(f"F1 Score: {metrics['f1_score']:.4f}")
            print(f"Support: {metrics['support']}")
        
        # fold 결과 저장
        fold_results.append({
            'fold': fold,
            'metrics': best_epoch_metrics
        })
    
    # 모든 fold의 평균 성능 계산 및 출력
    print("\n" + "="*50)
    print("Average Performance Across All Folds:")
    
    # 클래스별 평균 성능
    class_avg_metrics = {}
    for class_idx in range(17):
        accuracies = [fold['metrics'][f'class_{class_idx}']['accuracy'] for fold in fold_results]
        f1_scores = [fold['metrics'][f'class_{class_idx}']['f1_score'] for fold in fold_results]
        
        avg_acc = np.mean(accuracies)
        avg_f1 = np.mean(f1_scores)
        std_acc = np.std(accuracies)
        std_f1 = np.std(f1_scores)
        
        print(f"\nClass {class_idx}:")
        print(f"Accuracy: {avg_acc:.4f} (±{std_acc:.4f})")
        print(f"F1 Score: {avg_f1:.4f} (±{std_f1:.4f})")
    
    # 결과를 DataFrame으로 저장
    results_dict = {
        'fold': [],
        'class': [],
        'accuracy': [],
        'f1_score': [],
        'support': []
    }
    
    for fold_result in fold_results:
        fold_num = fold_result['fold']
        metrics = fold_result['metrics']
        
        for class_idx in range(17):
            class_metrics = metrics[f'class_{class_idx}']
            results_dict['fold'].append(fold_num)
            results_dict['class'].append(class_idx)
            results_dict['accuracy'].append(class_metrics['accuracy'])
            results_dict['f1_score'].append(class_metrics['f1_score'])
            results_dict['support'].append(class_metrics['support'])
    
    results_df = pd.DataFrame(results_dict)
    now = datetime.now()
    results_df.to_csv(f'best_params_fold_class_metrics_{now.month:02}{now.day:02}{now.hour:02}{now.minute:02}.csv', index=False)
    print("\nDetailed metrics saved to 'best_params_fold_class_metrics.csv'")
    
    return fold_results



def generate_pseudo_labels(model, loader, device):
    model.eval()
    pseudo_labels = []
    confidences = []
    
    with torch.no_grad():
        for image, _ in tqdm(loader, desc="Generating pseudo labels"):
            image = image.to(device)
            outputs = model(image)
            probs = torch.softmax(outputs, dim=1)
            conf, preds = torch.max(probs, dim=1)
            pseudo_labels.extend(preds.cpu().numpy())
            confidences.extend(conf.cpu().numpy())
            
    return np.array(pseudo_labels), np.array(confidences)

def objective(trial , device):
    try:
        
        print(f"GPU Memory before trial: {torch.cuda.memory_allocated(device)/(1024**2):.2f}MB")
        # 하이퍼파라미터 탐색 공간 정의
        params = {
            # 'model_name': trial.suggest_categorical('model_name', ['swinv2_tiny_window8_256', 'efficientnet_b0', 'resnet18']),
            'model_name': trial.suggest_categorical('model_name', ['tf_efficientnet_b3.ns_jft_in1k']),
            # 'img_size': trial.suggest_categorical('img_size', [256]),
            'batch_size': trial.suggest_categorical('batch_size', [8,16,32]),
            'lr': trial.suggest_float('lr', 1e-5, 1e-3, log=True),
            'weight_decay' : trial.suggest_float('weight_decay', 1e-6, 1e-2, log=True),
            'dropout_rate' : trial.suggest_float('dropout_rate', 0.0, 0.3),
            'optimizer': trial.suggest_categorical('optimizer', ['Adam', 'AdamW']),
        }
    
        
        # 모델 선택 및 입력 이미지 크기 설정
        if params['model_name'] == "swinv2_tiny_window8_256":
            params['img_size'] = 256
        elif params['model_name'] == "tf_efficientnet_b3.ns_jft_in1k":
            params['img_size'] = 300
            
        # 데이터셋 준비
        trn_transform, val_transform = get_transforms(params['img_size'])
        
        # -----------------------------------------------------------------------------
        # K-fold 검증 준비
        n_splits = 5
        skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
        
        # 전체 데이터셋 로드
        full_dataset = ImageDataset(train_csv_path, train_img_path, transform=trn_transform)
        
        # 데이터와 라벨 분리
        data = pd.read_csv(train_csv_path)
        X = np.arange(len(data))
        y = data['target'].values
        
        fold_scores = []
        
        # K-fold 교차 검증
        for fold, (train_idx, val_idx) in enumerate(skf.split(X, y), 1):
            print(f"\nFold {fold}")
            
            # 데이터셋 분할
            train_dataset = Subset(full_dataset, train_idx)
            val_dataset = Subset(full_dataset, val_idx)
            
            train_loader = DataLoader(
                train_dataset,
                batch_size=params['batch_size'],
                shuffle=True,
                num_workers=4,
                pin_memory=True
            )
            
            val_loader = DataLoader(
                val_dataset,
                batch_size=params['batch_size'],
                shuffle=False,
                num_workers=4,
                pin_memory=True
            )
            
            # 모델 설정
            model = timm.create_model(
                params['model_name'],
                pretrained=True,
                num_classes=17
            ).to(device)
            
            # Optimizer 설정
            optimizer_class = getattr(torch.optim, params['optimizer'])
            optimizer = optimizer_class(model.parameters(), lr=params['lr'])
            
            loss_fn = nn.CrossEntropyLoss()
            
            # -----------------------------------------------------------------------------
            best_val_f1 = 0
            patience = 5
            patience_counter = 0
            EPOCH = 30
            
            # 학습
            for epoch in range(EPOCH):
                train_ret = train_one_epoch(train_loader, model, optimizer, loss_fn, device)
                val_ret = valid_one_epoch(val_loader, model, loss_fn, device)
                current_val_f1 = val_ret['metrics']['overall']['f1_score']
                
                    
                if not epoch % 10:
                    print(f"Epoch {epoch+1}: Train F1 = {train_ret['f1']:.4f}, Val F1 = {current_val_f1:.4f}")
                
                if current_val_f1 > best_val_f1:
                    best_val_f1 = current_val_f1
                    patience_counter = 0
                else:
                    patience_counter += 1
                    
                if patience_counter >= patience:
                    print("Early stopping!")
                    break
                
                if current_val_f1 < 0.5:
                    print("Early stopping!")
                    break
         
            fold_scores.append(best_val_f1)
            
        print(f"GPU Memory after trial: {torch.cuda.memory_allocated(device)/(1024**2):.2f}MB")    
        
        mean_f1 = np.mean(fold_scores)
        return mean_f1
    
    finally:
        torch.cuda.empty_cache()
        gc.collect()
        
        
        
        
        
def main():
    set_seed()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if device=="cuda": torch.cuda.empty_cache()
    
    # Optuna를 사용한 하이퍼파라미터 최적화
    study = optuna.create_study(
        direction='maximize',
        storage=optuna.storages.JournalStorage(
            optuna.storages.JournalFileStorage("study.log")  # 로그 파일로 저장
        ),
        load_if_exists=True
    )
    study.optimize(lambda trial: objective(trial, device), n_trials=50)
    
    best_params = study.best_params
    print("Best parameters:", best_params)
    
    
    # Best parameters 저장
    now = datetime.now()
    with open(f"best_params_{now.month:02}{now.day:02}{now.hour:02}{now.minute:02}.json", "w") as f:
        json.dump(best_params, f)
    print("Best parameters saved to best_params.json")
    
    
    # Best parameters 불러오기
    # with open("best_params.json", "r") as f:
    #     loaded_params = json.load(f)
    # print("Loaded best parameters:", loaded_params)
    
    # 베스트 파라미터로 k-fold 검증 수행
    fold_results = evaluate_best_params(best_params, device)
    
    
    
    # 최적의 하이퍼파라미터로 최종 모델 학습

    if best_params['model_name'] == "swinv2_tiny_window8_256":
        img_size = 256
    elif best_params['model_name'] == "tf_efficientnet_b3.ns_jft_in1k":
        img_size = 300
            
    # img_size = best_params['img_size']
    trn_transform, tst_transform = get_transforms(img_size)
    
    # 전체 학습 데이터셋으로 모델 학습
    trn_dataset = ImageDataset(train_csv_path, train_img_path, transform=trn_transform)
    tst_dataset = ImageDataset(sample_path, test_img_path, transform=tst_transform)
    
    trn_loader = DataLoader(
        trn_dataset,
        batch_size=best_params['batch_size'],
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    
    tst_loader = DataLoader(
        tst_dataset,
        batch_size=best_params['batch_size'],
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    # 최종 모델 학습
    model = timm.create_model(
        best_params['model_name'],
        pretrained=True,
        num_classes=17
    ).to(device)
    
    optimizer_class = getattr(torch.optim, best_params['optimizer'])
    optimizer = optimizer_class(model.parameters(), lr=best_params['lr'])
    loss_fn = nn.CrossEntropyLoss()
    
    
    # -----------------------------------------------------------------------------
    # 초기 학습
    print("Training initial model...")
    for epoch in range(30):
        ret = train_one_epoch(trn_loader, model, optimizer, loss_fn, device)
        print(f"Epoch {epoch+1}: F1 = {ret['f1']:.4f}")
    
    # Pseudo Labeling
    print("\nGenerating pseudo labels...")
    pseudo_labels, confidences = generate_pseudo_labels(model, tst_loader, device)
    
    # 높은 신뢰도의 예측만 선택 (임계값: 0.9)
    confidence_threshold = 0.9
    high_confidence_mask = confidences > confidence_threshold
    
    # Pseudo label 데이터 생성
    test_data = pd.read_csv(sample_path)
    pseudo_df = pd.DataFrame({
        'ID': test_data['ID'][high_confidence_mask],
        'target': pseudo_labels[high_confidence_mask]
    })
    
    # Pseudo label 데이터로 추가 학습
    print(f"\nFound {len(pseudo_df)} high-confidence pseudo labels")
    if len(pseudo_df) > 0:
        pseudo_dataset = ImageDataset(pseudo_df, test_img_path, transform=trn_transform)
        pseudo_loader = DataLoader(
            pseudo_dataset,
            batch_size=best_params['batch_size'],
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )
        
        # -----------------------------------------------------------------------------
        print("Training with pseudo labels...")
        for epoch in range(20):
            ret = train_one_epoch(pseudo_loader, model, optimizer, loss_fn, device)
            print(f"Pseudo Label Epoch {epoch+1}: F1 = {ret['f1']:.4f}")
    
    
    # 최종 예측 및 저장
    print("\nGenerating final predictions...")
    final_preds = []
    model.eval()
    with torch.no_grad():
        for image, _ in tqdm(tst_loader):
            image = image.to(device)
            preds = model(image)
            final_preds.extend(preds.argmax(dim=1).cpu().numpy())
    
    pred_df = pd.DataFrame({
        'ID': test_data['ID'],
        'target': final_preds
    })
    
    pred_df.to_csv(f"predictions_{now.month:02}{now.day:02}{now.hour:02}{now.minute:02}.csv", index=False)
    print("\nPredictions saved to 'predictions.csv'")

if __name__ == "__main__":
    main()