In [1]:
import os
import sys
import numpy as np
import torch

sys.path.append(os.path.abspath('..'))

from models.transformer import BinPackingTransformer
from env.env import BinPacking3DEnv

In [2]:
# Create a transformer model
transformer = BinPackingTransformer(
	d_model=128,
	n_head=8,
	n_layers=3,
	d_feedforward=512,
)

In [3]:
transformer

BinPackingTransformer(
  (ems_list_embedding): Embedding(
    (linear): Linear(in_features=6, out_features=128, bias=True)
  )
  (buffer_embedding): Embedding(
    (linear): Linear(in_features=3, out_features=128, bias=True)
  )
  (transformer_blocks): ModuleList(
    (0-2): 3 x TransformerBlock(
      (self_attn_ems_list): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (self_attn_buffer): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (norm1_ems_list): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (norm1_buffer): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mlp_ems_list): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): ReLU()
        (2): Linear(in_features=512, out_features=128, bias=True)
      )
      (mlp_buffer): Sequential(
        (0): Lin

In [4]:
# Create an environment and dummy input to test the model
env = BinPacking3DEnv(
	bin_size=(5, 5, 5),
	items=[(2, 3, 1), (2, 2, 3), (1, 1, 2), (3, 2, 2)],
	buffer_size=2,
	num_rotations=2,
	max_ems=100,
)

obervation = env.reset()

In [5]:
# EMS list
ems_list = obervation['ems_list']

ems_list.shape

(100, 6)

In [7]:
ems_mask = obervation['ems_mask']

ems_mask.shape

(101,)

In [20]:
# Buffer
buffer = obervation['buffer']

buffer.shape

(2, 3)

In [22]:
ems_list_np = np.expand_dims(ems_list, axis=0)  # [1, max_ems, 6]
buffer_np = np.expand_dims(buffer, axis=0)      # [1, buffer_size, 3]

ems_list_np.shape, buffer_np.shape

((1, 100, 6), (1, 2, 3))

In [24]:
ems_list_tensor = torch.tensor(ems_list_np, dtype=torch.float32)
buffer_tensor = torch.tensor(buffer_np, dtype=torch.float32)

ems_list_tensor.shape, buffer_tensor.shape

(torch.Size([1, 100, 6]), torch.Size([1, 2, 3]))