In [2]:
import torch
from procgen import ProcgenGym3Env
from torchinfo import summary
import core

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

model = None
player = None
ppo = None
env= None
envKW = {}

modelPath = "models/"
def loadAll(fname, loadEnv=True):
    model.load_state_dict(torch.load(modelPath + fname + "/model.pth"))
    player.load_state_dict(torch.load(modelPath + fname + "/player.pth"))
    ppo.load_state_dict(torch.load(modelPath + fname + "/ppo.pth"))
    if loadEnv:
        envKW = torch.load(modelPath + fname + "/envKW.pth")
        env = ProcgenGym3Env(**envKW)
        env.callmethod("set_state", torch.load(modelPath + fname + "/env_states.pth"))
    else:
        player.reset()

def saveAll(fname):
    import os
    os.makedirs(modelPath + fname, exist_ok=True)
    torch.save(model.state_dict(), modelPath + fname + "/model.pth")
    torch.save(player.state_dict(), modelPath + fname + "/player.pth")
    torch.save(ppo.state_dict(), modelPath + fname + "/ppo.pth")
    torch.save(envKW, modelPath + fname + "/envKW.pth")
    torch.save(env.callmethod("get_state"), modelPath + fname + "/env_states.pth")
    torch.save(ppo.all_stats, modelPath + fname + "/stats.pth")


cuda:0


In [3]:
num_models = 2
num_agents = 16
envKW = core.getKW(num=num_models*num_agents, env_name="coinrun", distribution_mode="easy", paint_vel_info=True, use_backgrounds=False, restrict_themes=True)
env = ProcgenGym3Env(**envKW)
print(env.ob_space)
print(env.ac_space)

Dict(rgb=D256[64,64,3])
D15[]


In [4]:
from CVModels import CNNAgent, ViTValue, VectorModelValue
model = ViTValue(depth=4, num_heads=4, embed_dim=32, mlp_ratio=4, valueHeadLayers=1).to(device)
# model = ViTValue(depth=3, num_heads=4, embed_dim=16, mlp_ratio=4, valueHeadLayers=1)
model = VectorModelValue(model, n=num_models).to(device)
# model = CNNAgent([64, 64, 3], 15, channels=16, layers=[1,1,1,1], scale=[1,1,1,1], vheadLayers=1).to(device)
model.train()
summary(model, input_size=(2, 2, 3, 64, 64))

Layer (type:depth-idx)                             Output Shape              Param #
VectorModelValue                                   [2, 2, 15]                --
├─ModuleList: 1-1                                  --                        --
│    └─ViTValue: 2-1                               [2, 15]                   --
│    │    └─VisionTransformer: 3-1                 --                        61,199
│    │    └─ValueHead: 3-2                         [2, 1]                    33
│    └─ViTValue: 2-2                               [2, 15]                   --
│    │    └─VisionTransformer: 3-3                 --                        61,199
│    │    └─ValueHead: 3-4                         [2, 1]                    33
Total params: 122,464
Trainable params: 122,464
Non-trainable params: 0
Total mult-adds (M): 1.81
Input size (MB): 0.20
Forward/backward pass size (MB): 12.11
Params size (MB): 0.42
Estimated Total Size (MB): 12.73

In [5]:
# model.load_state_dict(torch.load(modelPath + "vitNegT8BigFin" + "/model.pth"))

In [6]:
from PPO import VectorPPO
from ProcgenPlayer import VectorPlayer

rewardScale = 8.0
terminateReward = 1 - 10.0 / rewardScale
livingReward = 0
print("terminateReward", terminateReward, "livingReward", livingReward, "discountedSumLiving", livingReward / (1 - 0.99)) # if terminate reward > discountedSumLiving the agent will perfer to run into obstacles.
player = VectorPlayer(env, num_agents=num_agents, num_models=num_models, epsilon=0.01, epsilon_decay=0.99, rewardScale=rewardScale, livingReward=0, terminateReward=terminateReward)
ppo = VectorPPO(model, env, num_agents=num_agents, num_models=num_models, player=player, gamma=0.99, weight_decay=0.0, warmup_steps=10, train_steps=1000)

terminateReward -0.25 livingReward 0 discountedSumLiving 0.0




In [7]:
# ppo.runGame()
# loss = ppo.train(debug=True)
# print(loss)
# import torchviz
# torchviz.make_dot(torch.sum(loss), params=dict(model.named_parameters()))

In [8]:
# loadAll("vitNegT8BigFin")

In [9]:
for i in range(200):
    ppo.runGame()
    ppo.train()
    if i % 10 == 0:
        # print("episodeLength", ppo.all_stats[-1]["game/episodeLength"], "episodeReward", ppo.all_stats[-1]["game/episodeReward"],
        #       "epoch", ppo.all_stats[-1]["epoch"], "steps", ppo.all_stats[-1]["steps"], 
        #       "\nloss", ppo.all_stats[-1]["ppo/loss/total"].item(), "policy", ppo.all_stats[-1]["ppo/loss/policy"].item(), 
        #       "value", ppo.all_stats[-1]["ppo/loss/value"].item(),
        #       "entropy", ppo.all_stats[-1]["ppo/policy/entropy"].item())
        print(f"episodeLength {ppo.all_stats[-1]['game/episodeLength']} episodeReward {ppo.all_stats[-1]['game/episodeReward']} " + 
              f"\nepoch {ppo.all_stats[-1]['epoch']} steps {ppo.all_stats[-1]['steps']} " +
              f"\nloss {ppo.all_stats[-1]['ppo/loss/total']} policy {ppo.all_stats[-1]['ppo/loss/policy']} " +
              f"\nvalue {ppo.all_stats[-1]['ppo/loss/value']} entropy {ppo.all_stats[-1]['ppo/policy/entropy']} " +
              f"\nstale {ppo.all_stats[-1]['game/staleSteps']}              ")
    # if i % 100 == 0:
    #     stats = ppo.all_stats[-1]
    #     for k, v in stats.items():
    #         # if "time" in k:
    #         print(k, v)

episodeLength [245.3859649122807, 281.6842105263158] episodeReward [-0.03508771929824561, 0.0043859649122807015] 
epoch 0 steps 4096 
loss tensor([0.0078, 0.0067], dtype=torch.float64) policy tensor([ 9.0856e-17, -6.3068e-16], dtype=torch.float64) 
value tensor([0.0078, 0.0067], dtype=torch.float64) entropy tensor([2.7060, 2.7061]) 
stale 0              
episodeLength [172.94117647058823, 294.6470588235294] episodeReward [0.08823529411764706, 0.10294117647058823] 
epoch 10 steps 45056 
loss tensor([0.0006, 0.0026], dtype=torch.float64) policy tensor([-0.0006,  0.0018], dtype=torch.float64) 
value tensor([0.0012, 0.0008], dtype=torch.float64) entropy tensor([2.6895, 2.6932]) 
stale 959              
episodeLength [123.23529411764706, 225.52941176470588] episodeReward [0.04411764705882353, 0.18382352941176472] 
epoch 20 steps 86016 
loss tensor([0.0035, 0.0012], dtype=torch.float64) policy tensor([0.0028, 0.0008], dtype=torch.float64) 
value tensor([0.0006, 0.0004], dtype=torch.float64) 

KeyboardInterrupt: 

In [None]:
# saveAll("vitNegT8BigFin400")