## Part 1: Generate Synthetic Dataset

In [None]:
import os
import random
import pickle
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import Counter
# Parameters
TRAIN_DIR = "dataset_class_train"
VALI_DIR = "dataset_class_valid"
WINDOW_HEIGHT = 18
LEFT_PAD = 2
BLANK_CLASS = 10
MAX_WIDTH = 200

def count_labels_across_strips(dataset):
    label_counts = Counter()
    for _, labels in dataset:
        for label, _, _ in labels:
            label_counts[label] += 1

    print("\n🔎 Digit Counts in Generated Strips:")
    for k in sorted(label_counts):
        name = str(k) if k != BLANK_CLASS else "blank"
        print(f"  Class {name}: {label_counts[k]}")


def pad_to_height(img_np, target_height=WINDOW_HEIGHT):
    h, w = img_np.shape
    pad_top = (target_height - h) // 2
    pad_bottom = target_height - h - pad_top
    return np.pad(img_np, ((pad_top, pad_bottom), (0, 0)), constant_values=255)


def load_and_trim_digit_images(base_dir):
    dataset = []
    for cls in tqdm([str(i) for i in range(10)] + ['none'], desc=f"Loading {base_dir}"):
        cls_path = os.path.join(base_dir, cls)
        if not os.path.isdir(cls_path):
            continue
        for fname in os.listdir(cls_path):
            if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                img = Image.open(os.path.join(cls_path, fname)).convert("L")
                img_np = np.array(img)
                cols = np.where(np.any(img_np < 255, axis=0))[0]
                if cols.size == 0: continue
                trimmed = img_np[:, cols[0]:]
                padded = pad_to_height(trimmed)
                label = int(cls) if cls.isdigit() else BLANK_CLASS
                dataset.append((padded, label))
    return dataset

def find_similar_edge(target_col, candidates, threshold=10):
    random.shuffle(candidates)
    for candidate_img, candidate_label, idx in candidates:
        diff = np.mean(np.abs(candidate_img[:, 0] - target_col))
        if diff < threshold:
            return candidate_img, candidate_label, idx
    return random.choice(candidates)

def build_realistic_strip(digit_bank, none_bank, left_pad=LEFT_PAD, height=WINDOW_HEIGHT):
    target_width = random.randint(60, MAX_WIDTH)
    max_digits = 8
    digit_count = 0

    strip = np.full((height, left_pad), 255, dtype=np.uint8)
    labels = []
    current_width = left_pad

    banks_available = [bank for bank in [digit_bank, none_bank] if bank]
    current_bank = random.choice(banks_available)

    # Pop the first image
    prev_img, prev_label, prev_idx = current_bank.pop()
    strip = np.concatenate((strip, prev_img), axis=1)
    start = current_width
    current_width += prev_img.shape[1]
    if prev_label != BLANK_CLASS:
        labels.append((prev_label, start, current_width))
        digit_count += 1

    while current_width < target_width and (digit_bank or none_bank):
        # Determine which bank to sample from
        banks_available = []
        if digit_count < max_digits and digit_bank:
            banks_available.append(digit_bank)
        if none_bank:
            banks_available.append(none_bank)
        if not banks_available:
            break

        current_bank = random.choice(banks_available)
        next_img, next_label, next_idx = find_similar_edge(prev_img[:, -1], current_bank)

        # If too wide, skip
        if current_width + 1 + next_img.shape[1] > target_width:
            break

        # Add transition column
        transition_col = np.median(np.stack([prev_img[:, -1], next_img[:, 0]]), axis=0).astype(np.uint8).reshape(-1, 1)
        strip = np.concatenate((strip, transition_col), axis=1)
        current_width += 1

        # Append next digit
        start = current_width
        strip = np.concatenate((strip, next_img), axis=1)
        current_width += next_img.shape[1]
        if next_label != BLANK_CLASS:
            labels.append((next_label, start, current_width))
            digit_count += 1

        # Remove used sample
        current_bank[:] = [item for item in current_bank if item[2] != next_idx]
        prev_img = next_img

    # Pad right
    strip = np.concatenate((strip, np.full((height, left_pad), 255, dtype=np.uint8)), axis=1)
    return strip, labels

