In [16]:
import os
import sys
from os.path import join
import json
import torch
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import einops
# Import necessary libraries
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from transformers import GPT2Config, GPT2Model
from tqdm import tqdm, trange
# from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup, AdamW

In [19]:
import sys
sys.path.append("/n/home12/binxuwang/Github/DiffusionReasoning")
from GPT_models.GPT_RAVEN_model_lib import MultiIdxGPT2Model, sample_next_token, seqtsr2imgtsr, seqtsr2attrtsr, completion_eval, preprocess_ids
from rule_new_utils import infer_rule_from_sample_batch, compute_rule_statistics

In [None]:
@torch.no_grad()
def sample_next_token(model, prefix_inputs, max_length=81, strategy="greedy", device="cuda", temperature=1.0, cond=None):
    prefix_inputs = prefix_inputs.to(device)
    model.eval().to(device)
    prefix_length = prefix_inputs.size(1)
    for i in range(max_length - prefix_length):
        outputs, logits1, logits2, logits3 = model(prefix_inputs, y=cond)
        if strategy == "greedy":
            next_token1 = torch.argmax(logits1[:, -1, :], dim=-1, keepdim=True)
            next_token2 = torch.argmax(logits2[:, -1, :], dim=-1, keepdim=True)
            next_token3 = torch.argmax(logits3[:, -1, :], dim=-1, keepdim=True)
        elif strategy == "sample":
            next_token1 = torch.multinomial(F.softmax(logits1[:, -1, :] / temperature, dim=-1), num_samples=1)
            next_token2 = torch.multinomial(F.softmax(logits2[:, -1, :] / temperature, dim=-1), num_samples=1)
            next_token3 = torch.multinomial(F.softmax(logits3[:, -1, :] / temperature, dim=-1), num_samples=1)
        else:
            raise ValueError("Invalid strategy")
        next_token = torch.cat([next_token1, next_token2, next_token3], dim=-1)
        prefix_inputs = torch.cat([prefix_inputs, next_token[:,None,:]], dim=1)
    return prefix_inputs

In [None]:
@torch.no_grad()
def completion_eval(eval_samples, model, cond=None, device='cuda', num_mask=9, batch_size=512, 
                    strategy="greedy", temperature=1.0, return_stats=False):
    eval_samples = eval_samples.to(device)
    if batch_size is None:
        batch_size = eval_samples.size(0)
    eval_complete = []
    for idx in trange(0, eval_samples.size(0), batch_size):
        eval_batch = eval_samples[idx:idx+batch_size]
        cond_batch = cond[idx:idx+batch_size] if cond is not None else None
        eval_complete_batch = sample_next_token(model, eval_batch[:,:-num_mask,:], temperature=temperature,
                                          max_length=81, strategy=strategy, device=device, cond=cond_batch).cpu()
        eval_complete.append(eval_complete_batch)
        
    eval_complete = torch.cat(eval_complete, dim=0)
    # eval_complete = sample_next_token(model, eval_samples[:,:-num_mask,:], 
    #                                   max_length=81, strategy=strategy, device=device).cpu()
    # eval_complete_attr = seqtsr2attrtsr(eval_complete, h=3, w=3, p=3, R=3)
    # note we need to denormalize, offset by - 1
    eval_complete = eval_complete - 1
    eval_complete_img = seqtsr2imgtsr(eval_complete, h=3, w=3, p=3, R=3)
    C3_list, C2_list, rule_col_list = infer_rule_from_sample_batch(eval_complete_img)
    C3_count, C2_count, anyvalid_count, total = compute_rule_statistics(C3_list, C2_list, rule_col_list)
    # final_row = np.array(rule_col_list, dtype=object)[:,-1]
    # anyvalid_count = sum([len(x) > 0 for x in final_row])
    print(f"Completion: C3: {C3_count / total:.3f} [{C3_count}/{total}],  valid: {anyvalid_count / total / 3:.3f} [{anyvalid_count}/{total*3}]")
    if return_stats:
        return eval_complete, C3_list, C2_list, rule_col_list, {"C3": C3_count, "C2": C2_count, "anyvalid": anyvalid_count, "total": total}
    else:
        return eval_complete, C3_list, C2_list, rule_col_list


