In [1]:
!pip install -qU evaluate --no-deps

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25h

## Prepare the data
- Loading the [twitter sentiment analysis data](https://www.kaggle.com/datasets/jp797498e/twitter-entity-sentiment-analysis) from kaggle
- Remove duplicates
- Remove null rows
- Remove non-useful label and columns
- Map the labels from Object to Int
- Extract a balanced 5000 entries from the train since I have a compute constraint
- Transform the two splits: train and validation. To a DatasetDict object
- Clean the text from urls, tags, lowercase it, etc.
- Finally Tokenize your data

In [39]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("jp797498e/twitter-entity-sentiment-analysis")

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/


In [40]:
import pandas as pd
import os

column_names = ['id', 'entity', 'sentiment', 'text']

train_df = pd.read_csv(os.path.join(path, "twitter_training.csv"), header=None, names=column_names)
val_df = pd.read_csv(os.path.join(path, "twitter_validation.csv"), header=None, names=column_names)

train_df.head()

Unnamed: 0,id,entity,sentiment,text
0,2401,Borderlands,Positive,im getting on borderlands and i will murder yo...
1,2401,Borderlands,Positive,I am coming to the borders and I will kill you...
2,2401,Borderlands,Positive,im getting on borderlands and i will kill you ...
3,2401,Borderlands,Positive,im coming on borderlands and i will murder you...
4,2401,Borderlands,Positive,im getting on borderlands 2 and i will murder ...


In [41]:
# Remove duplicate rows
train_df = train_df.drop_duplicates()
val_df = val_df.drop_duplicates()
train_df.isna().sum(), val_df.isna().sum()

(id             0
 entity         0
 sentiment      0
 text         326
 dtype: int64,
 id           0
 entity       0
 sentiment    0
 text         0
 dtype: int64)

In [42]:
# Remove rows with any missing values
train_df = train_df.dropna()

In [43]:
train_df.isna().sum()

id           0
entity       0
sentiment    0
text         0
dtype: int64

In [44]:
train_df.sentiment.value_counts()

sentiment
Negative      21698
Positive      19713
Neutral       17708
Irrelevant    12537
Name: count, dtype: int64

In [45]:
val_df.sentiment.value_counts()

sentiment
Neutral       285
Positive      277
Negative      266
Irrelevant    172
Name: count, dtype: int64

In [46]:
# Remove rows with 'Irrelevant' sentiment
train_df = train_df[train_df['sentiment'] != 'Irrelevant']
# Remove rows with 'Irrelevant' sentiment from validation DataFrame
val_df = val_df[val_df['sentiment'] != 'Irrelevant']

In [47]:
train_df.sentiment.value_counts(True)

sentiment
Negative    0.367022
Positive    0.333446
Neutral     0.299531
Name: proportion, dtype: float64

In [48]:
# Define the mapping dictionary
sentiment_mapping = {'Negative': 0, 'Neutral': 1, 'Positive': 2}

# Apply the mapping to create a new 'label' column
train_df['label'] = train_df['sentiment'].map(sentiment_mapping)
val_df['label'] = val_df['sentiment'].map(sentiment_mapping)
train_df.label.value_counts()

label
0    21698
2    19713
1    17708
Name: count, dtype: int64

In [49]:
# Determine the number of classes
num_classes = train_df['label'].nunique()

# Calculate the number of samples per class
samples_per_class = 5000 // num_classes

# Sample the data
balanced_train_df = (
    train_df.groupby('label', group_keys=False)
    .apply(lambda x: x.sample(n=samples_per_class, random_state=42))
    .reset_index(drop=True)
)

  .apply(lambda x: x.sample(n=samples_per_class, random_state=42))


In [50]:
balanced_train_df.head()

Unnamed: 0,id,entity,sentiment,text,label
0,13162,Xbox(Xseries),Negative,Damn!!!!!!,0
1,8568,NBA2K,Negative,Fix guarding my corner please @NBA2K I literal...,0
2,1922,CallOfDutyBlackopsColdWar,Negative,Can we keep warzone as the COD br and simply u...,0
3,10888,TomClancysGhostRecon,Negative,had to jump underground to get a double kill..,0
4,4778,Google,Negative,"My google hangouts is not working, my laptop w...",0


In [51]:
balanced_train_df.label.value_counts(True)

label
0    0.333333
1    0.333333
2    0.333333
Name: proportion, dtype: float64

In [52]:
val_df.label.value_counts(True)

label
1    0.344203
2    0.334541
0    0.321256
Name: proportion, dtype: float64

In [53]:
from datasets import Dataset, DatasetDict

# Select only the 'text' and 'label' columns
train_dataset = Dataset.from_pandas(balanced_train_df[['text', 'label']])
val_dataset = Dataset.from_pandas(val_df[['text', 'label']])

# Combine into a DatasetDict
ds_split = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset
})

