In [None]:
import os
import gc
import shutil
from PIL import Image, UnidentifiedImageError
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm

# 학습 함수
def train_model():

    # 경로 지정
    base_dir = "/workspace/dataset" 
    train_dir = os.path.join(base_dir, "train")
    val_dir = os.path.join(base_dir, "val")
    
    batch_size = 64  # 네트워크에 전달할 이미지 수 
    num_epochs = 20  # 반복 학습 수
    learning_rate = 1e-3 
    patience = 3   
    delta = 0.001 
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 데이터 전리리
    transform_train = transforms.Compose([
        transforms.Resize((512, 512)), # 이미지 크기 변환
        transforms.RandomHorizontalFlip(), # 이미지 좌우 반전
        transforms.RandomRotation(15), # 이미지를 ~15도 ~ +15도 범위로 랜덤 회전
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1), # 밝기, 대비, 채도를 변경
        transforms.RandomPerspective(distortion_scale=0.2, p=0.3), # 원근 왜곡 적용
        transforms.ToTensor(), # PIL 이미지 → 텐서로 변환
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 이미지 정규화
    ])
    transform_val = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_dataset = datasets.ImageFolder(train_dir, transform=transform_train)
    val_dataset = datasets.ImageFolder(val_dir, transform=transform_val)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=8, pin_memory=True)

    # 모델, 손실 함수, 옵티마이저 정의
    num_classes = len(train_dataset.classes)
    model = models.mobilenet_v3_large(weights="DEFAULT")
    model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scaler = GradScaler(enabled=(device.type == 'cuda'))

    # 학습 및 검증 루프
    best_val_loss = float("inf")
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        train_loss, correct_train, total_train = 0.0, 0, 0
        
        loop = tqdm(train_loader, desc=f"[Epoch {epoch+1}/{num_epochs}] Train")
        for images, labels in loop:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            with autocast(enabled=(device.type == 'cuda')):
                outputs = model(images)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct_train += (preds == labels).sum().item()
            total_train += labels.size(0)
            
            loop.set_postfix(loss=loss.item(), acc=f"{correct_train / total_train:.4f}")

        # 검증
        model.eval()
        val_loss, correct_val, total_val = 0.0, 0, 0
        val_loop = tqdm(val_loader, desc=f"[Epoch {epoch+1}/{num_epochs}] Val  ")
        with torch.no_grad():
            for images, labels in val_loop:
                images, labels = images.to(device), labels.to(device)
                
                with autocast(enabled=(device.type == 'cuda')):
                    outputs = model(images)
                    loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, preds = torch.max(outputs, 1)
                correct_val += (preds == labels).sum().item()
                total_val += labels.size(0)

        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        train_accuracy = correct_train / total_train
        val_accuracy = correct_val / total_val

        print(f"\nEpoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | "
              f"Train Acc: {train_accuracy:.4f} | Val Acc: {val_accuracy:.4f}")

        # Early Stopping + 모델 저장
        if avg_val_loss < best_val_loss - delta:
            print(f"Val Loss 개선 ({best_val_loss:.4f} -> {avg_val_loss:.4f}). 모델을 저장합니다.")
            best_val_loss = avg_val_loss
            patience_counter = 0
            os.makedirs("/workspace/model", exist_ok=True)
            torch.save(model.state_dict(), "/workspace/model/Trainmodel1.pth")
        else:
            patience_counter += 1
            print(f"Val Loss가 개선되지 않았습니다. (Patience: {patience_counter}/{patience})")
            if patience_counter >= patience:
                print(f"\n조기 종료: Epoch {epoch+1}에서 훈련 중단")
                break
    print("훈련이 완료되었습니다.")

if __name__ == "__main__":
    train_model()