In [1]:
import torch
from board import move
from board import create_action_mask
import torch.nn as nn
from torch.distributions.categorical import Categorical
import sys
from model import BoardGFLowNet
from board import random_board, get_reward, move, create_action_mask
import wandb



  from .autonotebook import tqdm as notebook_tqdm


In [56]:

def sample_move(boards: torch.Tensor, 
                logits: torch.Tensor, 
                at_step_limit: bool):
    
    batch_size, _, _ = boards.shape
    last_logits = logits[:, -1, :]
    
    if(at_step_limit):
        mask = torch.ones(6) * -1e20
        mask[1] = 0
        mask = mask.expand((batch_size, 6))
    else:
        mask = create_action_mask(boards)
    
    last_logits = torch.softmax(mask + last_logits, dim=1)
    new_moves = Categorical(probs=last_logits).sample()
    new_moves = torch.Tensor(new_moves).type(torch.LongTensor)
    return new_moves, last_logits[torch.arange(batch_size), new_moves]

def loss_fn(predicted_logZ: torch.Tensor, 
            reward: torch.Tensor, 
            forward_probabilities: torch.Tensor):
    
    log_Pf = torch.log(forward_probabilities).sum(dim=1)
    inner = predicted_logZ.squeeze() + log_Pf - torch.log(reward) 
    return inner ** 2

def train(
    lr=1e-4,
    decoder_layers=3,
    encoder_layers=3,
    embed_dim=32,
    d_ff=32,
    n_heads=8,
    batch_size=16,
    side_len=3,
    max_steps=20,
    total_batches=1000,
    checkpoint_freq=10,
    ):
    
    gfn = BoardGFLowNet(side_len, embed_dim, d_ff, n_heads, encoder_layers, decoder_layers, 6)
    optimizer = torch.optim.Adam(gfn.parameters(), lr=lr)

    for batch in range(total_batches):
    
        boards = random_board(batch_size, side_len) 
        finished = torch.zeros((batch_size, 1))
        moves = torch.zeros(batch_size, 1).type(torch.LongTensor)
        forward_probabilities = torch.ones(batch_size, 1)

        predicted_logZ, _ = gfn(boards, moves)
        batch_loss = 0
        batch_reward = 0
        batch_matching = 0
        
        for i in range(max_steps):
            _, logits = gfn(boards, moves)
            new_move, move_prob = sample_move(boards, logits, i == max_steps-1)

            for index in range(len(move_prob)):
                if(finished[index] == 1):
                    move_prob[index] = 1
            for index, _move in enumerate(new_move):
                if(_move == 1):
                    finished[index] = 1

            forward_probabilities = torch.cat([forward_probabilities, move_prob.unsqueeze(1)], dim=1)
            moves = torch.cat([moves, new_move.unsqueeze(1)], dim=1)
            boards = boards.clone()
            boards = move(boards, new_move, finished_mask=finished)
        
        reward, matching = get_reward(boards)
        loss = loss_fn(predicted_logZ, reward, forward_probabilities)
        loss = torch.sum(loss)
        reward = torch.sum(reward)
        matching = torch.sum(matching)
        loss.backward(retain_graph=False)
        batch_reward += reward
        batch_matching += matching
        batch_loss += loss
        optimizer.step()
        optimizer.zero_grad()
        batch_reward = batch_reward.item() / batch_size
        batch_matching = batch_matching.item() / batch_size
        batch_loss = batch_loss.item() / batch_size
        print(f'Batch {batch}, loss: {batch_loss}, reward: {batch_reward}, Matching: {batch_matching}')
        if((batch+1) % checkpoint_freq == 0):
            torch.save(gfn.state_dict(), f'checkpoints/model_step_{batch}.pt')

In [57]:
l = train()

