In [1]:
import numpy as np 
import matplotlib.pyplot as plt 
from adjustText import adjust_text

from PIL import Image
from pdb import set_trace as st 
import random
import json
import os 
from tqdm import tqdm
import string
from src.Aux import * 
from src.Triangles import Triangle
from src.Illustrate import * 
from src.text_in_bbox import *

In [2]:
emnist_loc = "/home/yfrid/Desktop/stem-whiteboard/dataset/mnist/EMNIST/raw/"
emnist_images = read_idx_ubyte(emnist_loc+"emnist-byclass-train-images-idx3-ubyte")
emnist_labels = read_idx_ubyte(emnist_loc+"emnist-byclass-train-labels-idx1-ubyte")
emnist_chars = string.digits+string.ascii_uppercase+string.ascii_lowercase
EMNIST = EMNIST_Handler(emnist_images,emnist_labels,emnist_chars)

Reading image data: 697932 images of 28x28 pixels.
Reading label data: 697932 labels.


  self.model.load_state_dict(torch.load(weights_path, map_location=self.device))


In [3]:
lower_letters = string.ascii_lowercase
#greek_letters = ['α', 'β', 'γ', 'δ', 'θ']

def plot_geometry_improved(description):
    fig, ax = plt.subplots(figsize=(6,6),dpi = 64)
    labels = []
    handwritten =bool(random.getrandbits(1))
    # Set up the plot with some padding
    all_x = [v["x"] for v in description["vertices"]]
    all_y = [v["y"] for v in description["vertices"]]
    x_range = max(all_x) - min(all_x)
    y_range = max(all_y) - min(all_y)
    padding = max(x_range, y_range) * 0.1
    
    ax.set_xlim(min(all_x) - padding, max(all_x) + padding)
    ax.set_ylim(min(all_y) - padding, max(all_y) + padding)
####################################################### SEGMENTS
    description["index_lookup"] = {d["mark"]: i for i, d in enumerate(description["vertices"])} ## add index to each vertice for easier search
    description,segment_labels = plot_segments(description,ax,EMNIST,handwritten)
################################################### SPECIALS
    description = plot_specials(description,ax,EMNIST,handwritten)

################################################### ANGLES
    description, angle_labels = plot_angles(description,ax,EMNIST,handwritten)

    

#################################################### Text
    text = ""
    for segment in description["segments"]:
        if segment["known"] == True:
            text += f'{segment["mark"]} = {segment["length"]} {segment["unit"]}. '

    for angle in description["angles"]:
        if angle["known"] == True:
            text += f'∢{angle["mark"]} = {angle["value"]}°. '
    for question in description["questions"]:
        text += question
#################################################### VERTICES    
    labels = plot_vertices(description,ax,EMNIST,handwritten)

    for v in description["vertices"]:
        v["x"] = np.around((v["x"]-min(all_x))/(max(all_x)-min(all_x)),2)
        v["y"] = np.around((v["y"]-min(all_y))/(max(all_y)-min(all_y)),2)
        

##############################################################    
    # Combine all labels for adjustment
    all_labels = labels + segment_labels + angle_labels
    
    # Get all non-text artists (points, lines) to avoid
    avoid_objects = [artist for artist in ax.get_children() 
                    if not isinstance(artist, plt.Text)]
    
    # Fine-tune the label positioning
    adjust_text(all_labels, ax=ax,
                add_objects=avoid_objects,
                expand_points=(1.5, 1.5),    # Distance from points
                expand_text=(1.2, 1.2),     # Distance between labels
                expand_objects=(1.2, 1.2),  # Distance from lines/objects
                arrowprops=dict(arrowstyle='-', color='gray', alpha=0.5, lw=0.5),
                force_points=(0.5, 0.5),    # Force to avoid points
                force_text=(0.5, 0.5),      # Force to avoid text overlap
                force_objects=(0.3, 0.3),   # Force to avoid other objects
                lim=1000)                   # Maximum iterations
    
    ax.set_aspect('equal')
    ax.axis('off')

    plt.tight_layout()

    xs,ys = find_bbox(ax)
    buf = BytesIO()
    ax.figure.savefig(buf, format='png', bbox_inches='tight', dpi=64)
    plt.close(fig)
    buf.seek(0)
    img = Image.open(buf)
    cut_img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
    y_pad_size,x_pad_size= 384-cut_img_cv.shape[0], 384-cut_img_cv.shape[1]
    img_cv = np.ones([384,384,3],dtype = np.int8)*255
    if xs[0] == 10:
        left_pad = x_pad_size
    else:
        left_pad = 0
    if ys[0] == 10:
        top_pad = y_pad_size
    else:
        top_pad = 0
    img_cv[top_pad:top_pad+cut_img_cv.shape[0],left_pad:left_pad+cut_img_cv.shape[1]] = cut_img_cv
    xs,ys = find_bbox(img_cv,"image")
