In [2]:
from torch.nn import functional as F
from collections import OrderedDict
import torch.nn as nn
import torch
import os

In [3]:
git_home = os.getcwd() # get current directory 
file = f'{git_home}/bigram_model.pt'

In [4]:
## This code is was given as part of the architecture.py ##

vocab = list("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!#$%&'()*+,-./:;<=>?@[\]^_`{|}\n")

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):

        logits = self.token_embedding_table(idx) 
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits


#model.eval()
#input_tensor = torch.randint(len(vocab), (1,10))
#logits = model(input_tensor)

In [5]:
# load model
model = BigramLanguageModel(len(vocab)) # run model
state_dict = torch.load(file) # load state_dict from the .pt file 
model.load_state_dict(state_dict) # load the given state_dict into the model 

<All keys matched successfully>

In [6]:
# set manual seed for reproducibility
torch.manual_seed(1337) # this seed value was given in the architecture.py file

# generate text with the manual seed and the state_dict given to the model

generated_text = ["{"] # added this because we know that the flag will have {. 
idx = torch.tensor([[vocab.index(generated_text[-1])]])
for i in range(50):
    logits = model(idx)
    probs = F.softmax(logits.squeeze(), dim=0)
    index = torch.multinomial(probs, num_samples=1).item()
    token = vocab[index]
    generated_text.append(token)
    idx = torch.tensor([[index]])
    
# join tokens to form the generated text
generated_text = "".join(generated_text)
print(generated_text)

{Pr0t3c7_L1fe}
HTB{Pr0t3c7_L1fe}
HTB{Pr0t3c7_L1fe}



**૮₍˶ •. • ⑅₎ა ♡ glockachu**