In [1]:
from encoder_layer import Encoder_block
from positional_encoding import Positional_Encoding
import torch
from torch import nn
from transformers import AutoTokenizer

torch.Size([1, 3, 512])


In [2]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt = torch.load("checkpoints/checkpoint_epoch_5.pt", map_location=device)

# recreate model
encoder = Encoder_block(d_model=512, d_ff=2048, num_heads=8).to(device)
d_model = 512
num_classes = 3  # B, I, O
# Define classifier on top of encoder output
classifier = nn.Linear(d_model, num_classes).to(device)  # 3 = BIO tags

encoder.load_state_dict(ckpt['model_state_dict'])
classifier.load_state_dict(ckpt['classifier_state_dict'])

encoder.eval()
classifier.eval()

print("Loaded checkpoint from epoch:", ckpt['epoch'])

Loaded checkpoint from epoch: 5


In [4]:
d_model = 512  # main model dimension
num_heads = 8  # number of heads
d_ff = 2048    # feedforward hidden dimension
seq_len = 256  # max input length
vocab_size = 30000
embedding_layer = nn.Embedding(vocab_size, d_model).to(device)
encoder_layer = Encoder_block(d_model=512, d_ff=2048, num_heads=8).to(device)
pos_encoding = Positional_Encoding(seq_len, d_model).to(device)

def prepare_batch_encoder_input(input_ids):
    # input_ids = (batch_size, seq_len) already tensor
    x = embedding_layer(input_ids)      # (batch, seq_len, d_model)
    x = pos_encoding(x)                 # (batch, seq_len, d_model)
    return x

In [5]:
def tokenize_batch(batch):

    # convert list of tokens to string
    texts = [" ".join(tokens) for tokens in batch["document"]]
    
    tokenized = tokenizer(
        texts,
        truncation=True,
        padding="max_length",
        max_length=256
    )

    # add word_ids manually (for each sample)
    tokenized["word_ids"] = [
        tokenized.word_ids(i) for i in range(len(texts))
    ]

    return tokenized

In [6]:
def encode_text(text):
    encoded = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=256
    )

    input_ids = encoded["input_ids"].to(device)

    # Step 1: convert input_ids → embeddings
    x = embedding_layer(input_ids)
    x = pos_encoding(x)

    # Step 2: pass through trained encoder
    with torch.no_grad():
        encoder_output, attn_weights = encoder(x, mask=None)

    return encoder_output, attn_weights, input_ids

In [7]:
features, attention, input_ids = encode_text("Transformers are amazing.")

In [8]:
print(features.shape)

torch.Size([1, 256, 512])


In [9]:
def tokens_to_words(input_ids):
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
    return tokens

In [10]:
def extract_keywords_better(input_ids, attn_weights, top_k=5):
    
    # CASE 1 — attn_weights is a list of tensors (layers × heads)
    if isinstance(attn_weights, list):
        attn_stack = torch.stack(attn_weights)   # (L, H, S, S)
        cls_scores = attn_stack[:, :, 0, :].mean(dim=(0, 1))  # avg layers+heads
    
    # CASE 2 — attn_weights is a single tensor (heads × S × S)
    else:
        # attn_weights = (H, S, S)
        cls_scores = attn_weights[:, 0, :].mean(dim=0)  # avg heads

    # Remove CLS token (index 0)
    token_scores = cls_scores[1:]  # first token is CLS

    # Get top tokens
    top_indices = torch.topk(token_scores, top_k).indices

    # Get wordpiece IDs → convert to tokens
    top_token_ids = input_ids[0][1:][top_indices]   # skip CLS
    
    top_tokens = tokenizer.convert_ids_to_tokens(top_token_ids.tolist())
    return top_tokens

In [11]:
encoder_output, attn_weights, input_ids = encode_text("Transformers are amazing.")
attn_weights = attn_weights.squeeze(0)

keywords = extract_keywords_better(input_ids, attn_weights, top_k=4)

print("Extracted Keywords:", keywords)

