In [1]:
import torch
import torch.nn as nn

GPT_CONFIG_124M = {
    "vocab_size": 50257,    
    "context_length": 1024, 
    "emb_dim": 768,         
    "n_heads": 12,          
    "n_layers": 12,         
    "drop_rate": 0.1,       
    "qkv_bias": True       
}

In [2]:
class multiheadv2(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,attention_head,boolbias):
        super().__init__()
        self.head_dim=d_out//attention_head
        self.d_out=d_out
        self.attention_head=attention_head

        self.W_query = nn.Linear(d_in, d_out, bias=boolbias)
        self.W_key = nn.Linear(d_in, d_out, bias=boolbias)
        self.W_value = nn.Linear(d_in, d_out, bias=boolbias)

        self.out_proj = nn.Linear(d_out, d_out)  
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self,x):
        b,num_token,d_out=x.shape

        keys = self.W_key(x) 
        queries = self.W_query(x)
        values = self.W_value(x)

        keys=keys.view(b,num_token,self.attention_head,self.head_dim)
        queries=queries.view(b,num_token,self.attention_head,self.head_dim)
        values=values.view(b,num_token,self.attention_head,self.head_dim)

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_score=queries @ keys.transpose(2,3)

        mask_bool = self.mask.bool()[:num_token, :num_token]
        attn_score.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_score / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)


        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        context_vec = context_vec.contiguous().view(b, num_token, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec
            



In [3]:
class LayerNorm(nn.Module):
    def __init__(self,emb_dim):
        super().__init__()
        self.eps=1e-5
        self.scale_params=nn.Parameter(torch.ones(emb_dim))
        self.shift_params=nn.Parameter(torch.zeros(emb_dim))
    
    def forward(self,x):
        mean=x.mean(dim=-1,keepdim=True)
        var=x.var(dim=-1,keepdim=True,unbiased=False)
        norm=(x-mean)/torch.sqrt(var+self.eps)
        return norm*self.scale_params+ self.shift_params
    
        

In [4]:
class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(torch.sqrt(torch.tensor(2.0 / torch.pi)) *(x + 0.044715 * torch.pow(x, 3))
        ))
    

class feedforward(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.layers=nn.Sequential(
            nn.Linear(config['emb_dim'],config['emb_dim']*4),
            GELU(),
            nn.Linear(config['emb_dim']*4,config['emb_dim']),

        )
    
    def forward(self,x):
        return self.layers(x)

In [5]:
class TransformerBlock(nn.Module):
    def __init__ (self,config):
        super().__init__()
        self.attn=multiheadv2(d_in=config['emb_dim'],d_out=config['emb_dim'],context_length=config['context_length'],dropout=config['drop_rate'],attention_head=config['n_heads'],boolbias=config['qkv_bias'])
        self.Layernorm1=LayerNorm(emb_dim=config['emb_dim'])
        self.Layernorm2=LayerNorm(emb_dim=config['emb_dim'])
        self.feedforw=feedforward(config=config)
        self.dropout=nn.Dropout(config['drop_rate'])
    
    def forward(self,x):
        ## attnetion block
        skip=x
        x=self.Layernorm1(x)
        x=self.attn(x)
        x=self.dropout(x)
        x=x+skip

        ## feed forward nn block
        skip=x
        x=self.Layernorm2(x)
        x=self.feedforw(x)
        x=self.dropout(x)
        x=x+skip

        return x


In [6]:
class GPT_2(nn.Module):
    def __init__ (self,cfg):
        super().__init__()
        self.token_emb=nn.Embedding(cfg['vocab_size'],cfg["emb_dim"])
        self.pos_emb=nn.Embedding(cfg['context_length'],cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])
        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
        
        self.final_norm = LayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(
            cfg["emb_dim"], cfg["vocab_size"], bias=False
        )
    
    def forward(self,inputidx):
        batch_size,seq=inputidx.shape
        tokens=self.token_emb(inputidx)
        pos_embeds = self.pos_emb(torch.arange(seq, device=inputidx.device))
        x=tokens+pos_embeds
        x=self.drop_emb(x)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits
    



OPEN AI weights are in tensorflow format , we need to convert them to torch compatible first

In [7]:
from gptweightdownload import download_and_load_gpt2
settings, params = download_and_load_gpt2(model_size="124M", models_dir="gpt2")

