<a href="https://colab.research.google.com/github/achanhon/coursdeeplearningcolab/blob/master/rl_sample.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
!pip install gymnasium
!pip install swig
!pip install "gymnasium[box2d]"
!pip install schedulefree



In [9]:
import torch
import gymnasium
import schedulefree

GAMMA=0.99

class MemoryBuffer:
    def __init__(self):
        self.i = 0
        self.full = False

        L=50000
        self.s = torch.zeros(L,8)
        self.a = torch.zeros(L,4)
        self.r = torch.zeros(L)
        self.s_ = torch.zeros(L,8)
        self.f = torch.zeros(L)

    def push(self, s, a, r, s_, f):
        self.s[self.i] = s
        self.a[self.i][a] = 1
        self.r[self.i] = r
        self.s_[self.i] = s_
        self.f[self.i] = 1-f
        self.i += 1
        if self.i >= self.r.shape[0]:
            self.full = True
            self.i = 0

    def getBatch(self, B=64):
        if self.full:
            I = list((torch.rand(B) * self.r.shape[0]).long())
        else:
            I = list((torch.rand(B) * self.i).long())
        return (self.s[I], self.a[I], self.r[I], self.s_[I], self.f[I])

def trial(env,agent, T, memory):
    totalR = 0
    s, info = env.reset()
    s = torch.Tensor(s)
    for _ in range(1000):
        a = agent.sample(s,T)
        s_, r, terminated, truncated, info = env.step(a)
        s_ = torch.Tensor(s_)

        memory.push(s,a,r,s_,terminated or truncated)
        totalR+=r
        if terminated or truncated:
            return totalR
        else:
            s = s_

def train(agent,T,memory,nbstep):
    optimizer = schedulefree.AdamWScheduleFree(agent.parameters(), lr=0.001)

    meanloss = torch.zeros(nbstep)
    for step in range(nbstep):
        B = memory.getBatch()
        S, A, R, S_,F = B

        Q = agent.Q(S)
        QA = (Q * A).sum(1)
        Q_ = agent.V(S_,T)
        loss = ((GAMMA * Q_ * F + R - QA)**2).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            meanloss[step] = loss.clone()
    return float(meanloss.mean())


In [10]:
def leakyRelu(x):
    return torch.minimum(x,x*0.2)

class Block(torch.nn.Module):
    def __init__(self):
        super(Block,self).__init__()

        self.f1 =torch.nn.Linear(24,8)
        self.f2 =torch.nn.Linear(8,24)
        self.f3 =torch.nn.Linear(24,16)

    def forward(self,x):
        f = leakyRelu(self.f1(x))
        f = leakyRelu(self.f2(f))
        f = leakyRelu(self.f3(f))

        tmp = torch.zeros(x.shape[0],8)
        f = torch.cat([tmp,f],dim=1)
        return x+f


class LunarAgent(torch.nn.Module):
    def __init__(self):
        super(LunarAgent,self).__init__()

        self.b1 = Block()
        self.b2 = Block()
        self.b3 = Block()
        self.b4 = Block()
        self.b5 = Block()
        self.b6 = Block()
        self.b7 = Block()

        self.A =torch.nn.Linear(24,4)

    def forward(self,x):
        code = torch.zeros(x.shape[0],16)
        x = torch.cat([x,code],dim=1)

        x = self.b1(x)
        x = self.b2(x)
        x = self.b3(x)
        x = self.b4(x)
        x = self.b5(x)
        x = self.b6(x)
        x = self.b7(x)

        return self.A(x)

    def Q(self,x):
        return self.forward(x)

    def pi(self,Q,T):
        return torch.nn.functional.softmax(T*Q,dim=1)

    def V(self,x,T):
        Q =self.Q(x)
        pi = self.pi(Q,T)
        return (Q*pi).sum(1)

    def sample(self,x,T):
        with torch.no_grad():
            pi = self.pi(self.Q(x.view(1,-1)),T)
            return int(torch.multinomial(pi, num_samples=1))

In [11]:
env = gymnasium.make("LunarLander-v3")

T = 0.25
agent = LunarAgent()

for j in range(10):
    T = 1.3*T
    if T>1.5:
        memory = MemoryBuffer()
        for _ in range(30):
            trial(env,agent,T,memory)
    for i in range(20+j*5):
        v = trial(env,agent,T,memory)
        l = train(agent,T,memory, nbstep=100+20*j)
        if i%5==0:
            print("\t",v,l)
    tot = 0
    for _ in range(30):
        tot+=trial(env,agent,T,memory)
    print(T,tot/30)


	 -211.9193267100821 140.1090087890625
	 -347.62495157637306 15.792109489440918
	 -81.39771774836467 10.606745719909668
	 -361.10400774616386 9.039100646972656
0.35 -143.03772432824124
	 -267.6278156205011 16.52506446838379
	 -61.314386529326846 9.003512382507324
	 -119.36474019174743 6.102186679840088
	 -117.87790837391309 5.420865058898926
	 -96.81846960902439 6.345512390136719
0.48999999999999994 -111.75985960088568
	 -219.13100086167472 10.215813636779785
	 -57.59028893787283 7.810074806213379
	 -92.70094242732058 6.986778736114502
	 -76.13936443465332 7.206244945526123
	 -80.36913743460899 6.9672932624816895
	 -87.0052326514288 5.021262168884277
0.6859999999999998 -94.8553551446476
	 -63.96503796850151 5.250551700592041
	 -90.19385198202752 4.089931964874268
	 -85.80880746247459 3.295361042022705
	 -80.34869231495858 3.590965509414673
	 -92.27028194398527 2.588238477706909
	 -48.425613225193715 2.8561806678771973
	 -88.20360997034535 3.207838535308838
0.9603999999999997 -63.671461

KeyboardInterrupt: 