In [1]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [5]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

# BasicBlock 클래스 정의
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * out_channels)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

# SmallResNet 클래스 정의
class SmallResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=1):
        super(SmallResNet, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=1)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=1)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.linear(out)
        return out

def SmallResNet18():
    return SmallResNet(BasicBlock, [2, 2, 2, 2])

# FocalLoss 클래스 정의
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        targets = targets.float()  # 라벨을 float 형으로 변환
        BCE_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt)**self.gamma * BCE_loss
        return F_loss

# CustomDataset 클래스 정의
class CustomDataset(Dataset):
    def __init__(self, dir_path_ball, dir_path_background, transform=None):
        self.dir_path_ball = dir_path_ball
        self.dir_path_background = dir_path_background
        self.transform = transform

        self.image_paths = []
        self.labels = []

        for filename in os.listdir(self.dir_path_ball):
            if filename.endswith('.jpg'):
                self.image_paths.append(os.path.join(self.dir_path_ball, filename))
                self.labels.append(1)

        for filename in os.listdir(self.dir_path_background):
            if filename.endswith('.png'):
                self.image_paths.append(os.path.join(self.dir_path_background, filename))
                self.labels.append(0)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image = Image.open(self.image_paths[index])
        label = self.labels[index]

        if self.transform:
            image = self.transform(image)

        return image, label

# 이미지 변환 정의
transform = transforms.Compose([
    transforms.Resize((30, 30)),  # 이미지 크기 조정
    transforms.ToTensor(),  # 텐서로 변환
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 정규화
])

# 데이터셋 경로
dir_path_ball = r'C:\Users\lwj01\HowFastTennisBallIs\novak_sinner_over_30\cropped_ball\augmentation'
dir_path_background = r'C:\Users\lwj01\HowFastTennisBallIs\novak_sinner_over_30\cropped_baseground'

# 데이터셋과 데이터로더 생성
train_dataset = CustomDataset(
    dir_path_ball=dir_path_ball,
    dir_path_background=dir_path_background,
    transform=transform
)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)

# 모델, 손실 함수, 옵티마이저 설정
model = SmallResNet18().to(device)
criterion = FocalLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 학습 과정
num_epochs = 1
for epoch in range(num_epochs):
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        outputs = torch.squeeze(outputs)
        loss = criterion(outputs, labels) 
        loss.backward()
        optimizer.step()

        print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
        
print("Training complete.")


