## Interative Classification Tool

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torch.nn.functional as F
import ipywidgets
from IPython.display import display, clear_output
import cv2
import numpy as np
import os
from PIL import Image
import base64
import io
import threading
import time

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
TASK = 'thumbs'
CATEGORIES = ['thumbs_up', 'thumbs_down']
DATASETS = ['A', 'B']
BATCH_SIZE = 8

# Set up transforms
TRANSFORMS = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

**Create datasets**

In [None]:
datasets = {}
for name in DATASETS:
    dataset_path = f"{TASK}_{name}"
    
    # Create the directories if they don't exist
    if not os.path.exists(dataset_path):
        os.makedirs(dataset_path)
        for category in CATEGORIES:
            os.makedirs(os.path.join(dataset_path, category), exist_ok=True)
    
    # Check if there are any image files in the dataset
    has_images = any(
        any(file.endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff')) 
            for file in os.listdir(os.path.join(dataset_path, category)))
        for category in CATEGORIES
    )
    
    if has_images:
        datasets[name] = ImageFolder(dataset_path, transform=TRANSFORMS)
        print(f"Dataset {name} initialized with {len(datasets[name])} images")
    else:
        print(f"No valid images found in {dataset_path}. Skipping dataset.")

# Initialize active dataset
if datasets:
    dataset = datasets[DATASETS[0]]
    print(f"Active dataset: {DATASETS[0]} with {len(dataset)} images")
else:
    dataset = None
    print("No datasets were initialized due to lack of images.")

No valid images found in thumbs_A. Skipping dataset.
No valid images found in thumbs_B. Skipping dataset.
No datasets were initialized due to lack of images.


**Create camera widget**

In [None]:
class Camera:
    def __init__(self):
        self.cap = cv2.VideoCapture(0)
        self.frame = None
        self.stopped = False
        
    def start(self):
        threading.Thread(target=self.update, daemon=True).start()
        return self
    
    def update(self):
        while not self.stopped:
            ret, frame = self.cap.read()
            if ret:
                self.frame = cv2.resize(frame, (224, 224))
            time.sleep(0.01)  # Small delay to reduce CPU usage
    
    def read(self):
        return self.frame
    
    def stop(self):
        self.stopped = True
    
    def release(self):
        self.stop()
        self.cap.release()

In [6]:
camera = Camera().start()

# Create image preview
camera_widget = ipywidgets.Image(format='jpeg')
def update_image():
    frame = camera.read()
    if frame is not None:
        _, buffer = cv2.imencode('.jpg', cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        camera_widget.value = buffer.tobytes()

In [7]:
# Create widgets
dataset_widget = ipywidgets.Dropdown(options=DATASETS, description='dataset')
category_widget = ipywidgets.Dropdown(options=CATEGORIES, description='category')
count_widget = ipywidgets.IntText(description='count')
save_widget = ipywidgets.Button(description='add')

In [8]:
# Function to update datasets
def update_datasets():
    global dataset, datasets
    datasets.clear()  # Clear existing datasets
    for name in DATASETS:
        dataset_path = f"{TASK}_{name}"
        if os.path.exists(dataset_path):
            try:
                new_dataset = ImageFolder(dataset_path, transform=TRANSFORMS)
                if len(new_dataset) > 0:
                    datasets[name] = new_dataset
                    print(f"Dataset {name} updated with {len(new_dataset)} images")
                else:
                    print(f"Dataset {name} is empty. Skipping.")
            except FileNotFoundError as e:
                print(f"Error loading dataset {name}: {str(e)}")
    
    if datasets:
        dataset = datasets[dataset_widget.value]
        print(f"Active dataset: {dataset_widget.value} with {len(dataset)} images")
    else:
        dataset = None
        print("No valid datasets available.")

In [9]:
# Function to save image
def save(c):
    frame = camera.read()
    if frame is not None:
        dataset_name = dataset_widget.value
        category = category_widget.value
        dataset_path = f"{TASK}_{dataset_name}"
        category_path = os.path.join(dataset_path, category)
        
        # Ensure the category directory exists
        os.makedirs(category_path, exist_ok=True)
        
        path = os.path.join(category_path, f"{count_widget.value}.jpg")
        cv2.imwrite(path, frame)
        count_widget.value += 1
        print(f"Image saved to {path}")
        update_datasets()

save_widget.on_click(save)

In [10]:
# Create data collection widget
data_collection_widget = ipywidgets.VBox([
    camera_widget, dataset_widget, category_widget, count_widget, save_widget
])

# Create model
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, len(CATEGORIES))
model = model.to(device)

