In [1]:
import nltk
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from datasets import load_from_disk
# Load dataset (replace with your actual dataset loading code)
import torch
from transformers import AutoTokenizer, BertForSequenceClassification, TrainingArguments, Trainer
from torch.utils.data import DataLoader
import nltk
nltk.download('punkt')

dataset = load_from_disk("bbc_dataset") 


  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to
[nltk_data]     /Users/jameelamer/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# Constants
MAX_SENTENCES = 8
SEQ_LENGTH = 128
BATCH_SIZE = 4

In [3]:
# 1. Preprocessing Function
def preprocess_function(examples):
    tokenized_articles = []
    tokenized_masks = []
    labels_list = []

    for article, summary in zip(examples["Article"], examples["extractive_summary"]):
        # Clean inputs
        article = " ".join(article) if isinstance(article, list) else article
        summary = " ".join(summary) if isinstance(summary, list) else summary
        if (summary==""):
            return {"input_ids": [], "attention_mask": [], "labels": []} 
        # Tokenize sentences
        sentences = nltk.sent_tokenize(article)[:MAX_SENTENCES]
        num_sentences = len(sentences)
        
        # Tokenize all sentences
        tokenized = tokenizer(
            sentences,
            padding="max_length",
            truncation=True,
            max_length=SEQ_LENGTH,
            return_tensors="pt"
        )
        
        # Create labels (1 if sentence is in summary)
        labels = [1 if sent in summary else 0 for sent in sentences]
        labels += [0] * (MAX_SENTENCES - num_sentences)  # Pad labels
        
        # Pad tensors
        padded_input_ids = torch.zeros((MAX_SENTENCES, SEQ_LENGTH), dtype=torch.long)
        padded_attention_mask = torch.zeros((MAX_SENTENCES, SEQ_LENGTH), dtype=torch.long)
        
        padded_input_ids[:num_sentences] = tokenized["input_ids"]
        padded_attention_mask[:num_sentences] = tokenized["attention_mask"]
        
        tokenized_articles.append(padded_input_ids)
        tokenized_masks.append(padded_attention_mask)
        labels_list.append(torch.tensor(labels, dtype=torch.float))

    return {
        "input_ids": tokenized_articles,
        "attention_mask": tokenized_masks,
        "labels": labels_list
    }

# 2. Apply Preprocessing

val_dataset = dataset["validation"].map(preprocess_function, batched=True)
train_dataset = dataset["train"].map(preprocess_function, batched=True)


Map: 100%|██████████| 445/445 [00:00<00:00, 797.20 examples/s]
Map: 100%|██████████| 1335/1335 [00:01<00:00, 1044.37 examples/s]


In [4]:
train_dataset.column_names

['Title',
 'Article',
 'Summary',
 'Category',
 'extractive_summary',
 'input_ids',
 'attention_mask',
 'labels']

In [5]:
# Remove unnecessary columns
columns_to_remove = ["filename", "Article", "Summary", "__index_level_0__","extractive_summary"]
train_dataset = train_dataset.remove_columns([col for col in columns_to_remove if col in train_dataset.column_names])
val_dataset = val_dataset.remove_columns([col for col in columns_to_remove if col in val_dataset.column_names])

# Set format for PyTorch
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

In [6]:
# Custom Collate Function
def collate_fn(batch):
    return {
        "input_ids": torch.stack([item["input_ids"] for item in batch]),
        "attention_mask": torch.stack([item["attention_mask"] for item in batch]),
        "labels": torch.stack([item["labels"] for item in batch])
    }


# Custom Trainer Class
class SentenceTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):  # Add **kwargs
        # Reshape inputs: [batch, sentences, seq_len] -> [batch*sentences, seq_len]
        batch_size, num_sentences, seq_len = inputs["input_ids"].shape
        flat_inputs = {
            "input_ids": inputs["input_ids"].view(-1, seq_len),
            "attention_mask": inputs["attention_mask"].view(-1, seq_len)
        }
        
        # Forward pass
        outputs = model(**flat_inputs)
        logits = outputs.logits.view(batch_size, num_sentences)
        
        # BCEWithLogitsLoss for multi-label classification
        loss_fct = torch.nn.BCEWithLogitsLoss()
        loss = loss_fct(logits, inputs["labels"])
        
        return (loss, outputs) if return_outputs else loss


In [7]:
# Initialize Model
model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=1,  # Binary classification per sentence
    problem_type="multi_label_classification"
)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
# Training Arguments
training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=3,
    evaluation_strategy="steps",
    eval_steps=15000,
    save_steps=30000,
    logging_dir="./logs",
    logging_steps=100,
    learning_rate=2e-5,
    warmup_steps=500,
    weight_decay=0.01,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="loss",
)