In [None]:
@torch.no_grad()
def sample_next_token(model, prefix_inputs, max_length=81, strategy="greedy", device="cuda", 
                      temperature=1.0, beam_width=3, cond=None):
    prefix_inputs = prefix_inputs.to(device)
    model.eval().to(device)
    prefix_length = prefix_inputs.size(1)
    
    if strategy == "beam_search":
        # Initialize beams with the prefix inputs and their scores
        beams = [(prefix_inputs, 0)]  # List of (sequence, score)
        
        for i in range(max_length - prefix_length):
            new_beams = []
            for seq, score in beams:
                outputs, logits1, logits2, logits3 = model(seq, y=cond)
                # Compute probabilities and take top beam_width tokens
                logits_combined = torch.cat([logits1[:, -1, :], logits2[:, -1, :], logits3[:, -1, :]], dim=-1)
                probs = F.softmax(logits_combined / temperature, dim=-1)
                top_probs, top_indices = torch.topk(probs, beam_width, dim=-1)
                
                for j in range(beam_width):
                    next_token = top_indices[:, j].unsqueeze(-1)
                    new_seq = torch.cat([seq, next_token[:, None, :]], dim=1)
                    new_score = score + torch.log(top_probs[:, j]).item()
                    new_beams.append((new_seq, new_score))
            
            # Sort all new beams and keep the top beam_width
            new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
            beams = new_beams
        
        # Return the sequence with the highest score
        best_seq = max(beams, key=lambda x: x[1])[0]
        return best_seq
    
    else:
        # Greedy or sampling strategies (existing code)
        for i in range(max_length - prefix_length):
            outputs, logits1, logits2, logits3 = model(prefix_inputs, y=cond)
            if strategy == "greedy":
                next_token1 = torch.argmax(logits1[:, -1, :], dim=-1, keepdim=True)
                next_token2 = torch.argmax(logits2[:, -1, :], dim=-1, keepdim=True)
                next_token3 = torch.argmax(logits3[:, -1, :], dim=-1, keepdim=True)
            elif strategy == "sample":
                next_token1 = torch.multinomial(F.softmax(logits1[:, -1, :] / temperature, dim=-1), num_samples=1)
                next_token2 = torch.multinomial(F.softmax(logits2[:, -1, :] / temperature, dim=-1), num_samples=1)
                next_token3 = torch.multinomial(F.softmax(logits3[:, -1, :] / temperature, dim=-1), num_samples=1)
            else:
                raise ValueError("Invalid strategy")
            next_token = torch.cat([next_token1, next_token2, next_token3], dim=-1)
            prefix_inputs = torch.cat([prefix_inputs, next_token[:, None, :]], dim=1)
        return prefix_inputs

In [4]:
tabdir = "/n/home12/binxuwang/Github/DiffusionReasoning/Tables"
figdir = "/n/home12/binxuwang/Github/DiffusionReasoning/Figures_newrule"

GPT_exproot = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/DL_Projects/GPT2_raven"

In [5]:
!ls {GPT_exproot}

GPT2_base_pilot
GPT2_base_pilot_fixed
GPT2_base_RAVEN_cond_heldout0-20240701-225339
GPT2_base_RAVEN_uncond_heldout0-20240515-021155
GPT2_base_RAVEN_uncond_heldout0-20240630-023945
GPT2_base_RAVEN_uncond_heldout0_stream0_016M-20240819-010518
GPT2_base_RAVEN_uncond_heldout0_stream0_016M-20240820-020027
GPT2_base_RAVEN_uncond_heldout0_stream0_016M-20240820-023934
GPT2_base_RAVEN_uncond_heldout0_stream0_16M-20240819-032725
GPT2_base_RAVEN_uncond_heldout0_stream0_16M-20240820-021649
GPT2_base_RAVEN_uncond_heldout0_stream0_16M-20240820-024013
GPT2_base_RAVEN_uncond_heldout0_stream1_6M-20240818-012450
GPT2_base_RAVEN_uncond_heldout0_stream1_6M-20240818-013524
GPT2_base_RAVEN_uncond_heldout0_stream1_6M-20240818-014017
GPT2_base_RAVEN_uncond_heldout0_stream16M-20240819-010517
GPT2_base_RAVEN_uncond_heldout0_stream1_6M-20240819-010517
GPT2_base_RAVEN_uncond_heldout0_stream16M-20240820-020037
GPT2_base_RAVEN_uncond_heldout0_stream1_6M-20240820-020037
GPT2_base_RAVEN_uncond_heldout0_stream1_6M-202