Extracted Keywords: ['are', 'transformers', '.', '[SEP]']


In [12]:
print("input_ids shape:", input_ids.shape)
print("attn_weights shape:", attn_weights.shape)

input_ids shape: torch.Size([1, 256])
attn_weights shape: torch.Size([8, 256, 256])


In [13]:
STOPWORDS = {"is", "are", "the", "a", "an", "of", "to", "and"}

def filter_keywords(words):
    return [w for w in words if w.lower() not in STOPWORDS]

In [14]:
keywords = extract_keywords_better(input_ids, attn_weights, top_k=4)
# keywords = merge_subwords(keywords)
keywords = filter_keywords(keywords)[:5]  # pick top 2 after filtering
print(keywords)

['transformers', '.', '[SEP]']


In [15]:
import nltk
nltk.download("averaged_perceptron_tagger_eng")

from nltk import pos_tag

def prefer_nouns_adjectives(words):
    tags = pos_tag(words, tagset="universal")  # works in new NLTK
    ranked = sorted(tags, key=lambda x: 0 if x[1] in ("NOUN","ADJ") else 1)
    return [w for w, t in ranked]

[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     C:\Users\kumar\AppData\Roaming\nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


In [18]:
def merge_subwords(tokens):
    merged = []
    current = ""

    for t in tokens:
        if t.startswith("##"):
            current += t[2:]            # append subword
        else:
            if current:                # save previous word
                merged.append(current)
            current = t                # start a new word

    if current:
        merged.append(current)

    return merged

In [20]:
keywords = extract_keywords_better(input_ids, attn_weights, top_k=2)
keywords = merge_subwords(keywords)
keywords = filter_keywords(keywords)
keywords[:3]

['transformers']

In [3]:
import torch, os
from torch import nn
from transformers import AutoTokenizer
import re
from encoder_layer import Encoder_block
from positional_encoding import Positional_Encoding

# ----------------------------- CONFIG -----------------------------
d_model = 512
num_heads = 8
d_ff = 2048
seq_len = 256
vocab_size = 30000
num_classes = 3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ------------------------- LOAD TOKENIZER -------------------------
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# -------------------------- BUILD MODEL --------------------------
embedding_layer = nn.Embedding(vocab_size, d_model)
pos_encoding = Positional_Encoding(seq_len, d_model)
encoder = Encoder_block(d_model=d_model, d_ff=d_ff, num_heads=num_heads)
classifier = nn.Linear(d_model, num_classes)   # B, I, O

# Move everything to device
embedding_layer.to(device)
pos_encoding.to(device)
encoder.to(device)
classifier.to(device)

# -------------------------- LOAD CHECKPOINT ----------------------
# Change this to your best checkpoint
CHECKPOINT_PATH = "checkpoints/checkpoint_epoch_5.pt"   # or whichever you prefer

checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)

# embedding_layer.load_state_dict(torch.load("checkpoints/embedding_layer.pt") if os.path.exists("checkpoints/embedding_layer.pt") 
#                                else print("Warning: No separate embedding checkpoint, assuming it was saved inside encoder"))
# If you only saved encoder + classifier (as in your training script), load like this:
encoder.load_state_dict(checkpoint['model_state_dict'])
classifier.load_state_dict(checkpoint['classifier_state_dict'])

embedding_layer.eval()
encoder.eval()
classifier.eval()

# -------------------------- ALIGNMENT HELPERS --------------------
def align_tags_with_tokens(predicted_ids, word_ids):
    """
    Convert subword-level predictions back to word-level BIO tags.
    Rule: If any subword of a word is predicted as B or I → the whole word gets that tag.
          Priority: B > I > O
    """
    word_level_tags = []
    current_word_id = None
    current_best_tag = 0  # O

    for word_id, pred_id in zip(word_ids, predicted_ids):
        if word_id is None:
            continue
        if word_id != current_word_id:
            if current_word_id is not None:
                word_level_tags.append(current_best_tag)
            current_word_id = word_id
            current_best_tag = pred_id
        else:
            # Same word → keep the "strongest" tag (B > I > O)
            if pred_id == 1:        # B has highest priority
                current_best_tag = 1
            elif pred_id == 2 and current_best_tag == 0:
                current_best_tag = 2

    # Don't forget the last word
    if current_word_id is not None:
        word_level_tags.append(current_best_tag)

    return word_level_tags

