In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader

# ResNet50 모델 불러오기
model = models.resnet50(pretrained=True)

# 마지막 fully connected layer 수정
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)  # 2개의 출력 (0 또는 1)

# 데이터셋과 데이터 로더 설정
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ResNet50은 기본적으로 224x224 이미지를 입력으로 사용
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 옵티마이저와 손실 함수 설정
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# CUDA 사용 설정 (GPU 사용 시)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)


In [12]:
import os
import pandas as pd
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# CustomDataset 클래스 정의
class CustomDataset(Dataset):
    def __init__(self, dir_path_ball, dir_path_background, transform=None):
        self.dir_path_ball = dir_path_ball
        self.dir_path_background = dir_path_background
        self.transform = transform
        
        # 이미지 파일 경로와 라벨을 리스트로 만듭니다.
        self.image_paths = []
        self.labels = []

        # 라벨이 1인 이미지 (테니스공)
        for filename in os.listdir(self.dir_path_ball):
            if filename.endswith('.png'):  # .png 파일만 추가
                self.image_paths.append(os.path.join(self.dir_path_ball, filename))
                self.labels.append(1)
        
        # 라벨이 0인 이미지 (배경)
        for filename in os.listdir(self.dir_path_background):
            if filename.endswith('.png'):  # .png 파일만 추가
                self.image_paths.append(os.path.join(self.dir_path_background, filename))
                self.labels.append(0)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        # 이미지를 로드하고 라벨을 가져옵니다.
        image = Image.open(self.image_paths[index])
        label = self.labels[index]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label


dir_path_ball = r'C:\Users\lwj01\HowFastTennisBallIs\novak_sinner_over_30\cropped_ball\augmentation'
dir_path_background = r'C:\Users\lwj01\HowFastTennisBallIs\novak_sinner_over_30\cropped_baseground'

train_dataset = CustomDataset(
    dir_path_ball=dir_path_ball,
    dir_path_background=dir_path_background,
    transform=transform
)

train_loader = DataLoader(dataset=train_dataset, batch_size=512, shuffle=True)


# 학습 과정
num_epochs = 1
for epoch in range(num_epochs):
    for inputs, labels in train_loader:
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

print("Training complete.")


Epoch [1/1], Loss: 0.8339
Epoch [1/1], Loss: 0.1221
Epoch [1/1], Loss: 0.0336
Epoch [1/1], Loss: 0.0071
Epoch [1/1], Loss: 0.0035
Epoch [1/1], Loss: 0.0010
Epoch [1/1], Loss: 0.0006
Epoch [1/1], Loss: 0.0003
Epoch [1/1], Loss: 0.0002
Epoch [1/1], Loss: 0.0001
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1], Loss: 0.0000
Epoch [1/1],

KeyboardInterrupt: 

In [13]:
torch.save(model, 'model_v3.pt')