In [4]:
import torch
import torch.nn as nn
import instruction_test_data as itd
import numpy as np
import card_embedding as ce
import nesting
from multi_head_attention import MultiHeadAttention
from positional_embedding import PositionalEmbedding
torch.manual_seed(1)

d=8
shared_embedding_holder = ce.SharedEmbeddingHolder(d, device='cuda')
positional_embedding = PositionalEmbedding(d, device='cuda')

In [5]:
instruction_embedding = ce.InstructionEmbedding(shared_embedding_holder, d, device="cuda")

instruction_embeddings = instruction_embedding(itd.instructions_batch)
print(instruction_embeddings)
print(instruction_embeddings.size())

tensor([[ 5.6120,  3.2177,  4.4789,  1.9952,  0.7815, -7.9421, -9.0347,  8.1897],
        [ 2.2885, -2.0347, -2.3284,  0.6772, -9.2289, -9.8569, -3.0202,  0.2285],
        [ 0.1685,  0.2371,  0.5718,  1.6196,  1.0329,  0.2250,  2.0036,  0.9421],
        [ 0.1818,  1.4895,  0.3771,  1.5049,  1.0562,  0.0494,  0.0668,  0.8926],
        [-3.0103, -0.3263, -2.4020, -2.2322, -0.9752, -0.1198, -7.6821,  1.9554]],
       device='cuda:0', grad_fn=<SumBackwardAutogradNestedTensor1>)
torch.Size([5, 8])


In [None]:
from itertools import chain