id2tag = {0: "O", 1: "B", 2: "I"}

# -------------------------- PREDICTION FUNCTION ------------------
def extract_keyphrases(text: str, top_k=None):
    # 1. Tokenize with word_ids
    encoded = tokenizer(
        text,
        truncation=True,
        max_length=256,
        return_offsets_mapping=True,
        return_special_tokens_mask=True
    )
    
    input_ids = torch.tensor([encoded["input_ids"]]).to(device)  # (1, seq_len)
    attention_mask = torch.tensor([encoded["attention_mask"]]).to(device)

    # word_ids for alignment
    word_ids = encoded.word_ids()

    # 2. Forward pass
    with torch.no_grad():
        x = embedding_layer(input_ids)              # (1, seq, d_model)
        x = pos_encoding(x)
        encoder_out, _ = encoder(x, mask=None)      # (1, seq, d_model)
        logits = classifier(encoder_out)            # (1, seq, 3)
        
        predictions = torch.argmax(logits, dim=-1)   # (1, seq)
        predictions = predictions[0].cpu().tolist() # list of length seq

    # 3. Align subword predictions → word-level
    word_level_preds = align_tags_with_tokens(predictions, word_ids)

    # 4. Extract phrases using offsets
    tokens = tokenizer.convert_ids_to_tokens(encoded["input_ids"])
    offsets = encoded["offset_mapping"]

    keyphrases = []
    current_phrase = []
    current_start = None

    for token, offset, pred in zip(tokens, offsets, predictions):
        if encoded["special_tokens_mask"][encoded["input_ids"].index(tokenizer.convert_tokens_to_ids(token))]:
            continue  # skip [CLS], [SEP], etc.

        word_id = encoded.word_ids()[encoded["input_ids"].index(tokenizer.convert_tokens_to_ids(token))]
        if word_id is None:
            continue

        word_pred = word_level_preds[word_id]

        if word_pred == 1:  # B
            if current_phrase:
                keyphrases.append(" ".join(current_phrase))
            current_phrase = [token.replace("##", "")]
            current_start = offset[0]
        elif word_pred == 2 and current_phrase:  # I
            current_phrase.append(token.replace("##", ""))
        else:  # O
            if current_phrase:
                keyphrases.append(" ".join(current_phrase))
                current_phrase = []

    if current_phrase:
        keyphrases.append(" ".join(current_phrase))

    # Optional: deduplication + lowercasing for cleaner output
    keyphrases = list(dict.fromkeys(keyphrases))  # preserve order
    keyphrases = [kp.strip() for kp in keyphrases if kp.strip()]

    if top_k:
        return keyphrases[:top_k]
    return keyphrases

# ----------------------------- USAGE EXAMPLE -----------------------------
if __name__ == "__main__":
    sample_text = """
    Self-supervised learning has become an important paradigm in natural language processing.
    Models like BERT and RoBERTa achieve state-of-the-art performance by pre-training on large corpora.
    Keyphrase extraction remains a challenging task in information retrieval.
    """

    print("Input text:")
    print(sample_text.strip())
    print("\nPredicted keyphrases:")
    kps = extract_keyphrases(sample_text, top_k=10)
    for i, kp in enumerate(kps, 1):
        print(f"{i}. {kp}")

Input text:
Self-supervised learning has become an important paradigm in natural language processing.
    Models like BERT and RoBERTa achieve state-of-the-art performance by pre-training on large corpora.
    Keyphrase extraction remains a challenging task in information retrieval.

Predicted keyphrases:
1. self - supervised learning has
2. paradigm in natural language processing
3. state -
4. pre - training on large corp ora
5. key
6. ph
7. rase extraction remains


