<a href="https://colab.research.google.com/github/achanhon/coursdeeplearningcolab/blob/master/Untitled27.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch

class Game:
  def __init__(self,other=None):
    self.A = [(0,1),(-1,0),(1,0),(0,0),(0,2)]
    if other is None:
      self.t = 0
      self.f=False
      self.p = (1,0)
      self.road = (torch.rand(3,60)<0.05).half()
      self.road[:,0:2]=0
    else:
      self.p,self.t,self.f,self.road = other

  def copy(self):
    return Game((self.p,self.t,self.f,self.road))

  def getVisibleState(self):
    return (self.p[0],self.p[1]-self.t), self.road[:,self.t:self.t+7], self.f

  def getVisibleStateString(self):
    p,road = self.getVisibleState()
    string_parts = []
    for r in range(3):
      for c in range(7):
        if road[r][c]==1:
          string_parts.append("o")
        else:
          string_parts.append(" ")
        if r==p[0] and c==p[1]:
          string_parts[-1]="x"
      string_parts.append("\n")
    return "----------\n"+"".join(string_parts)+"----------"

  def update(self,a):
    assert not self.f, print("update a final state")
    dr,dc = self.A[a]
    p = self.p

    # remove forbidden action
    if p[1]==self.t and dc<1:
      dr,dc = 0,1
    if p[0]+dr<0:
      dr,dc = 0,1
    if p[0]+dr>2:
      dr,dc=0,1
    if p[1]+dc>self.t+5:
      dr,dc = 0,1

    if self.road[p[0]+dr][p[1]+dc]==1:
      self.f = True
      return -1
    if dc==2 and self.road[p[0]][p[1]+1]==1:
      self.f = True
      return -1

    self.p = (p[0]+dr,p[1]+dc)
    self.t = self.t+1
    if self.t==50:
      self.f = True
    return 0.06

In [2]:
class RL(torch.nn.Module):
  def toTensor(self,x):
    z = torch.zeros(3,3,7)
    p,road,f = x
    if f:
      return z
    z[2,:,:]=1
    z[0]=road
    z[1][p[0]][p[1]]=1
    return z

  def __init__(self):
    super(RL,self).__init__()

    self.l1 = torch.nn.Conv2d(3,8,kernel_size=3,padding=1,bias=False)
    self.l11 = torch.nn.Conv2d(11,16,kernel_size=3,padding=1,bias=False)
    self.l2 = torch.nn.Conv2d(19,24,kernel_size=3,padding=1,bias=False)
    self.l22 = torch.nn.Conv2d(27,32,kernel_size=3,padding=1,bias=False)

    self.next = torch.nn.Conv2d(35,15,kernel_size=3,padding=1,bias=False)

    self.l3 = torch.nn.Conv2d(35,32,kernel_size=(3,7),bias=False)
    self.l4 = torch.nn.Conv2d(35,32,kernel_size=1,bias=False)

    self.r = torch.nn.Linear(65,5,bias=False)
    self.qa = torch.nn.Linear(65,5,bias=False)

  def forward(self,x):
    z = torch.nn.functional.leaky_relu(self.l1(x))
    z = torch.cat([z,x],dim=1)
    z = torch.nn.functional.leaky_relu(self.l11(z))
    z = torch.cat([z,x],dim=1)
    z = torch.nn.functional.leaky_relu(self.l2(z))
    z = torch.cat([z,x],dim=1)
    z = torch.nn.functional.leaky_relu(self.l22(z))
    z = torch.cat([z,x],dim=1)

    ss = self.next(z)
    ss = torch.clamp(ss, -0.1,1.1) + 0.1*ss

    z1 = torch.nn.functional.leaky_relu(self.l3(z))
    z1 = z1.view(z.shape[0],32)
    z2 = torch.nn.functional.max_pool2d(self.l4(z),kernel_size=(3,7))
    z2 = z2.view(z.shape[0],32)
    tmp = x[:,2,0,0].view(z.shape[0],1)

    z = torch.cat([z1,z2,tmp],dim=1)

    qa = self.qa(z)
    qa = torch.clamp(qa,min=-1.1,max=2)+0.1*qa
    r = self.r(z)
    r = torch.clamp(r,min=-1.1,max=0.1)+0.1*r

    return qa,r,ss.view(x.shape[0],5,3,3,7)

  def toTensorS(self,X):
    Z = torch.zeros(len(X),3,3,7)
    return torch.stack([self.toTensor(x) for x in X])

  def P(self,QA):
    return torch.nn.functional.softmax(QA,1)

  def Q(self,X):
    QA,_,_ = self.forward(X)
    return (QA * self.P(QA)).sum(1)

  def Qa(self,X,a):
    A = torch.zeros(len(X),5).to(device=X.device)
    for i in range(len(X)):
      A[i][a[i]]=1
    QA,R,XX = self.forward(X)
    return (QA*A).sum(1),(R*A).sum(1),(XX*A.view(X.shape[0],5,1,1,1)).sum(1)

  def policy(self,x):
    x = self.toTensor(x).unsqueeze(0)
    qa, _, _ = self.forward(x)
    p = self.P(qa)
    a = torch.multinomial(p, 1).item()
    return a

