In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
from torchvision.datasets import ImageFolder
from model import model
from handtrackermodule import handDetector

import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [None]:
EPOCHS = 10
LR = 0.001
BATCH_SIZE = 16
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
hands = handDetector(maxHands=1)
trainDataset = ImageFolder("dataset")
trainDataLoader = DataLoader(trainDataset, BATCH_SIZE, True)

model = model().to(device=DEVICE)
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), LR)

In [None]:
loss_arr = []
epoch_arr = []
for epoch in tqdm(range(EPOCHS)):
    epoch_arr.append(epoch)
    total_loss = 0
    for img, label in trainDataLoader:
        landmarks = hands.findPosition(img)
        pred = model(landmarks)
        loss = loss_fn(pred, label)
        total_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    total_loss /= len(trainDataLoader)
    print(f"Epoch {epoch} : Loss = {total_loss}")
    loss_arr.append(total_loss)

In [None]:
plt.plot(loss_arr, epoch_arr)

In [None]:
torch.save(model, "models/model.pth")