# Classifier - Using YOLO outputs only

## Read YOLO outputs

In [None]:
from json_to_lists import read_yolo_json, get_labels

names, datas = read_yolo_json("./outputs/combined_rgb/yolo_test.json")
labels = get_labels(names)

## Create NN

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class class_nn(nn.Module):
    def __init__(self):
        super(class_nn, self).__init__()
        self.fc1 = nn.Linear(120, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 12)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))  # Since output is bits (0–1)
        return x

model = class_nn()

## Prepare training and validation sets

In [None]:
from collections import Counter
from torch.utils.data import DataLoader, TensorDataset, random_split

inputs = torch.tensor(datas, dtype=torch.float32)
targets = torch.tensor(labels, dtype=torch.float32)

dataset = TensorDataset(inputs, targets)

split = 0.8
train_size = int(len(inputs) * split)
test_size = len(inputs) - train_size

train_dataset = torch.utils.data.Subset(dataset, range(train_size))
val_dataset = torch.utils.data.Subset(dataset, range(train_size, train_size + test_size))

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)

def count_label_combinations(dataset):
    combo_counter = Counter()
    for _, label in dataset:
        key = tuple(label.int().tolist())  # convert tensor to tuple of ints
        combo_counter[key] += 1
    return combo_counter

train_combos = count_label_combinations(train_dataset)
val_combos = count_label_combinations(val_dataset)

# Print results
print(train_size)
print("Train label combinations:")
for combo, count in train_combos.items():
    print(f"{combo}: {count}")

print("\nValidation label combinations:")
for combo, count in val_combos.items():
    print(f"{combo}: {count}")

## Train NN

In [None]:
import torch.optim as optim
criterion = nn.BCELoss()  # Binary Cross Entropy Loss for 12-bit outputs
optimizer = optim.Adam(model.parameters(), lr=0.0005)

epochs = 50

for epoch in range(epochs):
    model.train()
    train_loss = 0
    for x_batch, y_batch in train_loader:
        optimizer.zero_grad()
        outputs = model(x_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x_val, y_val in val_loader:
            val_outputs = model(x_val)
            val_loss += criterion(val_outputs, y_val).item()

    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")


## Save trained model

In [None]:
from pathlib import Path
Path("classifier_weights").mkdir(exist_ok=True)
torch.save(model.state_dict(), 'classifier_weights/yolo_only.pth')

## Load trained model

In [None]:
model = class_nn()  # instantiate the model
model.load_state_dict(torch.load('classifier_weights/yolo_only.pth'))
model.eval()  # set the model to evaluation mode

## Evaluate trained model

In [None]:
all_preds = []
all_labels = []
def evaluate_accuracy(loader):
    model.eval()
    exact_matches = 0
    total_samples = 0
    with torch.no_grad():
        for x_batch, y_batch in loader:
            outputs = model(x_batch)
            predicted = (outputs > 0.5).float() # todo we can play with this maybe?
            
            # If error state last bit is 1 then all other should be 0
            error_mask = predicted[:, -1] == 1  # samples where last bit is 1
            predicted[error_mask, :-1] = 0      # set others to 0
            
            non_empty_mask = ~(y_batch.sum(axis=1) == 0)
            predicted = predicted[non_empty_mask]
            y_batch = y_batch[non_empty_mask]
            
            matches = (predicted == y_batch).all(dim=1)  # full match per sample
           
            exact_matches += matches.sum().item()
            total_samples += y_batch.size(0)
            
            all_preds.extend([tuple(p.int().tolist()) for p in predicted])
            all_labels.extend([tuple(y.int().tolist()) for y in y_batch])

            
    return exact_matches / total_samples

train_acc = evaluate_accuracy(train_loader)
val_acc = evaluate_accuracy(val_loader)

print(f"Train Accuracy: {train_acc:.2%}")
print(f"Validation Accuracy: {val_acc:.2%}")

## Generate confusion matrix

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd


def gen_confusion_matrix(all_preds, all_labels):
    # Get all unique states across true and predicted
    all_states = sorted(set(all_labels))
    print(all_states)
    state_to_index = {state: idx for idx, state in enumerate(all_states)}

    # Convert state tuples to indices
    y_true_idx = [state_to_index.get(state, len(all_states)) for state in all_labels]
    y_pred_idx = [state_to_index.get(state, len(all_states)) for state in all_preds]
    
    cm = confusion_matrix(y_true_idx, y_pred_idx)
    return cm, all_states


def plot_state_confusion_matrix(cm, state_labels, title="Confusion Matrix of States- YOLO", save_path=None, normalize=True):
    # Convert tuples like (1, 0, 1) to strings: "101"
    state_strs = [''.join(map(str, state)) for state in state_labels]
    state_strs.append("other")

    if normalize:
        # Normalize rows to sum to 1 (avoid division by zero)
        cm_normalized = cm.astype('float')
        row_sums = cm_normalized.sum(axis=1, keepdims=True)
        cm_normalized = cm_normalized / row_sums
        cm_normalized = pd.DataFrame(cm_normalized, index=state_strs, columns=state_strs)
        fmt = ".2f"
        data_to_plot = cm_normalized
    else:
        # Use raw counts
        data_to_plot = pd.DataFrame(cm, index=state_strs, columns=state_strs)
        fmt = "d"

    plt.figure(figsize=(10, 8))
    sns.heatmap(data_to_plot, annot=True, fmt=fmt, cmap="Reds", cbar=True)
    plt.title(title)
    plt.ylabel("True State")
    plt.xlabel("Predicted State")
    plt.xticks(rotation=45, ha="right")
    plt.yticks(rotation=0)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300)
    plt.show()


cm, all_states = gen_confusion_matrix(all_preds, all_labels)
plot_state_confusion_matrix(cm, all_states, save_path="Confusion_matrix_Yolo")