In [54]:
ds_split

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 4998
    })
    validation: Dataset({
        features: ['text', 'label', '__index_level_0__'],
        num_rows: 828
    })
})

In [None]:
import re
def clean_text(examples):
    for i in range(len(examples['text'])):
        #remove urls
        examples['text'][i] = re.sub(r'https?://\S+|www\.\S+', "", examples['text'][i])
        # remove html tags
        examples['text'][i] = re.sub(r'<[^>]+>', '', examples['text'][i])
        # handle spaces
        examples['text'][i] = re.sub(r"\s+", " ", examples['text'][i]).strip()
        # convert to lower case
        examples['text'][i] = examples['text'][i].lower()
processed_ds = ds_split.map(clean_text, batched=True)

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")

def preprocess_function(examples):
    return tokenizer(examples['text'], padding="max_length", truncation=True)

tokenized_ds = ds_split.map(preprocess_function, batched=True)

In [21]:
tokenized_ds

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 4998
    })
    validation: Dataset({
        features: ['text', 'label', '__index_level_0__', 'input_ids', 'attention_mask'],
        num_rows: 828
    })
})

## Fine-tune DistillBERT
I used DistilBERT for training because its lightweight (doesn't need too much compute) and it gives good results with few epochs.
steps:
- Load the model from huggingface hub
- Define a compute_metrics function in order to use it for evaluation (accuracy, precision, recall and f1)
- Define my training arguments and trainer
- Start training for only 5 epochs
    - We see the model for only 5 epochs get 77% on f1 
- Evaluate on the validation data
- Finally push the model to my huggingface hub for later uses
    - You can find it in my profile in huggingface benhima-mohamed-amine.

In [46]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

model = AutoModelForSequenceClassification.from_pretrained("distilbert/distilbert-base-uncased", num_labels=3)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
import evaluate
import numpy as np

accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)

    accuracy = accuracy_metric.compute(predictions=predictions, references=labels)["accuracy"]
    precision = precision_metric.compute(predictions=predictions, references=labels, average="macro")["precision"]
    recall = recall_metric.compute(predictions=predictions, references=labels, average="macro")["recall"]
    f1 = f1_metric.compute(predictions=predictions, references=labels, average="macro")["f1"]

    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }


In [27]:
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()

hf_token = user_secrets.get_secret("HUGGINGFACE_TOKEN")
login(token = hf_token)

In [50]:
training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=1,
    lr_scheduler_type="linear",
    warmup_ratio=0.2,
    run_name="distellbert_sentiment_analysis",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    fp16=True,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds['train'],
    eval_dataset=tokenized_ds['validation'],
    compute_metrics=compute_metrics,
)

In [51]:
trainer.train()



Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.2615,0.780899,0.683575,0.688769,0.684708,0.683659
2,0.7083,0.672236,0.730676,0.733844,0.733575,0.728369
3,0.378,0.63728,0.756039,0.756206,0.756974,0.756115
4,0.3745,0.638266,0.772947,0.772604,0.774199,0.772881
5,0.8133,0.648752,0.777778,0.777364,0.779566,0.776698




