In [1]:
import torch
from board import move
from board import create_action_mask
import torch.nn as nn
from torch.distributions.categorical import Categorical
import sys
from model import BoardGFLowNet
from board import random_board, get_reward, move, create_action_mask
import wandb



  from .autonotebook import tqdm as notebook_tqdm


In [12]:
def loss_fn(predicted_logZ: torch.Tensor, 
            reward: torch.Tensor, 
            forward_probabilities: list):
    
    log_Pf = sum(list(map(torch.log, forward_probabilities)))
    inner = predicted_logZ + log_Pf - torch.log(reward) 
    return inner ** 2

def sample_move(boards: torch.Tensor, 
                logits: torch.Tensor, 
                at_step_limit: bool):
    
    batch_size, _, _ = boards.shape
    last_logits = logits[:, -1, :]
    
    if(at_step_limit):
        mask = torch.ones(6) * -1e20
        mask[1] = 0
        mask = mask.expand((batch_size, 6))
    else:
        mask = create_action_mask(boards)
    
    last_logits = torch.softmax(mask + last_logits, dim=1)
    new_moves = Categorical(probs=last_logits).sample()
    new_moves = torch.Tensor(new_moves).type(torch.LongTensor)
    return new_moves, last_logits[torch.arange(batch_size), new_moves]

def train(
    lr=1e-4,
    decoder_layers=3,
    encoder_layers=3,
    embed_dim=32,
    d_ff=32,
    n_heads=8,
    batch_size=16,
    side_len=3,
    max_steps=20,
    total_batches=1000,
    checkpoint_freq=10,
    ):
    
    gfn = BoardGFLowNet(side_len, embed_dim, d_ff, n_heads, encoder_layers, decoder_layers, 6)
    optimizer = torch.optim.Adam(gfn.parameters(), lr=lr)

    

    for batch in range(total_batches):
    
        boards = random_board(batch_size, side_len) 
        finished = torch.zeros((batch_size, 1))
        moves = torch.zeros(batch_size, 1).type(torch.LongTensor)
        forward_probabilities = torch.ones(batch_size, 1)

        predicted_logZ, _ = gfn(boards, moves)
        print(predicted_logZ.shape)

        

        batch_loss = 0
        batch_reward = 0
        batch_matching = 0
        

        for i in range(max_steps):
            _, logits = gfn(boards, moves)
            print('logits.shape:',logits.shape)
            print('boards.shape:', boards.shape)
            
            new_move, move_prob = sample_move(boards, logits, i == max_steps-1)
            
            print(new_move, move_prob)
            print(move_prob.unsqueeze(1).shape)
            print(forward_probabilities.shape)
            forward_probabilities = torch.cat([forward_probabilities, move_prob.unsqueeze(1)], dim=1)
            print(forward_probabilities)
            moves = torch.cat([moves, new_move.unsqueeze(1)], dim=1)
            print(moves)
            boards = boards.clone()
            boards = move(boards, new_move)
        
        sys.exit()

        reward, matching = get_reward(boards)
        loss = loss_fn(predicted_logZ, reward, forward_probabilities)
        loss.backward(retain_graph=False)

        batch_reward += reward
        batch_matching += matching
        batch_loss += loss
        
        optimizer.step()
        optimizer.zero_grad()
        batch_reward = batch_reward.item() / batch_size
        batch_matching = batch_matching.item() / batch_size
        batch_loss = batch_loss.item() / batch_size
        print(f'Batch {batch}, loss: {batch_loss}, reward: {batch_reward}, Matching: {batch_matching}')
        if((batch+1) % checkpoint_freq == 0):
            torch.save(gfn.state_dict(), f'checkpoints/model_step_{batch}.pt')

        

In [13]:
b, m = train()

torch.Size([16, 1])
logits.shape: torch.Size([16, 1, 6])
boards.shape: torch.Size([16, 3, 3])
tensor([2, 3, 5, 4, 5, 2, 3, 3, 1, 3, 5, 3, 3, 3, 3, 5]) tensor([0.3442, 0.7876, 0.3355, 0.1824, 0.3165, 0.4116, 0.5805, 0.5841, 0.0446,
        0.4295, 0.2068, 0.6123, 0.4897, 0.4941, 0.6066, 0.4458],
       grad_fn=<IndexBackward0>)
torch.Size([16, 1])
torch.Size([16, 1])
tensor([[1.0000, 0.3442],
        [1.0000, 0.7876],
        [1.0000, 0.3355],
        [1.0000, 0.1824],
        [1.0000, 0.3165],
        [1.0000, 0.4116],
        [1.0000, 0.5805],
        [1.0000, 0.5841],
        [1.0000, 0.0446],
        [1.0000, 0.4295],
        [1.0000, 0.2068],
        [1.0000, 0.6123],
        [1.0000, 0.4897],
        [1.0000, 0.4941],
        [1.0000, 0.6066],
        [1.0000, 0.4458]], grad_fn=<CatBackward0>)
tensor([[0, 2],
        [0, 3],
        [0, 5],
        [0, 4],
        [0, 5],
        [0, 2],
        [0, 3],
        [0, 3],
        [0, 1],
        [0, 3],
        [0, 5],
        [0, 3]

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [5]:
boards = random_board(100, 4)

In [8]:
boards[-1]

tensor([[ 0,  4,  5,  1],
        [ 6,  2, 11,  3],
        [ 8, 14, 13,  7],
        [ 9, 12, 10, 15]])

In [4]:
torch.arange(16).reshape(4,4).expand(100, 4,4)

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        ...,

        [[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]]])