def create_full_dataset(base_dir):
    all_images = load_and_trim_digit_images(base_dir)

    # Assign unique index
    digit_bank = [(img, lbl, idx) for idx, (img, lbl) in enumerate(all_images) if lbl != BLANK_CLASS]
    none_bank  = [(img, lbl, idx) for idx, (img, lbl) in enumerate(all_images) if lbl == BLANK_CLASS]

    total_unique = len(digit_bank) + len(none_bank)
    used_indices = set()
    dataset = []

    pbar = tqdm(total=total_unique, desc=f"Creating {base_dir} dataset")

    while digit_bank or none_bank:
        strip_img, labels = build_realistic_strip(digit_bank, none_bank)

        # Collect used indices
        current_used = {idx for _, _, idx in digit_bank + none_bank}
        newly_used = set(range(total_unique)) - current_used - used_indices
        used_indices.update(newly_used)
        pbar.update(len(newly_used))

        dataset.append((strip_img, labels))

    pbar.close()
    return dataset


# Visualization Function
def visualize_strips(dataset, num_samples=5):
    samples = random.sample(dataset, num_samples)
    for idx, (img, labels) in enumerate(samples, 1):
        plt.figure(figsize=(10, 2))
        plt.imshow(img, cmap='gray')
        plt.title(f"Sample #{idx}")
        ax = plt.gca()
        for label, start, end in labels:
            rect = plt.Rectangle((start, 0), end-start, img.shape[0], 
                                 edgecolor='red', facecolor='none', linewidth=2)
            ax.add_patch(rect)
            ax.text((start+end)/2, -1, str(label), color='blue', fontsize=12, ha='center', va='bottom')
        plt.axis('off')
        plt.show()

def export_labeled_images(dataset, output_dir, prefix="strip"):
    os.makedirs(output_dir, exist_ok=True)
    for idx, (img, labels) in enumerate(dataset):
        fig, ax = plt.subplots(figsize=(12, 2))
        ax.imshow(img, cmap='gray')
        for label, start, end in labels:
            rect = plt.Rectangle((start, 0), end - start, img.shape[0],
                                 edgecolor='red', facecolor='none', linewidth=1)
            ax.add_patch(rect)
            ax.text((start + end) / 2, -1, str(label), color='blue',
                    fontsize=10, ha='center', va='bottom')
        ax.axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f"{prefix}_{idx:04d}.png"), dpi=100)
        plt.close(fig)


if __name__ == "__main__":
    train_data = create_full_dataset(TRAIN_DIR)
    val_data = create_full_dataset(VALI_DIR)

    with open("train_dataset.pkl", "wb") as f:
        pickle.dump(train_data, f)
    with open("val_dataset.pkl", "wb") as f:
        pickle.dump(val_data, f)

    print(f"\n✅ Dataset creation complete.")
    print(f"Train set strips: {len(train_data)}")
    print(f"Validation set strips: {len(val_data)}")
    count_labels_across_strips(train_data)
    count_labels_across_strips(val_data)
    # 🎨 Export all labeled strips to folders for visual inspection
    #export_labeled_images(train_data, "labeled_strips_train", prefix="train")
    #export_labeled_images(val_data, "labeled_strips_val", prefix="val")
    # Show 5 random samples from training dataset
    print("\n🎨 Showing 5 random training samples:")
    visualize_strips(train_data, num_samples=5)


## Part 2: Train CNN on INT8 Inputs + Calibrate Thresholds

In [None]:
import torch
import random
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import pickle
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from collections import Counter
from collections import defaultdict
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score
import json
from sklearn.metrics import precision_recall_fscore_support
import seaborn as sns



# ====== Parameters ======
K_size = 64
NUM_CLASSES = 11
BATCH_SIZE = 128
EPOCH_NUM = 1000
INPUT_WIDTH = 14
INPUT_HEIGHT = 18
FC_WIDTH = 12
FC_HEIGHT = 16
BLANK_CLASS = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"📟 Using device: {device}")

