In [1]:
import sys
import logging.config
import yaml
import importlib

from os.path import join as path_join
from IPython.display import display, Markdown

src_path = path_join("..", "algos")
if src_path not in sys.path:
    sys.path.insert(0, src_path)

with open("logging_config.yml") as f:
    logging.config.dictConfig(yaml.load(f, Loader=yaml.FullLoader))
logger = logging.getLogger("Jupyter")

def display_md(content: str):
    display(Markdown(content))

In [2]:
from snek import base, snek1d
importlib.reload(base)
importlib.reload(snek1d)

import math
import numpy as np
import torch
import pandas as pd
import random

from algo_battle.domain import Richtung, FeldZustand, ArenaDefinition

In [3]:
def generate_state(length: int) -> snek1d.Snek1DState:
    movements = []
    for n in range(length):
        movements.append(snek1d.Movement(
            tuple(random.randint(0, 100) for _ in range(4)),
            Richtung.zufall(), random.choice(base.field_states)
        ))
    return snek1d.Snek1DState(movements)

In [4]:
kernel_size = 10
out_features = 4

model = snek1d.Snek1DModel(snek1d.Movement.size(), kernel_size, out_features)

In [5]:
state = generate_state(128)
state_tensor = state.as_tensor(None)
prediction = model(state_tensor)
print(prediction)

tensor([[1., 1., 1., 1.]], grad_fn=<SoftmaxBackward>)


In [6]:
state_batch = torch.cat((generate_state(128).as_tensor(None), generate_state(128).as_tensor(None)))
model(state_batch)

tensor([[0.5724, 0.5887, 0.4634, 0.4763],
        [0.4276, 0.4113, 0.5366, 0.5237]], grad_fn=<SoftmaxBackward>)