In [9]:
# Create Trainer
# Initialize Trainer with the fixed class
trainer = SentenceTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,
)

# 3. Start training
trainer.train()

#  Save the final model
trainer.save_model("./bertsum_bbc_news/bertsum_finetuned_model")
tokenizer.save_pretrained("./bertsum_bbc_news/bertsum_finetuned_model")

Step,Training Loss,Validation Loss


('./bertsum_bbc_news/bertsum_finetuned_model/tokenizer_config.json',
 './bertsum_bbc_news/bertsum_finetuned_model/special_tokens_map.json',
 './bertsum_bbc_news/bertsum_finetuned_model/vocab.txt',
 './bertsum_bbc_news/bertsum_finetuned_model/added_tokens.json',
 './bertsum_bbc_news/bertsum_finetuned_model/tokenizer.json')

In [10]:
import torch

def generate_summary(model, tokenizer, text, device):
    model.eval()
    
    # Tokenize input text
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)

    # Move inputs to the correct device
    inputs = {key: val.to(device) for key, val in inputs.items()}
    
    with torch.no_grad():
        output = model(**inputs)  # Forward pass on MPS
        logits = output.logits.squeeze(-1)  # Extract logits
        
        # Ensure logits are moved to CPU before processing
        logits = logits.cpu()

        # Select sentences using thresholding
        predicted_labels = (logits > 0.5).int()
        
        sentences = text.split(". ")  # Sentence tokenization
        min_length = min(len(sentences), len(predicted_labels))
        # print("Logits:", logits)
        # print("Predicted Labels:", predicted_labels)
        # print("Sentences:", sentences)
        selected_sentences = [sentences[i] for i in range(min_length) if predicted_labels[i] == 1]        
        summary = " ".join(selected_sentences)
        return summary

# Detect MPS device on Mac
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)  # Move model to MPS


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [11]:
# Example Usage:
text = """
Jamieson issues warning to bigots\n\nScotland's justice minister has warned bigoted soccer fans that she wants to hit them "where it hurts most" by banning them from matches.\n\nCathy Jamieson said exclusion orders are one of a series of measures being considered in the Scottish Executive campaign against sectarianism. She praised Celtic and Rangers for their work in tackling the problem. However, the minister said stopping sectarian abuse associated with Old Firm matches is a key objective. Ms Jamieson was speaking ahead of the third round Scottish Cup clash between the Glasgow clubs at Parkhead on Sunday. The sectarianism long associated with sections of the support from both clubs has become a significant target for the executive. Last week Ms Jamieson and First Minister Jack McConnell met supporters' representatives from both clubs to discuss the issue.\n\nThey plan to hold an anti-sectarian summit next month with officials from the clubs, church leaders, senior police officers and local authority chiefs among those to be invited. Speaking on BBC Radio Scotland's Sunday Live programme, Ms Jamieson described Friday's meeting as "very productive" and said putting the squeeze on the bigots would be a key aim. Ms Jamieson stressed that sectarianism has not been confined to football but it can act as a "trigger" for tensions and violence. Clubs have taken action in the past to ban troublesome fans and supporters' groups expressed their desire to ensure that the game is no longer tainted by the problem.\n\nMs Jamieson said the executive should have a role in tackling the soccer troublemakers. She said: "We can't get away from the fact that in some instances some of the religious hatred that some people try to associate with football boils over into violence. "That is the kind of thing we want to stop and that's the kind of thing supporters' groups are very clear they don't want to be part of either, and they will work with us to try and deal with that."\n\nMs Jamieson praised the police for their action and said: "The police do want to identify whether there are particular individuals who are going over the top and inciting hatred or violence - they will crack down very effectively on them. "We have of course already indicated that we will consider the introduction of banning orders to give additional powers to where there are people who are going over the top, who have made inappropriate behaviour at football matches, to be able to stop them attending the games. "That's the kind of thing that will hit those kind of people where it hurts the most in not allowing them to attend the games," she said. Praising Celtic and Rangers for their efforts, she said: "I don't think there is any doubt that we have seen some positive moves from the clubs. "Both Rangers and Celtic football clubs have been involved in working with the executive to produce, for example, an educational pack for  """
summary = generate_summary(model, tokenizer, text, device)
print("Generated Summary:", summary)


Generated Summary: 

Scotland's justice minister has warned bigoted soccer fans that she wants to hit them "where it hurts most" by banning them from matches.

Cathy Jamieson said exclusion orders are one of a series of measures being considered in the Scottish Executive campaign against sectarianism