In [8]:
def assign(left, right):
    if left.shape != right.shape:
        raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}")
    return torch.nn.Parameter(torch.tensor(right))

In [9]:
import numpy as np
import torch

def load_weights_into_gpt(gpt, params):
    gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
    gpt.token_emb.weight = assign(gpt.token_emb.weight, params['wte'])
    
    for b in range(len(params["blocks"])):
        q_w, k_w, v_w = np.split(params["blocks"][b]["attn"]["c_attn"]["w"], 3, axis=-1)
        gpt.trf_blocks[b].attn.W_query.weight = assign(gpt.trf_blocks[b].attn.W_query.weight, q_w.T)
        gpt.trf_blocks[b].attn.W_key.weight = assign(gpt.trf_blocks[b].attn.W_key.weight, k_w.T)
        gpt.trf_blocks[b].attn.W_value.weight = assign(gpt.trf_blocks[b].attn.W_value.weight, v_w.T)

        q_b, k_b, v_b = np.split(params["blocks"][b]["attn"]["c_attn"]["b"], 3, axis=-1)
        gpt.trf_blocks[b].attn.W_query.bias = assign(gpt.trf_blocks[b].attn.W_query.bias, q_b)
        gpt.trf_blocks[b].attn.W_key.bias = assign(gpt.trf_blocks[b].attn.W_key.bias, k_b)
        gpt.trf_blocks[b].attn.W_value.bias = assign(gpt.trf_blocks[b].attn.W_value.bias, v_b)

        gpt.trf_blocks[b].attn.out_proj.weight = assign(gpt.trf_blocks[b].attn.out_proj.weight, params["blocks"][b]["attn"]["c_proj"]["w"].T)
        gpt.trf_blocks[b].attn.out_proj.bias = assign(gpt.trf_blocks[b].attn.out_proj.bias, params["blocks"][b]["attn"]["c_proj"]["b"])

        gpt.trf_blocks[b].feedforw.layers[0].weight = assign(gpt.trf_blocks[b].feedforw.layers[0].weight, params["blocks"][b]["mlp"]["c_fc"]["w"].T)
        gpt.trf_blocks[b].feedforw.layers[0].bias = assign(gpt.trf_blocks[b].feedforw.layers[0].bias, params["blocks"][b]["mlp"]["c_fc"]["b"])
        gpt.trf_blocks[b].feedforw.layers[2].weight = assign(gpt.trf_blocks[b].feedforw.layers[2].weight, params["blocks"][b]["mlp"]["c_proj"]["w"].T)
        gpt.trf_blocks[b].feedforw.layers[2].bias = assign(gpt.trf_blocks[b].feedforw.layers[2].bias, params["blocks"][b]["mlp"]["c_proj"]["b"])

        gpt.trf_blocks[b].Layernorm1.scale_params = assign(gpt.trf_blocks[b].Layernorm1.scale_params, params["blocks"][b]["ln_1"]["g"])
        gpt.trf_blocks[b].Layernorm1.shift_params = assign(gpt.trf_blocks[b].Layernorm1.shift_params, params["blocks"][b]["ln_1"]["b"])
        gpt.trf_blocks[b].Layernorm2.scale_params = assign(gpt.trf_blocks[b].Layernorm2.scale_params, params["blocks"][b]["ln_2"]["g"])
        gpt.trf_blocks[b].Layernorm2.shift_params = assign(gpt.trf_blocks[b].Layernorm2.shift_params, params["blocks"][b]["ln_2"]["b"])

    gpt.final_norm.scale_params = assign(gpt.final_norm.scale_params, params["g"])
    gpt.final_norm.shift_params = assign(gpt.final_norm.shift_params, params["b"])
    gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])



In [10]:
def generate(model, idx, max_token_size, context_Size, top_k=5, temperature=1.0):
    for _ in range(max_token_size):
        idx_split = idx[:, -context_Size:]

        with torch.no_grad():
            logits = model(idx_split)

        logits_last = logits[:, -1, :]
        
        logits_last = logits_last / temperature
        
       
        values, indices = torch.topk(logits_last, k=top_k, dim=-1)
        probs = torch.softmax(values, dim=-1)
        sampled_idx = torch.multinomial(probs, num_samples=1)
        maxi = torch.gather(indices, -1, sampled_idx)
       
    
        idx = torch.cat((idx, maxi), dim=1)
    
    return idx