In [7]:
expname = "GPT2_medium_RAVEN_uncond_heldout0_stream0_16M-20240820-024019"
expdir = join(GPT_exproot, expname)

In [10]:
!ls {expdir}

ckpt  config.json  repr_classifier  samples  tensorboard_logs


In [13]:
config = json.load(open(join(expdir, 'config.json')))
config


{'batch_size': 64,
 'epoch_total': 457,
 'save_ckpt_every_step': 25000,
 'eval_model_every_step': 2500,
 'lr': 0.0001,
 'num_warmup_steps': 100,
 'n_embd': 768,
 'n_class': 0,
 'n_layer': 24,
 'n_head': 12,
 'is_sep_embed': True,
 'heldout_id': [1, 16, 20, 34, 37],
 'train_sample_num': 140000,
 'val_sample_num': 20000,
 'eval_temperature': 1.0,
 'is_class_cond': False}

In [21]:
def load_gpt2_raven(expname, ckpt_step=999999):
    expdir = join(GPT_exproot, expname)
    config = json.load(open(join(expdir, 'config.json')))
    gpt2_raven = MultiIdxGPT2Model(attribute_dims=(7,10,10), vocab_size=27, max_length=83, n_embd=config['n_embd'],
                               n_class=config['n_class'], n_head=config['n_head'], n_layer=config['n_layer'])
    gpt2_raven.load_state_dict(th.load(join(expdir, 'ckpt', f'gpt2_step{ckpt_step}.pth')))
    gpt2_raven.to('cuda').eval()
    return gpt2_raven

expname = "GPT2_medium_RAVEN_uncond_heldout0_stream0_16M-20240820-024019"
ckpt_step = 999999
gpt2_raven = load_gpt2_raven(expname, ckpt_step)

In [22]:
print("Ab initio generation, sampling: ")
# rnd_idx = np.random.choice(len(attr_seq_tsr_val), 512)
# eval_samples = attr_seq_tsr_val[rnd_idx,:,:]
eval_samples_empty = torch.zeros(512, 81, 3, dtype=th.long).to('cuda')
eval_complete, C3_list, C2_list, rule_col_list = completion_eval(eval_samples_empty, gpt2_raven, num_mask=81, 
                                                                 device='cuda', strategy="sample", batch_size=512)
th.cuda.empty_cache()

Ab initio generation, sampling: 


  0%|          | 0/1 [00:00<?, ?it/s]

Completion: C3: 0.357 [183/512],  valid: 0.625 [960/1536]


https://github.com/wouterkool/stochastic-beam-search/tree/stochastic-beam-search

https://github.com/evanthebouncy/stoicastic_beam

In [91]:
import torch
import torch.nn.functional as F

