In [1]:
!pip install pandas #csv
import pandas as pd
!pip install gradio #gui
import gradio as gr

Collecting gradio
  Downloading gradio-5.14.0-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl.metadata (9.7 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.8-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.7.0 (from gradio)
  Downloading gradio_client-1.7.0-py3-none-any.whl.metadata (7.1 kB)
Collecting markupsafe~=2.0 (from gradio)
  Downloading MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.9.3 (from gradio)
  Downloading ruff-0.9.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.meta

**Function to read CSV file**

In [2]:
def read_csv(file):
    df = pd.read_csv(file.name)
    return df

**class of scatter plot**

In [3]:
class ScatterPlot:
    def __init__(self, df, x, y, category):
        self.df = df
        self.x = x
        self.y = y
        self.category = category

        self.svg_width, self.svg_height = 500, 500
        self.svg_margin = 80

        self.shapes = ["circle", "rectangle", "polygon"]
        self.category2shape = {cat: self.shapes[i % len(self.shapes)] for i, cat in enumerate(self.df[self.category].unique())}

        self.svg_elements = []

    #create svg with the plot elments -> called after csv has been uploaded
    def generate_svg(self):
        self.calculate_min_max()
        self.draw_grid()
        self.draw_axes()
        self.draw_ticks()
        self.plot_points()
        self.create_legend()

        svg_content = f'<svg width="{self.svg_width}" height="{self.svg_height}" style="background-color: white;">' + "".join(self.svg_elements) + '\n</svg>'

        #js script to implement click interaction to the svg and its elements
        svg_content += """
        <script>

        //variables to keep track of the selected point and the nearest points
        let selected_origin = null;
        let nearest_points = [];

        //function to switch origin to selected point on left mouse click
        function switch_origin(event) {
          const clicked_point = event.target;
          if (clicked_point.tagName === 'circle' || clicked_point.tagName === 'rectangle' || clicked_point.tagName === 'polygon') {
            const point_id = clicked_point.getAttribute('id');

        }

        //function to get the 5 nearest points

        //function to compute the euclidean distance between 2 points
        function eudlidean_dist(p1, p2) {
          return Math.sqrt(Math.pow(p1[0] - p2[0], 2) + Math.pow(p1[1] - p2[1], 2));
        }

        </script>
        """

        return svg_content

    def calculate_min_max(self):
        self.x_min, self.x_max = self.df[self.x].min(), self.df[self.x].max()
        self.y_min, self.y_max = self.df[self.y].min(), self.df[self.y].max()

    def draw_axes(self):
       #draw axes along the margin
       self.svg_elements.append(f'<line x1="{self.svg_margin}" y1="{self.svg_height - self.svg_margin}" x2="{self.svg_width - self.svg_margin}" y2="{self.svg_height - self.svg_margin}" stroke="black"/>')
       self.svg_elements.append(f'<line x1="{self.svg_margin}" y1="{self.svg_margin}" x2="{self.svg_margin}" y2="{self.svg_height - self.svg_margin}" stroke="black"/>')


    def draw_grid(self):
        for i in range(1, 7):

            #find and draw grid lines aligned with the tick marks
            x_tick = self.svg_margin + i * (self.svg_width - self.svg_margin) / 8
            y_tick = self.svg_margin + i * (self.svg_height - self.svg_margin) / 8

            self.svg_elements.append(f'<line x1="{x_tick}" y1="{self.svg_margin}" x2="{x_tick}" y2="{self.svg_height - self.svg_margin}" stroke="lightgray" stroke-dasharray="2,2"/>')
            self.svg_elements.append(f'<line x1="{self.svg_margin}" y1="{y_tick}" x2="{self.svg_width - self.svg_margin}" y2="{y_tick}" stroke="lightgray" stroke-dasharray="2,2"/>')


    def draw_ticks(self):
        #tick mark length
        tick_length = 5

        for i in range(7):
            #draw tickmark on x axis at even spacing (width-margin)/8 (8 chunks for the 8 ticks)
            x_tick = self.svg_margin + i * (self.svg_width - self.svg_margin) / 8
            #draw tickmark on y axis at even spacing (height-margin)/8 (8 chunks for the 8 ticks)
            y_tick = self.svg_margin + i * (self.svg_height - self.svg_margin) / 8

            #compute x tick values as the spacing in relation to the min x value
            x_tick_value = self.x_min + (self.x_max - self.x_min) * i / 8
            #compute y tick values as the spacing in relation to the min y value
            y_tick_value = self.y_min + (self.y_max - self.y_min) * (8 - i) / 8

            #append tick marks and tick values to the svg
            self.svg_elements.append(f'<line x1="{x_tick}" y1="{self.svg_height - self.svg_margin}" x2="{x_tick}" y2="{self.svg_height - self.svg_margin + tick_length}" stroke="black"/>')
            self.svg_elements.append(f'<text x="{x_tick}" y="{self.svg_height - self.svg_margin + 15}" font-size="10" text-anchor="middle">{round(x_tick_value, 2)}</text>')
            self.svg_elements.append(f'<line x1="{self.svg_margin - tick_length}" y1="{y_tick}" x2="{self.svg_margin}" y2="{y_tick}" stroke="black"/>')
            self.svg_elements.append(f'<text x="{self.svg_margin - 30}" y="{y_tick + 3}" font-size="10" text-anchor="end">{round(y_tick_value, 2)}</text>')


    def plot_points(self):
        for _, row in self.df.iterrows():
            #iterate the data points and normalize their values and hence their position to match the scale of the plot
            x_pos = self.normalize(row[self.x], self.x_min, self.x_max, self.svg_margin, self.svg_width - self.svg_margin)
            y_pos = self.normalize(row[self.y], self.y_min, self.y_max, self.svg_height - self.svg_margin, self.svg_margin)
            #extract the category shape from the category name
            shape = self.category2shape[row[self.category]]

            if shape == "circle":
                self.svg_elements.append(f'<circle cx="{x_pos}" cy="{y_pos}" r="5" fill="blue"/>')
            elif shape == "rectangle":
                self.svg_elements.append(f'<rect x="{x_pos-4}" y="{y_pos-4}" width="8" height="8" fill="blue"/>')
            elif shape == "polygon":
                self.svg_elements.append(f'<polygon points="{x_pos-5},{y_pos+5} {x_pos+5},{y_pos+5} {x_pos},{y_pos-5}" fill="blue"/>')


    def create_legend(self):
        #position legend box
        legend_x, legend_y = self.svg_width - 50, self.svg_margin

        for i, (category, shape) in enumerate(self.category2shape.items()):
            #vertical separation of object classes
            y_pos = legend_y + i * 20

            #draw shape and class
            if shape == "circle":
                self.svg_elements.append(f'<circle cx="{legend_x}" cy="{y_pos}" r="5" fill="blue"/>')
            elif shape == "rectangle":
                self.svg_elements.append(f'<rect x="{legend_x-5}" y="{y_pos-5}" width="10" height="10" fill="blue"/>')
            elif shape == "polygon":
                self.svg_elements.append(f'<polygon points="{legend_x-5},{y_pos+5} {legend_x+5},{y_pos+5} {legend_x},{y_pos-5}" fill="blue"/>')
            self.svg_elements.append(f'<text x="{legend_x + 15}" y="{y_pos + 3}" font-size="10">{category}</text>')


    #linearly map values to svg scope
    def normalize(self, value, data_min, data_max, svg_min, svg_max):
        #y =      m    +          x          *                     k
        return svg_min +  (value - data_min) * (svg_max - svg_min) / (data_max - data_min)

**Function to upload CSV in Gradio**

In [4]:
def upload_file(file):

    #read CSV file
    df = pd.read_csv(file.name)

    #convert table to html object to enable scrolling
    df_html = df.to_html(index=False) #remove indexing from table
    scrollable_table = f'<div style="max-height: 300px; overflow:auto;">{df_html}</div>'

    #extract numerical columns for the axes
    numerical_columns = df.select_dtypes(include=["number"]).columns.tolist()
    x_col, y_col = numerical_columns[:2]

    #extract category column for the category name and shape
    category_columns = df.select_dtypes(exclude=["number"]).columns.tolist()
    category = category_columns[0]

    #call scatter plot class for read csv
    scatter_plot = ScatterPlot(df, x_col, y_col, category)

    #generate a scatter plot svg from the scatter plot clalss
    svg_content = scatter_plot.generate_svg()

    return scrollable_table, svg_content

**create and launch gradio interface**

In [5]:
#use CSV reading as input
demo = gr.Interface(
    fn=upload_file,
    inputs="file",
    #CSV reading outputs two html objects
    outputs=["html", "html"],
    title="Scatter plot",
    description="Upload a CSV file by dragging or clicking. Hit 'submit' to show scatter plot <br>Left click: Recenter with respects to selected point <br>Right click: Show 5 nearest points to selected point <br>Click 'Clear' to upload a new CSV file"
)

demo.launch()

Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://f92640b748748ea31b.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