In [3]:
import random

def smoothL1(x):
  return torch.min(x*x,x.abs()).mean()

def explore(agent, buffer, nbruns):
    agent.cpu()
    averagetotalreward = 0
    for i in range(nbruns):
        game = Game()
        for j in range(1000):
            x = game.getVisibleState()
            a = agent.policy(x)
            r = game.update(a)
            averagetotalreward += r
            buffer.append((x, a, r, game.getVisibleState()))

            if game.f:
              break
    return averagetotalreward / nbruns


def training(agent, buffer, nbsteps,verbose):
    lr = 0.00005
    gamma=0.9
    if torch.cuda.is_available():
      agent.cuda()
    optimizer = torch.optim.Adam(agent.parameters(), lr=lr)

    buffercopy = []
    random.shuffle(buffer)
    meanR, meanRL,meanpred=0,0,0
    for step in range(nbsteps):
        if len(buffer)<64:
          break

        X, A, R, XX = [],[],torch.zeros(64),[]
        for i in range(64):
          x, a, r, xx = buffer.pop()
          X.append(x)
          A.append(a)
          R[i]=r
          XX.append(xx)
          buffercopy.append((x, a, r, xx))

        X = agent.toTensorS(X)
        XX = agent.toTensorS(XX)
        if torch.cuda.is_available():
          X,XX,R = X.cuda(),XX.cuda(),R.cuda()

        QA,Rpred,XXpred = agent.Qa(X,A)
        QXX = agent.Q(XX)

        assert QA.shape==QXX.shape
        assert QA.shape==R.shape
        assert Rpred.shape==R.shape
        assert XXpred.shape==XX.shape

        rloss = smoothL1(R-Rpred)
        XXloss = smoothL1(XX-XXpred)
        RLloss = smoothL1(gamma*QXX+R-QA)

        meanR+=float(rloss)
        meanRL+=float(RLloss)
        meanpred+=float(XXloss)
        if step % 20 == 19 and verbose:
            print("\t", step, meanR/20,meanRL/20,meanpred/20)
            meanR,meanRL,meanpred=0,0,0

        loss = RLloss+rloss+XXloss
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(agent.parameters(), 10)
        optimizer.step()

    return buffercopy+buffer

In [None]:
buffer = []
agent = RL()
score = explore(agent, buffer, 100) #warmup
for i in range(400):
  score = explore(agent, buffer, 10)
  if i%12==0:
      print(i, "score",score, len(buffer))
  buffer = training(agent, buffer, 100, i%12==0)
  if len(buffer)>100000:
    random.shuffle(buffer)
    buffer = buffer[0:100000]

print("final score", explore(agent, buffer, 1000))