In [None]:
import random
import os

# Dataset: https://www.kaggle.com/datasets/innominate817/hagrid-sample-30k-384p

classes = ['Fist', 'Palm']

# This function loads data that is (image_addr, label).
def load_data(palm_folder, fist_folder):
    data = []
    
    for img_name in os.listdir(palm_folder):
        data.append((f'{palm_folder}/{img_name}', 1))
    for img_name in os.listdir(fist_folder):
        data.append((f'{fist_folder}/{img_name}', 0))
    
    random.shuffle(data)
    return data

data = load_data('../Hagrid_Sample_30k_384p/train_val_palm', '../Hagrid_Sample_30k_384p/train_val_fist')
len(data)

In [None]:
# Splitting the data into train and test.
splitting_ratio = 0.8
train_data = data[:int(splitting_ratio*len(data))]
test_data = data[int(splitting_ratio*len(data)):]
len(train_data), len(test_data)

In [None]:
# Loading Libraries.
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as tf
from torchinfo import summary
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Setting up device agnostic code.
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

In [None]:
train_tf = tf.Compose([
    tf.Resize((224, 224)),
    tf.RandomHorizontalFlip(0.5),
    tf.RandomRotation(45),
    tf.ToTensor(),
])
test_tf = tf.Compose([
    tf.Resize((224, 224)),
    tf.ToTensor(),
])

class PalmFistDataset(Dataset):
    def __init__(self, data, tf):
        super().__init__()
        self.data = data
        self.tf = tf
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_name, label = self.data[idx]
        img = Image.open(img_name).convert('RGB')
        if self.tf:
            img = self.tf(img)
        return (img, label)

train_dataset = PalmFistDataset(train_data, train_tf)
test_dataset = PalmFistDataset(test_data, test_tf)
train_dataset, test_dataset

In [None]:
# Visualizing the dataset.
img, label = train_dataset[random.randint(0, len(train_dataset))]
plt.title(classes[label])
plt.imshow(img.permute(1,2,0).numpy())

In [None]:
# Creating dataloaders.
BATCH_SIZE = 32
train_dataloader = DataLoader(train_dataset, BATCH_SIZE, True)
test_dataloader = DataLoader(test_dataset, BATCH_SIZE, False)
len(train_dataloader), len(test_dataloader)

In [None]:
# Again visaulizing dataset.
img, label = next(iter(train_dataloader))
img = img[0]
label = label[0]
plt.title(classes[label.item()])
plt.imshow(img.permute(1,2,0).numpy())

In [None]:
class PalmFistModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.MaxPool2d(2),
            
            nn.Conv2d(16, 16, 3, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.MaxPool2d(2),
            
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.MaxPool2d(2),
            
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.MaxPool2d(2),
            
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(7*7*32, 1),
            nn.Sigmoid(),
        )
        
    def forward(self, x):
        return self.classifier(self.conv_block(x))

model = PalmFistModel().to(DEVICE)
model

In [None]:
# observing model information.
summary(model, input_size=[32,3,224,224])

In [None]:
# This class is responsible for training and testing the model.
class RunModel:
    def __init__(self, model, train_dataloader, test_dataloader, lr, epochs):
        self.model = model
        self.train_dataloader = train_dataloader
        self.test_dataloader = test_dataloader
        self.epochs = epochs
        
        self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=lr)
        self.label_loss_fn = nn.BCELoss()
        
        self.results = {
            'train_loss': [],
            'test_loss': [],
            'train_acc': [],
            'test_acc': [],
        }
    
    def train(self):
        losses = 0
        accuracy = 0
        
        self.model.train()
        for batch, (img, label) in enumerate(self.train_dataloader):
            img = img.to(DEVICE)
            label = label.to(DEVICE)
            
            img = img.type(torch.float32)
            label = label.type(torch.float32)
            
            pred = self.model(img)
            pred = pred.T.squeeze(dim=0)
            
            loss = self.label_loss_fn(pred, label)
            losses += loss.item()
            
            accuracy += (pred.round() == label).sum().item()/len(pred)
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
        losses /= len(self.train_dataloader)
        accuracy /= len(self.train_dataloader)
        
        return losses, accuracy
    
    def test(self):
        losses = 0
        accuracy = 0
        
        self.model.eval()
        with torch.inference_mode():
            for batch, (img, label) in enumerate(self.test_dataloader):
                img = img.to(DEVICE)
                label = label.to(DEVICE)

                img = img.type(torch.float32)
                label = label.type(torch.float32)

                pred = self.model(img)
                pred = pred.T.squeeze(dim=0)

                loss = self.label_loss_fn(pred, label)
                losses += loss.item()

                accuracy += (pred.round() == label).sum().item()/len(pred)
            
        losses /= len(self.test_dataloader)
        accuracy /= len(self.test_dataloader)
        
        return losses, accuracy
    
    def run(self):
        for epoch in range(self.epochs):
            train_loss, train_accuracy = self.train()
            test_loss, test_accuracy = self.test()
            self.results['train_loss'].append(train_loss)
            self.results['test_loss'].append(test_loss)
            self.results['train_acc'].append(train_accuracy)
            self.results['test_acc'].append(test_accuracy)
            
            print('-'*50 + f' Epoch {epoch} of {self.epochs}' + '-'*50)
            print(f'Train Loss: {train_loss:.3f} Test Loss: {test_loss:.3f}')
            print(f'Train Accuracy: {train_accuracy:.3f} Test Accuracy: {test_accuracy:.3f}')

torch.manual_seed(42)
torch.cuda.manual_seed(42)

run = RunModel(model, train_dataloader, test_dataloader, 0.0001, 1000)
run.run()

In [None]:
# Rendering results.
epochs = [i for i in range(run.epochs)]
fig, axes = plt.subplots(2,1)

axes[0].plot(epochs, run.results['train_loss'], label='Train Loss')
axes[0].plot(epochs, run.results['test_loss'], label='Test Loss')

axes[1].plot(epochs, run.results['train_acc'], label='Train Acc')
axes[1].plot(epochs, run.results['test_acc'], label='Test Acc')

axes[0].legend(loc='upper left')
axes[1].legend(loc='upper left')

fig.show()

In [None]:
# Predicting on images and average classification time is 12ms which is awesome.
def pred(model, image):
    model.eval()
    with torch.inference_mode():
        pred = model(image.unsqueeze(dim=0).to(DEVICE))
        label_pred = pred.round()
        plt.title(classes[int(label_pred.cpu().item())])
        plt.imshow(image.permute(1,2,0).numpy())

%timeit pred(model, test_dataset[random.randint(0, len(test_dataset)-1)][0])

In [None]:
# Saving / Loading Model.
#torch.save(model.state_dict(), 'PalmFistClassificationModel1.pth')
model1 = PalmFistModel()
model1.load_state_dict(torch.load('PalmFistClassificationModel.pth'))