TrainOutput(global_step=785, training_loss=0.6221721377911841, metrics={'train_runtime': 684.0321, 'train_samples_per_second': 36.533, 'train_steps_per_second': 1.148, 'total_flos': 3310419327989760.0, 'train_loss': 0.6221721377911841, 'epoch': 5.0})

In [57]:
trainer.evaluate()



{'eval_loss': 0.6487522125244141,
 'eval_accuracy': 0.7777777777777778,
 'eval_precision': 0.7773635046889577,
 'eval_recall': 0.779565640937482,
 'eval_f1': 0.7766975901308637,
 'eval_runtime': 7.1803,
 'eval_samples_per_second': 115.316,
 'eval_steps_per_second': 3.621,
 'epoch': 5.0}

In [None]:
trainer.model.push_to_hub(
    repo_id="mohamed-amine-benhima/distellbert-base-uncased-commonshare-5epochs",
    use_temp_dir=True 
)

## Topic Modeling

In [None]:
from nltk.corpus import stopwords
import string
import spacy
import re

# Load spacy model
nlp = spacy.load("en_core_web_sm")
stop_words = set(stopwords.words("english"))
punct_table = str.maketrans("", "", string.punctuation)

# Clean tweet text 
def clean_tweet(example):
    text = example['text'] 
    text = re.sub(r"http\S+|www\S+", "", text) #remove URLs
    text = re.sub(r"@\w+", "", text) # remove mentions
    text = re.sub(r"#\w+", "", text) # remove hashtags
    return {"text": text}

# Tokenize and lemmatize cleaned text
def preprocess(example):
    doc = nlp(example["text"].lower().translate(punct_table)) # lower ==> remove punct ==> lemmetatization
    tokens = [token.lemma_ for token in doc if token.lemma_ not in stop_words and token.is_alpha] # filiter stop words and non alphabetic words
    return {"tokens": tokens}

# Apply cleaning and preprocessing
ds_cleaned = ds_split.map(clean_tweet)
processed_ds = ds_cleaned.map(preprocess)


In [56]:
processed_ds

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'tokens'],
        num_rows: 4998
    })
    validation: Dataset({
        features: ['text', 'label', '__index_level_0__', 'tokens'],
        num_rows: 828
    })
})

In [57]:
# view the first tokens col in the two rows in train
ds['train'][:2]['tokens']

[['damn'],
 ['fix',
  'guard',
  'corner',
  'please',
  'I',
  'literally',
  'ball',
  'guard',
  'get',
  'open',
  'haha']]

In [58]:
from gensim.corpora import Dictionary

# Create a dictionary from the tokens
dictionary = Dictionary(ds['train']['tokens'])

# Filter out rare and very common words
dictionary.filter_extremes(no_below=5, no_above=0.5)

# Create the corpus (BoW representation)
corpus = [dictionary.doc2bow(tokens) for tokens in ds['train']['tokens']]

### Hyper-parameters tuning

In [60]:
import itertools
import numpy as np
import pandas as pd
from gensim.models import LdaModel
from gensim.models.coherencemodel import CoherenceModel

# Define the parameter grid
param_grid = {
    'num_topics': [5, 10, 15],
    'alpha': ['symmetric', 0.1, 0.5, 0.9],
    'eta': ['auto', 0.1, 0.5, 0.9]
}

# Generate all combinations of parameters
param_combinations = list(itertools.product(*param_grid.values()))
param_names = list(param_grid.keys())

# Function to compute coherence score
def compute_coherence(corpus, dictionary, num_topics, alpha, eta):
    lda_model = LdaModel(
        corpus=corpus,
        id2word=dictionary,
        num_topics=num_topics,
        alpha=alpha,
        eta=eta,
        random_state=42,
        passes=10,
        per_word_topics=True
    )
    coherence_model = CoherenceModel(model=lda_model, texts=ds['train']['tokens'], dictionary=dictionary, coherence='c_v')
    return coherence_model.get_coherence()