# ====== Model ======
class CNNModel(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES, k=K_size):
        super().__init__()
        self.conv1 = nn.Conv2d(1, k, kernel_size=3, stride=1, padding=0)  # VALID
        self.relu1 = nn.ReLU()
        self.fc1 = nn.Linear(k * FC_HEIGHT * FC_WIDTH, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
       # x = x[:, :, 1:1+FC_HEIGHT, 1:1+FC_WIDTH]  # Crop to 16x12
        x = x.reshape(x.size(0), -1)
        return self.fc1(x)

# ====== Dataset Loader =======
class SlidingWindowDataset(Dataset):
    def __init__(self, pkl_path, input_width=14, input_height=18, fc_width=12, fc_height=16,
                 blank_class=10, pos_ratio=0.0, neg_ratio=0.5, limit_negatives=False):
        with open(pkl_path, 'rb') as f:
            self.data = pickle.load(f)

        self.samples = []
        self.negatives = []
        self.class_counts = defaultdict(int)

        for img, labels in self.data:
            h, w = img.shape
            buffer = np.full((input_height, input_width), 255, dtype=np.uint8)
            x = 0
            active_label = None
            reset_buffer = False
            label_idx = 0
            labels = sorted(labels, key=lambda x: x[1])  # Sort by start column

            while x < w:
                # Shift window left, insert new column
                buffer[:, :-1] = buffer[:, 1:]
                buffer[:, -1] = img[:, x]
                x += 1

                # Check for start of a new digit
                if label_idx < len(labels):
                    label, start, end = labels[label_idx]

                    if start <= x - 1 <= end:  # inside digit
                        active_label = (label, start, end)

                    if active_label and x == end + 1:  # one column after end
                        crop = np.clip(buffer.astype(np.int16) - 128, -128, 127).astype(np.int8)
                        self.samples.append((torch.tensor(crop, dtype=torch.int8).unsqueeze(0), label))
                        self.class_counts[label] += 1

                        # Reset buffer but preserve last two columns
                        last_two = buffer[:, -2:].copy()
                        buffer.fill(255)
                        buffer[:, -2:] = last_two

                        active_label = None
                        label_idx += 1
                        continue  # skip overlap logic on reset

                # If not in active digit — treat as candidate negative
                if not active_label:
                    overlaps = [max(0, min(x, e) - max(x - input_width, s)) / (e - s)
                                for _, s, e in labels]
                    max_overlap = max(overlaps + [0])
                    if max_overlap <= neg_ratio:
                        crop = np.clip(buffer.astype(np.int16) - 128, -128, 127).astype(np.int8)
                        self.negatives.append((torch.tensor(crop, dtype=torch.int8).unsqueeze(0), blank_class))

        # Optionally filter negatives
        if limit_negatives and self.negatives:
            total_positives = len(self.samples)
            max_negatives = total_positives
            def score_negative_patch(img_tensor):
                img = img_tensor.numpy()[0]
                non_white_cols = np.sum(np.any(img != 127, axis=0))
                return non_white_cols
            scored = sorted(self.negatives, key=lambda x: -score_negative_patch(x[0]))
            selected = scored[:max_negatives]
            self.samples.extend(selected)
            self.class_counts[blank_class] = len(selected)
        else:
            self.samples.extend(self.negatives)
            self.class_counts[blank_class] = len(self.negatives)

        total = sum(self.class_counts.values())
        print("\n📊 Class Distribution:")
        for k in sorted(self.class_counts):
            name = str(k) if k != blank_class else "blank"
            print(f"  Class {name}: {self.class_counts[k]} ({(self.class_counts[k] / total)*100:.2f}%)")
        print(f"  Total samples: {total}\n")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img, label = self.samples[idx]
        return img.to(torch.float32), label


# ====== Training Function ======
def train_model(model, train_loader, val_loader,l1_lambda =1e-5,l2_lambda = 1e-1,label_smoothing=0.2):
    criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
    optimizer = optim.Adam(model.parameters(), lr=1e-3,weight_decay=l2_lambda)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5, verbose=True)

    best_val_loss = float('inf')
    best_model = None
    patience = 10
    counter = 0

    for epoch in range(EPOCH_NUM):
        model.train()
        total_loss, correct, total = 0, 0, 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            # 🔥 Add L1 penalty
            l1_norm = sum(p.abs().sum() for p in model.parameters())
            loss = loss + l1_lambda * l1_norm
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            correct += (out.argmax(1) == y).sum().item()
            total += y.size(0)

        train_acc = 100 * correct / total

        model.eval()
        val_loss, correct, total = 0, 0, 0
        preds, labels = [], []
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                out = model(x)
                loss = criterion(out, y)
                val_loss += loss.item()
                pred = out.argmax(1)
                correct += (pred == y).sum().item()
                total += y.size(0)
                preds.extend(pred.cpu().numpy())
                labels.extend(y.cpu().numpy())

        val_acc = 100 * correct / total
        avg_train_loss = total_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        print(f"Epoch {epoch+1}: Train Acc={train_acc:.2f}%, Val Acc={val_acc:.2f}%, "
              f"Train Loss={avg_train_loss:.4f}, Val Loss={avg_val_loss:.4f}, LR={optimizer.param_groups[0]['lr']:.6f}")

        scheduler.step(avg_val_loss)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model = model.state_dict()
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print("Early stopping.")
                break

    model.load_state_dict(best_model)
    return model, preds, labels