Epoch [1/1], Batch [1/1180], Loss: 0.1257
Epoch [1/1], Batch [2/1180], Loss: 0.0694
Epoch [1/1], Batch [3/1180], Loss: 0.0065
Epoch [1/1], Batch [4/1180], Loss: 0.0055
Epoch [1/1], Batch [5/1180], Loss: 0.0031
Epoch [1/1], Batch [6/1180], Loss: 0.0028
Epoch [1/1], Batch [7/1180], Loss: 0.0015
Epoch [1/1], Batch [8/1180], Loss: 0.0052
Epoch [1/1], Batch [9/1180], Loss: 0.0011
Epoch [1/1], Batch [10/1180], Loss: 0.0007
Epoch [1/1], Batch [11/1180], Loss: 0.0024
Epoch [1/1], Batch [12/1180], Loss: 0.0044
Epoch [1/1], Batch [13/1180], Loss: 0.0067
Epoch [1/1], Batch [14/1180], Loss: 0.0037
Epoch [1/1], Batch [15/1180], Loss: 0.0018
Epoch [1/1], Batch [16/1180], Loss: 0.0057
Epoch [1/1], Batch [17/1180], Loss: 0.0015
Epoch [1/1], Batch [18/1180], Loss: 0.0015
Epoch [1/1], Batch [19/1180], Loss: 0.0020
Epoch [1/1], Batch [20/1180], Loss: 0.0007
Epoch [1/1], Batch [21/1180], Loss: 0.0004
Epoch [1/1], Batch [22/1180], Loss: 0.0030
Epoch [1/1], Batch [23/1180], Loss: 0.0010
Epoch [1/1], Batch [

Epoch [1/1], Batch [190/1180], Loss: 0.0056
Epoch [1/1], Batch [191/1180], Loss: 0.0015
Epoch [1/1], Batch [192/1180], Loss: 0.0024
Epoch [1/1], Batch [193/1180], Loss: 0.0035
Epoch [1/1], Batch [194/1180], Loss: 0.0020
Epoch [1/1], Batch [195/1180], Loss: 0.0015
Epoch [1/1], Batch [196/1180], Loss: 0.0022
Epoch [1/1], Batch [197/1180], Loss: 0.0015
Epoch [1/1], Batch [198/1180], Loss: 0.0016
Epoch [1/1], Batch [199/1180], Loss: 0.0003
Epoch [1/1], Batch [200/1180], Loss: 0.0014
Epoch [1/1], Batch [201/1180], Loss: 0.0010
Epoch [1/1], Batch [202/1180], Loss: 0.0004
Epoch [1/1], Batch [203/1180], Loss: 0.0002
Epoch [1/1], Batch [204/1180], Loss: 0.0016
Epoch [1/1], Batch [205/1180], Loss: 0.0016
Epoch [1/1], Batch [206/1180], Loss: 0.0010
Epoch [1/1], Batch [207/1180], Loss: 0.0004
Epoch [1/1], Batch [208/1180], Loss: 0.0005
Epoch [1/1], Batch [209/1180], Loss: 0.0003
Epoch [1/1], Batch [210/1180], Loss: 0.0012
Epoch [1/1], Batch [211/1180], Loss: 0.0010
Epoch [1/1], Batch [212/1180], L

Epoch [1/1], Batch [377/1180], Loss: 0.0000
Epoch [1/1], Batch [378/1180], Loss: 0.0002
Epoch [1/1], Batch [379/1180], Loss: 0.0000
Epoch [1/1], Batch [380/1180], Loss: 0.0000
Epoch [1/1], Batch [381/1180], Loss: 0.0000
Epoch [1/1], Batch [382/1180], Loss: 0.0000
Epoch [1/1], Batch [383/1180], Loss: 0.0000
Epoch [1/1], Batch [384/1180], Loss: 0.0000
Epoch [1/1], Batch [385/1180], Loss: 0.0000
Epoch [1/1], Batch [386/1180], Loss: 0.0000
Epoch [1/1], Batch [387/1180], Loss: 0.0000
Epoch [1/1], Batch [388/1180], Loss: 0.0000
Epoch [1/1], Batch [389/1180], Loss: 0.0000
Epoch [1/1], Batch [390/1180], Loss: 0.0000
Epoch [1/1], Batch [391/1180], Loss: 0.0000
Epoch [1/1], Batch [392/1180], Loss: 0.0000
Epoch [1/1], Batch [393/1180], Loss: 0.0000
Epoch [1/1], Batch [394/1180], Loss: 0.0000
Epoch [1/1], Batch [395/1180], Loss: 0.0000
Epoch [1/1], Batch [396/1180], Loss: 0.0000
Epoch [1/1], Batch [397/1180], Loss: 0.0001
Epoch [1/1], Batch [398/1180], Loss: 0.0000
Epoch [1/1], Batch [399/1180], L

Epoch [1/1], Batch [564/1180], Loss: 0.0000
Epoch [1/1], Batch [565/1180], Loss: 0.0000
Epoch [1/1], Batch [566/1180], Loss: 0.0000
Epoch [1/1], Batch [567/1180], Loss: 0.0000
Epoch [1/1], Batch [568/1180], Loss: 0.0000
Epoch [1/1], Batch [569/1180], Loss: 0.0000
Epoch [1/1], Batch [570/1180], Loss: 0.0001
Epoch [1/1], Batch [571/1180], Loss: 0.0000
Epoch [1/1], Batch [572/1180], Loss: 0.0000
Epoch [1/1], Batch [573/1180], Loss: 0.0000
Epoch [1/1], Batch [574/1180], Loss: 0.0000
Epoch [1/1], Batch [575/1180], Loss: 0.0000
Epoch [1/1], Batch [576/1180], Loss: 0.0000
Epoch [1/1], Batch [577/1180], Loss: 0.0000
Epoch [1/1], Batch [578/1180], Loss: 0.0000
Epoch [1/1], Batch [579/1180], Loss: 0.0000
Epoch [1/1], Batch [580/1180], Loss: 0.0000
Epoch [1/1], Batch [581/1180], Loss: 0.0000
Epoch [1/1], Batch [582/1180], Loss: 0.0000
Epoch [1/1], Batch [583/1180], Loss: 0.0002
Epoch [1/1], Batch [584/1180], Loss: 0.0001
Epoch [1/1], Batch [585/1180], Loss: 0.0000
Epoch [1/1], Batch [586/1180], L

Epoch [1/1], Batch [751/1180], Loss: 0.0000
Epoch [1/1], Batch [752/1180], Loss: 0.0000
Epoch [1/1], Batch [753/1180], Loss: 0.0000
Epoch [1/1], Batch [754/1180], Loss: 0.0000
Epoch [1/1], Batch [755/1180], Loss: 0.0000
Epoch [1/1], Batch [756/1180], Loss: 0.0000
Epoch [1/1], Batch [757/1180], Loss: 0.0000
Epoch [1/1], Batch [758/1180], Loss: 0.0000
Epoch [1/1], Batch [759/1180], Loss: 0.0000
Epoch [1/1], Batch [760/1180], Loss: 0.0002
Epoch [1/1], Batch [761/1180], Loss: 0.0000
Epoch [1/1], Batch [762/1180], Loss: 0.0000
Epoch [1/1], Batch [763/1180], Loss: 0.0000
Epoch [1/1], Batch [764/1180], Loss: 0.0000
Epoch [1/1], Batch [765/1180], Loss: 0.0000
Epoch [1/1], Batch [766/1180], Loss: 0.0000
Epoch [1/1], Batch [767/1180], Loss: 0.0000
Epoch [1/1], Batch [768/1180], Loss: 0.0000
Epoch [1/1], Batch [769/1180], Loss: 0.0000
Epoch [1/1], Batch [770/1180], Loss: 0.0000
Epoch [1/1], Batch [771/1180], Loss: 0.0000
Epoch [1/1], Batch [772/1180], Loss: 0.0000
Epoch [1/1], Batch [773/1180], L

Epoch [1/1], Batch [938/1180], Loss: 0.0000
Epoch [1/1], Batch [939/1180], Loss: 0.0000
Epoch [1/1], Batch [940/1180], Loss: 0.0000
Epoch [1/1], Batch [941/1180], Loss: 0.0000
Epoch [1/1], Batch [942/1180], Loss: 0.0001
Epoch [1/1], Batch [943/1180], Loss: 0.0000
Epoch [1/1], Batch [944/1180], Loss: 0.0000
Epoch [1/1], Batch [945/1180], Loss: 0.0000
Epoch [1/1], Batch [946/1180], Loss: 0.0000
Epoch [1/1], Batch [947/1180], Loss: 0.0000
Epoch [1/1], Batch [948/1180], Loss: 0.0000
Epoch [1/1], Batch [949/1180], Loss: 0.0000
Epoch [1/1], Batch [950/1180], Loss: 0.0000
Epoch [1/1], Batch [951/1180], Loss: 0.0000
Epoch [1/1], Batch [952/1180], Loss: 0.0000
Epoch [1/1], Batch [953/1180], Loss: 0.0002
Epoch [1/1], Batch [954/1180], Loss: 0.0000
Epoch [1/1], Batch [955/1180], Loss: 0.0000
Epoch [1/1], Batch [956/1180], Loss: 0.0000
Epoch [1/1], Batch [957/1180], Loss: 0.0000
Epoch [1/1], Batch [958/1180], Loss: 0.0000
Epoch [1/1], Batch [959/1180], Loss: 0.0000
Epoch [1/1], Batch [960/1180], L

Epoch [1/1], Batch [1122/1180], Loss: 0.0000
Epoch [1/1], Batch [1123/1180], Loss: 0.0000
Epoch [1/1], Batch [1124/1180], Loss: 0.0000
Epoch [1/1], Batch [1125/1180], Loss: 0.0000
Epoch [1/1], Batch [1126/1180], Loss: 0.0000
Epoch [1/1], Batch [1127/1180], Loss: 0.0000
Epoch [1/1], Batch [1128/1180], Loss: 0.0000
Epoch [1/1], Batch [1129/1180], Loss: 0.0000
Epoch [1/1], Batch [1130/1180], Loss: 0.0000
Epoch [1/1], Batch [1131/1180], Loss: 0.0000
Epoch [1/1], Batch [1132/1180], Loss: 0.0000
Epoch [1/1], Batch [1133/1180], Loss: 0.0000
Epoch [1/1], Batch [1134/1180], Loss: 0.0000
Epoch [1/1], Batch [1135/1180], Loss: 0.0000
Epoch [1/1], Batch [1136/1180], Loss: 0.0000
Epoch [1/1], Batch [1137/1180], Loss: 0.0000
Epoch [1/1], Batch [1138/1180], Loss: 0.0000
Epoch [1/1], Batch [1139/1180], Loss: 0.0000
Epoch [1/1], Batch [1140/1180], Loss: 0.0000
Epoch [1/1], Batch [1141/1180], Loss: 0.0000
Epoch [1/1], Batch [1142/1180], Loss: 0.0000
Epoch [1/1], Batch [1143/1180], Loss: 0.0000
Epoch [1/1

In [6]:
torch.save(model, 'model_v5.pt')