In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import os
import ast
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt
import matplotlib.patches as patches


In [None]:
class BicycleStateDataset(Dataset):
    def __init__(self, feather_file, img_dir="",_class = None, transform=None):
        self.data = pd.read_feather(feather_file).reset_index(drop=True).dropna()
        if _class:
            self.data = self.data[self.data['class'] == _class]
            self.data = self.data.reset_index(drop=True)
        # Convert bbox_2d from string representation to list of floats
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # Construct full path
        img_path = row["filename"]
        if self.img_dir:
            img_path = os.path.join(self.img_dir, img_path)

        image = Image.open(img_path).convert("RGB")

        x1, y1, x2, y2 = map(float, row["bbox_2d"])

        roi = image.crop((x1, y1, x2, y2))

        if self.transform:
            image = self.transform(image)
            roi = self.transform(roi)

        label = torch.tensor(int(row["target"]), dtype=torch.long)

        return image, roi, label
    def show_image_with_bbox(self, idx, color="red", width=3):
        """
        Display the image with bounding box drawn around the bicycle.
        
        Args:
            idx: Index of the sample to display
            color: Color of the bounding box (default: "red")
            width: Width of the bounding box lines (default: 3)
        """
        import matplotlib.pyplot as plt
        import matplotlib.patches as patches
        
        row = self.data.iloc[idx]
        
        # Construct full path
        img_path = row["filename"]
        if self.img_dir:
            img_path = os.path.join(self.img_dir, img_path)
        
        # Load original image (before any transforms)
        image = Image.open(img_path).convert("RGB")
        x1, y1, x2, y2 = map(float, row["bbox_2d"])
        
        # Create figure and axis
        fig, ax = plt.subplots(1, figsize=(10, 8))
        ax.imshow(image)
        
        # Create rectangle patch
        rect = patches.Rectangle(
            (x1, y1), x2 - x1, y2 - y1,
            linewidth=width, edgecolor=color, facecolor='none'
        )
        ax.add_patch(rect)
        
        # Add label information
        label = row["target"]
        ax.set_title(f"Image: {row['filename']}\nLabel: {label}", fontsize=12)
        ax.axis('off')
        
        plt.tight_layout()
        plt.show()
    
    def get_class_frequencies(self):
        return self.data.groupby('target').size().to_dict()



# ---------- Model ----------
class DualStreamBicycleClassifier(nn.Module):
    def __init__(self, backbone_name="resnet18", pretrained=True, num_classes=2):
        super().__init__()
        backbone_scene = getattr(models, backbone_name)(pretrained=pretrained)
        backbone_roi = getattr(models, backbone_name)(pretrained=pretrained)

        # Remove classification heads
        self.scene_encoder = nn.Sequential(*list(backbone_scene.children())[:-1])  # [B, 512, 1, 1]
        self.roi_encoder = nn.Sequential(*list(backbone_roi.children())[:-1])

        self.fc_scene = nn.Linear(backbone_scene.fc.in_features, 256)
        self.fc_roi = nn.Linear(backbone_roi.fc.in_features, 256)

        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, full_img, roi):
        # Scene encoding
        scene_feat = self.scene_encoder(full_img).flatten(1)
        roi_feat = self.roi_encoder(roi).flatten(1)

        scene_feat = F.relu(self.fc_scene(scene_feat))
        roi_feat = F.relu(self.fc_roi(roi_feat))

        fused = torch.cat([scene_feat, roi_feat], dim=1)
        logits = self.classifier(fused)
        return logits





In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
])

train_dataset = BicycleStateDataset("classifier_training/train_data/cnn_classifier_train_data.feather", "data/nuscenes/datasets/v1.0-trainval/",_class='bicycle', transform=transform)
val_dataset = BicycleStateDataset("classifier_training/val_data/cnn_classifier_val_data.feather", "data/nuscenes/datasets/v1.0-trainval/", _class='bicycle',transform=transform)

In [None]:
### HAND LABELING

