In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

vocab = list("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!#$%&'()*+,-./:;<=>?@[\]^_`{|}\n")

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):

        logits = self.token_embedding_table(idx) 
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits

In [2]:
model = BigramLanguageModel(len(vocab))               ## create an instance of the BigramLanguageModel
model.load_state_dict(torch.load("bigram_model.pt"))  ## load the trained weights
model.eval()                                          ## set the model to evaluation mode

BigramLanguageModel(
  (token_embedding_table): Embedding(93, 93)
)

In [3]:
flag = "HTB"                                          ## we expect the flag to begin with HTB
idxs = torch.tensor([vocab.index(c) for c in flag], dtype=torch.long).unsqueeze(0) ##convert the initial flag to indexes

In [4]:
logits = model(idxs)

In [5]:
next_idx = torch.multinomial(torch.softmax(logits[:, -1], dim=1), num_samples=1).item()
next_char = vocab[next_idx]

In [6]:
flag += next_char

In [7]:
print(flag)                                           ## confirm the above code works as we get { prediction

HTB{


In [None]:
while True:                                           ## run the same prediction until we get }
    idxs = torch.tensor([vocab.index(c) for c in flag], dtype=torch.long).unsqueeze(0)
    logits = model(idxs)
    next_idx = torch.multinomial(torch.softmax(logits[:, -1], dim=1), num_samples=1).item()
    next_char = vocab[next_idx]
    flag += next_char
    if next_char == "}":                              ## stop if the closing bracket } is predicted
        break

print(flag)