#    xs[1]+= x_pad_size
#    ys[1]+= y_pad_size
#    text = f"AB⊥BC, ADE~ABC, DE∥AB ,and {chr(8738)}BAC=90°. prove ACD≅ACE"
    tokens = tokenize_with_equations(text)
    bbox = (xs[0], ys[0], xs[1]-xs[0], ys[1]-ys[0])  # x, y, width, height
#    print(xs)
#    print(ys)
    #plt.plot([xs[0],xs[1]],[ys[1],ys[0]])
    if handwritten:
        img_text = EMNIST.text_mat(tokens,bbox[2])
        img_text = np.repeat(np.expand_dims(img_text,2),3,axis = 2)
        img_text = np.array(np.round(255*img_text),dtype = np.int16)
        plt.imshow(img_cv,cmap='gray')
        plt.imshow(img_text,cmap='gray',extent=[bbox[0], 
        bbox[0] + bbox[2], bbox[1] + min(bbox[3],img_text.shape[0]), bbox[1]],
        vmin=0,
        vmax=1,
        alpha=1 # Transparency (0=invisible, 1=opaque)
    )
        plt.xlim(0,img_cv.shape[1])
        plt.ylim(img_cv.shape[0],0)
    else:
        img_text = render_wrapped_text(tokens,bbox[2], bbox[3])
        combined = overlay_text_on_image(img_cv, img_text, bbox)
        plt.imshow(combined)
    plt.axis('off')
    plt.subplots_adjust(left=0, right=1, bottom=0, top=1)

    return description


def convert_numpy_types(obj):
    if isinstance(obj, np.bool_):
        return bool(obj)
    elif isinstance(obj, (np.int_, np.intc, np.intp, np.int8, np.int16, np.int32, np.int64)):
        return int(obj)
    elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
        return float(obj)
    elif isinstance(obj, dict):
        return {k: convert_numpy_types(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [convert_numpy_types(v) for v in obj]
    else:
        return obj



In [4]:
def save_json(description, output_dir, idx):
    os.makedirs(f"{output_dir}/labels", exist_ok=True)
    with open(f"{output_dir}/labels/{idx:06d}.json", 'w') as f:
        json.dump(description, f, indent=2)  # indent=2 for human readability
        
np.random.seed(1)#np.random.randint(100))  # For reproducible results
output_dir = "/home/yfrid/Desktop/stem-whiteboard/dataset/3V"


for idx in tqdm(range(0)):
    tri = Triangle(np.random.choice(range(5,15)), #### Random limb A 
                   np.random.choice(range(5,15)), #### Random limb B
                   rotation=np.random.randint(360),#### Random rotation
                   mirror=np.random.choice([True,False])) #### random mirroring 
    tri.third_vertice("scalene")
    tri.set_3V_question(num_questions = np.random.choice(range(1,4)))
    tri.rotate()
    new = plot_geometry_improved(tri.description)
    plt.savefig(f"{output_dir}/images/{idx:06d}.png", dpi=64,facecolor='white')
    save_json(convert_numpy_types(new), output_dir, idx)
    
    plt.close()


for idx in tqdm(range(1468,2000)):
    tri = Triangle(np.random.choice(range(5,15)), #### Random limb A 
                   np.random.choice(range(5,15)), #### Random limb B
                   rotation=np.random.randint(360),#### Random rotation
                   mirror=np.random.choice([True,False])) #### random mirroring 
    tri.third_vertice("right")
    tri.set_3V_question(num_questions = np.random.choice(range(1,4)))
    tri.rotate()
    new = plot_geometry_improved(tri.description)
    plt.savefig(f"{output_dir}/images/{idx:06d}.png", dpi=64,facecolor='white')
    save_json(convert_numpy_types(new), output_dir, idx)
    
    plt.close()

0it [00:00, ?it/s]
  0%|                                                   | 0/532 [00:00<?, ?it/s]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
  0%|▏                                          | 2/532 [00:00<02:40,  3.30it/s]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
  1%|▏                                          | 3/532 [00:00<02:44,  3.21it/s]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
  1%|▎                                          | 4/532 [00:01<02:36,  3.37it/s]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
  1%|▌                                          | 7/532 [00:02<04:07,  2.12it/s]Looks like you are using a tranform that doesn't support FancyArrowPatch, using ax.annotate instead. The arrows might strike through texts. I

In [5]:
tri.description

{'vertices': [{'mark': 'A', 'x': 0.33, 'y': 0.0},
  {'mark': 'B', 'x': 1.0, 'y': 1.0},
  {'mark': 'C', 'x': 0.0, 'y': 0.37}],
 'segments': [{'mark': 'AB', 'known': True, 'length': 14.0, 'unit': 'cm'},
  {'mark': 'AC', 'known': False},
  {'mark': 'BC', 'known': True, 'length': 15.23, 'unit': 'cm'}],
 'angles': [{'mark': 'ABC', 'known': False},
  {'mark': 'BCA', 'known': False},
  {'mark': 'CAB', 'known': True, 'value': 90, 'unit': 'deg'}],
 'specials': [],
 'questions': ['Find Segment AC. ', 'Find Angle ∢ABC. '],
 'index_lookup': {'A': 0, 'B': 1, 'C': 2}}