# ====== Manual Quantization & Evaluation ======
# ✅ Improved Quantization Function
def quantize_tensor(tensor, is_bias=False, is_fc_layer=False, k=K_size, percentile_clip=0.1):
    if is_bias:
        if is_fc_layer:
            # Scaled FC bias (we keep as is)
            scale_factor = (k * 16 * 12) * 100
            return torch.clamp(torch.round(tensor * scale_factor), -2**31, 2**31 - 1).to(torch.int64)
        else:
            # Conv bias: symmetric absmax
            scale = max(abs(tensor.min()), abs(tensor.max())) / 127 if tensor.max() != tensor.min() else 1.0
            return torch.clamp(torch.round(tensor / scale), -2**15, 2**15 - 1).to(torch.int32)

    # 🎯 Improved FC Weight Quantization using symmetric percentile clipping
    if is_fc_layer:
        t_np = tensor.detach().cpu().numpy()
        lower = np.percentile(t_np, percentile_clip)
        upper = np.percentile(t_np, 100 - percentile_clip)
        absmax = max(abs(lower), abs(upper))

        if absmax == 0:
            return torch.zeros_like(tensor, dtype=torch.int8)

        scale = absmax / 127
        tensor_clipped = torch.clamp(tensor, -absmax, absmax)
        return torch.clamp(torch.round(tensor_clipped / scale), -128, 127).to(torch.int8)

    # ✅ Default convolution quantization (no change)
    scale = max(abs(tensor.min()), abs(tensor.max())) / 127 if tensor.max() != tensor.min() else 1.0
    return torch.clamp(torch.round(tensor / scale), -128, 127).to(torch.int8)


def compute_thresholds(model, loader, weight=None, bias=None, percentile=5, blank_class=10):
    model.eval()
    per_class_scores = {i: [] for i in range(NUM_CLASSES)}

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            out = model.relu1(model.conv1(x))
            fc_input = out.reshape(out.size(0), -1).cpu().numpy().astype(np.int32)

            if weight is None or bias is None:
                weight = model.fc1.weight.detach().cpu().numpy().astype(np.int8)
                bias = model.fc1.bias.detach().cpu().numpy().astype(np.int64)

            result = fc_input @ weight.T.astype(np.int32) + bias  # Quantized forward

            for score_vec, label in zip(result, y.numpy()):
                per_class_scores[label].append(score_vec[label])  # True class score only

    # Build per-class threshold dictionary
    thresholds = {}
    for cls in range(NUM_CLASSES):
        if cls == blank_class:
            thresholds[cls] = 0  # ✅ Always use 0 for blank class
        else:
            scores = per_class_scores[cls]
            if len(scores) < 5:
                thresholds[cls] = 0
            else:
                thresholds[cls] = int(np.percentile(scores, percentile))

    print(f"\n📏 Per-Class Quantized Thresholds (percentile={percentile}):")
    for c in range(NUM_CLASSES):
        print(f"  Class {c}: {thresholds[c]}")
    return thresholds

def compute_best_thresholds(model, loader, weight=None, bias=None, blank_class=10):

    model.eval()
    all_scores = []  # Each item: (score, pred_class, is_correct)

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            out = model.relu1(model.conv1(x))
            fc_input = out.reshape(out.size(0), -1).cpu().numpy().astype(np.int32)

            if weight is None or bias is None:
                weight = model.fc1.weight.detach().cpu().numpy().astype(np.int8)
                bias = model.fc1.bias.detach().cpu().numpy().astype(np.int64)

            result = fc_input @ weight.T.astype(np.int32) + bias
            y_true = y.numpy()

            for scores, true in zip(result, y_true):
                for cls in range(NUM_CLASSES):
                    score = scores[cls]
                    is_correct = int(cls == true)
                    all_scores.append((cls, score, is_correct))

    # Now: group scores by class
    thresholds = {}
    for cls in range(NUM_CLASSES):
        plot_score_distribution(all_scores, cls)
        if cls == blank_class:
            thresholds[cls] = 0
            continue

        cls_scores = [(score, correct) for c, score, correct in all_scores if c == cls]
        if len(cls_scores) < 10:
            thresholds[cls] = 0
            continue

        scores, labels = zip(*cls_scores)
        scores = np.array(scores)
        labels = np.array(labels)

        # Sweep thresholds and compute F1
        best_thresh, best_f1 = 0, 0
        for t in np.unique(scores):
            preds = (scores >= t).astype(int)
            f1 = f1_score(labels, preds, zero_division=0)
            if f1 > best_f1:
                best_f1 = f1
                best_thresh = t

        thresholds[cls] = int(best_thresh)

    print("\n📏 Calibrated Per-Class Thresholds (F1-optimized):")
    for c in range(NUM_CLASSES):
        print(f"  Class {c}: {thresholds[c]}")
    return thresholds