class InstructionDataEmbedding(nn.Module):
    def __init__(
        self,
        shared_embedding_holder: ce.SharedEmbeddingHolder,
        dimension_out: int,
        device=None,
        dtype=None,
    ):
        self.factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.dimension_out = dimension_out
        self.attack_data_embedding = ce.AttackDataEmbedding(
            dimension_out, **self.factory_kwargs
        )
        self.discard_data_embedding = ce.DiscardDataEmbedding(
            dimension_out, **self.factory_kwargs
        )
        self.card_amount_data_embedding = ce.CardAmountDataEmbedding(
            shared_embedding_holder, dimension_out, **self.factory_kwargs
        )
        self.return_to_deck_type_data_embedding = ce.ReturnToDeckTypeDataEmbedding(
            shared_embedding_holder, dimension_out, **self.factory_kwargs
        )
        self.filter_embedding = ce.FilterEmbedding(
            shared_embedding_holder, dimension_out, **self.factory_kwargs
        )
        self.player_target_data_embedding = ce.PlayerTargetDataEmbedding(
            shared_embedding_holder, dimension_out, **self.factory_kwargs
        )
        self.instruction_data_type_embedding = nn.Embedding(
            6, dimension_out, padding_idx=0, **self.factory_kwargs
        )
        self.data_multi_head_attention = MultiHeadAttention(
            dimension_out,
            dimension_out,
            dimension_out,
            max(dimension_out // 16, 4),
            4,
            **self.factory_kwargs,
        )
        self.position_embedding = shared_embedding_holder.position_embedding

    def forward(
        self,
        instruction_indices,
        instruction_data_types,
        instruction_data_type_indices,
        instruction_data,
        instruction_data_indices,
        batch_size: int,
    ) -> torch.Tensor:
        instruction_data_type_embeddings = self.instruction_data_type_embedding(
            instruction_data_types
        )
        if instruction_data[0]:
            attack_data_embeddings = self.attack_data_embedding(
                torch.stack(instruction_data[0])
            )
        else:
            attack_data_embeddings = []
        if instruction_data[1]:
            discard_data_embeddings = self.discard_data_embedding(
                torch.tensor(instruction_data[1], **self.factory_kwargs)
            )
        else:
            discard_data_embeddings = []
        if instruction_data[2]:
            card_amount_data_embeddings = self.card_amount_data_embedding(
                torch.stack(instruction_data[2])
            )
        else:
            card_amount_data_embeddings = []
        if instruction_data[3]:
            return_to_deck_type_data_embeddings = self.return_to_deck_type_data_embedding(
                torch.stack(instruction_data[3])
            )
        else:
            return_to_deck_type_data_embeddings = []
        if instruction_data[4]:
            filter_embeddings = self.filter_embedding(instruction_data[4])
        else:
            filter_embeddings = []
        if instruction_data[5]:
            player_target_data_embeddings = self.player_target_data_embedding(
                torch.tensor(instruction_data[5], **self.factory_kwargs)
            )
        else:
            player_target_data_embeddings = []
        # return attack_data_embeddings, discard_data_embeddings, card_amount_data_embeddings, return_to_deck_type_data_embeddings, filter_embeddings, player_target_data_embeddings

        sorted_data = self.sort_tensors_with_respect_to_index(
            (
                attack_data_embeddings,
                discard_data_embeddings,
                card_amount_data_embeddings,
                return_to_deck_type_data_embeddings,
                filter_embeddings,
                player_target_data_embeddings,
            ),
            instruction_data_indices,
        )
        return sorted_data + instruction_data_type_embeddings

    def sort_tensors_with_respect_to_index(self, tensors, indices):
        return torch.stack(
            [
                tensor
                for _, tensor in sorted(
                    zip(
                        chain.from_iterable(indices),
                        chain.from_iterable(tensors),
                    ),
                    key=lambda pair: pair[0],
                )
            ]
        )

In [28]:
(
    condition_types,
    condition_indices,
    instruction_data_types,
    instruction_data_type_indices,
    instruction_data,
    instruction_data_indices,
) = nesting.flatten_instructions("ConditionType", itd.conditions_batch, device="cuda")
print(instruction_data_indices)
instruction_data_embedding = InstructionDataEmbedding(shared_embedding_holder, d, device="cuda")
tensors = instruction_data_embedding(
    condition_indices,
    instruction_data_types,
    instruction_data_type_indices,
    instruction_data,
    instruction_data_indices,
    len(itd.conditions_batch)
)
print(tensors)

([], [], [(0, 0, 0), (1, 0, 0), (2, 0, 0), (2, 1, 0)], [], [], [])
([], [], tensor([[ 0.2699,  1.5510, -1.4474, -1.6604,  1.2613,  0.2293, -0.8266, -0.2678],
        [ 0.2699,  1.5510, -1.4474, -1.6604,  1.2613,  0.2293, -0.8266, -0.2678],
        [ 0.3690, -0.1970, -0.0166,  2.0454, -0.3100,  1.0262,  0.3317,  2.1031],
        [ 0.2699,  1.5510, -1.4474, -1.6604,  1.2613,  0.2293, -0.8266, -0.2678]],
       device='cuda:0', grad_fn=<AddBackward0>), [], [], [])


In [29]:
for _, tensor in sorted(
                    zip(
                        chain.from_iterable(instruction_data_type_indices),
                        chain.from_iterable(tensors),
                    ),
                    key=lambda pair: pair[0],
                ):
                print(tensor)

tensor([ 0.2699,  1.5510, -1.4474, -1.6604,  1.2613,  0.2293, -0.8266, -0.2678],
       device='cuda:0', grad_fn=<UnbindBackward0>)
tensor([ 0.2699,  1.5510, -1.4474, -1.6604,  1.2613,  0.2293, -0.8266, -0.2678],
       device='cuda:0', grad_fn=<UnbindBackward0>)
tensor([ 0.3690, -0.1970, -0.0166,  2.0454, -0.3100,  1.0262,  0.3317,  2.1031],
       device='cuda:0', grad_fn=<UnbindBackward0>)
tensor([ 0.2699,  1.5510, -1.4474, -1.6604,  1.2613,  0.2293, -0.8266, -0.2678],
       device='cuda:0', grad_fn=<UnbindBackward0>)


In [6]:
condition_embedding = ce.ConditionEmbedding(shared_embedding_holder, d, device="cuda")

condition_embeddings = condition_embedding(itd.conditions_batch)
print(condition_embeddings)
print(condition_embeddings.size())

RuntimeError: stack expects a non-empty TensorList