In [None]:
import torch
import cv2
from torchvision import transforms, models
from jetcam.csi_camera import CSICamera
import numpy as np
import ipywidgets as widgets
from IPython.display import display
from jetcam.utils import bgr8_to_jpeg

In [None]:
class_names = ['bike', 'railroads', 'road', 'stop', 'trafficlights']   #change to actual names

In [None]:
model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
model.eval()
model = model.to('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [None]:
camera = CSICamera(width=224, height=224, capture_fps=65)
output_widget = widgets.Image(format='jpeg', width=224, height=224)
camera.running = True


In [None]:
import time
print("Starting live classification.")

output_widget = widgets.Image(format='jpeg', width=224, height=224)
display(output_widget)

while True:
    frame = camera.value  # numpy array (height, width, channels)

    image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    input_tensor = transform(image).unsqueeze(0).to('cuda' if torch.cuda.is_available() else 'cpu')

    with torch.no_grad(): #making the prediction/inference
        outputs = model(input_tensor)
        _, predicted = outputs.max(1)
        label = class_names[predicted.item()]

    annotated_frame = frame.copy()
    cv2.putText(annotated_frame, f'Prediction: {label}', (10, 20),
                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

    output_widget.value = bgr8_to_jpeg(annotated_frame) #Displaying the Image
    time.sleep(0.05) 
    