# Initialize a list to store results
results = []

# Perform grid search
for params in param_combinations:
    param_dict = dict(zip(param_names, params))
    coherence = compute_coherence(corpus, dictionary, **param_dict)
    results.append(param_dict | {'coherence': coherence})

# Convert results to a DataFrame
results_df = pd.DataFrame(results)

# Find the best parameters
best_params = results_df.loc[results_df['coherence'].idxmax()]
print("Best Parameters:", best_params)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Best Parameters: num_topics           15
alpha         symmetric
eta                 0.5
coherence      0.393105
Name: 34, dtype: object


we have coherance of 0.4 which is normal for a subset of twitter dataset.

In [64]:
# Extract best parameters
best_num_topics = best_params['num_topics']
best_alpha = best_params['alpha']
best_eta = best_params['eta']

# Train the LDA model with best parameters
lda_model = LdaModel(
    corpus=corpus,
    id2word=dictionary,
    num_topics=best_num_topics,
    alpha=best_alpha,
    eta=best_eta,
    random_state=42,
    passes=10,
    per_word_topics=True
)

In [65]:
for idx, topic in lda_model.show_topics(num_topics=best_num_topics, num_words=10, formatted=False):
    print(f"Topic {idx}:")
    print(", ".join([word for word, _ in topic]))
    print()

Topic 0:
game, good, creed, assassin, new, world, black, access, look, youtube

Topic 1:
xbox, series, x, amazing, wow, console, game, one, look, incredible

Topic 2:
verizon, nvidia, nice, service, card, com, customer, crazy, lovely, top

Topic 3:
happy, skin, birthday, weekend, art, eamaddennfl, god, club, madden, reddit

Topic 4:
great, stream, beautiful, g, see, video, thank, new, gaming, use

Topic 5:
unk, pubg, player, u, ban, wtf, interesting, play, mobile, csgo

Topic 6:
league, legend, play, kill, well, apex, come, omg, ai, get

Topic 7:
I, love, game, one, play, rhandlerr, get, thank, good, like

Topic 8:
call, duty, gta, pubg, include, car, v, c, cute, india

Topic 9:
I, get, like, play, good, go, back, much, time, come

Topic 10:
good, fuck, look, big, go, fifa, long, another, back, game

Topic 11:
dead, red, redemption, go, shit, epic, awesome, guy, fucking, game

Topic 12:
I, get, buy, game, wait, excited, go, shit, fix, year

Topic 13:
johnson, google, people, microsoft,