Batch 0, loss: 37.42610549926758, reward: 0.0006950950482860208, Matching: 1.4375
Batch 1, loss: 42.69380187988281, reward: 0.001965435454621911, Matching: 1.6875
Batch 2, loss: 57.75373077392578, reward: 0.0008981323335319757, Matching: 1.125
Batch 3, loss: 77.88799285888672, reward: 0.0010476823663339019, Matching: 1.6875
Batch 4, loss: 65.56709289550781, reward: 0.0007311212830245495, Matching: 1.5
Batch 5, loss: 59.6051139831543, reward: 0.0017600570572540164, Matching: 1.625
Batch 6, loss: 75.13094329833984, reward: 0.0018536229617893696, Matching: 1.375
Batch 7, loss: 102.02838897705078, reward: 0.003063319716602564, Matching: 2.0
Batch 8, loss: 62.86403274536133, reward: 0.0008039365638978779, Matching: 1.1875
Batch 9, loss: 85.043701171875, reward: 0.0010099445935338736, Matching: 1.4375
Batch 10, loss: 73.5926742553711, reward: 0.0043679517693817616, Matching: 1.8125
Batch 11, loss: 56.23155975341797, reward: 0.0015761922113597393, Matching: 1.1875
Batch 12, loss: 77.603675842

KeyboardInterrupt: 

In [24]:
def get_reward(boards: torch.Tensor):
    batch_size, _, side_len = boards.shape
    ground_truth = torch.arange(0, side_len**2).reshape(side_len,side_len).expand_as(boards)
    mismatch = boards - ground_truth
    
    match = mismatch == 0
    mismatch = mismatch != 0
    num_mismatch = mismatch.flatten(1).count_nonzero(1)
    num_match = match.flatten(1).count_nonzero(1)
    reward = torch.exp(-num_mismatch) 
    return reward, num_match

In [25]:
get_reward(b)

(tensor([0.0009, 0.0009, 0.0009, 0.0003, 0.0001, 0.0001, 0.0003, 0.0001, 0.0025,
         0.0025, 0.0003, 0.0003, 0.0001, 0.0001, 0.0001, 0.0009]),
 tensor([2, 2, 2, 1, 0, 0, 1, 0, 3, 3, 1, 1, 0, 0, 0, 2]))

In [5]:
boards = random_board(100, 4)

In [8]:
boards[-1]

tensor([[ 0,  4,  5,  1],
        [ 6,  2, 11,  3],
        [ 8, 14, 13,  7],
        [ 9, 12, 10, 15]])

In [4]:
torch.arange(16).reshape(4,4).expand(100, 4,4)

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        ...,

        [[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]]])

In [82]:
DIR_UP = 2
DIR_DOWN = 3
DIR_RIGHT = 4
DIR_LEFT = 5


In [99]:
b, m

