In [2]:
from blockulib.classess import PlayingLoop
import blockulib.models as blom
import blockulib
import torch

In [3]:
class ModelBasedLoop(PlayingLoop):
    
    def __init__(self, model_path = "models/conv_model.pth", architecture = blom.ConvModel):
        self.model = architecture()
        state_dict = torch.load(model_path)
        self.model.load_state_dict(state_dict)
        self.generator = blockulib.BlockGenerator()
        self.model.eval()
        
    def __call__(self, num_games = 1, batch_size = 4096, temperature = 1.0, top_k: int = None):
        pos_list = [[torch.zeros(9, 9)] for i in range(num_games)]
        state = [True for i in range(num_games)]
        active_games = num_games
        move = 0
        
        while (active_games > 0):
            move += 1
            new_index = []
            for i in range(num_games):
                if state[i]:
                    new_index.append(i)
            boards = [pos_list[new_index[i]][-1].clone() for i in range(active_games)]
            
            pos, ind = blockulib.possible_moves(boards, self.generator)
            logits = self.get_model_pred(pos, batch_size = batch_size).squeeze(1)
            decisions = blockulib.logits_to_choices(logits, ind, active_games, temperature = temperature, top_k = top_k)
            
            for i in range(active_games):
                if (decisions[i] is None):
                    state[new_index[i]] = False
                    active_games -= 1
                else:
                    pos_list[new_index[i]].append(pos[decisions[i]])
                    
        #print("ended after ", move, " moves")            
        return pos_list
                    
    def get_model_pred(self, data, batch_size, device = None):
        if (data.shape[0] == 0):
            return torch.tensor([[]])
        if device is None:
            device = ("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(device)
        predictions = []
        
        with torch.no_grad():
            for i in range(0, data.shape[0], batch_size):
                batch = data[i:i+batch_size].to(device)
                batch = batch.unsqueeze(1)
                output = self.model(batch)
                predictions.append(output.cpu())
            
        return torch.cat(predictions)

In [4]:
class Probe(ModelBasedLoop):
    
    def __call__(self, pos_tensor, depth = 15, batch_size = 4096, temperature = 1.0, top_k: int = None):
        num_games = pos_tensor.shape[0]
        state = [True for i in range(num_games)]
        game_length = torch.zeros(num_games)
        last_logit = torch.full((num_games, ), float('-inf'))
        active_games = num_games
        
        for d in range(depth):
            new_index = []
            for i in range(num_games):
                if state[i]:
                    new_index.append(i)
            print(active_games, " vs ", len(new_index))
            boards = [pos_tensor[new_index[i]].clone() for i in range(active_games)]
            
            pos, ind = blockulib.possible_moves(boards, self.generator)
            logits = self.get_model_pred(pos, batch_size = batch_size).squeeze(1)
            print(logits.shape)
            decisions = blockulib.logits_to_choices(logits, ind, active_games, temperature = temperature, top_k = top_k)
            
            for i in range(active_games):
                if (decisions[i] is None):
                    state[new_index[i]] = False
                    active_games -= 1
                else:
                    pos_tensor[new_index[i]] = pos[decisions[i]]
                    game_length[new_index[i]] += 1.
                    last_logit[new_index[i]] = logits[i]
            if (active_games == 0):
                print("juz?")
                break
           
        return game_length, last_logit

In [5]:
probe = Probe()

In [6]:
tensor_dir = "data/tensors/"
x_dict = torch.load(tensor_dir + "x.pth")
x = x_dict['x'][3056:3060]
#x = x.unsqueeze(0)
x = x.repeat_interleave(10, dim=0)
print(x.shape)

torch.Size([40, 9, 9])


In [None]:
class DeepSearch(ModelBasedLoop):
    def __init__(self):
        super().__init__(...)
        self.probe = Probe()
    
    def get_model_pred(self, data, batch_size, device = None):
        with torch.no_grad():E
            for i in range(0, data.shape[0], batch_size):
                batch = data[i:i+batch_size].to(device)
                batch = batch.unsqueeze(1)
                output = self.model(batch)
                predictions.append(output.cpu())