In [5]:
import torch
import torch.nn as nn
from config import AppConfig, MODEL_CONFIG
from modules.GPT2_Model import GPTModel
import tiktoken

In [2]:
MODEL_CONFIG.GPT_CONFIG_124M.update(MODEL_CONFIG.GPT_MODEL_CONFIGS[MODEL_CONFIG.MODEL_NAME_TO_USE])
model = GPTModel(MODEL_CONFIG.GPT_CONFIG_124M)


# To get the model ready for classification-finetuning, we first freeze the model, meaning that we make all layers non-trainable
for param in model.parameters():
    param.requires_grad = False

#Add the classification head
model.out_head = torch.nn.Linear(in_features=MODEL_CONFIG.GPT_CONFIG_124M["emb_dim"], 
                                    out_features=MODEL_CONFIG.TRAINING_CONFIG["num_classes"])

#Only fine tune the last transformer block and the final LayerNorm module, which connects this block to the output layer
for param in model.trf_blocks[-1].parameters():
    param.requires_grad = True

for param in model.final_norm.parameters():
    param.requires_grad = True

In [3]:
print(model)

GPTModel(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.0, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=768, out_features=768, bias=True)
        (W_key): Linear(in_features=768, out_features=768, bias=True)
        (W_value): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_shortcut): Dropout(p=0.0, inplace=False)
    )
    (1): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=7

In [None]:
model_state_dict = torch.load(f"{AppConfig.SPAM_TRAINED_MODEL_DIR}/Spam Classifier.pth")
model.load_state_dict(model_state_dict)
tokenizer = tiktoken.get_encoding("gpt2")


In [13]:
device =  "cpu"
torch.manual_seed(123)

<torch._C.Generator at 0x276ce4fff90>

Step 1: Prepare inputs to the model

Step 2: Truncate sequences if they too long

Step 3: Pad sequences to the longest sequence

Step 4: Add batch dimension

Step 5: Model inference without gradient tracking

Step 6: Logits of the last output token

Step 7: Return the classified result

In [14]:
def classify_spam(text, model, tokenizer, device, max_length=None, pad_token_id=50256):
    model.eval()

    # Prepare inputs to the model
    input_ids = tokenizer.encode(text)
    supported_context_length = model.pos_emb.weight.shape[0]
    # Note: In the book, this was originally written as pos_emb.weight.shape[1] by mistake
    # It didn't break the code but would have caused unnecessary truncation (to 768 instead of 1024)

    # Truncate sequences if they too long
    input_ids = input_ids[:min(max_length, supported_context_length)]

    # Pad sequences to the longest sequence
    input_ids += [pad_token_id] * (max_length - len(input_ids))
    input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0) # add batch dimension

    # Model inference
    with torch.no_grad():
        logits = model(input_tensor)[:, -1, :]  # Logits of the last output token
    predicted_label = torch.argmax(logits, dim=-1).item()

    # Return the classified result
    return "spam" if predicted_label == 1 else "not spam"

In [21]:
text_1 = (
    "You are a winner you have been specially"
    " selected to receive $1000 cash or a $2000 award."
)

print(classify_spam(
    text_1, model, tokenizer, device, max_length=120
))

spam


In [22]:
text_2 = (
    "Hey, just wanted to check if we're still on"
    " for dinner tonight? Let me know!"
)

print(classify_spam(
    text_2, model, tokenizer, device, max_length=120
))

not spam


In [23]:
text_3 = (
    "Congratulations! You won a free ticket."
    "  Claim now!"
)

print(classify_spam(
    text_3, model, tokenizer, device, max_length=120
))

spam
