In [None]:
import mingpt
from mingpt.prisonerTrainer import PrisonerTrainer
from mingpt.model import GPT
from torchinfo import summary

In [None]:
model_config = GPT.get_default_config()
model_config.model_type = 'gpt-pico'
model_config.vocab_size = 2
model_config.block_size = 128
model_config.alg_name = "ppo"
model = GPT(model_config)
summary(model)

In [None]:
train_config = PrisonerTrainer.get_default_config()
train_config.learning_rate = 5e-4 # many possible options, see the file
train_config.max_iters = 500
train_config.gamma = 0.50
train_config.alg_name = model_config.alg_name
trainer = PrisonerTrainer(train_config, model)
print(trainer.equilibriumDiscount(startCoop=True), trainer.equilibriumDiscount(startCoop=False))
print(0.5 * trainer.equilibriumDiscount(startCoop=True) + 0.5 * trainer.equilibriumDiscount(startCoop=False))
trainer.run()

In [None]:
import json
with open("rewStats.json", 'r') as file:
    iter_list, rew_dict, avg_rets, loss_list = json.load(file)
print(iter_list, rew_dict, avg_rets)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

alg_name = train_config.alg_name
if alg_name == "reject":
    alg_name = "rejection sampling"

plt.figure()
title = f"self play prisoner's dillema with {alg_name}"
plt.title(title)
for k in rew_dict:
    plt.scatter(iter_list, rew_dict[k], label=k)
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), title=f"g: {train_config.gamma}\ng thresh after c: {trainer.equilibriumDiscount(startCoop=True)}\ng thresh after d: {trainer.equilibriumDiscount(startCoop=False)}")
plt.savefig(alg_name + ".jpg", bbox_inches='tight')
plt.show()