@torch.no_grad()
def sample_next_token_beam_search(
    model, prefix_inputs, max_length=81, beam_size=5, device="cuda", cond=None, strategy="topk_beam", temperature=1.0, return_best=True
):
    prefix_inputs = prefix_inputs.to(device)  # shape [1, seq_len, 3]
    batch_size = prefix_inputs.size(0)
    prefix_length = prefix_inputs.size(1)
    if batch_size != 1:
        raise NotImplementedError("Beam search for batch size >1 not implemented.")
    model.eval().to(device)
    # Initialize the beam with the prefix input and zero cumulative log probability
    beam = [(prefix_inputs, 0.0)]  # list of tuples (sequence tensor, cumulative_log_prob)
    current_beam_size = 1 # number of sequences in the beam, initialized to 1
    for _ in range(max_length - prefix_length):
        # Collect sequences and cumulative_log_probs from the current beam
        beam_sequences = torch.cat([seq for seq, _ in beam], dim=0)  # [beam_size, seq_len, 3]
        beam_log_probs = torch.tensor(
            [log_prob for _, log_prob in beam], device=device
        )  # [beam_size]

        # Run the model on all sequences in the beam
        outputs, logits1, logits2, logits3 = model(beam_sequences, y=cond)
        logits1 = logits1[:, -1, :] / temperature # [beam_size, vocab_size1]
        logits2 = logits2[:, -1, :] / temperature # [beam_size, vocab_size2]
        logits3 = logits3[:, -1, :] / temperature # [beam_size, vocab_size3]

        V1, V2, V3 = logits1.size(-1), logits2.size(-1), logits3.size(-1)
        total_vocab_size = V1 * V2 * V3

        # Compute log probabilities
        log_probs1 = F.log_softmax(logits1, dim=-1)  # [beam_size, V1]
        log_probs2 = F.log_softmax(logits2, dim=-1)  # [beam_size, V2]
        log_probs3 = F.log_softmax(logits3, dim=-1)  # [beam_size, V3]

        # Compute joint log probabilities for all combinations
        log_probs1_exp = log_probs1.unsqueeze(2).unsqueeze(3)  # [beam_size, V1, 1, 1]
        log_probs2_exp = log_probs2.unsqueeze(1).unsqueeze(3)  # [beam_size, 1, V2, 1]
        log_probs3_exp = log_probs3.unsqueeze(1).unsqueeze(2)  # [beam_size, 1, 1, V3]
        joint_log_probs = (
            log_probs1_exp + log_probs2_exp + log_probs3_exp
        )  # [beam_size, V1, V2, V3]

        # Flatten joint log probabilities
        joint_log_probs = joint_log_probs.view(current_beam_size, -1)  # [beam_size, V1*V2*V3]

        # Compute cumulative log probabilities
        cumulative_log_probs = (
            beam_log_probs.unsqueeze(1) + joint_log_probs
        )  # [beam_size, V1*V2*V3]

        # Flatten for selecting top candidates, combinatoric of [ old beams X new tokens ] 
        cumulative_log_probs_flat = cumulative_log_probs.view(-1)  # [beam_size * V1*V2*V3]

        if strategy == "topk_beam":
            # Get top beam_size sequences
            sampled_log_probs, sampled_indices = torch.topk(cumulative_log_probs_flat, k=beam_size)
        # elif strategy == "sample_beam":
        #     cumulative_probs_flat = torch.exp(cumulative_log_probs_flat - cumulative_log_probs_flat.max())
        #     cumulative_probs_flat /= cumulative_probs_flat.sum()
        #     sampled_indices = torch.multinomial(cumulative_probs_flat, num_samples=beam_size, replacement=False)
        #     sampled_log_probs = cumulative_log_probs_flat[sampled_indices]
        else:
            raise ValueError(f"Invalid strategy: {strategy}")
        
        new_beam = []
        for idx in range(beam_size):
            log_prob = sampled_log_probs[idx]
            flat_index = sampled_indices[idx]
            beam_idx = flat_index // total_vocab_size
            vocab_idx = flat_index % total_vocab_size
            seq = beam_sequences[beam_idx]  # [seq_len, 3]
            # Decode indices for the next token's attributes
            idx1 = vocab_idx // (V2 * V3)
            idx2 = (vocab_idx % (V2 * V3)) // V3
            idx3 = vocab_idx % V3
            # Create the next token
            next_token = torch.tensor(
                [[idx1.item(), idx2.item(), idx3.item()]], device=device
            ).unsqueeze(0)  # [1, 1, 3]
            # Append the next token to the sequence
            new_seq = torch.cat([seq.unsqueeze(0), next_token], dim=1)  # [1, seq_len+1, 3]
            new_beam.append((new_seq, log_prob))
        current_beam_size = len(new_beam)
        # Update the beam with new sequences
        beam = new_beam
    
    if return_best:
        # Return the sequence with the highest cumulative log probability
        best_seq = beam[0][0]  # [1, max_length, 3]
        return best_seq.squeeze(0)  # Remove the batch dimension
    else:
        seqs = torch.cat([seq for seq, _ in beam], dim=0)
        scores = torch.tensor([log_prob for _, log_prob in beam], device=device)
        return seqs, scores


