In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

In [None]:
def sdf_circle(center, radius, grid):
    """Generate Signed Distance Function for a circle."""
    cx, cy = center
    x, y = np.meshgrid(np.arange(grid), np.arange(grid))
    sdf = np.sqrt((x - cx)**2 + (y - cy)**2) - radius
    return sdf

def sdf_rectangle(center, width, height, grid):
    """Generate Signed Distance Function for a rectangle."""
    cx, cy = center
    x, y = np.meshgrid(np.arange(grid), np.arange(grid))
    dx = np.maximum(np.abs(x - cx) - width / 2, 0)
    dy = np.maximum(np.abs(y - cy) - height / 2, 0)
    sdf = np.sqrt(dx**2 + dy**2)
    sdf = np.where(np.abs(x - cx) <= width / 2, np.maximum(dy, 0), sdf)
    sdf = np.where(np.abs(y - cy) <= height / 2, np.maximum(dx, 0), sdf)
    return sdf

def sdf_triangle(vertices, grid):
    """Generate Signed Distance Function for a triangle."""
    x, y = np.meshgrid(np.arange(grid), np.arange(grid))
    v0, v1, v2 = vertices
    
    def edge_sdf(px, py, ax, ay, bx, by):
        pa_x, pa_y = px - ax, py - ay
        ba_x, ba_y = bx - ax, by - ay
        h = np.clip((pa_x * ba_x + pa_y * ba_y) / (ba_x**2 + ba_y**2), 0, 1)
        return np.sqrt((pa_x - h * ba_x)**2 + (pa_y - h * ba_y)**2)
    
    sdf0 = edge_sdf(x, y, *v0, *v1)
    sdf1 = edge_sdf(x, y, *v1, *v2)
    sdf2 = edge_sdf(x, y, *v2, *v0)
    
    sign0 = (x - v0[0]) * (v1[1] - v0[1]) - (y - v0[1]) * (v1[0] - v0[0]) < 0
    sign1 = (x - v1[0]) * (v2[1] - v1[1]) - (y - v1[1]) * (v2[0] - v1[0]) < 0
    sign2 = (x - v2[0]) * (v0[1] - v2[1]) - (y - v2[1]) * (v0[0] - v2[0]) < 0
    
    sdf = np.minimum(np.minimum(sdf0, sdf1), sdf2)
    sdf[sign0 == sign1] = -sdf[sign0 == sign1]
    sdf[sign1 == sign2] = -sdf[sign1 == sign2]
    return sdf

def generate_dataset(grid_size=32, num_samples=100):
    """Generate a dataset of SDFs for different shapes."""
    dataset = []
    labels = []
    
    for _ in range(num_samples):
        shape_type = np.random.choice(['circle', 'rectangle', 'triangle'])
        if shape_type == 'circle':
            radius = np.random.uniform(5, 12)
            center = np.random.uniform(radius, grid_size - radius, 2)
            sdf = sdf_circle(center, radius, grid_size)
        elif shape_type == 'rectangle':
            width = np.random.uniform(5, 12)
            height = np.random.uniform(5, 12)
            center = np.random.uniform(max(width, height) / 2, grid_size - max(width, height) / 2, 2)
            sdf = sdf_rectangle(center, width, height, grid_size)
        elif shape_type == 'triangle':
            vertices = np.random.uniform(5, grid_size - 5, (3, 2))
            sdf = sdf_triangle(vertices, grid_size)
        
        dataset.append(sdf)
        labels.append(shape_type)
    
    return np.array(dataset), np.array(labels)

# Generate the dataset
dataset, labels = generate_dataset()
tensor_dataset = torch.tensor(dataset).unsqueeze(1).float()


In [None]:
plt.imshow(tensor_dataset[0][0], cmap='hot', interpolation='nearest')
plt.show()
print(tensor_dataset[0][0].type)