<a href="https://colab.research.google.com/github/achanhon/coursdeeplearningcolab/blob/master/rl_sample.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install gymnasium
!pip install swig
!pip install "gymnasium[box2d]"

In [None]:
import random
import torch
import gymnasium

class StateSampler:
  def __init__(self,memory):
    self.memory = memory
    w = [t for t,_,_,_,_,_,_ in memory]
    w = torch.Tensor(w)+1
    self.w = w/float(w.sum())

  def get(self,n):
    I = torch.multinomial(self.w, n, replacement=True)
    return torch.stack(memory[I][2],dim=0)

class TransitionSampler:
  def __init__(self,memory):
    self.memory = memory
    w = [totR for _,totR,_,_,_,_,_ in memory]
    w = torch.Tensor(w)*0.1
    self.w = torch.nn.functional.softmax(w,dim=0)

  def get(self,n):
    I = torch.multinomial(self.w, n, replacement=True)
    return torch.stack(memory[I][2:],dim=0)

def tokenf(f):
  out = torch.zeros(7):
  out[0]=f
  if -0.001<=f<=0.001
    out[1]=0
  else:
    if f<=0:
      out[1]=-1
    else
      out[1]=1
  f = int(abs(f)*32)
  for i in range(5):
    out[i+2] = f%2
    f = f//2
  return f

def tokens(s):
  out = [token(float(s[i])) for i in range(6)]
  out = out+[torch.Tensor([float(s[6]),float(s[7])])]
  return torch.cat(out,dim=0)

def trial(env,agent,eps):
  s, _ = env.reset(seed=0)
  totR,s,traj =0, tokens(s),[]
  agent.eval().cpu()
  for _ in range(3000):
    if random.random()<eps:
      a = int(random.random()*4)
    else:
      _,a = agent(s.view(1,-1)).max(1)
      a = int(a)

    s_, r, terminated, truncated, _ = env.step(a)
    s_,totR,a = tokens(s_),totR+r,torch.eye(4)[a]
    traj.append((s,a,r,s_))

    if terminated or truncated:
      traj = [[totR,s,a,r,s_] for (s,a,r,s_) in traj]
      for i in range(len(traj)-1):
        traj[i].append(traj[i+1][2])
      traj[-1].append(torch.zeros(4))
      return totR,traj
    else:
      s = s_

def learnQ(agent,optimizer, batchS,batchT):
  Gamma=0.999
  batchT = [obj.cuda() for obj in batchT]
  S,R,A,S_,R_ = batchT

  X = torch.cat([batchS.cuda(),S,S_],dim=0)
  Q = agent(X)# group state for batch norm

  reg = Q.mean()

  QA,QA_ = Q[:S.shape[0]],Q[S.shape[0]:]
  bellman = torch.nn.functional.relu(Gamma*QA_+R-QA)
  bellman = (bellman**2).mean()

  loss = bellman+0.01*reg
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  return float(loss)

def train(agent, optimizer,memory,n)
  ss = StateSampler(memory)
  ts = TransitionSampler(memory)
  for _ in range(n):
    losses = []
    for _ in range(20):
      batchS = ss.get(256)
      batchT = ts.get(256)
      losses.append(learnQ(agent,optimizer,batchS,batchT))
    print(sum(losses)/20)

In [None]:
class ScaledLinear(torch.nn.Module):
  def __init__(self, Din, Dout):
    super(ScaledLinear, self).__init__()
    self.l = torch.nn.Linear(Din, Dout, bias=False)
    self.bn = torch.nn.BatchNorm1d(Dout)

  def forward(self, x):
    x = self.bn(self.l(x))
    return torch.nn.functional.relu(x)

class LunarAgent(torch.nn.Module):
  def __init__(self):
    super(LunarAgent,self).__init__()

    self.l1 = ScaledLinear(44,256)
    self.l2 = ScaledLinear(256,128)
    self.l3 = ScaledLinear(44+128,256)
    self.l4 = ScaledLinear(256,256)
    self.l5 = ScaledLinear(44+256,512)
    self.final =torch.nn.Linear(512,4)

  def forward(self,x):
    f = self.l1(x)
    f = self.l2(f)
    f = self.l3(torch.cat([x,f],dim=1))
    f = self.l4(f)
    f = self.l5(torch.cat([x,f],dim=1))
    return self.final(f)


In [None]:
env = gymnasium.make("LunarLander-v3", continuous=False,enable_wind=False)

T = 0.3
agent = LunarAgent()
SEUIL = 0.
for _ in range(100):
    SEUIL+=trial(env,agent,T)
SEUIL = SEUIL/100

for j in range(20):
    memory = MemoryBuffer()
    for _ in range(5):
        trial(env,agent,T,memory,seed=42)
    for i in range(10):
        trial(env,agent,T,memory,seed=i)
    for _ in range(35):
        trial(env,agent,T,memory)

    for _ in range(10+j*2):
        v = trial(env,agent,T,memory)
        l = train(agent,T,memory, nbstep=100+10*j)
        print("\t",v,l)

    tot = 0
    for _ in range(100):
        tot+=trial(env,agent,T)
    tot = tot/100
    print(j,T,tot)
    if SEUIL<tot:
        T = 1.2*T
        SEUIL = tot