# Create model widgets
model_save_button = ipywidgets.Button(description='save model')
model_load_button = ipywidgets.Button(description='load model')
model_path_widget = ipywidgets.Text(description='model path', value='my_model.pth')

def load_model(c):
    model.load_state_dict(torch.load(model_path_widget.value, map_location=device))
model_load_button.on_click(load_model)

def save_model(c):
    torch.save(model.state_dict(), model_path_widget.value)
model_save_button.on_click(save_model)

model_widget = ipywidgets.VBox([
    model_path_widget,
    ipywidgets.HBox([model_load_button, model_save_button])
])



In [11]:
# Create live execution widgets
state_widget = ipywidgets.ToggleButtons(options=['stop', 'live'], description='state', value='stop')
prediction_widget = ipywidgets.Text(description='prediction')
score_widgets = [ipywidgets.FloatSlider(min=0.0, max=1.0, description=category, orientation='vertical') for category in CATEGORIES]

def live():
    while state_widget.value == 'live':
        frame = camera.read()
        if frame is not None:
            image = TRANSFORMS(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))).unsqueeze(0).to(device)
            output = model(image)
            output = F.softmax(output, dim=1).detach().cpu().numpy().flatten()
            category_index = output.argmax()
            prediction_widget.value = CATEGORIES[category_index]
            for i, score in enumerate(output):
                score_widgets[i].value = score
        time.sleep(0.1)  # Small delay to reduce update frequency

state_widget.observe(lambda change: threading.Thread(target=live, daemon=True).start() if change['new'] == 'live' else None, names='value')

live_execution_widget = ipywidgets.VBox([
    ipywidgets.HBox(score_widgets),
    prediction_widget,
    state_widget
])

In [12]:
# Create training widgets
optimizer = torch.optim.Adam(model.parameters())

epochs_widget = ipywidgets.IntText(description='epochs', value=1)
eval_button = ipywidgets.Button(description='evaluate')
train_button = ipywidgets.Button(description='train')
loss_widget = ipywidgets.FloatText(description='loss')
accuracy_widget = ipywidgets.FloatText(description='accuracy')
progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='progress')

In [13]:
def train_eval(is_training):
    global dataset
    update_datasets()  # Refresh datasets before training/evaluation
    
    if dataset is None or len(dataset) == 0:
        print("No data available for training/evaluation.")
        return
    
    train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    if is_training:
        model.train()
    else:
        model.eval()
    
    for epoch in range(epochs_widget.value):
        running_loss = 0.0
        correct = 0
        total = 0
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            
            if is_training:
                optimizer.zero_grad()
            
            outputs = model(images)
            loss = F.cross_entropy(outputs, labels)
            
            if is_training:
                loss.backward()
                optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            progress_widget.value = (i + 1) / len(train_loader)
            loss_widget.value = running_loss / (i + 1)
            accuracy_widget.value = correct / total
        
        print(f"Epoch {epoch+1}/{epochs_widget.value} - Loss: {loss_widget.value:.4f}, Accuracy: {accuracy_widget.value:.4f}")
        
        if not is_training:
            break

In [14]:
train_button.on_click(lambda c: train_eval(is_training=True))
eval_button.on_click(lambda c: train_eval(is_training=False))

train_eval_widget = ipywidgets.VBox([
    epochs_widget,
    progress_widget,
    loss_widget,
    accuracy_widget,
    ipywidgets.HBox([train_button, eval_button])
])

In [15]:
# Combine all widgets
all_widget = ipywidgets.VBox([
    ipywidgets.HBox([data_collection_widget, live_execution_widget]), 
    train_eval_widget,
    model_widget
])

# Display the widget
display(all_widget)

VBox(children=(HBox(children=(VBox(children=(Image(value=b'', format='jpeg'), Dropdown(description='dataset', …

In [16]:
# Start updating the camera feed
def update_camera():
    while True:
        update_image()
        time.sleep(0.03)  # Update at approximately 30 FPS

camera_thread = threading.Thread(target=update_camera, daemon=True)
camera_thread.start()

# Clean up
def cleanup():
    camera.release()

# Register the cleanup function to be called when the kernel is shut down
import atexit
atexit.register(cleanup)

<function __main__.cleanup()>