In [82]:
DIR_UP = 2
DIR_DOWN = 3
DIR_RIGHT = 4
DIR_LEFT = 5


In [99]:
b, m

(tensor([[[2, 4, 3],
          [8, 0, 1],
          [6, 7, 5]],
 
         [[0, 3, 2],
          [7, 5, 8],
          [6, 1, 4]],
 
         [[4, 7, 3],
          [0, 8, 5],
          [2, 6, 1]],
 
         [[5, 6, 1],
          [8, 3, 2],
          [7, 0, 4]],
 
         [[5, 3, 1],
          [6, 7, 0],
          [2, 4, 8]],
 
         [[3, 7, 2],
          [6, 1, 8],
          [5, 4, 0]],
 
         [[1, 4, 3],
          [8, 5, 0],
          [6, 7, 2]],
 
         [[4, 2, 7],
          [8, 6, 3],
          [0, 1, 5]],
 
         [[1, 2, 5],
          [0, 3, 7],
          [4, 6, 8]],
 
         [[3, 0, 1],
          [5, 4, 2],
          [6, 8, 7]],
 
         [[4, 5, 1],
          [2, 3, 0],
          [6, 7, 8]],
 
         [[3, 1, 4],
          [6, 5, 7],
          [0, 8, 2]],
 
         [[6, 0, 2],
          [5, 1, 8],
          [4, 7, 3]],
 
         [[5, 4, 3],
          [1, 6, 0],
          [7, 8, 2]],
 
         [[5, 3, 8],
          [1, 6, 0],
          [7, 4, 2]],
 
         [

In [101]:
move(b,m)

tensor([[[2, 4, 3],
         [8, 7, 1],
         [6, 0, 5]],

        [[0, 3, 2],
         [7, 5, 8],
         [6, 1, 4]],

        [[0, 7, 3],
         [4, 8, 5],
         [2, 6, 1]],

        [[5, 6, 1],
         [8, 3, 2],
         [7, 0, 4]],

        [[5, 3, 0],
         [6, 7, 1],
         [2, 4, 8]],

        [[3, 7, 2],
         [6, 1, 8],
         [5, 4, 0]],

        [[1, 4, 3],
         [8, 5, 0],
         [6, 7, 2]],

        [[4, 2, 7],
         [8, 6, 3],
         [0, 1, 5]],

        [[1, 2, 5],
         [4, 3, 7],
         [0, 6, 8]],

        [[3, 0, 1],
         [5, 4, 2],
         [6, 8, 7]],

        [[4, 5, 1],
         [2, 3, 8],
         [6, 7, 0]],

        [[3, 1, 4],
         [6, 5, 7],
         [0, 8, 2]],

        [[6, 0, 2],
         [5, 1, 8],
         [4, 7, 3]],

        [[5, 4, 0],
         [1, 6, 3],
         [7, 8, 2]],

        [[5, 3, 0],
         [1, 6, 8],
         [7, 4, 2]],

        [[2, 5, 0],
         [1, 7, 4],
         [3, 6, 8]]])

In [4]:

b[1],g[1]

NameError: name 'g' is not defined

In [55]:
new_moves, last_logits = sample_move(b, l, False)

In [56]:
last_logits

tensor([0.6787, 0.2406, 0.4327, 0.5036, 0.1199, 0.5295, 0.1566, 0.1077, 0.4292,
        0.4017, 0.1154, 0.2412, 0.3354, 0.3584, 0.0835, 0.2243],
       grad_fn=<IndexBackward0>)

In [57]:
new_moves

tensor([3, 2, 2, 2, 1, 3, 4, 4, 2, 3, 1, 1, 3, 2, 1, 2])

In [30]:
new_moves, last_logits

(tensor([3, 3, 2, 1, 3, 3, 4, 2, 5, 5, 3, 4, 5, 5, 3, 3]),
 tensor([[0.0000, 0.2310, 0.5036, 0.0000, 0.2654, 0.0000],
         [0.0000, 0.2310, 0.5036, 0.0000, 0.2654, 0.0000],
         [0.0000, 0.1742, 0.4327, 0.0000, 0.0000, 0.3931],
         [0.0000, 0.0928, 0.2406, 0.5572, 0.1094, 0.0000],
         [0.0000, 0.2310, 0.5036, 0.0000, 0.2654, 0.0000],
         [0.0000, 0.2310, 0.5036, 0.0000, 0.2654, 0.0000],
         [0.0000, 0.1199, 0.0000, 0.5459, 0.1338, 0.2004],
         [0.0000, 0.1742, 0.4327, 0.0000, 0.0000, 0.3931],
         [0.0000, 0.1117, 0.2370, 0.5295, 0.1218, 0.0000],
         [0.0000, 0.1117, 0.2370, 0.5295, 0.1218, 0.0000],
         [0.0000, 0.2310, 0.5036, 0.0000, 0.2654, 0.0000],
         [0.0000, 0.1199, 0.0000, 0.5459, 0.1338, 0.2004],
         [0.0000, 0.1117, 0.2370, 0.5295, 0.1218, 0.0000],
         [0.0000, 0.1117, 0.2370, 0.5295, 0.1218, 0.0000],
         [0.0000, 0.2310, 0.5036, 0.0000, 0.2654, 0.0000],
         [0.0000, 0.2310, 0.5036, 0.0000, 0.2654, 0.0000