In [1]:
from text2brick.models import GraphLegoWorldData
import numpy as np
import torch
from text2brick.dataset.dataset import MNISTDataset
from text2brick.gym import LegoEnv


In [2]:
table = np.array([[0, 0, 0, 0],
                  [1, 1, 0, 0],
                  [0, 1, 1, 0],
                  [1, 1, 0, 0]])

In [3]:
# dataset = MNISTDataset()
# array, _, _, _ = dataset.sample(sample_index=1)

In [4]:
# env = LegoEnv(array.shape[0])
# print(env)

In [5]:
# obs, reward, done, info = env.step((0, 14), array) # Valid
# obs, reward, done, info = env.step((0, 14), array) # Overlap
# obs, reward, done, info = env.step((0, 27), array) # Half of the brick is out of the world 
# obs, reward, done, info = env.step((3, 14), array) # Invalid
# obs, reward, done, info = env.step((1, 14), array) # Valid
# print(obs)

In [6]:
graph = GraphLegoWorldData(img=table)
graph.add_brick(3, 0)
graph.add_brick(2, 0)
graph.print_graph()
graph.save_as_ldraw()

Outside: Brick at (3, 0) is out of the world
Number of nodes: 4
Number of edges: 3
Node 0: {'x': 0, 'y': 0, 'validity': True}
Node 1: {'x': 1, 'y': 1, 'validity': True}
Node 2: {'x': 0, 'y': 2, 'validity': True}
Node 3: {'x': 2, 'y': 0, 'validity': True}
Edge (0, 1): {}
Edge (1, 2): {}
Edge (1, 3): {}


In [13]:
data = graph.get_nodes()

for node, data in data: 
    print(node, data)

0 {'x': 0, 'y': 0, 'validity': True}
1 {'x': 1, 'y': 1, 'validity': True}
2 {'x': 0, 'y': 2, 'validity': True}
3 {'x': 2, 'y': 0, 'validity': True}


In [8]:
sub = graph.subgraph(3)
for node, data in sub.nodes(data=True):
    print(f"Node {node}: {data}")
for u, v,  in sub.edges():
    print(f"Edge ({u}, {v}):")

sub

Node 0: {'x': 0, 'y': 0, 'validity': True}
Node 1: {'x': 1, 'y': 1, 'validity': True}
Node 2: {'x': 0, 'y': 2, 'validity': True}
Edge (0, 1):
Edge (1, 2):


<networkx.classes.graph.Graph at 0x2e886e48c70>

In [9]:
graph.graph_to_table()

array([[0, 0, 0, 0],
       [1, 1, 0, 0],
       [0, 1, 1, 0],
       [1, 1, 1, 1]])

In [10]:
data = graph.graph_to_torch()
c = torch.stack([data.x, data.y])
data.edge_index

tensor([[0, 1, 1],
        [1, 2, 3]])

In [11]:
from torch.utils.data import DataLoader
from text2brick.dataset import CustomDatasetGraph

dataset = CustomDatasetGraph()


In [12]:
train_data = DataLoader(dataset, batch_size=1, shuffle=True)

for i in range(2):
    x, edge_index, next_node = next(iter(train_data))
    print(f"Sample {i + 1}:")
    print(f"Node Features (x): {x}")
    print(f"Edge Index: {edge_index}")
    print(f"Next Node: {next_node}")
    print("-" * 30)

Sample 1:
Node Features (x): tensor([[[ 5,  7,  5,  7,  5,  7,  9,  6,  8, 10,  6,  8, 10,  8, 10, 12,  9,
          11, 10, 12, 14, 10, 12, 14, 11, 13, 15, 12, 14, 16, 14, 16, 14, 16,
          15, 17, 19, 15, 17, 19, 16, 18, 20, 16, 18, 20, 16, 18, 20, 19],
         [ 0,  0,  1,  1,  2,  2,  2,  3,  3,  3,  4,  4,  4,  5,  5,  5,  6,
           6,  7,  7,  7,  8,  8,  8,  9,  9,  9, 10, 10, 10, 11, 11, 12, 12,
          13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16, 17, 17, 17, 18]]])
Edge Index: tensor([[[ 0,  1,  2,  3,  4,  5,  5,  6,  6,  7,  8,  9, 11, 12, 13, 14, 14,
          15, 16, 17, 17, 18, 19, 20, 21, 22, 22, 23, 23, 24, 25, 25, 26, 26,
          28, 29, 30, 31, 32, 33, 33, 34, 35, 36, 37, 38, 38, 39, 39, 40, 41,
          42, 43, 44, 45, 47, 48],
         [ 2,  3,  4,  5,  7,  7,  8,  8,  9, 10, 11, 12, 13, 14, 16, 16, 17,
          17, 18, 18, 19, 21, 22, 23, 24, 24, 25, 25, 26, 27, 27, 28, 28, 29,
          30, 31, 32, 33, 34, 34, 35, 37, 38, 39, 40, 40, 41, 41, 42, 

In [1]:
from text2brick.gym import Text2Brick_v1
import torch
import numpy as np
from PIL import Image
from torchsummary import summary
from text2brick.gym import Text2Brick_v1, BrickPlacementGNN, CNN, SNN, PositionHead2D, MLP

image_target = np.random.rand(3, 224, 224)
image_environment = np.random.rand(3, 224, 224)

def numpy_to_pil(image_np):
    image_np = (image_np * 255).astype(np.uint8)
    image_np = image_np.transpose(1, 2, 0)
    return Image.fromarray(image_np)

image_target_pil = numpy_to_pil(image_target)
image_environment_pil = numpy_to_pil(image_environment)

# image_target_pil.show()
# image_environment_pil.show()

In [2]:
node_features = torch.randn(10, 2)  # 10 nodes with 2 features each
edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]])  # Example edge connections

# Initialize the model
#model = Text2Brick_v1(image_target=image_target_pil)

In [3]:
gnn = BrickPlacementGNN()
gnn_result = gnn.forward(node_features, edge_index)
gnn_result.shape

torch.Size([1, 64])

In [4]:
snn = SNN(image_target=image_target_pil)
snn_result = snn.forward(image_environment_pil)
snn_result.shape

Using cache found in C:\Users\ebern/.cache\torch\hub\pytorch_vision_v0.10.0


torch.Size([1, 13, 13])

In [5]:
mlp = MLP((1, 13, 13), (1, 64))

In [6]:
mlp_output = mlp.forward(snn_result, gnn_result)

torch.Size([1, 169])


In [18]:
mlp_output.shape 

torch.Size([1, 16])

In [19]:
head = PositionHead2D(mlp_output_dim=16)

In [20]:
head.forward(mlp_output)

tensor([[4, 5]])

In [3]:
full_model = Text2Brick_v1(image_target=image_target_pil)

INIT SNN
INIT GNN


Using cache found in C:\Users\ebern/.cache\torch\hub\pytorch_vision_v0.10.0


In [4]:
full_model.forward(image_environment_pil, node_features, edge_index)

torch.Size([1, 169])


tensor([[3, 6]])