In [None]:
def TrainSVM(path1, C_par, gamma_par):
    import numpy as np
    from sklearn import svm
    import matplotlib.pyplot as plt
    from sklearn.model_selection import train_test_split
    import matplotlib as mpl
    from scipy import io
    from scipy.io import savemat
    import joblib
    
#     path1 = 'C:/Users/Borod/Research/A. Matlab Programs/hBN project/programs2/Contrast2.mat'
    data1 = io.loadmat(path1)
    data1 = data1['AllCon']
    RGB_data = data1[0:3, :]
    label_data = data1[4, :]
    
    RGB_data = np.transpose(RGB_data)
    train_data, test_data, train_label, test_label = train_test_split(RGB_data, label_data, random_state=1, train_size = 0.7, test_size = 0.3)
    
    svm_model = svm.SVC(C=C_par, kernel='rbf', gamma=gamma_par, decision_function_shape='ovr')
    svm_model.fit(train_data, train_label)
    
#     print('TrainCase Score: ', svm_model.score(train_data, train_label))
#     print('TestCase Score: ', svm_model.score(test_data, test_label))
    ScoreDic = {"TrainScore": svm_model.score(train_data, train_label), "TestScore": svm_model.score(test_data, test_label)}
    savemat(path1+"Score.mat", ScoreDic)
    
    R_min, R_max = RGB_data[:,0].min(), RGB_data[:,0].max()
    G_min, G_max = RGB_data[:,1].min(), RGB_data[:,1].max()
    B_min, B_max = RGB_data[:,2].min(), RGB_data[:,2].max()
    R_ = np.linspace(R_min, R_max, 50)
    G_ = np.linspace(G_min, G_max, 50)
    B_ = np.linspace(B_min, B_max, 50)
    Rs, Gs, Bs = np.meshgrid(R_, G_, B_, indexing='ij')
    assert np.all(Rs[:,0,0] == R_)
    assert np.all(Gs[0,:,0] == G_)
    assert np.all(Bs[0,0,:] == B_)
    
    grid_test = np.stack((Rs.flat, Gs.flat, Bs.flat), axis=1)
    grid_hat = svm_model.predict(grid_test)
    grid_hat = grid_hat.reshape(Rs.shape)
    
    
    mdic1 = {"Rs": Rs, "Gs" : Gs, "Bs" : Bs, "grid_hat" : grid_hat, "RGB_data" : RGB_data, "label_data" : label_data}
#     C:/Users/Borod/Research/A. Matlab Programs/hBN project/programs3/
    savemat(path1 + "SVM_results.mat", mdic1)
    joblib.dump(svm_model, path1 + 'svm_model.m')

In [1]:
import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk
import numpy as np
import pandas as pd
from sklearn import svm
import pickle