In [118]:
from numpy.random import gumbel

In [122]:
gumbel(loc=np.zeros(10))

array([-0.66904392, -0.30546029,  2.58577679,  1.21405472,  1.49607145,
        0.80941592,  0.10930291,  1.53330812, -0.77075769, -0.66470453])

In [150]:
import torch
import torch.nn.functional as F
from numpy.random import gumbel # 
# torch version gumbel 
from torch.distributions.gumbel import Gumbel

def sample_gumbel(shape, device, eps=1e-20):
    """Sample Gumbel noise."""
    U = torch.rand(shape, device=device)
    return -torch.log(-torch.log(U + eps) + eps)


@torch.no_grad()
def sample_next_token_stochastic_beam_search(
    model, prefix_inputs, max_length=81, beam_size=5, device="cuda", cond=None, strategy="gumbel_topk_beam", temperature=1.0, return_best=True
):
    prefix_inputs = prefix_inputs.to(device)  # shape [1, seq_len, 3]
    batch_size = prefix_inputs.size(0)
    prefix_length = prefix_inputs.size(1)
    if batch_size != 1:
        raise NotImplementedError("Beam search for batch size >1 not implemented.")
    model.eval().to(device)
    # Initialize the beam with the prefix input and zero cumulative log probability
    beam = [(prefix_inputs, 0.0, 0.0)]  # list of tuples (sequence tensor, cumulative_log_prob)
    current_beam_size = 1 # number of sequences in the beam, initialized to 1
    for _ in range(max_length - prefix_length):
        # Collect sequences and cumulative_log_probs from the current beam
        beam_sequences = torch.cat([seq for seq, _,  _ in beam], dim=0)  # [beam_size, seq_len, 3]
        beam_log_probs = torch.tensor(
            [log_prob for _, log_prob, _ in beam], device=device
        )  # [beam_size]
        beam_gumbel_log_probs = torch.tensor(
            [gumbellog_prob for _, _, gumbellog_prob in beam], device=device
        )  # [beam_size]
        # Run the model on all sequences in the beam
        outputs, logits1, logits2, logits3 = model(beam_sequences, y=cond)
        logits1 = logits1[:, -1, :] / temperature # [beam_size, vocab_size1]
        logits2 = logits2[:, -1, :] / temperature # [beam_size, vocab_size2]
        logits3 = logits3[:, -1, :] / temperature # [beam_size, vocab_size3]

        V1, V2, V3 = logits1.size(-1), logits2.size(-1), logits3.size(-1)
        total_vocab_size = V1 * V2 * V3

        # Compute log probabilities
        log_probs1 = F.log_softmax(logits1, dim=-1)  # [beam_size, V1]
        log_probs2 = F.log_softmax(logits2, dim=-1)  # [beam_size, V2]
        log_probs3 = F.log_softmax(logits3, dim=-1)  # [beam_size, V3]

        # Compute joint log probabilities for all combinations
        log_probs1_exp = log_probs1.unsqueeze(2).unsqueeze(3)  # [beam_size, V1, 1, 1]
        log_probs2_exp = log_probs2.unsqueeze(1).unsqueeze(3)  # [beam_size, 1, V2, 1]
        log_probs3_exp = log_probs3.unsqueeze(1).unsqueeze(2)  # [beam_size, 1, 1, V3]
        joint_log_probs = (
            log_probs1_exp + log_probs2_exp + log_probs3_exp
        )  # [beam_size, V1, V2, V3]

        # Flatten joint log probabilities
        joint_log_probs = joint_log_probs.view(current_beam_size, -1)  # [beam_size, V1*V2*V3]

        # Compute cumulative log probabilities
        cumulative_log_probs = (
            beam_log_probs.unsqueeze(1) + joint_log_probs
        )  # [beam_size, V1*V2*V3]

        gumbel_phi_ss = np.random.gumbel(cumulative_log_probs.cpu().numpy())
        gumbel_phi_ss = th.tensor(gumbel_phi_ss, device=device)
        # gumbel_phi_ss = cumulative_log_probs + sample_gumbel(cumulative_log_probs.shape, device)
        z = th.max(gumbel_phi_ss, dim=1).values
        exp_neg_g_phi_s = th.exp(-beam_gumbel_log_probs) # [beam_size]
        exp_neg_z = th.exp(-z) # [beam_size]
        exp_neg_g_phi_ss = th.exp(-gumbel_phi_ss) # [beam_size, V1*V2*V3]
        gumbel_hat_phi_ss = -th.log(exp_neg_g_phi_s[:, None] - exp_neg_z[:, None] + exp_neg_g_phi_ss) # [beam_size, V1*V2*V3]
        # Flatten for selecting top candidates, combinatoric of [ old beams X new tokens ] 
        cumulative_log_probs_flat = cumulative_log_probs.view(-1)  # [beam_size * V1*V2*V3]
        gumbel_hat_phi_ss_flat = gumbel_hat_phi_ss.view(-1) # [beam_size * V1*V2*V3]
        if strategy == "gumbel_topk_beam":
            sampled_gumble_hat_log_probs, sampled_indices = torch.topk(gumbel_hat_phi_ss_flat, k=beam_size)
            sampled_log_probs = cumulative_log_probs_flat[sampled_indices]
        # if strategy == "topk_beam":
        #     # Get top beam_size sequences
        #     sampled_log_probs, sampled_indices = torch.topk(cumulative_log_probs_flat, k=beam_size)
        # elif strategy == "sample_beam":
        #     cumulative_probs_flat = torch.exp(cumulative_log_probs_flat - cumulative_log_probs_flat.max())
        #     cumulative_probs_flat /= cumulative_probs_flat.sum()
        #     sampled_indices = torch.multinomial(cumulative_probs_flat, num_samples=beam_size, replacement=False)
        #     sampled_log_probs = cumulative_log_probs_flat[sampled_indices]
        else:
            raise ValueError(f"Invalid strategy: {strategy}")
        
        new_beam = []
        for idx in range(beam_size):
            flat_index = sampled_indices[idx]
            log_prob = sampled_log_probs[idx]
            gumbel_hat_log_prob = sampled_gumble_hat_log_probs[idx]
            beam_idx = flat_index // total_vocab_size
            vocab_idx = flat_index % total_vocab_size
            seq = beam_sequences[beam_idx]  # [seq_len, 3]
            # Decode indices for the next token's attributes
            idx1 = vocab_idx // (V2 * V3)
            idx2 = (vocab_idx % (V2 * V3)) // V3
            idx3 = vocab_idx % V3
            # Create the next token
            next_token = torch.tensor(
                [[idx1.item(), idx2.item(), idx3.item()]], device=device
            ).unsqueeze(0)  # [1, 1, 3]
            # Append the next token to the sequence
            new_seq = torch.cat([seq.unsqueeze(0), next_token], dim=1)  # [1, seq_len+1, 3]
            new_beam.append((new_seq, log_prob, gumbel_hat_log_prob))
        current_beam_size = len(new_beam)
        # Update the beam with new sequences
        beam = new_beam
    
    if return_best:
        # Return the sequence with the highest cumulative log probability
        best_seq = beam[0][0]  # [1, max_length, 3]
        return best_seq.squeeze(0)  # Remove the batch dimension
    else:
        seqs = torch.cat([seq for seq, _, _ in beam], dim=0)
        scores = torch.tensor([log_prob for _, log_prob, _ in beam], device=device)
        gumbel_hat_log_probs = torch.tensor([gumbel_hat_log_prob for _, _, gumbel_hat_log_prob in beam], device=device)
        return seqs, scores, gumbel_hat_log_probs


