<a href="https://colab.research.google.com/github/achanhon/coursdeeplearningcolab/blob/master/rl_sample.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install gymnasium
!pip install swig
!pip install "gymnasium[box2d]"
!pip install schedulefree

Collecting gymnasium
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m958.1/958.1 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium
Successfully installed farama-notifications-0.0.4 gymnasium-1.0.0
Collecting swig
  Downloading swig-4.2.1-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl.metadata (3.6 kB)
Downloading swig-4.2.1-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m25.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: swig
Successfully installed swig-4.2.1
Collecti

In [2]:
import torch
import gymnasium
import schedulefree

GAMMA=0.99

class MemoryBuffer:
    def __init__(self):
        self.i = 0
        self.full = False

        L=10000
        self.s = torch.zeros(L,8)
        self.a = torch.zeros(L,4)
        self.r = torch.zeros(L)
        self.s_ = torch.zeros(L,8)
        self.f = torch.zeros(L)

    def push(self, s, a, r, s_, f):
        self.s[self.i] = s
        self.a[self.i][a] = 1
        self.r[self.i] = r
        self.s_[self.i] = s_
        self.f[self.i] = 1-f
        self.i += 1
        if self.i >= self.r.shape[0]:
            self.full = True
            self.i = 0

    def getBatch(self, B=64):
        if self.full:
            I = list((torch.rand(B) * self.r.shape[0]).long())
        else:
            I = list((torch.rand(B) * self.i).long())
        return (self.s[I], self.a[I], self.r[I], self.s_[I], self.f[I])

def trial(env,agent, T, memory):
    totalR = 0
    s, info = env.reset()
    s = torch.Tensor(s)
    for _ in range(1000):
        a = agent.sample(s,T)
        s_, r, terminated, truncated, info = env.step(a)
        s_ = torch.Tensor(s_)

        memory.push(s,a,r,s_,terminated or truncated)
        totalR+=r
        if terminated or truncated:
            return totalR
        else:
            s = s_

def train(agent,T,memory,nbstep):
    optimizer = schedulefree.AdamWScheduleFree(agent.parameters(), lr=0.001)

    meanloss = torch.zeros(nbstep)
    for step in range(nbstep):
        B = memory.getBatch()
        S, A, R, S_,F = B

        Q = agent.Q(S)
        QA = (Q * A).sum(1)
        Q_ = agent.V(S_,T)
        loss = ((GAMMA * Q_ * F + R - QA)**2).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            meanloss[step] = loss.clone()
    return float(meanloss.mean())


In [3]:
def leakyRelu(x):
    return torch.minimum(x,x*0.2)

class Block(torch.nn.Module):
    def __init__(self):
        super(Block,self).__init__()

        self.f1 =torch.nn.Linear(24,8)
        self.f2 =torch.nn.Linear(8,24)
        self.f3 =torch.nn.Linear(24,16)

    def forward(self,x):
        f = leakyRelu(self.f1(x))
        f = leakyRelu(self.f2(f))
        f = leakyRelu(self.f3(f))

        tmp = torch.zeros(x.shape[0],8)
        f = torch.cat([tmp,f],dim=1)
        return x+f


class LunarAgent(torch.nn.Module):
    def __init__(self):
        super(LunarAgent,self).__init__()

        self.b1 = Block()
        self.b2 = Block()
        self.b3 = Block()
        self.b4 = Block()
        self.b5 = Block()
        self.b6 = Block()
        self.b7 = Block()

        self.A =torch.nn.Linear(24,4)

    def forward(self,x):
        code = torch.zeros(x.shape[0],16)
        x = torch.cat([x,code],dim=1)

        x = self.b1(x)
        x = self.b2(x)
        x = self.b3(x)
        x = self.b4(x)
        x = self.b5(x)
        x = self.b6(x)
        x = self.b7(x)

        return self.A(x)

    def Q(self,x):
        return self.forward(x)

    def pi(self,Q,T):
        return torch.nn.functional.softmax(T*Q,dim=1)

    def V(self,x,T):
        Q =self.Q(x)
        pi = self.pi(Q,T)
        return (Q*pi).sum(1)

    def sample(self,x,T):
        with torch.no_grad():
            pi = self.pi(self.Q(x.view(1,-1)),T)
            return int(torch.multinomial(pi, num_samples=1))

In [None]:
env = gymnasium.make("LunarLander-v3", continuous=False, gravity=-8.0,enable_wind=False)

T = 0.25
agent = LunarAgent()

for j in range(10):
    memory = MemoryBuffer()
    for _ in range(50):
        trial(env,agent,T,memory)
    for i in range(20+j*5):
        v = trial(env,agent,T,memory)
        l = train(agent,T,memory, nbstep=100+30*j)
        print("\t",v,l)
    tot,tot_ = 0,0
    for _ in range(100):
        tot+=trial(env,agent,T,memory)
        tot_+=trial(env,agent,T*1.3,memory)
    print(T,tot/100,tot_/100)
    if tot<tot_:
        T = 1.3*T


	 -345.7655706821955 102.96977233886719
	 -97.91735229908537 98.11039733886719
	 -437.0885561939049 69.7258071899414
	 -433.4941312467837 19.52145767211914
	 -411.1011152998981 13.909363746643066
	 -203.8995195297377 9.448213577270508
	 -65.91881972040338 10.514863014221191
	 -311.38980717559537 10.197474479675293
	 -14.379819660352013 10.75926399230957
	 -85.24373970556735 9.943022727966309
	 -34.22665250400922 10.828673362731934
	 -66.65336622996215 9.050848007202148
	 -198.12926428258015 8.090961456298828
	 -115.49384909896432 8.015170097351074
	 -102.02303909749712 6.289062023162842
	 -62.383052718819876 6.777736663818359
	 -179.5925827965764 7.692903518676758
	 -326.34773929284347 7.591897487640381
	 -118.3567101672956 6.818421840667725
	 -311.04607371871026 7.7490925788879395
0.25 -143.20862136961628 -133.82640996549972
	 -217.16701896427233 7.110373020172119
	 -18.852689404791988 6.676052570343018
	 -146.80801904464295 7.142418384552002
	 -232.4611199049043 7.04820442199707
	 -1