In [66]:
!pip install -q pyLDAvis

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.6/60.6 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m38.6/38.6 MB[0m [31m45.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tsfresh 0.21.0 requires scipy>=1.14.0; python_version >= "3.10", but you have scipy 1.13.1 which is incompatible.
cesium 0.12.4 requires numpy<3.0,>=2.0, but you have numpy 1.26.4 which is incompatible.
dopamine-rl 4.1.2 requires gymnasium>=1.0.0, but you have gymnasium 0.29.0 which is incompatible.
imbalanced-learn 0.13.0 requires scikit-learn<2,>=1.3.2, but you have scikit-learn 1.2.2 which is incompatible.
plotnine 0.14.5 requires matplotlib>=3.8.0, but you have matplotlib 3.7.2 which is incompatible.
mlxtend 0.23.4 requires scikit-learn>=1.3.1, b

### Visualization

In [69]:
import pyLDAvis
import pyLDAvis.gensim

# Visualize the topics
pyLDAvis.enable_notebook()

LDAvis_prepared = pyLDAvis.gensim.prepare(lda_model, corpus, dictionary)

LDAvis_prepared

## NER

In [76]:
import spacy
from spacy import displacy

# Load the English model
nlp = spacy.load("en_core_web_sm")

# Process and display entities for 10 samples from the training set
for i in range(35, 45):
    text = ds_split["train"][i]["text"]  # get the i-th text
    doc = nlp(text)  # apply NER and other NLP pipeline components
    
    print(f"Text {i+1}:")
    for ent in doc.ents:
        print(f"  - Entity: {ent.text}, Label: {ent.label_}")
    
    displacy.render(doc, style="ent", jupyter=True)

Text 36:
  - Entity: seven years, Label: DATE


Text 37:
  - Entity: Towski, Label: PERSON


Text 38:
  - Entity: an hour, Label: TIME


Text 39:


Text 40:


Text 41:
  - Entity: 3080, Label: CARDINAL


Text 42:


Text 43:
  - Entity: Lady Gaga, Label: PERSON


Text 44:
  - Entity: 2, Label: CARDINAL


Text 45:
  - Entity: tonight, Label: TIME
  - Entity: Trump, Label: PERSON
  - Entity: Johnson &Johnson, Label: ORG
  - Entity: Drs, Label: PERSON


## Extractive Summary
I used BERTSum for Extractive Summary because it gives very good results, as shown below.

Since I want anyone who has this notebook to be able to run the code, some functions and classes I wrote inside cells.
For typical real world project i am going to move them to seperate .py files

[Source and Credits](https://github.com/ereverter/bertsum-hf/tree/main)

### Extractive Summarization with BERT

In this approach, BERT is adapted for extractive summarization by treating it as a sentence-level classification task:

- Sentence Segmentation: The input document is divided into individual sentences.
- Input Formatting: Each sentence is preceded by a [CLS] token and followed by a [SEP] token. This structure allows BERT to process the entire document while distinguishing between sentences.
- Sentence Representation: After passing through BERT, the embeddings corresponding to the [CLS] tokens serve as representations for their respective sentences.
- Sentence Scoring: A classifier (e.g., a linear layer) is applied to these sentence embeddings to assign a relevance score to each sentence, indicating its importance for the summary.
- Summary Generation: The top-k sentences with the highest scores are selected and ordered as they appear in the original document to form the extractive summary.
- This method leverages BERT's contextual understanding to identify and extract the most salient sentences from a document, providing a concise summary without generating new text.


Generated by ChatGPT and checked by me.

In [94]:
'''
Auxiliary functions for the fine tuning of the extractive summarization task.
'''
import json
import evaluate
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize, word_tokenize
import torch
import subprocess
import random
import os
import numpy as np
import itertools

# Most basic tokenization possible #
def tokenize_text_to_sentences(text):
    return sent_tokenize(text)

# Auxiliary functions for inference #
def prepare_sample(sample,
                   tokenizer,
                   max_src_ntokens=200, 
                   min_src_ntokens=5,
                   max_nsents=100,
                   max_length=512,
                   return_tensors=True):
    """
    Prepare sample to run inference.
    `sample` is of the form [sentence1, sentence2, ...]
    """
    inputs = {}

    # Prepare the right input for BERT
    src = tokenizer(
        sample,
        max_length=max_src_ntokens,
        truncation=True,
        stride=0,
        return_token_type_ids=False,
        return_attention_mask=False
    )
    
    # Ignore senteces that are too short
    # *Assumption*: if sentence is short it is not relevant
    idxs = [i for i, sentence in enumerate(src['input_ids']) if (len(sentence) > min_src_ntokens)]

    # Trim sentences to a maximum. Note they are already trimmed by the tokenizer
    src = [src['input_ids'][i] for i in idxs]
    sample = [sample[i] for i in idxs]
    src = src[:max_nsents]

    # Flatten into a single sequence (sents will be separated by [SEP] and [CLS] tokens already)
    src = list(itertools.chain(*src))
    if len(src) > max_length:
        src = src[:max_length-1] + [tokenizer.sep_token_id] # Truncate to 512 tokens

    # Intercalate 0s and 1s to differentiate between sentences
    _segs = [-1] + [i for i, t in enumerate(src) if t == tokenizer.sep_token_id]
    segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))]
    segment_ids = []
    for i, s in enumerate(segs):
        if (i % 2 == 0):
            segment_ids += s * [0]
        else:
            segment_ids += s * [1]

    # Get [CLS] positions, trim labels
    cls_ids = [i for i, t in enumerate(src) if t == tokenizer.cls_token_id]
    sample = sample[:len(cls_ids)]

    # Store data
    del _segs, segs, idxs
    inputs['input_ids'] = torch.tensor(src).unsqueeze(0) if return_tensors else src
    inputs['mask'] = torch.tensor([1] * len(src)).unsqueeze(0) if return_tensors else [1] * len(src)
    inputs['segment_ids'] = torch.tensor(segment_ids).unsqueeze(0) if return_tensors else segment_ids
    inputs['cls_ids'] = torch.tensor(cls_ids).unsqueeze(0) if return_tensors else cls_ids
    inputs['mask_cls'] = torch.tensor([1] * len(cls_ids)).unsqueeze(0) if return_tensors else [1] * len(cls_ids)
    inputs['sample'] = sample

    return inputs

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [78]:
from datasets import load_dataset

