In [56]:
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import os

In [57]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [58]:
# make sure data directories exist
if not os.path.isdir('.\\data'):
    os.makedirs('.\\data')
if not os.path.isdir('.\\data\\train'):
    os.makedirs('.\\data\\train')
if not os.path.isdir('.\\data\\test'):
    os.makedirs('.\\data\\test')

In [59]:
class ModifiedNet(torch.nn.Module):
    def __init__(self, out_features):
        super(ModifiedNet, self).__init__()
        self.resnet = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")
        for param in self.resnet.parameters():
            param.requires_grad = False

        self.resnet.fc = torch.nn.Linear(self.resnet.fc.in_features, out_features)

    def forward(self, x):
        x = self.resnet(x)
        return x

In [60]:
import os

categories = []

class ModifiedDataSet(Dataset):
    def __init__(self, base_directory):
        self.img_dir = []
        self.NUMBER_OF_OUT_FEATURES = 0

        # Loop through each folder in the base directory
        for folder_name in os.listdir(base_directory):
            folder_path = os.path.join(base_directory, folder_name)

            # Check if it's a directory
            if os.path.isdir(folder_path):
                self.NUMBER_OF_OUT_FEATURES += 1
                categories.append(folder_name)
                # List all files in the current folder
                for file_name in os.listdir(folder_path):
                    file_path = os.path.join(folder_path, file_name)

                    # Check if it's a file (not a subfolder)
                    if os.path.isfile(file_path):

                        img = Image.open(file_path).convert("RGB")
                        preprocess = transforms.Compose([
                            transforms.Resize(256),
                            transforms.CenterCrop(224),
                            transforms.ToTensor(),
                            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                        ])
                        img = preprocess(img)
                        self.img_dir.append((img, categories.index(folder_name)))

    def __len__(self):
        return len(self.img_dir)

    def __getitem__(self, idx):
        return self.img_dir[idx]

    def get_num_labels(self):
        return self.NUMBER_OF_OUT_FEATURES


# Behavior Setting

When `use_saved_dataset = True`, the system will try to load previously trained `model.pt`, and `test.pkl` and `train.pkl` which contain the numpy array that store the RGB value of every image in the dataset.

If the training or testing data changed, and the model needs to be retrained, set `use_saved_dataset = False`

In [61]:
import pickle as pkl

use_saved_dataset = True

train_dataset: ModifiedDataSet
test_dataset: ModifiedDataSet

def load_dataset(saved):
    global train_dataset, test_dataset
    if saved:
        if os.path.exists("train.pkl"):
            with open("train.pkl", "rb") as f:
                train_dataset = pkl.load(f)
                print("Successfully loaded train.pkl")
        else:
            train_dataset = ModifiedDataSet(base_directory=".\\data\\train")
            with open("train.pkl", "wb") as f:
                pkl.dump(train_dataset, f)

        if os.path.exists("test.pkl"):
            with open("test.pkl", "rb") as f:
                test_dataset = pkl.load(f)
                print("Successfully loaded test.pkl")
        else:
            test_dataset = ModifiedDataSet(base_directory=".\\data\\test")
            with open("test.pkl", "wb") as f:
                pkl.dump(test_dataset, f)
    else:
        train_dataset = ModifiedDataSet(base_directory=".\\data\\train")
        test_dataset = ModifiedDataSet(base_directory=".\\data\\test")
        with open("train.pkl", "wb") as f:
            pkl.dump(train_dataset, f)
        with open("test.pkl", "wb") as f:
            pkl.dump(test_dataset, f)

load_dataset(use_saved_dataset)

Successfully loaded train.pkl
Successfully loaded test.pkl


In [62]:
model: ModifiedNet
learning_rate = 0.001
batch_size = 64
epochs = 50
loss_fn: torch.nn.CrossEntropyLoss
optimizer: torch.optim.SGD

train_dataloader: DataLoader
test_dataloader: DataLoader

def load_dataloader():
    global train_dataloader
    if len(train_dataset) != 0:
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    global test_dataloader
    if len(test_dataset) != 0:
        test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
load_dataloader()

In [63]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)

    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * batch_size + len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [64]:
def test_loop(dataloader, model, loss_fn):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [65]:
def train_and_test():
    if len(train_dataset) == 0:
        return
    global model
    global optimizer
    global loss_fn
    loss_fn = torch.nn.CrossEntropyLoss()
    model = ModifiedNet(train_dataset.get_num_labels()).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loop(train_dataloader, model, loss_fn, optimizer)
        test_loop(test_dataloader, model, loss_fn)
    print("Done!")
    torch.save(model, "model.pt")


if use_saved_dataset and os.path.exists("model.pt"):
    categories = []
    base_directory = ".\\data\\train"
    for folder_name in os.listdir(base_directory):
            folder_path = os.path.join(base_directory, folder_name)
            # Check if it's a directory
            if os.path.isdir(folder_path):
                categories.append(folder_name)
    model = torch.load("model.pt")
    model.to(device)
else:
    train_and_test()

# GUI

Helps capturing training/testing data, and showing live prediction. Can run the cell below many times without any error.

If new data has been recorded, make sure to set `use_saved_dataset = False`, rerun all previous code cells to reload the data and retrain the network.

In [67]:
import tkinter as tk
from tkinter import messagebox
import cv2
from PIL import Image, ImageTk

cap = cv2.VideoCapture(0)

preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

should_classify = False

def classify_image():
    global should_classify
    if should_classify:
        should_classify = False
    else:
        if len(train_dataset) != 0 and len(test_dataset) != 0 and len(categories) != 0:
            should_classify = True
        else:
            should_classify = False


def update_info():
    if not should_classify:
        message.set("Stopped")
    if not cap.isOpened():
        messagebox.showerror("Error", "Could not open webcam")
        return

    ret, frame = cap.read()
    if ret:
        rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_image = Image.fromarray(rgb_frame)
        imgtk = ImageTk.PhotoImage(image=pil_image)
        image_label.config(image=imgtk)
        image_label.image = imgtk
        if should_classify:
            img = preprocess(pil_image).unsqueeze(0).to(device)
            classification = model(img)
            result = torch.argmax(classification).item()
            message.set(f"Classifying as {categories[result]}")


    root.after(10, update_info)

import time

def collect_train():
    if input_box.get().strip() == "":
        messagebox.showerror("Error", "Label is empty")
        return
    global should_classify
    should_classify = False
    ret, frame = cap.read()
    if ret:
        if not os.path.exists(f".\\data\\train\\{input_box.get()}"):
            os.makedirs(f".\\data\\train\\{input_box.get()}")
        cv2.imwrite(f".\\data\\train\\{input_box.get()}\\{time.time() * 1000}.jpg", frame)

def collect_test():
    if input_box.get().strip() == "":
        messagebox.showerror("Error", "Label is empty")
        return
    global should_classify
    should_classify = False
    ret, frame = cap.read()
    if ret:
        if not os.path.exists(f".\\data\\test\\{input_box.get()}"):
            os.makedirs(f".\\data\\test\\{input_box.get()}")

        cv2.imwrite(f".\\data\\test\\{input_box.get()}\\{time.time() * 1000}.jpg", frame)

# Create the main application window
root = tk.Tk()
root.geometry("800x600")
root.title("Webcam Capture")

# Create a button that calls the take_picture function when clicked
classify_button = tk.Button(root, text="Start/stop classify", command=classify_image, padx=10, pady=5)
classify_button.grid(column=0, row=0)


message = tk.StringVar()
message.set("Stopped")
message_label = tk.Label(root, textvariable=message)
message_label.grid(column=1, row=0, columnspan=2)

collect_data_button = tk.Button(root, text="Collect as training", command=collect_train, padx=10, pady=5)
collect_data_button.grid(column=1, row=1)

collect_data_button2 = tk.Button(root, text="Collect as testing", command=collect_test, padx=10, pady=5)
collect_data_button2.grid(column=2, row=1)

input_box = tk.Entry(root, width=30)
input_box.insert(0, "Place your label here")
input_box.grid(column=0, row=1)

# Create a label to display the captured image
image_label = tk.Label(root)
image_label.grid(column=0, row=2, columnspan=3)

# Start the Tkinter event loop
update_info()
root.mainloop()
cap.release()