In [7]:
projection = nn.Linear(768, 512)

projection.load_state_dict(checkpoint['projection_state_dict'])
projection.to(device)
projection.eval()

KeyError: 'projection_state_dict'

In [20]:
# predict_and_visualize.py

from encoder_layer import Encoder_block
from positional_encoding import Positional_Encoding
from transformers import BertTokenizerFast, BertModel
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained("distilbert-base-uncased")
bert_emb = DistilBertModel.from_pretrained("distilbert-base-uncased").get_input_embeddings()

d_model = 512
num_heads = 8
d_ff = 2048

encoder = Encoder_block(d_model=d_model, num_heads=num_heads, d_ff=d_ff)
classifier = nn.Linear(d_model, 3)

checkpoint = torch.load("restored_model.pt", map_location="cpu")
encoder.load_state_dict(checkpoint['encoder_state_dict'])
classifier.load_state_dict(checkpoint['classifier_state_dict'])
projection.load_state_dict(ckpt["projection_state_dict"])

device = "cuda" if torch.cuda.is_available() else "cpu"
encoder.to(device)
classifier.to(device)
bert_emb.to(device)
projection.to(device)

encoder.eval()
classifier.eval()

def extract_keyphrases(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
    input_ids = inputs["input_ids"].to(device)

    with torch.no_grad():
        x = bert_emb(input_ids)
        x = projection(x)  
        encoder_out, attn_weights = encoder(x, mask=inputs["attention_mask"].to(device))
        
        # attn_weights will be list of (batch, heads, seq, seq) or whatever your function returns
        logits = classifier(encoder_out)
        preds = torch.argmax(logits, dim=-1)[0].cpu().numpy()

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    phrases = []
    curr = []
    for tok, pred in zip(tokens, preds):
        if tok in ["[CLS]", "[SEP]", "[PAD]"]: continue
        tok = tok.replace("##", "")
        if pred == 1:  # B
            if curr: phrases.append(" ".join(curr))
            curr = [tok]
        elif pred == 2 and curr:
            curr.append(tok)
        else:
            if curr: phrases.append(" ".join(curr)); curr = []
    if curr: phrases.append(" ".join(curr))
    return phrases

# Test
text = "Transformer networks have revolutionized natural language processing with self-attention mechanisms."
print(extract_keyphrases(text))
# Output milega: ['transformer networks', 'natural language processing', 'self-attention mechanisms']

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DistilBertTokenizer'. 
The class this function is called from is 'BertTokenizerFast'.


['transform', 'er', 'networks', 'have', 'ized', 'language', 'processing', 'with', 'attention', 'mechanisms', '.']


In [5]:
ckpt = torch.load("checkpoint_epoch_15.pt", map_location="cpu")
print(ckpt.keys())

dict_keys(['epoch', 'model_state_dict', 'classifier_state_dict', 'optimizer_state_dict', 'loss'])


In [8]:
projection = nn.Linear(768, 512)
encoder = Encoder_block(d_model=512, num_heads=8, d_ff=2048)
classifier = nn.Linear(512, 3)

In [10]:
ckpt = torch.load("checkpoint_epoch_15.pt", map_location="cpu")
optimizer = torch.optim.AdamW(
    list(projection.parameters()) +
    list(encoder.parameters()) +
    list(classifier.parameters()),
    lr=1e-4
)

optimizer.load_state_dict(ckpt["optimizer_state_dict"])

In [11]:
torch.save({
    "projection_state_dict": projection.state_dict(),
    "encoder_state_dict": encoder.state_dict(),
    "classifier_state_dict": classifier.state_dict()
}, "restored_model.pt")

In [12]:
ckpt = torch.load("restored_model.pt")
projection.load_state_dict(ckpt["projection_state_dict"])
encoder.load_state_dict(ckpt["encoder_state_dict"])
classifier.load_state_dict(ckpt["classifier_state_dict"])

<All keys matched successfully>