checkpoint = 'eReverter/bert-finetuned-cnn_dailymail'
dataset = 'eReverter/cnn_dailymail_extractive'

In [None]:
data_dict = load_dataset('eReverter/cnn_dailymail_extractive')
data_dict

In [None]:
'''
Simple BERT-based summarizer for extractive summarization.
'''
import torch
import torch.nn as nn
from transformers import BertModel, BertConfig, BertPreTrainedModel

class Classifier(nn.Module):
    """
    Simple classifier to predict the probability of each sentence to be included in the summary.
    """
    def __init__(self, hidden_size, **kwargs):
        super(Classifier, self).__init__(**kwargs)
        self.linear = nn.Linear(hidden_size, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, mask_cls):
        h = self.linear(x).squeeze(-1)
        sent_scores = self.sigmoid(h) * mask_cls.float()
        return sent_scores
    
class BertSummarizerConfig(BertConfig):
    """
    Configuration class to store the configuration of a `BertSummarizer`.
    Inherits from `BertConfig` and loads the BERT checkpoint.
    """
    def __init__(self, checkpoint=None, **kwargs):
        super(BertSummarizerConfig, self).__init__(**kwargs)
        self.checkpoint = checkpoint

class BertSummarizer(BertPreTrainedModel):
    """
    Architecture to fine tune BERT for extractive summarization.
    BERT is used to encode the sentences.
    Afterward, a simple linear layer is used to predict the probability of each sentence to be included in the summary.
    """
    config_class = BertSummarizerConfig
    base_model_prefix = 'bert'
    def __init__(self, config):
        super().__init__(config)

        self.bert = BertModel.from_pretrained(config.checkpoint) # Load pretrained bert
        self.encoder = Classifier(self.bert.config.hidden_size) # Add a linear layer on top of BERT for classification

        # Initialize encoder weights
        for p in self.encoder.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
            else:
                nn.init.uniform_(p, -1.0, 1.0)
                
    def forward(self, input_ids=None, segment_ids=None, cls_ids=None, mask=None, mask_cls=None, labels=None, src_ids=None, tgt_ids=None):
        """
        The last hidden state of the BERT is used to encode the sentences.
        The first token of each sentence is used as a representation of the sentence.
        The representation of each sentence is then used to predict the probability of each sentence to be included in the summary.
        """
        top_vec = self.bert(input_ids, token_type_ids=segment_ids, attention_mask=mask).last_hidden_state
        sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), cls_ids]
        sents_vec = sents_vec * mask_cls[:, :, None].float()
        sent_scores = self.encoder(sents_vec, mask_cls).squeeze(-1)
        return  {'logits': sent_scores,
                 'mask_cls': mask_cls}

