In [None]:
# Hymenoptera 데이터셋에 대한 고급 이미지 분류 실험

# 1. 필요한 라이브러리 임포트 및 환경 설정
from google.colab import drive
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, models, datasets
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
import pandas as pd
from datetime import datetime
import timm

# Google Drive 마운트 및 기본 설정
drive.mount('/content/drive')

# 필요한 패키지 설치
%pip install torch torchvision tqdm pandas matplotlib timm

# 프로젝트 디렉토리 설정
project_path = '/content/drive/MyDrive/Final'
os.chdir(project_path)

# 디렉토리 구조 설정
dirs = {
    'model': os.path.join(project_path, 'model'),
    'result': os.path.join(project_path, 'result'),
    'plots': os.path.join(project_path, 'plots'),
    'data': {
        'train': os.path.join(project_path, 'hymenoptera_data/train'),
        'val': os.path.join(project_path, 'hymenoptera_data/val')
    }
}

# 필요한 디렉토리 생성
for dir_path in [dirs['model'], dirs['result'], dirs['plots']]:
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

# 장치 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 2. 데이터 전처리 및 증강 설정
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(30),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

# Mixup 데이터 증강
def mixup_data(x, y, alpha=1.0):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

# 데이터셋 로드
train_dataset = datasets.ImageFolder(dirs['data']['train'], transform=data_transforms['train'])
val_dataset = datasets.ImageFolder(dirs['data']['val'], transform=data_transforms['val'])

# 클래스 정보 출력
print("\n데이터셋 정보:")
print("클래스:", train_dataset.classes)
print("클래스 수:", len(train_dataset.classes))
print("학습 데이터 수:", len(train_dataset))
print("검증 데이터 수:", len(val_dataset))

# 데이터로더 생성
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

# 3. 모델 설정 및 학습
def setup_model(model_name, num_classes):
    """모델 설정"""
    if model_name.startswith('efficientnet'):
        model = timm.create_model(model_name, pretrained=True)
        model.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(model.classifier.in_features, num_classes)
        )
    elif model_name == 'vit':
        model = timm.create_model('vit_base_patch16_224', pretrained=True)
        model.head = nn.Linear(model.head.in_features, num_classes)
    else:
        raise ValueError(f'Unknown model: {model_name}')
    
    return model.to(device)

def train_model(model, criterion, optimizer, scheduler, model_name, num_epochs=20, mixup=True):
    """모델 학습 함수"""
    best_acc = 0.0
    results = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': []
    }
    
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('-' * 10)
        
        # 학습 단계
        model.train()
        running_loss = 0.0
        running_corrects = 0
        
        for inputs, labels in tqdm(train_loader, desc='Training'):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            if mixup:
                inputs, labels_a, labels_b, lam = mixup_data(inputs, labels)
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            if mixup:
                loss = lam * criterion(outputs, labels_a) + (1 - lam) * criterion(outputs, labels_b)
            else:
                loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            if not mixup:
                running_corrects += torch.sum(preds == labels.data)
        
        if scheduler:
            scheduler.step()
        
        epoch_loss = running_loss / len(train_dataset)
        epoch_acc = running_corrects.double() / len(train_dataset) if not mixup else 0.0
        results['train_loss'].append(epoch_loss)
        results['train_acc'].append(epoch_acc.item() if not mixup else 0.0)
        
        print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        
        # 검증 단계
        model.eval()
        running_loss = 0.0
        running_corrects = 0
        
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc='Validation'):
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
        
        epoch_loss = running_loss / len(val_dataset)
        epoch_acc = running_corrects.double() / len(val_dataset)
        results['val_loss'].append(epoch_loss)
        results['val_acc'].append(epoch_acc.item())
        
        print(f'Val Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        
        # 최고 성능 모델 저장
        if epoch_acc > best_acc:
            best_acc = epoch_acc
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': epoch_loss,
                'acc': epoch_acc,
            }, os.path.join(dirs['model'], f'{model_name}_best_{timestamp}.pt'))
    
    # 최종 결과 저장
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    pd.DataFrame(results).to_csv(os.path.join(dirs['result'], f'{model_name}_results_{timestamp}.csv'))
    
    return results

# 4. 모델 실험 실행
models_to_test = [
    'efficientnet_b0',
    'efficientnet_b3',
    'vit'
]

results = {}

for model_name in models_to_test:
    print(f"\n{model_name} 모델 학습 시작")
    model = setup_model(model_name, len(train_dataset.classes))
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
    
    model_results = train_model(
        model,
        criterion,
        optimizer,
        scheduler,
        model_name,
        num_epochs=20,
        mixup=True
    )
    
    results[model_name] = model_results

# 5. 결과 시각화
def plot_training_results(results):
    """학습 결과 시각화"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 2, 1)
    for model_name, history in results.items():
        plt.plot(history['val_acc'], label=model_name)
    plt.title('검증 정확도')
    plt.xlabel('에폭')
    plt.ylabel('정확도')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    for model_name, history in results.items():
        plt.plot(history['val_loss'], label=model_name)
    plt.title('검증 손실')
    plt.xlabel('에폭')
    plt.ylabel('손실')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(dirs['plots'], f'training_results_{timestamp}.png'))
    plt.show()

# 결과 시각화 실행
plot_training_results(results)

# 최종 결과 출력
print("\n=== 최종 결과 ===")
for model_name, history in results.items():
    final_acc = history['val_acc'][-1]
    print(f"{model_name}: 최종 검증 정확도 = {final_acc:.4f}")
