In [7]:
import torch
import numpy as np
import os
import cma
from eval import play
from gan.config import SmallModelConfig
from gan.small_models import Generator
from gan.utils import tensor_to_level_str,check_playable
from gan.level_visualizer import LevelVisualizer
from gan.env import Env
from play_rl.wrappers import GridGame
from play_rl.policy import Policy
from PIL import Image

In [8]:
config = SmallModelConfig()
config.set_env()
config.cuda = False
if config.cuda:
    device = torch.device(
        "cuda" if torch.cuda.is_available else "cpu")
    print("device : cuda")
else:
    device = torch.device("cpu")
    print("device : cpu")
generator = Generator(
    out_dim=config.input_shape[0],
    shapes=config.model_shapes,
    z_shape=(config.latent_size,),
    filters=config.generator_filters,
).to(device)
env_def = Env(config.env_name,config.env_version)
level_visualizer = LevelVisualizer(env=env_def)

model_path = os.path.join(
    "/root/mnt/GVGAI-GAN/gan/checkpoints/none-866", "models_1160.tar")
load_model = torch.load(model_path)
generator.load_state_dict(load_model["generator"])
x = torch.randn(config.latent_size).to(device)


device : cpu


In [9]:
def eval(level_str):
    wall = 0
    enemy = 0
    for i, s in enumerate(level_str):
        for j, c in enumerate(s):
            if c == 'w':
                wall += 1
            if c in ['1', '2', '3']:
                enemy += 1
    return wall + enemy*10


env = GridGame(config.env_name, 200, env_def.state_shape)
actor = Policy(env.observation_space.shape, env.action_space)

def fitness(x: torch.Tensor):
    x = np.array(x)
    latent = torch.FloatTensor(x).view(1, -1, 1, 1)
    level = generator(latent)
    level_str = tensor_to_level_str('zelda', level)
    playable = check_playable(level_str[0])
    if not playable:
        return 100
    reward,step = play(level_str[0],env=env,actor=actor)
    # ev = eval(level_str)
    return -(reward*100+step)

def show(x):
    p_level = torch.nn.Softmax2d()(generator(x))
    level_strs = tensor_to_level_str(
        config.env_name, p_level)
    p_level_img = np.array(level_visualizer.draw_level(level_strs[0]))
    image = Image.fromarray(p_level_img)
    image.show()

es = cma.CMAEvolutionStrategy(x.tolist(), 0.2)
es.optimize(fitness)
best = np.array(es.best.get()[0])
print("INIT: ", np.array(x.tolist()))
print("BEST: ", best)
print("Fitness: ", fitness(best))

show(x)
show(torch.FloatTensor(best))


Connecting to host 127.0.0.1 at port 47197 ...
Client connected to server [OK]
(7_w,14)-aCMA-ES (mu_w=4.3,w_1=36%) in dimension 32 (seed=524868, Thu Aug 25 05:27:42 2022)
Iterat #Fevals   function value  axis ratio  sigma  min&max std  t[m:s]
    1     14 -2.900000000000000e+02 1.0e+00 1.86e-01  2e-01  2e-01 0:20.9
    2     28 -2.680000000000000e+02 1.1e+00 1.80e-01  2e-01  2e-01 0:42.7
    3     42 -1.150000000000000e+02 1.1e+00 1.75e-01  2e-01  2e-01 1:08.7
    4     56 -1.840000000000000e+02 1.1e+00 1.72e-01  2e-01  2e-01 1:35.4
    5     70 -1.760000000000000e+02 1.1e+00 1.71e-01  2e-01  2e-01 2:19.7
    6     84 -2.940000000000000e+02 1.1e+00 1.73e-01  2e-01  2e-01 3:01.4
    7     98 -1.910000000000000e+02 1.2e+00 1.78e-01  2e-01  2e-01 3:42.9
    8    112 -2.920000000000000e+02 1.2e+00 1.80e-01  2e-01  2e-01 4:24.4
    9    126 -1.720000000000000e+02 1.2e+00 1.82e-01  2e-01  2e-01 4:55.3
   10    140 -2.000000000000000e+02 1.3e+00 1.85e-01  2e-01  2e-01 5:19.7
   11    154 -1.4

  if(path is ''):
  if(path is ''):
  if(path is ''):
  if(path is ''):
  if(path is ''):
  if(path is ''):


KeyboardInterrupt: 