In [None]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = BertSummarizer.from_pretrained(checkpoint)

In [88]:
sample = data_dict['test'][24]['src']
sample

["(CNN)Since Iran's Islamic Revolution in 1979, women have been barred from attending most sports events involving men.",
 'But the situation appears set to improve in the coming months after a top Iranian sports official said that the ban will be lifted for some events.',
 'A plan to allow "women and families" to enter sports stadiums will come into effect in the next year, Deputy Sports Minister Abdolhamid Ahmadi said Saturday, according to state-run media.',
 "But it isn't clear exactly which games women will be able to attend.",
 'According to the state-run Press TV, Ahmadi said the restrictions would be lifted for indoor sports events.',
 'The rules won\'t change for all matches because some sports are mainly related to men and "families are not interested in attending" them, Press TV cited him as saying.',
 "Iranian authorities imposed the ban on women attending men's sports events after the revolution, deeming that mixed crowds watching games together was un-Islamic.",
 "During 

In [89]:
model_inputs = prepare_sample(sample, tokenizer)
updated_sample = model_inputs.pop('sample')

In [90]:
outputs = model(**model_inputs)
outputs

{'logits': tensor([[9.9696e-01, 8.5784e-01, 9.9671e-01, 2.9444e-02, 9.4092e-01, 6.1778e-01,
          9.5534e-01, 3.4963e-03, 2.1245e-01, 4.4468e-02, 2.4277e-04, 1.8779e-04,
          6.4061e-04, 8.5546e-04, 1.6864e-04, 1.1845e-02, 1.8611e-04]],
        grad_fn=<SqueezeBackward1>),
 'mask_cls': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [91]:
len(outputs['logits'][0]) == len(updated_sample)

True

In [92]:
# Select top 3 sentences for the summary
summary = ' '.join([updated_sample[i] for i in outputs['logits'].topk(3).indices.detach().cpu().numpy()[0]])
summary

'(CNN)Since Iran\'s Islamic Revolution in 1979, women have been barred from attending most sports events involving men. A plan to allow "women and families" to enter sports stadiums will come into effect in the next year, Deputy Sports Minister Abdolhamid Ahmadi said Saturday, according to state-run media. Iranian authorities imposed the ban on women attending men\'s sports events after the revolution, deeming that mixed crowds watching games together was un-Islamic.'

In [93]:
wikipedia_text = """
Wine is an alcoholic drink typically made from fermented grapes. 
Yeast consumes the sugar in the grapes and converts it to ethanol and carbon dioxide, 
releasing heat in the process. 
Different varieties of grapes and strains of yeasts are major factors in different styles of wine. 
These differences result from the complex interactions between the biochemical development of the grape, 
the reactions involved in fermentation, the grape's growing environment (terroir), and the wine production process. 
Many countries enact legal appellations intended to define styles and qualities of wine. 
These typically restrict the geographical origin and permitted varieties of grapes, 
as well as other aspects of wine production. 
Wines can be made by fermentation of other fruit crops such as plum, cherry, pomegranate, blueberry, 
currant and elderberry.
"""

sample = tokenize_text_to_sentences(wikipedia_text)
model_inputs = prepare_sample(sample, tokenizer)
updated_sample = model_inputs.pop('sample')
outputs = model(**model_inputs)
summary = ' '.join([updated_sample[i] for i in np.sort(outputs['logits'].topk(3).indices.detach().cpu().numpy()[0])])
summary

'\nWine is an alcoholic drink typically made from fermented grapes. Yeast consumes the sugar in the grapes and converts it to ethanol and carbon dioxide, releasing heat in the process. Different varieties of grapes and strains of yeasts are major factors in different styles of wine.'

The results are very good, knowing I didn't train it