import tiktoken
tokenizer = tiktoken.get_encoding("gpt2")
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded_tensor = torch.tensor(encoded).unsqueeze(0) 
    return encoded_tensor

def token_ids_to_text(token_ids, tokenizer):
    flat = token_ids.squeeze(0) 
    return tokenizer.decode(flat.tolist())


In [11]:
model = GPT_2(GPT_CONFIG_124M)
load_weights_into_gpt(model, params)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


if torch.cuda.is_available():
   device = torch.device("cuda")
elif torch.backends.mps.is_available():
   device = torch.device("mps")
else:
   device = torch.device("cpu")

print(f"Using {device} device.")


model.to(device) 


Using mps device.


GPT_2(
  (token_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.1, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (attn): multiheadv2(
        (W_query): Linear(in_features=768, out_features=768, bias=True)
        (W_key): Linear(in_features=768, out_features=768, bias=True)
        (W_value): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (Layernorm1): LayerNorm()
      (Layernorm2): LayerNorm()
      (feedforw): feedforward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (attn): multiheadv2(
        (W_query): Linear(in_features=768,

## after weights

In [12]:
def compute_checksum(model):
    return sum(p.sum().item() for p in model.parameters())

checksum = compute_checksum(model)
print(f"Model weight checksum: {checksum}")


Model weight checksum: -47459.91598956287


In [13]:
torch.manual_seed(42) 

token_ids = generate(
    model=model,
    idx=text_to_token_ids("Can a big, important country that has a major reserve currency like the US go broke—and, if so, what would that look like?", tokenizer).to(device),  
    max_token_size=100,  
    context_Size=GPT_CONFIG_124M["context_length"],
    top_k=80,  
    temperature=1.1
)

print("Output text:\n", token_ids_to_text(token_ids, tokenizer))


Output text:
 Can a big, important country that has a major reserve currency like the US go broke—and, if so, what would that look like?

I'm not sure these things are hard to come by, so I assume they could be.

There are lots of factors. I would love to see it.

I'm trying to pull the deck from one point to give it a different direction?

Are you concerned to see it slip back into the gold standard of the US's gold's gold standard?

I'll open with that as well.

What do you think of the current state of the


## Finetuning

In [14]:
import pandas as pd

df = pd.read_csv("SMSSpamCollection", sep="\t", header=None, names=["Label", "Text"])
df

Unnamed: 0,Label,Text
0,ham,"Go until jurong point, crazy.. Available only ..."
1,ham,Ok lar... Joking wif u oni...
2,spam,Free entry in 2 a wkly comp to win FA Cup fina...
3,ham,U dun say so early hor... U c already then say...
4,ham,"Nah I don't think he goes to usf, he lives aro..."
...,...,...
5567,spam,This is the 2nd time we have tried 2 contact u...
5568,ham,Will ü b going to esplanade fr home?
5569,ham,"Pity, * was in mood for that. So...any other s..."
5570,ham,The guy did some bitching but I acted like i'd...


making a balanced dataset

In [15]:
print(df["Label"].value_counts())

def create_balanced_dataset(df):
    
    num_spam = df[df["Label"] == "spam"].shape[0]
    
    ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)
    
    balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])

    return balanced_df

balanced_df = create_balanced_dataset(df)
print(balanced_df["Label"].value_counts())

Label
ham     4825
spam     747
Name: count, dtype: int64
Label
ham     747
spam    747
Name: count, dtype: int64


In [16]:
balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})

In [17]:
from sklearn.model_selection import train_test_split
train_df, temp_df = train_test_split(balanced_df, test_size=0.3, random_state=42, stratify=balanced_df['Label'])
validation_df, test_df = train_test_split(temp_df, test_size=1/3, random_state=42, stratify=temp_df['Label'])

print(f"Train: {len(train_df)}, Val: {len(validation_df)}, Test: {len(test_df)}")


Train: 1045, Val: 299, Test: 150


In [18]:
train_df.to_csv("train.csv", index=None)
validation_df.to_csv("validation.csv", index=None)
test_df.to_csv("test.csv", index=None)