(tensor([[[2, 4, 3],
          [8, 0, 1],
          [6, 7, 5]],
 
         [[0, 3, 2],
          [7, 5, 8],
          [6, 1, 4]],
 
         [[4, 7, 3],
          [0, 8, 5],
          [2, 6, 1]],
 
         [[5, 6, 1],
          [8, 3, 2],
          [7, 0, 4]],
 
         [[5, 3, 1],
          [6, 7, 0],
          [2, 4, 8]],
 
         [[3, 7, 2],
          [6, 1, 8],
          [5, 4, 0]],
 
         [[1, 4, 3],
          [8, 5, 0],
          [6, 7, 2]],
 
         [[4, 2, 7],
          [8, 6, 3],
          [0, 1, 5]],
 
         [[1, 2, 5],
          [0, 3, 7],
          [4, 6, 8]],
 
         [[3, 0, 1],
          [5, 4, 2],
          [6, 8, 7]],
 
         [[4, 5, 1],
          [2, 3, 0],
          [6, 7, 8]],
 
         [[3, 1, 4],
          [6, 5, 7],
          [0, 8, 2]],
 
         [[6, 0, 2],
          [5, 1, 8],
          [4, 7, 3]],
 
         [[5, 4, 3],
          [1, 6, 0],
          [7, 8, 2]],
 
         [[5, 3, 8],
          [1, 6, 0],
          [7, 4, 2]],
 
         [

In [101]:
move(b,m)

tensor([[[2, 4, 3],
         [8, 7, 1],
         [6, 0, 5]],

        [[0, 3, 2],
         [7, 5, 8],
         [6, 1, 4]],

        [[0, 7, 3],
         [4, 8, 5],
         [2, 6, 1]],

        [[5, 6, 1],
         [8, 3, 2],
         [7, 0, 4]],

        [[5, 3, 0],
         [6, 7, 1],
         [2, 4, 8]],

        [[3, 7, 2],
         [6, 1, 8],
         [5, 4, 0]],

        [[1, 4, 3],
         [8, 5, 0],
         [6, 7, 2]],

        [[4, 2, 7],
         [8, 6, 3],
         [0, 1, 5]],

        [[1, 2, 5],
         [4, 3, 7],
         [0, 6, 8]],

        [[3, 0, 1],
         [5, 4, 2],
         [6, 8, 7]],

        [[4, 5, 1],
         [2, 3, 8],
         [6, 7, 0]],

        [[3, 1, 4],
         [6, 5, 7],
         [0, 8, 2]],

        [[6, 0, 2],
         [5, 1, 8],
         [4, 7, 3]],

        [[5, 4, 0],
         [1, 6, 3],
         [7, 8, 2]],

        [[5, 3, 0],
         [1, 6, 8],
         [7, 4, 2]],

        [[2, 5, 0],
         [1, 7, 4],
         [3, 6, 8]]])

In [4]:

b[1],g[1]

NameError: name 'g' is not defined

In [55]:
new_moves, last_logits = sample_move(b, l, False)

In [56]:
last_logits

tensor([0.6787, 0.2406, 0.4327, 0.5036, 0.1199, 0.5295, 0.1566, 0.1077, 0.4292,
        0.4017, 0.1154, 0.2412, 0.3354, 0.3584, 0.0835, 0.2243],
       grad_fn=<IndexBackward0>)

In [57]:
new_moves

tensor([3, 2, 2, 2, 1, 3, 4, 4, 2, 3, 1, 1, 3, 2, 1, 2])

In [30]:
new_moves, last_logits

(tensor([3, 3, 2, 1, 3, 3, 4, 2, 5, 5, 3, 4, 5, 5, 3, 3]),
 tensor([[0.0000, 0.2310, 0.5036, 0.0000, 0.2654, 0.0000],
         [0.0000, 0.2310, 0.5036, 0.0000, 0.2654, 0.0000],
         [0.0000, 0.1742, 0.4327, 0.0000, 0.0000, 0.3931],
         [0.0000, 0.0928, 0.2406, 0.5572, 0.1094, 0.0000],
         [0.0000, 0.2310, 0.5036, 0.0000, 0.2654, 0.0000],
         [0.0000, 0.2310, 0.5036, 0.0000, 0.2654, 0.0000],
         [0.0000, 0.1199, 0.0000, 0.5459, 0.1338, 0.2004],
         [0.0000, 0.1742, 0.4327, 0.0000, 0.0000, 0.3931],
         [0.0000, 0.1117, 0.2370, 0.5295, 0.1218, 0.0000],
         [0.0000, 0.1117, 0.2370, 0.5295, 0.1218, 0.0000],
         [0.0000, 0.2310, 0.5036, 0.0000, 0.2654, 0.0000],
         [0.0000, 0.1199, 0.0000, 0.5459, 0.1338, 0.2004],
         [0.0000, 0.1117, 0.2370, 0.5295, 0.1218, 0.0000],
         [0.0000, 0.1117, 0.2370, 0.5295, 0.1218, 0.0000],
         [0.0000, 0.2310, 0.5036, 0.0000, 0.2654, 0.0000],
         [0.0000, 0.2310, 0.5036, 0.0000, 0.2654, 0.0000