In [83]:
from datasets import load_dataset
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from datasets import Dataset
from datasets import DatasetDict
from torchvision import models, transforms
import torch

In [84]:
dataset = load_dataset("imagefolder", data_dir="C:/Users/Aaditya Khanal/OneDrive/Desktop/posture-recognition/image_dataset")

# Label map
labels = dataset["train"].features["label"].names
id2label = {str(i): label for i, label in enumerate(labels)}
label2id = {label: i for i, label in enumerate(labels)}
print(id2label)

Resolving data files:   0%|          | 0/13853 [00:00<?, ?it/s]

{'0': 'barbell biceps curl', '1': 'bench press', '2': 'chest fly machine', '3': 'deadlift', '4': 'decline bench press', '5': 'hammer curl', '6': 'hip thrust', '7': 'incline bench press', '8': 'lat pulldown', '9': 'lateral raises', '10': 'leg extension', '11': 'leg raises', '12': 'plank', '13': 'pull up', '14': 'push up', '15': 'romanian deadlift', '16': 'russian twist', '17': 'shoulder press', '18': 'squat', '19': 't bar row', '20': 'tricep dips', '21': 'tricep pushdown'}


In [85]:
dataset_split = dataset["train"].train_test_split(test_size=0.1)
train_data = dataset_split["train"]
val_data = dataset_split["test"]

In [86]:
batch_size = 32
num_epochs = 10
num_classes = len(labels)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [87]:
from torchvision.transforms import Lambda

transform = transforms.Compose([
    Lambda(lambda image: image.convert("RGB")),  # force RGB mode
    transforms.Resize((224, 224)),
    transforms.ToTensor(), 
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])


In [88]:
class CustomHFDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item["image"]
        label = item["label"]
        if self.transform:
            image = self.transform(image)
        return image, label

train_dataset = CustomHFDataset(train_data, transform=transform)
val_dataset = CustomHFDataset(val_data, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

In [89]:
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

In [90]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [91]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = correct.double() / len(train_loader.dataset)
    print(f'Epoch {epoch+1}: Loss {epoch_loss:.4f}, Accuracy {epoch_acc:.4f}')

Epoch 1: Loss 0.3593, Accuracy 0.9300
Epoch 2: Loss 0.0194, Accuracy 0.9970
Epoch 3: Loss 0.0057, Accuracy 0.9995
Epoch 4: Loss 0.0022, Accuracy 0.9999
Epoch 5: Loss 0.0012, Accuracy 1.0000
Epoch 6: Loss 0.0008, Accuracy 1.0000
Epoch 7: Loss 0.0006, Accuracy 1.0000
Epoch 8: Loss 0.0004, Accuracy 1.0000
Epoch 9: Loss 0.0003, Accuracy 1.0000
Epoch 10: Loss 0.0003, Accuracy 1.0000


In [92]:
model.eval()
correct = 0
with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        correct += torch.sum(preds == labels.data)

val_acc = correct.double() / len(val_loader.dataset)
print(f'Validation Accuracy: {val_acc:.4f}')

Validation Accuracy: 0.9906