In [19]:
import torch
from torch.utils.data import Dataset


class SpamDataset(Dataset):
    def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
        self.data = pd.read_csv(csv_file)

        self.encoded_texts = [
            tokenizer.encode(text) for text in self.data["Text"]
        ]

        if max_length is None:
            self.max_length = self._longest_encoded_length()
        else:
            self.max_length = max_length
            
            self.encoded_texts = [
                encoded_text[:self.max_length]
                for encoded_text in self.encoded_texts
            ]


        self.encoded_texts = [
            encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))
            for encoded_text in self.encoded_texts
        ]

    def __getitem__(self, index):
        encoded = self.encoded_texts[index]
        label = self.data.iloc[index]["Label"]
        return (
            torch.tensor(encoded, dtype=torch.long),
            torch.tensor(label, dtype=torch.long)
        )

    def __len__(self):
        return len(self.data)

    def _longest_encoded_length(self):
        max_length = 0
        for encoded_text in self.encoded_texts:
            encoded_length = len(encoded_text)
            if encoded_length > max_length:
                max_length = encoded_length
        return max_length

In [20]:
import tiktoken
tokenizer = tiktoken.get_encoding("gpt2")


train_dataset = SpamDataset(
    csv_file="train.csv",
    max_length=None,
    tokenizer=tokenizer
)

print(train_dataset.max_length)


val_dataset = SpamDataset(
    csv_file="validation.csv",
    max_length=train_dataset.max_length,
    tokenizer=tokenizer
)
test_dataset = SpamDataset(
    csv_file="test.csv",
    max_length=train_dataset.max_length,
    tokenizer=tokenizer
)

print(test_dataset.max_length)

104
104


In [21]:
from torch.utils.data import DataLoader

num_workers = 0
batch_size = 8

torch.manual_seed(123)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    drop_last=True,
)

val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    drop_last=False,
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    drop_last=False,
)

In [22]:
text_2 = (
    "Is the following text 'spam'? Answer with 'yes' or 'no':"
    " 'You are a winner you have been specially"
    " selected to receive $1000 cash or a $2000 award.'"
)


token_ids = generate(
    model=model,
    idx=text_to_token_ids(text_2, tokenizer).to(device),  
    max_token_size=100,  
    context_Size=GPT_CONFIG_124M["context_length"],
    top_k=80,  
    temperature=1.1
)

print("Output text:\n", token_ids_to_text(token_ids, tokenizer))

Output text:
 Is the following text 'spam'? Answer with 'yes' or 'no': 'You are a winner you have been specially selected to receive $1000 cash or a $2000 award.'

But if you're a winner, you can accept upto $5 million cash. If you offer $5000 cash after your choice of prize will be ignored, just like with no choices?


Here you can tell me more about what happens when you don't give it to me.<|endoftext|>In 2016, I bought my own white house! Because I like to buy my own houses in an environment that has no lights and a decent food place because this place is where the great life


In [23]:
for param in model.parameters():
    param.requires_grad = False



torch.manual_seed(123)

num_classes = 2
model.out_head = torch.nn.Linear(in_features=GPT_CONFIG_124M["emb_dim"], out_features=num_classes)

for param in model.trf_blocks[-1].parameters():
    param.requires_grad = True

for param in model.final_norm.parameters():
    param.requires_grad = True

In [24]:
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
    model.eval()
    correct_predictions, num_examples = 0, 0

    if num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            input_batch, target_batch = input_batch.to(device), target_batch.to(device)

            with torch.no_grad():
                logits = model(input_batch)[:, -1, :]  
            predicted_labels = torch.argmax(logits, dim=-1)

            num_examples += predicted_labels.shape[0]
            correct_predictions += (predicted_labels == target_batch).sum().item()
        else:
            break
    return correct_predictions / num_examples

In [25]:
def calc_loss_batch(input_batch, target_batch, model, device):
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)
    logits = model(input_batch)[:, -1, :]  
    loss = torch.nn.functional.cross_entropy(logits, target_batch)
    return loss



In [26]:
def calc_loss_loader(data_loader, model, device, num_batches=None):
    total_loss = 0.
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches

In [27]:
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
    model.eval()
    with torch.no_grad():
        train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
        val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
    model.train()
    return train_loss, val_loss