def plot_score_distribution(all_scores, cls_id):
    cls_scores = [score for c, score, correct in all_scores if c == cls_id]
    cls_labels = [correct for c, score, correct in all_scores if c == cls_id]

    pos_scores = [s for s, l in zip(cls_scores, cls_labels) if l == 1]
    neg_scores = [s for s, l in zip(cls_scores, cls_labels) if l == 0]

    plt.hist(pos_scores, bins=50, alpha=0.6, label="True Positives")
    plt.hist(neg_scores, bins=50, alpha=0.6, label="False Positives")
    plt.title(f"Class {cls_id} Score Distribution")
    plt.xlabel("Score")
    plt.ylabel("Count")
    plt.legend()
    plt.show()


def evaluate_with_metrics(model, loader, thresholds, blank_class=10):
    model.eval()
    preds, labels = [],[]
    correct = 0
    total = 0

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            out = model(x)
            out = out.cpu().numpy()
            y = y.cpu().numpy()

            for score, true_label in zip(out, y):
                filtered = [s if s >= thresholds[i] else -np.inf for i, s in enumerate(score)]
                pred = np.argmax(filtered) if np.any(np.isfinite(filtered)) else blank_class

                preds.append(pred)
                labels.append(true_label)

                # Accuracy (exclude true blank==blank cases)
                if not (true_label == blank_class and pred == blank_class):
                    total += 1
                    if pred == true_label:
                        correct += 1

    acc = 100 * correct / total if total > 0 else 0
    print(f"\n📏 Adjusted Accuracy (excluding true blanks): {acc:.2f}%")

    # --- Digit-only accuracy ---
    digit_preds = [p for p, t in zip(preds, labels) if t != blank_class]
    digit_labels = [t for t in labels if t != blank_class]
    digit_correct = sum(int(p == t) for p, t in zip(digit_preds, digit_labels))
    digit_total = len(digit_labels)
    digit_acc = 100 * digit_correct / digit_total if digit_total else 0
    print(f"🔢 Digit-Only Accuracy: {digit_acc:.2f}%")

    # --- Precision / Recall / F1 ---
    precision, recall, f1, support = precision_recall_fscore_support(
        labels, preds, labels=[*range(10), blank_class], zero_division=0
    )

    print("\n📊 Precision / Recall / F1 per class:")
    for cls in range(10):
        print(f"  Class {cls}: P={precision[cls]:.2f}, R={recall[cls]:.2f}, F1={f1[cls]:.2f}, Support={support[cls]}")
    print(f"  Blank:    P={precision[blank_class]:.2f}, R={recall[blank_class]:.2f}, F1={f1[blank_class]:.2f}, Support={support[blank_class]}")

    return preds, labels



def save_thresholds(thresholds, filename="thresholds.json"):
    data = {f"class_{i}": int(thr) for i, thr in enumerate(thresholds)}
    with open(filename, "w") as f:
        json.dump(data, f, indent=2)
    print(f"✅ Thresholds saved to {filename}")


def visualize_samples(dataset, title):
    class_buckets = {i: [] for i in range(11)}  # 0-9 + blank (10)
    for img, label in dataset:
        if len(class_buckets[label]) < 10:
            class_buckets[label].append(img)
        if all(len(v) == 10 for v in class_buckets.values()):
            break

    fig, axes = plt.subplots(11, 10, figsize=(15, 15))
    fig.suptitle(title, fontsize=16)

    for class_id in range(11):
        for i, sample in enumerate(class_buckets[class_id]):
            axes[class_id, i].imshow(sample[0].numpy(), cmap='gray', vmin=-128, vmax=127)
            axes[class_id, i].axis('off')
        for j in range(len(class_buckets[class_id]), 10):
            axes[class_id, j].axis('off')
        axes[class_id, 0].set_ylabel(f"{class_id if class_id < 10 else 'blank'}", fontsize=12, rotation=0, labelpad=20)

    plt.tight_layout(rect=[0, 0.03, 1, 0.97])
    plt.show()

