In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Subset
from torchvision import datasets, transforms, models
import numpy as np
from tqdm import tqdm 
import pandas as pd
import random
from PIL import Image, ImageOps
from torchvision.transforms import functional as F

In [2]:
# 데이터 경로 설정
train_dir = "/test/final_exam/challenge/train"
augmented_dir = "/home/student/workspace/data"

os.makedirs(augmented_dir, exist_ok=True)


In [20]:
# 증강 함수 정의
def augment_images(class_path, save_path, min_count=300):
    images = [f for f in os.listdir(class_path) if f.endswith(".jpg") or f.endswith(".png")]
    
    # 원본 이미지 불러오기
    image_paths = [os.path.join(class_path, img) for img in images]
    augmented_images = []

    # 좌우 반전
    for img_path in image_paths:
        with Image.open(img_path) as img:
            flipped = ImageOps.mirror(img)
            augmented_images.append(flipped)
            augmented_images.append(img.copy())

    # 랜덤 회전 (시계방향 & 반시계방향 45도 이내)
    final_images = []
    for img in augmented_images:
        for _ in range(2):
            angle = random.uniform(-45, 45)
            rotated = img.rotate(angle)
            final_images.append(rotated)

    # 이미지 저장 (최소 100장 확보)
    while len(final_images) < min_count:
        for img in augmented_images:
            angle = random.uniform(-45, 45)
            rotated = img.rotate(angle)
            final_images.append(rotated)
            if len(final_images) >= min_count:
                break

    # 저장
    os.makedirs(save_path, exist_ok=True)
    for idx, img in enumerate(final_images):
        img.save(os.path.join(save_path, f"aug_{idx}.jpg"))


In [21]:
# 클래스별 증강 실행
for class_name in os.listdir(train_dir):
    class_path = os.path.join(train_dir, class_name)
    if os.path.isdir(class_path):
        save_path = os.path.join(augmented_dir, class_name)
        augment_images(class_path, save_path)

print("Data augmentation complete.")

Data augmentation complete.


In [3]:
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

# 데이터셋 로드 및 변환 정의
transform = transforms.Compose([
    transforms.Resize((1024, 1024)),  # 크기를 1024x1024로 리사이즈
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 정규화
])

# 증강된 데이터셋 경로
augmented_dir = "/home/student/workspace/data/train"

# 데이터셋 로드
dataset = datasets.ImageFolder(root=augmented_dir, transform=transform)

# Train/Validation Split
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# DataLoader 생성
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


In [4]:
import torch.nn as nn
from torchvision.models import resnet50

# ResNet 모델 정의
class ResNetScratch(nn.Module):
    def __init__(self, num_classes=300):  # 클래스 개수를 설정
        super(ResNetScratch, self).__init__()
        self.resnet = resnet50(weights=None)  # Scratch 학습
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)  # 출력층 변경

    def forward(self, x):
        return self.resnet(x)

# 모델 초기화
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResNetScratch(num_classes=300).to(device)


In [5]:
import torch.optim as optim
from tqdm import tqdm

# 학습 및 검증 함수
def train_and_validate(model, train_loader, val_loader, epochs=20):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)  # AdamW 옵티마이저
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)  # 학습률 스케줄러

    best_val_accuracy = 0.0

    for epoch in range(epochs):
        # ===== 학습 단계 =====
        model.train()
        train_loss = 0.0
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Training]"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        # ===== 검증 단계 =====
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Validation]"):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_accuracy = 100 * correct / total
        print(f"Epoch {epoch+1}/{epochs}: Train Loss = {train_loss/len(train_loader):.4f}, "
              f"Val Loss = {val_loss/len(val_loader):.4f}, Val Accuracy = {val_accuracy:.2f}%")

        # Best Model 저장
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), 'best_resnet_scratch.pth')
            print(f"Best model saved with accuracy: {best_val_accuracy:.2f}%")

        scheduler.step()


In [8]:
# Step 1: 모델 학습
train_and_validate(model, train_loader, val_loader, epochs=5)

# Step 2: 모델 가중치 파일 확인
if not os.path.exists('./best_resnet_scratch.pth'):
    raise FileNotFoundError("Model weights file 'best_resnet_scratch.pth' not found. Please ensure training is completed.")

# Step 3: 테스트 단계 실행
model.load_state_dict(torch.load('./best_resnet_scratch.pth'))
predictions = test_model(model, test_loader)

# Step 4: 제출 파일 생성
submission = pd.read_csv('./sample_submission_path')
submission['Label'] = predictions
submission.to_csv('submission_resnet_scratch.csv', index=False)
print("Submission file saved as 'submission_resnet_scratch.csv'.")


Epoch 1/5 [Training]:   0% 2/4694 [00:05<3:47:42,  2.91s/it]


KeyboardInterrupt: 