In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, confusion_matrix

# --- Shape Drawing Function ---
def draw_shape(shape, img_size=32):
    img = np.zeros((img_size, img_size), dtype=np.uint8)
    if shape == 'circle':
        cv2.circle(img, (img_size // 2, img_size // 2), img_size // 4, 255, -1)
    elif shape == 'square':
        cv2.rectangle(img, (8, 8), (24, 24), 255, -1)
    elif shape == 'triangle':
        pts = np.array([[img_size // 2, 8], [8, 24], [24, 24]], np.int32)
        cv2.drawContours(img, [pts], 0, 255, -1)
    return img

# --- Dataset Generation ---
shapes = ['circle', 'square', 'triangle']
X, y = [], []

for label, shape in enumerate(shapes):
    for _ in range(100):  # 100 samples per class
        img = draw_shape(shape)
        noise = np.random.normal(0, 10, img.shape).astype(np.uint8)
        noisy_img = cv2.add(img, noise)
        X.append(noisy_img.flatten())
        y.append(label)

X = np.array(X)
y = np.array(y)

# --- Train/Test Split ---
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# --- KNN Training and Accuracy Comparison ---
accuracies = []
k_values = range(1, 11)

for k in k_values:
    knn = KNeighborsClassifier(n_neighbors=k)
    knn.fit(X_train, y_train)
    y_pred = knn.predict(X_test)
    acc = accuracy_score(y_test, y_pred)
    accuracies.append(acc)

# --- Best k Evaluation ---
best_k = k_values[np.argmax(accuracies)]
print(f"Best accuracy: {max(accuracies):.2f} with k={best_k}")

# --- Final Confusion Matrix ---
final_model = KNeighborsClassifier(n_neighbors=best_k)
final_model.fit(X_train, y_train)
final_preds = final_model.predict(X_test)
print("Confusion Matrix:\n", confusion_matrix(y_test, final_preds))

# --- Plot Accuracy vs k ---
plt.plot(k_values, accuracies, marker='o', color='blue')
plt.title("KNN Accuracy vs K")
plt.xlabel("Number of Neighbors (k)")
plt.ylabel("Accuracy")
plt.grid(True)
plt.show()
