In [1]:
!pip install torch



In [2]:
import numpy as np
import torch
import torch.nn as nn 

In [3]:
from collections import namedtuple, deque
from itertools import count
import random
from tqdm import tqdm

In [4]:
class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()

        self.downsample = torch.nn.MaxPool2d(2)

        self.block_1 = nn.Sequential(
            nn.Conv2d(6, 64, 3, padding="same"),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding="same"),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding="same"),
            nn.ReLU()
        )

        self.block_2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding="same"),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding="same"),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding="same"),
            nn.ReLU()
        )

        self.block_3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding="same"),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding="same"),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding="same"),
            nn.ReLU()
        )

        self.upsample_1 = nn.ConvTranspose2d(256, 128, 2, stride=2)

        self.up_block_1 = nn.Sequential(
            nn.Conv2d(256, 128, 3, padding="same"),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding="same"),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding="same"),
            nn.ReLU()
        )

        self.upsample_2 = nn.ConvTranspose2d(128, 64, 2, stride=2)

        self.up_block_2 = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding="same"),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding="same"),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding="same"),
            nn.ReLU()
        )

        self.output_layer = nn.Conv2d(64, 6, 1, padding="same")


    def forward(self, x):
        h1 = self.block_1(x)
        d1 = self.downsample(h1)
        h2 = self.block_2(d1)
        d2 = self.downsample(h2)
        h3 = self.block_3(d2)

        u1 = self.upsample_1(h3)
        c1 = torch.cat( (h2, u1), dim=-3)
        h4 = self.up_block_1(c1)
        u2 = self.upsample_2(h4)
        c2 = torch.cat( (h1, u2), dim=-3)
        h5 = self.up_block_2(c2)
        output = self.output_layer(h5)
        return output

        

In [5]:
def sym_score(x):
    x = torch.argmax(x, dim=-3)
    reflected_x = torch.transpose(torch.flip(x, dims=[-1, -2]), -1, -2)
    return torch.sum(x != reflected_x)

In [6]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward', 'terminal'))


In [7]:
class Memory_Buffer():
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [8]:
policy_net = Unet()
target_net = Unet()
target_net.load_state_dict(policy_net.state_dict())
replay_buffer = Memory_Buffer(10_000)
optimizer = torch.optim.Adam(policy_net.parameters())
loss_function = nn.MSELoss()

In [9]:
epsilon = 1.0
tau = 0.005

In [10]:
for episode in tqdm(range(100_000)):
    epsilon = max(epsilon - 1/100_000, 0.05)
    blank_planes = torch.zeros(5, 16, 16)
    empty_plane = torch.ones(1, 16, 16)
    state = torch.cat((empty_plane, blank_planes), dim=0)
    for step in range(64):
        if np.random.random() > epsilon:
            q_values = policy_net(state)
            # mask q_values
            q_values = q_values * (1 - state)
            action = torch.tensor(torch.unravel_index(torch.argmax(q_values), q_values.shape))
        else:
            x = np.random.randint(16)
            y = np.random.randint(16)
            plane = np.random.randint(5)
            if state[plane, x, y] == 1:
                plane = 5
            action = torch.tensor([torch.tensor(plane), torch.tensor(x), torch.tensor(y)])
        old_state = state.clone()
        state[:, action[1], action[2]] = 0
        state[action[0], action[1], action[2]] = 1

        reward = torch.tensor(sym_score(state) - sym_score(old_state))

        if step == 63:
            terminal = torch.tensor(1)
        else:
            terminal = torch.tensor(0)
        replay_buffer.push(old_state, action, state, reward, terminal)

        if step % 4 == 0 and len(replay_buffer) > 1_000:
            transitions = replay_buffer.sample(16)
            batch = Transition(*zip(*transitions))

            state_batch = torch.stack(batch.state)
            action_batch = torch.stack(batch.action)
            reward_batch = torch.stack(batch.reward)
            next_state_batch = torch.stack(batch.next_state)
            terminal_batch = torch.stack(batch.terminal)

            # compute targets
            with torch.no_grad():
                next_q = target_net(next_state_batch) * (1- next_state_batch)
                next_q_maxes = torch.max(next_q.view(16, -1), dim=1)[0]
                target = reward_batch + (next_q_maxes * (1 - terminal_batch))
                
            # compute TD error
            predicted_q_values = policy_net(state_batch)[torch.arange(16), action_batch[:, 0], action_batch[:, 1], action_batch[:, 2]]
            loss = loss_function(predicted_q_values, target)

            # update policy network
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
            optimizer.step()

            # update target network
            target_net_state_dict = target_net.state_dict()
            policy_net_state_dict = policy_net.state_dict()
            for key in policy_net_state_dict:
                target_net_state_dict[key] = policy_net_state_dict[key]*tau + target_net_state_dict[key]*(1-tau)
            target_net.load_state_dict(target_net_state_dict)

  reward = torch.tensor(sym_score(state) - sym_score(old_state))
  0%|          | 15/100000 [00:00<23:14, 71.68it/s]


NameError: name 'TAU' is not defined