# Function to plot confusion matrix
def plot_confusion_matrix(y_true, y_pred,class_names, title):
    num_classes = len(class_names)
    cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title(title)
    plt.show()
# ====== Main ======
if __name__ == "__main__":
    
    train_set = SlidingWindowDataset("train_dataset.pkl",pos_ratio=0.0, neg_ratio=0.0,limit_negatives=False)
    val_set = SlidingWindowDataset("val_dataset.pkl",pos_ratio=0.0, neg_ratio=0.0,limit_negatives=False)
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=32)
    class_names = [str(i) for i in range(10)] + ["blank"]

    visualize_samples(train_set, "🎓 Train Set Samples per Class")
    visualize_samples(val_set, "🧪 Validation Set Samples per Class")

    model = CNNModel().to(device)
    model, preds, labels = train_model(model, train_loader, val_loader,l1_lambda=5e-5, l2_lambda=1e-2,label_smoothing=0.18)


    # Compute and plot Confusion Matrix Before Quantization
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)  # Move to GPU
            outputs = model(images)
            all_labels.extend(labels.cpu().numpy())

            _, predicted = torch.max(outputs, 1)  # Use argmax to select highest FC score
            all_preds.extend(predicted.cpu().numpy())
    plot_confusion_matrix(all_labels, all_preds, class_names,"Validation Confusion Matrix Before Quantization")

    quantized_state_dict = {}

    for name, param in model.named_parameters():
        is_fc_layer = "fc1.bias" in name  # ✅ Detect if this is an FC bias
        is_bias = "bias" in name  # ✅ Detect if this is a bias
        k = model.conv1.out_channels  # ✅ Get k value from Conv2D

        # ✅ Keep everything the same, only adjust FC bias scaling
        quantized_state_dict[name] = quantize_tensor(param, is_bias=is_bias, is_fc_layer=is_fc_layer, k=k)

    # ✅ Save Quantized Model
    torch.save(quantized_state_dict, "quantized_model.pth") 
    print("✅ Quantized model saved successfully with fixed FC bias scaling!")


    model.load_state_dict(quantized_state_dict, strict=False)

    # Post-Quantization Validation
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)  # Move to GPU
            outputs = model(images)
            all_labels.extend(labels.cpu().numpy())

            _, predicted = torch.max(outputs, 1)  # Use argmax to select highest FC score
            all_preds.extend(predicted.cpu().numpy())
    plot_confusion_matrix(all_labels, all_preds, class_names,"Validation Confusion Matrix After Quantization without thresholds")
    thresholds = compute_best_thresholds(model, val_loader)
    q_preds, q_labels = evaluate_with_metrics(model, val_loader, thresholds)
    plot_confusion_matrix(q_labels,q_preds, class_names, "Validation Confusion Matrix After Quantization")

    save_thresholds(thresholds, "thresholds.json")
    print("✅ Per-Class Thresholds:")
    for cls, t in thresholds.items():
        print(f"  Class {cls}: {t}")


## Part 3: Export .mif for FPGA

In [None]:
# ========= Binary & Hex Conversion =========
def int_to_bin(value, bit_width):
    if value < 0:
        value = (1 << bit_width) + value
    return format(value, f'0{bit_width}b')

def int_to_hex(value, bit_width):
    if value < 0:
        value = (1 << bit_width) + value
    hex_digits = bit_width // 4
    return format(value, f'0{hex_digits}X')

# ========= Save MIF File =========
def save_mif(filename, data, word_size, values_per_line, format_type="HEX"):
    depth = data.shape[0]
    with open(filename, "w") as f:
        f.write(f"WIDTH={word_size * values_per_line};\n")
        f.write(f"DEPTH={depth};\n")
        f.write("ADDRESS_RADIX=HEX;\n")
        f.write(f"DATA_RADIX={format_type};\n")
        f.write("CONTENT BEGIN\n")
        for i in range(depth):
            if format_type == "BIN":
                values = "".join(int_to_bin(int(val), word_size) for val in data[i])
            else:
                values = "".join(int_to_hex(int(val), word_size) for val in data[i])
            f.write(f"{i:X} : {values};\n")
        f.write("END;\n")

