# Classifier - Using YOLO + VLM outputs

## Install required packages

In [None]:
!pip install transformers

## Prepare BERT tokenizer

In [None]:
from transformers import BertTokenizer, BertModel

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)
model.eval()



## Read and tokenize VLM outputs

In [None]:
import json
import torch
INPUT_JSON = "outputs/combined_rgb/llava_test_prompt2.json"
INPUT_ADDITION_JSON = "outputs/combined_rgb/llava_test_prompt2_addition.json"
OUTPUT_JSON = "outputs/combined_rgb/llava_test_vectorized.json"

with open(INPUT_JSON, "r") as f1:
    data = json.load(f1)

    outputs = data["outputs"]

with open(INPUT_ADDITION_JSON) as f2:
    data_addition = json.load(f2)

    outputs_addition = data_addition["outputs"]

print(len(outputs))
print(len(outputs_addition))
outputs = outputs|outputs_addition

print(len(outputs))

vectorized_result = {}

for file_name, descriptions in outputs.items():
    text = descriptions[0]
    inputs = tokenizer(text, return_tensors='pt')
    with torch.no_grad():
        output = model(**inputs)

    embedding = output.pooler_output.squeeze().tolist()
    vectorized_result[file_name] = embedding


with open(OUTPUT_JSON, "w") as f:
    json.dump(vectorized_result, f, indent=2)

print(f"Tokenized data saved to {OUTPUT_JSON}")


## VLM Distance Metrics

In [None]:
import json
import numpy as np
import re
from collections import defaultdict
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances

# --- CONFIG ---
VECTOR_JSON = "outputs/combined_rgb/llava_test_vectorized.json"

# --- LOAD VECTOR DATA ---
with open(VECTOR_JSON, "r") as f:
    data = json.load(f)

# --- GROUP VECTORS BY CLASS ---
class_vectors = defaultdict(list)

for file_name, vec in data.items():
    match = re.search(r'_(\d{2})\.png$', file_name)
    if match:
        class_id = match.group(1)
        class_vectors[class_id].append(np.array(vec, dtype=np.float32))

# --- COMPUTE METRICS ---
metrics = {}
centroids = {}

for class_id, vectors in class_vectors.items():
    arr = np.stack(vectors)
    norms = np.linalg.norm(arr, axis=1)

    # Intra-class cosine similarity
    if len(arr) > 1:
        cos_sim_matrix = cosine_similarity(arr)
        upper_tri_indices = np.triu_indices_from(cos_sim_matrix, k=1)
        avg_cos_sim = np.mean(cos_sim_matrix[upper_tri_indices])
    else:
        avg_cos_sim = 1.0

    centroid = arr.mean(axis=0)
    centroids[class_id] = centroid

    metrics[class_id] = {
        "num_samples": len(arr),
        "avg_vector_norm": float(np.mean(norms)),
        "max_vector_norm": float(np.max(norms)),
        "min_vector_norm": float(np.min(norms)),
        "intra_class_variance": float(np.var(arr, axis=0).mean()),
        "avg_cosine_similarity": float(avg_cos_sim)
    }

# --- INTER-CLASS VARIANCE CALCULATION ---
# Convert centroids dict to ordered arrays
class_ids = sorted(centroids.keys())
centroid_matrix = np.stack([centroids[cid] for cid in class_ids])

# Euclidean-based inter-class variance
euclidean_dists = euclidean_distances(centroid_matrix)
upper_tri_indices = np.triu_indices_from(euclidean_dists, k=1)
avg_inter_class_variance = np.mean(euclidean_dists[upper_tri_indices])

# Optional: cosine distance matrix for reference
cosine_sim = cosine_similarity(centroid_matrix)
cosine_dists = 1 - cosine_sim

# --- REPORT ---
print("\nClass-wise Vector Metrics:\n")
for class_id in class_ids:
    stats = metrics[class_id]
    print(f"Class {class_id}:")
    for key, val in stats.items():
        print(f"  {key}: {val:.4f}" if isinstance(val, float) else f"  {key}: {val}")
    print()

print(f"\nAverage Inter-Class Variance (Euclidean distance between centroids): {avg_inter_class_variance:.4f}")

# Optional: print cosine distance matrix
print("\nCosine Distance Matrix (1 - similarity):")
header = "       " + " ".join([f"{cid:>6}" for cid in class_ids])
print(header)
for i, cid in enumerate(class_ids):
    row = "  " + cid + "  " + " ".join([f"{cosine_dists[i, j]:6.3f}" for j in range(len(class_ids))])
    print(row)


## Read YOLO Outputs

In [None]:
from json_to_lists import read_yolo_json, get_labels

yolo_names, yolo_data = read_yolo_json("./outputs/combined_rgb/yolo_test.json")
yolo_labels = get_labels(yolo_names)

yolo_data_dict = dict(zip(yolo_names, yolo_data))
yolo_label_dict = dict(zip(yolo_names, yolo_labels))

## Concatenate YOLO and VLM Outputs

In [None]:
concat_data = []
concat_labels = []
concat_names = []

for name, embedding in vectorized_result.items():
    concat_data.append(yolo_data_dict[name] + embedding)
    concat_labels.append(yolo_label_dict[name])
    concat_names.append(name)

## Create NN

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class class_nn(nn.Module):
    def __init__(self):
        super(class_nn, self).__init__()
        self.fc1 = nn.Linear(888, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 12)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))  # Since output is bits (0–1)
        return x

