In [1]:
import torch
from torch import nn
import matplotlib.pyplot as plt

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

Uzyskiwanie danych:

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
import os

data_dir = "/content/drive/MyDrive/images"
os.listdir(data_dir)


['serve', 'backhand', 'ready_position', 'forehand']

In [None]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])


In [None]:
from torchvision.datasets import ImageFolder

dataset = ImageFolder(
    root=data_dir,
    transform=transform
)


In [None]:
dataset.class_to_idx


{'backhand': 0, 'forehand': 1, 'ready_position': 2, 'serve': 3}

In [None]:
from torch.utils.data import DataLoader

dataloader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2
)


In [None]:
images, labels = next(iter(dataloader))

print(images.shape)  # (batch_size, 3, 224, 224)
print(labels.shape)  # (batch_size,)


torch.Size([32, 3, 224, 224])
torch.Size([32])


Podział na grupy treningowe, walidacyjne i testowe:



In [None]:
from torch.utils.data import random_split

torch.manual_seed(42)

total_size = len(dataset)

train_size = int(0.7 * total_size)
val_size   = int(0.15 * total_size)
test_size  = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset,
    [train_size, val_size, test_size]
)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)

Tworzenie modelu:

In [None]:
class TenisVisionCNN (nn.Module):
  def __init__(self, output_shape: int=4):
    super().__init__()

    self.conv_block_1 = nn.Sequential(
        nn.Conv2d(in_channels=3,
                  out_channels=32,
                  kernel_size=3,
                  padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2)
    )

    self.conv_block_2 = nn.Sequential(
        nn.Conv2d(in_channels=32,
                  out_channels=64,
                  kernel_size=3,
                  padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2)
    )

    self.conv_block_3 = nn.Sequential(
        nn.Conv2d(in_channels=64,
                  out_channels=128,
                  kernel_size=3,
                  padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2)
    )

    self.classifier = nn.Sequential(
        nn.Flatten(),
        nn.Linear(in_features=128*28*28, # policzone używając dummy torch (kod poniżej)
                  out_features=output_shape)
    )

  def forward(self, x: torch.Tensor):
    x = self.conv_block_1(x)
    x = self.conv_block_2(x)
    x = self.conv_block_3(x)
    x = self.classifier(x)
    return x


Sprawdzanie wartości na wyjściu trzeciego bloku, aby sprawdzić wejści do classifier:

In [None]:
model = TenisVisionCNN()

In [None]:
dummy = torch.rand(size=(1, 3, 224, 224))
x = dummy
x = model.conv_block_1(x)
x = model.conv_block_2(x)
x = model.conv_block_3(x)
x.shape

torch.Size([1, 128, 28, 28])

Trenowanie modelu:

In [None]:
model = TenisVisionCNN().to(device)

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(),
                             lr=0.001)

In [None]:
def accuracy_fn(y_true, y_pred):
    correct = (y_true == y_pred).sum().item()
    acc = correct / len(y_true) * 100
    return acc


Wprowadzony zostaje early stopping:

In [None]:
patience = 10        # ile epok czekać bez poprawy
best_val_loss = float("inf")
epochs_no_improve = 0


In [None]:
epochs = 100

for epoch in range(epochs):
    print(f"Epoch: {epoch}\n-------")

    # === TRAIN ===
    model.train()
    train_loss = 0

    for X, y in train_loader:
        X, y = X.to(device), y.to(device)

        y_pred = model(X)
        loss = loss_fn(y_pred, y)

        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_loss /= len(train_loader)

    # === VALIDATION ===
    model.eval()
    val_loss, val_acc = 0, 0

    with torch.inference_mode():
        for X_val, y_val in val_loader:
            X_val, y_val = X_val.to(device), y_val.to(device)

            val_pred = model(X_val)
            val_loss += loss_fn(val_pred, y_val).item()
            val_acc += accuracy_fn(
                y_true=y_val,
                y_pred=val_pred.argmax(dim=1)
            )

    val_loss /= len(val_loader)
    val_acc  /= len(val_loader)


    # dane na bierząco
    print(
            f"Train loss: {train_loss:.4f} | "
            f"Val loss: {val_loss:.4f} | "
            f"Val acc: {val_acc:.2f}%"
        )

    # === EARLY STOPPING ===
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), "best_model.pth")
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= patience:
        print(f"Early stopping at epoch {epoch}")
        break



Epoch: 0
-------
Train loss: 1.3516 | Val loss: 1.2010 | Val acc: 52.73%
Epoch: 1
-------
Train loss: 1.1106 | Val loss: 1.0517 | Val acc: 54.69%
Epoch: 2
-------
Train loss: 0.8904 | Val loss: 0.7445 | Val acc: 70.31%
Epoch: 3
-------
Train loss: 0.7518 | Val loss: 0.6870 | Val acc: 71.88%
Epoch: 4
-------
Train loss: 0.6422 | Val loss: 0.6455 | Val acc: 75.39%
Epoch: 5
-------
Train loss: 0.5600 | Val loss: 0.5631 | Val acc: 77.73%
Epoch: 6
-------
Train loss: 0.4917 | Val loss: 0.5163 | Val acc: 76.56%
Epoch: 7
-------
Train loss: 0.3908 | Val loss: 0.5173 | Val acc: 79.69%
Epoch: 8
-------
Train loss: 0.3336 | Val loss: 0.4969 | Val acc: 79.30%
Epoch: 9
-------
Train loss: 0.2768 | Val loss: 0.4144 | Val acc: 84.77%
Epoch: 10
-------
Train loss: 0.2389 | Val loss: 0.3706 | Val acc: 85.16%
Epoch: 11
-------
Train loss: 0.1688 | Val loss: 0.3737 | Val acc: 87.50%
Epoch: 12
-------
Train loss: 0.1328 | Val loss: 0.3800 | Val acc: 87.11%
Epoch: 13
-------
Train loss: 0.0858 | Val loss:

KeyboardInterrupt: 