# ========= Export Conv2D Weights =========
def export_conv2d_weights_to_mif(model, filename="CON_W.mif"):
    weights = model.state_dict()["conv1.weight"].numpy().squeeze(1)  # Shape: (K, 3, 3)

    # Flatten each filter in row-major order (W00 to W22)
    flat_filters = [w.flatten().astype(np.int8) for w in weights]

    lines = []
    for i in range(0, len(flat_filters), 2):
        f1 = flat_filters[i]
        f2 = flat_filters[i + 1] if i + 1 < len(flat_filters) else np.zeros(9, dtype=np.int8)
        combined = np.concatenate((f1, f2))  # 18 values per line
        lines.append(combined)

    data = np.array(lines, dtype=np.int8)  # Convert list of arrays to 2D NumPy array
    save_mif(filename, data, word_size=8, values_per_line=18, format_type="HEX")
    print(f"✅ Conv2D weights exported to '{filename}' as HEX")

# ========= Export Biases =========
def export_biases_to_mif(model):
    conv_bias = model.state_dict()["conv1.bias"].numpy().astype(np.int16)
    conv_bias = conv_bias.reshape(-1, 2)
    save_mif("CON_B.mif", conv_bias, word_size=16, values_per_line=2, format_type="HEX")

    fc_bias = model.state_dict()["fc1.bias"].numpy().astype(np.int64).reshape(-1, 1)
    save_mif("FCM_B.mif", fc_bias, word_size=45, values_per_line=1, format_type="BIN")
    print("✅ Conv2D & FC biases saved.")

# ========= Export FC Weights =========
def export_fc_weights_to_mif(model):
    weights = model.state_dict()["fc1.weight"].numpy().astype(np.int8).flatten()
    padded_len = int(np.ceil(len(weights) / 128)) * 128
    weights = np.pad(weights, (0, padded_len - len(weights)), 'constant')
    weights = weights.reshape(-1, 128)
    save_mif("FCM_W.mif", weights, word_size=8, values_per_line=128, format_type="HEX")
    print("✅ FC weights exported to FCM_W.mif")

# ========= Export ReLU Default Output =========
def export_relu_default_output_dynamic(model):
    input_image = torch.full((1, 1, 18, 14), 127, dtype=torch.float32)
    conv_output = model.conv1(input_image)
    relu_output = torch.relu(conv_output).detach().numpy().astype(np.int32)

    k = conv_output.shape[1]
    assert k % 2 == 0, "Number of filters (k) must be even."

    word_size = 19
    values_per_line = 384  # 2 filters' worth
    total_width = word_size * values_per_line
    num_lines = k // 2

    with open("REL_O.mif", "w") as f:
        f.write(f"WIDTH={total_width};\n")
        f.write(f"DEPTH={num_lines};\n")
        f.write("ADDRESS_RADIX=HEX;\n")
        f.write("DATA_RADIX=BIN;\n")
        f.write("CONTENT BEGIN\n")

        for i in range(0, k, 2):
            filt_pair = relu_output[0, i:i+2, :, :]  # (2, 16, 12)
            values = filt_pair.flatten()
            bin_line = "".join(int_to_bin(v, word_size) for v in values)
            f.write(f"{i//2:X} : {bin_line};\n")

        f.write("END;\n")

    print(f"✅ REL_O.mif exported successfully with WIDTH={total_width}, DEPTH={num_lines}, Filters={k}")

# ========= Execute All =========
model = CNNModel(num_classes=11, k=K_size)
model.load_state_dict(torch.load("quantized_model.pth", map_location="cpu"))

export_conv2d_weights_to_mif(model)
export_biases_to_mif(model)
export_fc_weights_to_mif(model)
export_relu_default_output_dynamic(model)


## Part 4: Test and export result

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