model = class_nn()

## Prepare training and validation sets

In [None]:
from collections import Counter
from torch.utils.data import DataLoader, TensorDataset, random_split

inputs = torch.tensor(concat_data, dtype=torch.float32)
targets = torch.tensor(concat_labels, dtype=torch.float32)

dataset = TensorDataset(inputs, targets)

split = 0.8
train_size = int(len(inputs) * split)
test_size = len(inputs) - train_size

train_dataset = torch.utils.data.Subset(dataset, range(train_size))
val_dataset = torch.utils.data.Subset(dataset, range(train_size, train_size + test_size))

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)

def count_label_combinations(dataset):
    combo_counter = Counter()
    for _, label in dataset:
        key = tuple(label.int().tolist())  # convert tensor to tuple of ints
        combo_counter[key] += 1
    return combo_counter

train_combos = count_label_combinations(train_dataset)
val_combos = count_label_combinations(val_dataset)

# Print results
print(train_size)
print("Train label combinations:")
for combo, count in train_combos.items():
    print(f"{combo}: {count}")

print("\nValidation label combinations:")
for combo, count in val_combos.items():
    print(f"{combo}: {count}")

## Train NN

In [None]:
import torch.optim as optim
criterion = nn.BCELoss()  # Binary Cross Entropy Loss for 12-bit outputs
optimizer = optim.Adam(model.parameters(), lr=0.0005)

epochs = 1000

for epoch in range(epochs):
    model.train()
    train_loss = 0
    for x_batch, y_batch in train_loader:
        optimizer.zero_grad()
        outputs = model(x_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x_val, y_val in val_loader:
            val_outputs = model(x_val)
            val_loss += criterion(val_outputs, y_val).item()

    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")


## Save trained model

In [None]:
from pathlib import Path
Path("classifier_weights").mkdir(exist_ok=True)
torch.save(model.state_dict(), 'classifier_weights/combined.pth')

## Load trained model

In [None]:
model = class_nn()  # instantiate the model
model.load_state_dict(torch.load('classifier_weights/combined.pth'))
model.eval()  # set the model to evaluation mode

## Evaluate trained model

In [None]:
all_preds = []
all_labels = []
def evaluate_accuracy(loader):
    model.eval()
    exact_matches = 0
    total_samples = 0
    with torch.no_grad():
        for x_batch, y_batch in loader:
            outputs = model(x_batch)
            predicted = (outputs > 0.5).float() # todo we can play with this maybe?
            
            # If error state last bit is 1 then all other should be 0
            error_mask = predicted[:, -1] == 1  # samples where last bit is 1
            predicted[error_mask, :-1] = 0      # set others to 0
            
            non_empty_mask = ~(y_batch.sum(axis=1) == 0)
            predicted = predicted[non_empty_mask]
            y_batch = y_batch[non_empty_mask]
            
            matches = (predicted == y_batch).all(dim=1)  # full match per sample
           
            exact_matches += matches.sum().item()
            total_samples += y_batch.size(0)
            
            all_preds.extend([tuple(p.int().tolist()) for p in predicted])
            all_labels.extend([tuple(y.int().tolist()) for y in y_batch])

            
    return exact_matches / total_samples

train_acc = evaluate_accuracy(train_loader)
val_acc = evaluate_accuracy(val_loader)

print(f"Train Accuracy: {train_acc:.2%}")
print(f"Validation Accuracy: {val_acc:.2%}")

## Generate confusion matrix

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd


def gen_confusion_matrix(all_preds, all_labels):
    # Get all unique states across true and predicted
    all_states = sorted(set(all_labels))
    print(all_states)
    state_to_index = {state: idx for idx, state in enumerate(all_states)}

    # Convert state tuples to indices
    y_true_idx = [state_to_index.get(state, len(all_states)) for state in all_labels]
    y_pred_idx = [state_to_index.get(state, len(all_states)) for state in all_preds]
    
    cm = confusion_matrix(y_true_idx, y_pred_idx)
    return cm, all_states


def plot_state_confusion_matrix(cm, state_labels, title="Confusion Matrix of States- YOLO", save_path=None, normalize=True):
    # Convert tuples like (1, 0, 1) to strings: "101"
    state_strs = [''.join(map(str, state)) for state in state_labels]
    state_strs.append("other")

    if normalize:
        # Normalize rows to sum to 1 (avoid division by zero)
        cm_normalized = cm.astype('float')
        row_sums = cm_normalized.sum(axis=1, keepdims=True)
        cm_normalized = cm_normalized / row_sums
        cm_normalized = pd.DataFrame(cm_normalized, index=state_strs, columns=state_strs)
        fmt = ".2f"
        data_to_plot = cm_normalized
    else:
        # Use raw counts
        data_to_plot = pd.DataFrame(cm, index=state_strs, columns=state_strs)
        fmt = "d"

    plt.figure(figsize=(10, 8))
    sns.heatmap(data_to_plot, annot=True, fmt=fmt, cmap="Reds", cbar=True)
    plt.title(title)
    plt.ylabel("True State")
    plt.xlabel("Predicted State")
    plt.xticks(rotation=45, ha="right")
    plt.yticks(rotation=0)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300)
    plt.show()


cm, all_states = gen_confusion_matrix(all_preds, all_labels)
plot_state_confusion_matrix(cm, all_states, save_path="Confusion_matrix_Yolo")