class InteractiveLabelingInterface:
    def __init__(self, dataset):
        self.dataset = dataset
        self.current_idx = 0
        self.labels = []
        self.indices = []
        
        # Create buttons
        self.btn_0 = widgets.Button(description='Label 0', button_style='info', layout=widgets.Layout(width='120px', height='50px'))
        self.btn_1 = widgets.Button(description='Label 1', button_style='success', layout=widgets.Layout(width='120px', height='50px'))
        self.btn_2 = widgets.Button(description='Label 2', button_style='warning', layout=widgets.Layout(width='120px', height='50px'))
        self.btn_skip = widgets.Button(description='Skip', button_style='danger', layout=widgets.Layout(width='120px', height='50px'))
        self.btn_back = widgets.Button(description='← Back', button_style='', layout=widgets.Layout(width='120px', height='50px'))
        
        # Create progress label
        self.progress_label = widgets.HTML(value=self.get_progress_text())
        
        # Bind click events
        self.btn_0.on_click(lambda b: self.label_image(0))
        self.btn_1.on_click(lambda b: self.label_image(1))
        self.btn_2.on_click(lambda b: self.label_image(2))
        self.btn_skip.on_click(lambda b: self.skip_image())
        self.btn_back.on_click(lambda b: self.go_back())
        
        # Output widget for the image
        self.output = widgets.Output()
        
    def get_progress_text(self):
        total = len(self.dataset)
        labeled = len(self.labels)
        return f"<h3>Progress: {labeled}/{total} images labeled ({self.current_idx + 1}/{total} shown)</h3>"
    
    def display_image(self):
        with self.output:
            clear_output(wait=True)
            
            if self.current_idx >= len(self.dataset):
                print("✅ All images labeled!")
                print(f"\nLabeled {len(self.labels)} images")
                return
            
            row = self.dataset.data.iloc[self.current_idx]
            
            # Construct full path
            img_path = row["filename"]
            if self.dataset.img_dir:
                import os
                img_path = os.path.join(self.dataset.img_dir, img_path)
            
            # Load image
            image = Image.open(img_path).convert("RGB")
            x1, y1, x2, y2 = map(float, row["bbox_2d"])
            
            # Create figure
            fig, ax = plt.subplots(1, figsize=(12, 8))
            ax.imshow(image)
            
            # Draw bounding box
            rect = patches.Rectangle(
                (x1, y1), x2 - x1, y2 - y1,
                linewidth=3, edgecolor='red', facecolor='none'
            )
            ax.add_patch(rect)
            
            # Add information
            original_label = row["target"]
            ax.set_title(f"Image: {row['filename']}\nOriginal Label: {original_label}", fontsize=12)
            ax.axis('off')
            
            plt.tight_layout()
            plt.show()
    
    def label_image(self, label):
        if self.current_idx < len(self.dataset):
            self.indices.append(self.current_idx)
            self.labels.append(label)
            self.current_idx += 1
            self.progress_label.value = self.get_progress_text()
            self.display_image()
    
    def skip_image(self):
        if self.current_idx < len(self.dataset):
            self.current_idx += 1
            self.progress_label.value = self.get_progress_text()
            self.display_image()
    
    def go_back(self):
        if len(self.labels) > 0:
            # Remove last label
            self.labels.pop()
            self.current_idx = self.indices.pop()
            self.progress_label.value = self.get_progress_text()
            self.display_image()
    
    def get_results(self):
        """Return a DataFrame with the labeling results"""
        results = []
        for idx, label in zip(self.indices, self.labels):
            row = self.dataset.data.iloc[idx]
            results.append({
                'index': idx,
                'filename': row['filename'],
                'original_label': row['target'],
                'new_label': label,
                'bbox_2d': row['bbox_2d']
            })
        return pd.DataFrame(results)
    
    def start(self):
        """Start the labeling interface"""
        # Display progress
        display(self.progress_label)
        
        # Display buttons
        button_box = widgets.HBox([self.btn_0, self.btn_1, self.btn_2, self.btn_skip, self.btn_back])
        display(button_box)
        
        # Display image output
        display(self.output)
        
        # Show first image
        self.display_image()


# Usage:
labeler = InteractiveLabelingInterface(train_dataset)
labeler.start()

In [None]:
# Get the labeled data as a DataFrame
results_df = labeler.get_results()

# Save to CSV
results_df.to_csv('hand_labeled_train_data.csv', index=False)

# Or save to feather format
results_df.to_feather('hand_labeled_train_data.feather')