In [157]:
seqs, scores, gumbel_hat_log_probs = sample_next_token_stochastic_beam_search(gpt2_raven, eval_samples_empty[0:1, :0], max_length=81, 
                beam_size=256, device='cuda', strategy="gumbel_topk_beam", temperature=1.0, return_best=False)
print(seqs.shape, scores.shape, gumbel_hat_log_probs.shape)

eval_complete = seqs.cpu() - 1
eval_complete_img = seqtsr2imgtsr(eval_complete, h=3, w=3, p=3, R=3)
C3_list, C2_list, rule_col_list = infer_rule_from_sample_batch(eval_complete_img)
C3_count, C2_count, anyvalid_count, total = compute_rule_statistics(C3_list, C2_list, rule_col_list)
# final_row = np.array(rule_col_list, dtype=object)[:,-1]
# anyvalid_count = sum([len(x) > 0 for x in final_row])
print(f"Completion: C3: {C3_count / total:.3f} [{C3_count}/{total}],  valid: {anyvalid_count / total / 3:.3f} [{anyvalid_count}/{total*3}]") 
print(C3_list)

torch.Size([256, 81, 3]) torch.Size([256]) torch.Size([256])
C3: 84/256 (0.33), C3 + C2: 142/256 (0.55), AnyValid: 476/768 (0.62)
Completion: C3: 0.328 [84/256],  valid: 0.620 [476/768]
[[], [], [], [], [22], [], [], [], [], [8], [], [5], [], [], [], [], [], [], [], [], [], [23], [], [], [], [], [], [], [], [8], [], [], [], [], [], [], [], [], [10], [], [], [], [], [25], [], [], [], [26], [], [], [], [], [], [], [26], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [22], [], [], [], [], [], [], [], [], [11], [5], [], [], [], [], [22], [], [], [], [], [6], [15], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [11], [3], [23], [], [], [], [5], [4], [], [], [], [], [11], [6], [], [], [], [], [22], [13], [], [], [], [22], [], [], [], [14], [18], [0], [], [], [3], [11], [], [], [], [], [11], [], [18], [], [12], [], [], [], [], [], [], [], [2], [11], [23], [], [], [28], [28], [4], [15], [2], [], [14], [], [18], [19],

In [156]:
eval_samples_empty.shape

torch.Size([512, 81, 3])

In [33]:
with th.no_grad():
    outputs, logits1, logits2, logits3 = gpt2_raven(eval_samples_empty[0:1, :0], y=None)
# print(outputs)
print(logits1.shape, logits2.shape, logits3.shape)

torch.Size([1, 1, 8]) torch.Size([1, 1, 11]) torch.Size([1, 1, 11])


### Top 1 from beam search

In [93]:
with th.no_grad():
    best_seq = sample_next_token_beam_search(gpt2_raven, eval_samples_empty[0:1, :0], max_length=81, beam_size=10, device='cuda')
print(best_seq.shape)

torch.Size([81, 3])


In [43]:
with th.no_grad():
    best_seq2 = sample_next_token_beam_search(gpt2_raven, eval_samples_empty[0:1, :0], max_length=81, beam_size=5, device='cuda')
print(best_seq.shape)

torch.Size([81, 3])


In [44]:
th.all(best_seq == best_seq2)

tensor(True, device='cuda:0')

In [95]:
eval_complete = best_seq.cpu().unsqueeze(0) - 1
eval_complete_img = seqtsr2imgtsr(eval_complete, h=3, w=3, p=3, R=3)
C3_list, C2_list, rule_col_list = infer_rule_from_sample_batch(eval_complete_img)
C3_count, C2_count, anyvalid_count, total = compute_rule_statistics(C3_list, C2_list, rule_col_list)
# final_row = np.array(rule_col_list, dtype=object)[:,-1]
# anyvalid_count = sum([len(x) > 0 for x in final_row])
print(f"Completion: C3: {C3_count / total:.3f} [{C3_count}/{total}],  valid: {anyvalid_count / total / 3:.3f} [{anyvalid_count}/{total*3}]") 
print(C3_list)

C3: 1/1 (1.00), C3 + C2: 1/1 (1.00), AnyValid: 3/3 (1.00)
Completion: C3: 1.000 [1/1],  valid: 1.000 [3/3]
[[35]]


### Top N seq from beam search

In [106]:
with th.no_grad():
    best_seqs, best_scores = sample_next_token_beam_search(gpt2_raven, eval_samples_empty[0:1, :0], max_length=81, beam_size=10, device='cuda', return_best=False)
print(best_seqs.shape, best_scores.shape)

torch.Size([10, 81, 3]) torch.Size([10])


In [107]:
eval_complete = best_seqs.cpu() - 1
eval_complete_img = seqtsr2imgtsr(eval_complete, h=3, w=3, p=3, R=3)
C3_list, C2_list, rule_col_list = infer_rule_from_sample_batch(eval_complete_img)
C3_count, C2_count, anyvalid_count, total = compute_rule_statistics(C3_list, C2_list, rule_col_list)
print(f"Completion: C3: {C3_count / total:.3f} [{C3_count}/{total}],  valid: {anyvalid_count / total / 3:.3f} [{anyvalid_count}/{total*3}]") 
print(C3_list)
print(list(best_scores.cpu().numpy()))


C3: 10/10 (1.00), C3 + C2: 10/10 (1.00), AnyValid: 30/30 (1.00)
Completion: C3: 1.000 [10/10],  valid: 1.000 [30/30]
[[35], [13], [35], [0], [26], [35], [35], [35], [35], [13]]
[-21.662336, -21.714895, -21.776728, -23.271214, -23.496033, -23.498383, -31.572052, -31.655994, -32.499695, -34.300903]


### Sample the beam

In [113]:
with th.no_grad():
    seqs, scores = sample_next_token_beam_search(gpt2_raven, eval_samples_empty[0:1, :0], max_length=81, beam_size=20, device='cuda', 
                                             strategy="gumbel_topk_beam", temperature=1.0, return_best=False)
print(seqs.shape, scores.shape)

torch.Size([20, 81, 3]) torch.Size([20])


In [114]:
eval_complete = seqs.cpu() - 1
eval_complete_img = seqtsr2imgtsr(eval_complete, h=3, w=3, p=3, R=3)
C3_list, C2_list, rule_col_list = infer_rule_from_sample_batch(eval_complete_img)
C3_count, C2_count, anyvalid_count, total = compute_rule_statistics(C3_list, C2_list, rule_col_list)
print(f"Completion: C3: {C3_count / total:.3f} [{C3_count}/{total}],  valid: {anyvalid_count / total / 3:.3f} [{anyvalid_count}/{total*3}]")

C3: 0/20 (0.00), C3 + C2: 0/20 (0.00), AnyValid: 4/60 (0.07)
Completion: C3: 0.000 [0/20],  valid: 0.067 [4/60]


### Scratch

In [None]:
!ls {expdir}/ckpt

gpt2_final.pth	     gpt2_step399999.pth  gpt2_step724999.pth
gpt2_init.pth	     gpt2_step424999.pth  gpt2_step749999.pth
gpt2_step124999.pth  gpt2_step449999.pth  gpt2_step74999.pth
gpt2_step149999.pth  gpt2_step474999.pth  gpt2_step774999.pth
gpt2_step174999.pth  gpt2_step499999.pth  gpt2_step799999.pth
gpt2_step199999.pth  gpt2_step49999.pth   gpt2_step824999.pth
gpt2_step224999.pth  gpt2_step524999.pth  gpt2_step849999.pth
gpt2_step249999.pth  gpt2_step549999.pth  gpt2_step874999.pth
gpt2_step24999.pth   gpt2_step574999.pth  gpt2_step899999.pth
gpt2_step274999.pth  gpt2_step599999.pth  gpt2_step924999.pth
gpt2_step299999.pth  gpt2_step624999.pth  gpt2_step949999.pth
gpt2_step324999.pth  gpt2_step649999.pth  gpt2_step974999.pth
gpt2_step349999.pth  gpt2_step674999.pth  gpt2_step999999.pth
gpt2_step374999.pth  gpt2_step699999.pth  gpt2_step99999.pth


In [None]:
ckpt_step = 999999
gpt2_raven = MultiIdxGPT2Model(attribute_dims=(7,10,10), vocab_size=27, max_length=83, n_embd=config['n_embd'],
                               n_class=config['n_class'], n_head=config['n_head'], n_layer=config['n_layer'])
gpt2_raven.load_state_dict(th.load(join(expdir, 'ckpt', f'gpt2_step{ckpt_step}.pth')))
gpt2_raven.to('cuda').eval()