class SVMTrainerApp:
    def __init__(self, root):
        self.root = root
        self.root.title("hBN Thickness SVM Trainer")
        
        # Frame for image display
        self.image_frame = tk.Frame(root)
        self.image_frame.pack(side=tk.LEFT, padx=10, pady=10)
        
        # Canvas for displaying the image
        self.canvas = tk.Canvas(self.image_frame, width=800, height=600)
        self.canvas.pack()
        
        # Load image button
        self.btn_load_image = tk.Button(root, text="Load Image", command=self.load_image)
        self.btn_load_image.pack(pady=10)
        
        # Frame for thickness buttons
        self.thickness_frame = tk.Frame(root)
        self.thickness_frame.pack(side=tk.RIGHT, padx=10, pady=10)
        
        # Thickness buttons
        thickness_labels = ["0-5 nm", "5-10 nm", "10-15 nm", "15-20 nm", "20-25 nm", "25-30 nm", "30-40 nm",
                            "40-50 nm", "50-60 nm", "60-70 nm", "70-80 nm", "80-100 nm", "larger than 100", "Background"]
        self.thickness_values = [0, 5, 10, 15, 20, 25, 30, 40, 50, 60, 70, 80, 100, -1]  # -1 for background
        
        self.thickness_buttons = []
        for label, value in zip(thickness_labels, self.thickness_values):
            button = tk.Button(self.thickness_frame, text=label, command=lambda v=value: self.set_thickness(v))
            button.pack(pady=2)
            self.thickness_buttons.append(button)
        
        # Train button
        self.btn_train = tk.Button(root, text="Train SVM", command=self.train_svm)
        self.btn_train.pack(pady=20)
        
        # Variables to store image and selections
        self.image = None
        self.photo = None
        self.thickness = None
        self.box_coords = []
        self.data = []
        
        # Bind events for drawing box
        self.canvas.bind("<ButtonPress-1>", self.start_box)
        self.canvas.bind("<B1-Motion>", self.draw_box)
        self.canvas.bind("<ButtonRelease-1>", self.end_box)

    def load_image(self):
        """Load and display an image."""
        file_path = filedialog.askopenfilename(filetypes=[("Image files", "*.jpg *.png")])
        if file_path:
            self.image = Image.open(file_path).convert('RGB')
            self.photo = ImageTk.PhotoImage(self.image)
            self.canvas.create_image(0, 0, image=self.photo, anchor=tk.NW)
            self.canvas.config(scrollregion=self.canvas.bbox(tk.ALL))
    
    def set_thickness(self, thickness):
        """Set the current thickness label for drawing."""
        self.thickness = thickness
    
    def start_box(self, event):
        """Start drawing a box."""
        if self.image and self.thickness is not None:
            self.box_coords = [(event.x, event.y)]
    
    def draw_box(self, event):
        """Update the box as the mouse moves."""
        if self.image and self.thickness is not None and self.box_coords:
            x0, y0 = self.box_coords[0]
            x1, y1 = event.x, event.y
            self.canvas.delete("box")
            self.canvas.create_rectangle(x0, y0, x1, y1, outline="red", tag="box")
    
    def end_box(self, event):
        """Finish drawing the box and collect RGB data."""
        if self.image and self.thickness is not None and self.box_coords:
            x0, y0 = self.box_coords[0]
            x1, y1 = event.x, event.y
            box = (min(x0, x1), min(y0, y1), max(x0, x1), max(y0, y1))
            self.collect_rgb_data(box)
            self.canvas.delete("box")
            self.box_coords = []
    
    def collect_rgb_data(self, box):
        """Collect RGB data from the selected box and store it with the thickness label."""
        if self.image:
            cropped_image = self.image.crop(box)
            img_array = np.array(cropped_image)
            avg_rgb = np.mean(img_array, axis=(0, 1))
            self.data.append((*avg_rgb, self.thickness))
    
    def train_svm(self):
        """Train an SVM model on the collected data."""
        if not self.data:
            messagebox.showerror("Error", "No data collected for training.")
            return
        
        df = pd.DataFrame(self.data, columns=["R", "G", "B", "Thickness"])
        file_path = filedialog.asksaveasfilename(defaultextension=".csv", filetypes=[("CSV files", "*.csv")])
        if file_path:
            df.to_csv(file_path, index=False)
            messagebox.showinfo("Success", f"Training data saved to {file_path}.")
        
        # Train SVM model
        X = df[["R", "G", "B"]]
        y = df["Thickness"]
        model = svm.SVC()
        model.fit(X, y)
        
        # Save trained model
        model_path = filedialog.asksaveasfilename(defaultextension=".pkl", filetypes=[("Pickle files", "*.pkl")])
        if model_path:
            with open(model_path, "wb") as f:
                pickle.dump(model, f)
            messagebox.showinfo("Success", f"SVM model saved to {model_path}.")

if __name__ == "__main__":
    root = tk.Tk()
    app = SVMTrainerApp(root)
    root.mainloop()


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(
Exception in Tkinter callback
Traceback (most recent call last):
  File "c:\Users\bills\AppData\Local\Programs\Python\Python310\lib\tkinter\__init__.py", line 1921, in __call__
    return self.func(*args)
  File "C:\Users\bills\AppData\Local\Temp\ipykernel_21184\3074662469.py", line 117, in train_svm
    model.fit(X, y)
  File "C:\Users\bills\AppData\Roaming\Python\Python310\site-packages\sklearn\svm\_base.py", line 192, in fit
    X, y = self._validate_data(
  File "C:\Users\bills\AppData\Roaming\Python\Python310\site-packages\sklearn\base.py", line 584, in _validate_data
    X, y = check_X_y(X, y, **check_params)
  File "C:\Users\bills\AppData\Roaming\Python\Python310\site-packages\sklearn\utils\validation.py", line 1106, in check_X_y
    X = check_array(
  File "C:\Users\bills\AppData\Roaming\Python\Python310\site-packages\sklearn\utils\validation.py", line 921, in check_array
    _assert_all_finite(
  File "