def emulate_fpga_sliding_detection(
    image, model, thresholds,
    input_width=14, input_height=18, keep_cols=2,
    fc_height=16, fc_width=12,
    blank_class=10, debug=False,
    save_fc=False, save_dir="fc_debug",
    device="cpu"
):
    import os
    import matplotlib.pyplot as plt

    if save_fc:
        os.makedirs(save_dir, exist_ok=True)

    h, w = image.shape
    pad_top = (input_height - h) // 2
    pad_bottom = input_height - h - pad_top
    padded_img = np.pad(image, ((pad_top, pad_bottom), (0, 0)), constant_values=255)

    buffer = np.full((input_height, input_width), 255, dtype=np.uint8)
    x = 0
    detections = []
    step = 0

    model.eval()
    model = model.to(device)

    if debug:
        debug_windows = []
        debug_scores = []

    while x < padded_img.shape[1]:
        # Shift buffer
        buffer[:, :-1] = buffer[:, 1:]
        buffer[:, -1] = padded_img[:, x]
        x += 1

        crop = np.clip(buffer.astype(np.int16) - 128, -128, 127).astype(np.int8)
        window_tensor = torch.tensor(crop, dtype=torch.int8).unsqueeze(0).unsqueeze(0).to(torch.float32).to(device)

        # Get conv+ReLU output (also used for FC visualization)
        with torch.no_grad():
            relu_out = model.relu1(model.conv1(window_tensor))  # [1, K, H, W]
            fc_input = relu_out[:, :, :fc_height, :fc_width]    # [1, K, 16, 12]
            logits = model.fc1(fc_input.reshape(1, -1)).squeeze(0).cpu().numpy()

        predicted_class = np.argmax(logits)
        threshold = thresholds.get(predicted_class, 0)

        # Save FC feature maps if enabled
        if save_fc:
            fc_array = fc_input.squeeze(0).cpu().numpy()  # [K, 16, 12]
            

            with open(os.path.join(save_dir, f"fc_step_{step:03d}.txt"), "w") as f:
                for k, fmap in enumerate(fc_array):
                    f.write(f"Filter {k}:\n")
                    for row in fmap:
                        f.write(" ".join(f"{v:5.1f}" for v in row) + "\n")
                    f.write("\n")

        # Debug storage
        if debug:
            debug_windows.append(buffer.copy())
            debug_scores.append(logits.copy())

        # Detection condition (do not reset for blank!)
        if predicted_class != blank_class and logits[predicted_class] >= threshold:
            detections.append((predicted_class, x - input_width, x))
            last_two = buffer[:, -keep_cols:].copy()
            buffer.fill(255)
            buffer[:, -keep_cols:] = last_two

        step += 1

    # Debug visualization
    if debug and debug_windows:
        num_windows = len(debug_windows)
        fig, axes = plt.subplots(num_windows, 2, figsize=(10, num_windows * 2))
        if num_windows == 1:
            axes = [axes]
        for i, (win, scores) in enumerate(zip(debug_windows, debug_scores)):
            ax_img, ax_bar = axes[i]
            ax_img.imshow(win, cmap='gray', vmin=0, vmax=255)
            ax_img.axis('off')
            ax_img.set_title(f"Step {i}")

            bars = ax_bar.bar(range(len(scores)), scores)
            ax_bar.set_ylim(min(scores) - 1, max(scores) + 1)
            ax_bar.set_title("Scores")
            for bar, score in zip(bars, scores):
                ax_bar.text(bar.get_x() + bar.get_width() / 2.0, bar.get_height(),
                            f"{score:.0f}", ha='center', va='bottom', fontsize=8, rotation=90)

        plt.tight_layout()
        plt.show()

    return detections



# Example usage (do not run in this snippet if not desired):
if __name__ == "__main__":
    # Suppose 'my_image' is a numpy array (grayscale) with shape (H, W)
    # and 'model' is your CNNModel already loaded on the correct device.
    # 'thresholds' is a dictionary mapping class index to threshold value, e.g.:
    thresholds  # simple initial thresholds; update as needed
    # For blank class, thresholds[BLANK_CLASS] remains 0.

    # For example, if you load an image:
    my_image = plt.imread("643_object_0.png")
    if my_image.max() <= 1.0:
        my_image = (my_image * 255).astype(np.uint8)

    # (Ensure my_image is a 2D array of type np.uint8)
    #
    # Then:
    model = model.to(device)  # Ensure model is on the correct device
    
    model.load_state_dict(torch.load("quantized_model.pth"))
    model.eval()
    model.to(device)  # Move model to GPU if using CUDA

    detections = emulate_fpga_sliding_detection(
        image=my_image,
        model=model,
        thresholds=thresholds,
        debug=True,
        save_fc=True,
        save_dir="fc_debug",
        device=device
    )   
    #
    # You can then process 'detections' as needed.
    print(f"{detections}")
    pass