In [28]:

def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
                            eval_freq, eval_iter):
    train_losses, val_losses, train_accs, val_accs = [], [], [], []
    examples_seen, global_step = 0, -1

    for epoch in range(num_epochs):
        model.train() 

        for input_batch, target_batch in train_loader:
            optimizer.zero_grad() 
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            loss.backward() 
            optimizer.step() 
            examples_seen += input_batch.shape[0] 
            global_step += 1

            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model, train_loader, val_loader, device, eval_iter)
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                print(f"Ep {epoch+1} (Step {global_step:06d}): "
                      f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")

        train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)
        val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)
        print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
        print(f"Validation accuracy: {val_accuracy*100:.2f}%")
        train_accs.append(train_accuracy)
        val_accs.append(val_accuracy)

    return train_losses, val_losses, train_accs, val_accs, examples_seen

In [31]:
import time

start_time = time.time()

torch.manual_seed(123)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
model.to(device) 

num_epochs = 12
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
    model, train_loader, val_loader, optimizer, device,
    num_epochs=num_epochs, eval_freq=50, eval_iter=5,
)

end_time = time.time()
execution_time_minutes = (end_time - start_time) / 60
print(f"Training completed in {execution_time_minutes:.2f} minutes.")



Ep 1 (Step 000000): Train loss 0.569, Val loss 0.558
Ep 1 (Step 000050): Train loss 0.498, Val loss 0.501
Ep 1 (Step 000100): Train loss 0.406, Val loss 0.456
Training accuracy: 77.50% | Validation accuracy: 85.00%
Ep 2 (Step 000150): Train loss 0.364, Val loss 0.430
Ep 2 (Step 000200): Train loss 0.371, Val loss 0.401
Ep 2 (Step 000250): Train loss 0.394, Val loss 0.377
Training accuracy: 75.00% | Validation accuracy: 80.00%
Ep 3 (Step 000300): Train loss 0.439, Val loss 0.365
Ep 3 (Step 000350): Train loss 0.432, Val loss 0.387
Training accuracy: 75.00% | Validation accuracy: 82.50%
Ep 4 (Step 000400): Train loss 0.477, Val loss 0.409
Ep 4 (Step 000450): Train loss 0.547, Val loss 0.474
Ep 4 (Step 000500): Train loss 0.357, Val loss 0.357
Training accuracy: 92.50% | Validation accuracy: 85.00%
Ep 5 (Step 000550): Train loss 0.460, Val loss 0.392
Ep 5 (Step 000600): Train loss 0.538, Val loss 0.354
Training accuracy: 72.50% | Validation accuracy: 87.50%
Ep 6 (Step 000650): Train loss 

KeyboardInterrupt: 

In [32]:
train_accuracy = calc_accuracy_loader(train_loader, model, device)
val_accuracy = calc_accuracy_loader(val_loader, model, device)
test_accuracy = calc_accuracy_loader(test_loader, model, device)

print(f"Training accuracy: {train_accuracy*100:.2f}%")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
print(f"Test accuracy: {test_accuracy*100:.2f}%")

Training accuracy: 94.23%
Validation accuracy: 92.64%
Test accuracy: 94.00%


In [33]:
def classify_review(text, model, tokenizer, device, max_length=None, pad_token_id=50256):
    model.eval()

    input_ids = tokenizer.encode(text)
    supported_context_length = model.pos_emb.weight.shape[0]
    # Note: In the book, this was originally written as pos_emb.weight.shape[1] by mistake
    # It didn't break the code but would have caused unnecessary truncation (to 768 instead of 1024)

    input_ids = input_ids[:min(max_length, supported_context_length)]


    input_ids += [pad_token_id] * (max_length - len(input_ids))
    input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0)

    with torch.no_grad():
        logits = model(input_tensor)[:, -1, :]  
    predicted_label = torch.argmax(logits, dim=-1).item()

    return "spam" if predicted_label == 1 else "not spam"

In [34]:
text_1 = (
    "You are a winner you have been specially"
    " selected to receive $1000 cash or a $2000 award."
)

print(classify_review(
    text_1, model, tokenizer, device, max_length=train_dataset.max_length
))


spam


In [35]:
torch.save(model.state_dict(), "review_classifier.pth")
