# Rouge-L score prediction via regression

The goal is to use spacy in order to identify word tags in sentences, and use the resulting parsing in order to find the sentence that best summarizes the text it is from.

In [1]:
import os

os.chdir("..")

In [2]:
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.linear_model import Lasso, Ridge, ElasticNet
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F

from src.metrics import single_rouge_score

In [3]:
from src.load_data import load_data

train_df, validation_df, test_df = load_data()
train_df.describe()

Unnamed: 0,text,titles
count,21401,21401
unique,21401,21401
top,Thierry Mariani sur la liste du Rassemblement ...,L'information n'a pas été confirmée par l'inté...
freq,1,1


In [4]:
# Count tokens appearing in relevant & irrelevant sentences in order to balance the scores of each tag
relevant_tag_count = 0
irrelevant_tag_count = 0

relevant_tag_counts: dict[str, int] = {}
irrelevant_tag_counts: dict[str, int] = {}


def text_to_sentences(text: str) -> list[str]:
    return [s.strip() for s in text.split(".")]

## SpaCy POS tagging

Using SpaCy to parse the text and identify the parts of speech in the text. The parts of speech are then used to identify the most important words in the text.

ISSUE: too long.

In [4]:
#!python -m spacy download fr_dep_news_trf

import spacy

tagger = spacy.load("fr_dep_news_trf")


def extract_tags(text: str, counter: dict[str, int]) -> int:
    """Add the found tags to the argument counter dictionary."""

    tags = tagger(text)
    for tag in tags:
        name = tag.pos_

        if name in counter:
            counter[name] += 1
        else:
            counter[name] = 1
    return len(tags)

For every text - target pair, we identify the sentence with the best Rouge-L score relative to the target, and we count for each token how much it appears in best sentences vs the other ones.

In [None]:
nrows = train_df.shape[0]

for _, (text, target) in tqdm(train_df.iterrows(), total=nrows):

    sentences = text_to_sentences(text)
    rouge_scores = [single_rouge_score(target, sentence)
                    for sentence in sentences]

    # Extract the index of the best sentence score
    best_sentence_index = rouge_scores.index(max(rouge_scores))

    # Count all tokens for all sentences. etc.
    for i, sentence in enumerate(sentences):
        if i == best_sentence_index:
            relevant_tag_count += extract_tags(sentence, relevant_tag_counts)
        else:
            irrelevant_tag_count += extract_tags(sentence,
                                                 irrelevant_tag_counts)

# Not viable: too long !

## Sentence & Paragraph Embeddings

We embed paragraphs and sentences using pretrained models. We then use a regressor from `scikit-learn` to predict the Rouge-L score of each sentence, and thus pick the best summarizing one using the max Rouge-L score.

In [5]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("all-MiniLM-L6-v2")

In [6]:
def load_or_compute(name: str, df: pd.DataFrame):

    embeddings_filename = f"data/{name}_embeddings.npy"
    scores_filename = f"data/{name}_scores.npy"

    if os.path.exists(embeddings_filename) and os.path.exists(scores_filename):
        return np.load(embeddings_filename), np.load(scores_filename)

    # Else: do and save
    scores: list[float] = []
    final_embeddings = []

    nrows = train_df.shape[0]

    for _, (text, target) in tqdm(train_df.iterrows(), total=nrows):

        # Extract sentences
        sentences = text_to_sentences(text)

        # Compute Rouge-L scores relative to the target
        rouge_scores = [single_rouge_score(
            target, sentence) for sentence in sentences]
        scores.extend(rouge_scores)

        # Compute embeddings
        sentence_embeddings = np.array(model.encode(sentences))
        paragraph_embeddings = np.array(model.encode(text))

        # Stack sentence embeddings with their respective paragraph embedding,
        # into the global sentence_embeddings list
        repeated_paragraph = np.tile(paragraph_embeddings, (len(sentences), 1))
        embeddings = np.concatenate(
            (sentence_embeddings, repeated_paragraph), axis=1)

        final_embeddings.extend(embeddings)

    np.save(embeddings_filename, final_embeddings)
    np.save(scores_filename, scores)
    return np.array(final_embeddings), np.array(scores)

In [7]:
train_embed, train_scores = load_or_compute("train", train_df)

In [8]:
valid_embed, valid_scores = load_or_compute("valid", validation_df)

In [9]:
# Visualize shapes
print(train_embed.shape, train_scores.shape)
print(valid_embed.shape, valid_scores.shape)

(386159, 768) (386159,)
(386159, 768) (386159,)


In [35]:
# Try different regression models
ridge_reg = Ridge(alpha=1.0)
lasso_reg = Lasso(alpha=1.0)
elastic_reg = ElasticNet(alpha=1.0, l1_ratio=0.5)

In [36]:
ridge_reg.fit(train_embed, train_scores)
print(ridge_reg.score(valid_embed, valid_scores))

0.3495076616838435


In [37]:
lasso_reg.fit(train_embed, train_scores)
print(lasso_reg.score(valid_embed, valid_scores))

-2.220446049250313e-16


In [38]:
elastic_reg.fit(train_embed, train_scores)
print(elastic_reg.score(valid_embed, valid_scores))

-2.220446049250313e-16


We can then test one of the regressors on the validation data in order to pick a sentence for each paragraph.

In [10]:
def pick_sentences(embeddings: np.ndarray, df: pd.DataFrame):
    """Given a dataset as DataFrame and precomputed embeddings, pick and test a sentence for each
    text to be summarized."""

    nrows = df.shape[0]
    best_sentences: list[str] = []

    # Embedding span pointers
    start = 0

    for _, (text, *_) in tqdm(df.iterrows(), total=nrows):

        # Extract sentences
        sentences = text_to_sentences(text)

        # Get embeddings for the current sentences
        sent_embeddings = embeddings[start: start + len(sentences)]

        # Predict the best sentence
        best_sentence_index = np.argmax(ridge_reg.predict(sent_embeddings))

        best_sentences.append(sentences[best_sentence_index])

        start += len(sentences)  # Move the pointer to the next span

    return best_sentences


def avg_score(sentences: list[str], targets: list[str]) -> float:
    scores = [
        single_rouge_score(target, sentence)
        for target, sentence in zip(targets, sentences)
    ]
    return float(np.mean(scores))

In [44]:
picked_sentences = pick_sentences(valid_embed, validation_df)
avg_score(picked_sentences, validation_df["titles"].tolist())

100%|██████████| 1500/1500 [00:00<00:00, 3855.43it/s]


0.11570915085148492

The best regressor gives an average Rouge-L score of 0.1157 on the validation set, which is not very good.

### Regression using a neural network

Instead, we will use a neural network with fully connected layers in order to predict the Rouge-L score of each sentence relative to their paragraph target.

In [11]:
class ScoreNN(nn.Module):
    """Rouge-L predictor"""

    def __init__(self, input_size: int, hidden_size: int) -> None:
        super().__init__()

        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 1)

    def forward(self, x: Tensor) -> Tensor:
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5) # Avoid overfitting
        x = self.fc2(x)
        return x

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ScoreNN(input_size=train_embed.shape[1], hidden_size=256).to(device)

In [14]:
# Prepare training
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Convert numpy arrays to torch tensors
train_embed_t = torch.from_numpy(train_embed).float().to(device)
train_scores_t = torch.from_numpy(train_scores).float().to(device)

valid_embed_t = torch.from_numpy(valid_embed).float().to(device)
valid_scores_t = torch.from_numpy(valid_scores).float().to(device)

# Training loop
epochs = 100
batch_size = 10_000

for epoch in range(epochs):

    # Shuffle the data
    indices = torch.randperm(train_embed_t.size(0))

    for i in tqdm(range(0, train_embed_t.size(0), batch_size)):
        batch_indices = indices[i: i + batch_size]

        # Forward pass
        outputs = model(train_embed_t[batch_indices]).squeeze()

        # Compute loss
        loss = criterion(outputs, train_scores_t[batch_indices])

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Validation loss
    # Issue: ça prend bcp trop de place à la fois ?
    # COmput
    # valid_outputs = model(valid_embed_t)
    # valid_loss = criterion(valid_outputs, valid_scores_t)
    # Compute validation loss by batch
    valid_loss = 0
    for i in range(0, valid_embed_t.size(0), batch_size):
        valid_outputs = model(valid_embed_t[i: i + batch_size]).squeeze()
        valid_loss += criterion(
            valid_outputs, valid_scores_t[i: i + batch_size]
        ).item()

    print(f"Epoch {epoch}, Loss: {loss.item()}, Validation loss: {valid_loss}")
    if epoch % 10 == 0:
        # Save the model
        torch.save(model.state_dict(), "data/score_nn.pth")

100%|██████████| 56/56 [00:00<00:00, 194.14it/s]


Epoch 0, Loss: 0.002407216001302004, Validation loss: 0.13831067248247564


100%|██████████| 56/56 [00:00<00:00, 247.87it/s]


Epoch 1, Loss: 0.0025812748353928328, Validation loss: 0.13097039377316833


100%|██████████| 56/56 [00:00<00:00, 224.67it/s]


Epoch 2, Loss: 0.0022947280667722225, Validation loss: 0.1306729952339083


100%|██████████| 56/56 [00:00<00:00, 277.95it/s]


Epoch 3, Loss: 0.0019884526263922453, Validation loss: 0.13090011361055076


100%|██████████| 56/56 [00:00<00:00, 286.93it/s]


Epoch 4, Loss: 0.0026592600625008345, Validation loss: 0.13148380140773952


100%|██████████| 56/56 [00:00<00:00, 278.61it/s]


Epoch 5, Loss: 0.0024047396145761013, Validation loss: 0.1304116863757372


100%|██████████| 56/56 [00:00<00:00, 289.16it/s]


Epoch 6, Loss: 0.0030209263786673546, Validation loss: 0.1307068788446486


100%|██████████| 56/56 [00:00<00:00, 287.38it/s]


Epoch 7, Loss: 0.002326800487935543, Validation loss: 0.12994045997038484


100%|██████████| 56/56 [00:00<00:00, 280.14it/s]


Epoch 8, Loss: 0.001987290335819125, Validation loss: 0.1303096527699381


100%|██████████| 56/56 [00:00<00:00, 285.53it/s]


Epoch 9, Loss: 0.0025172356981784105, Validation loss: 0.13043765700422227


100%|██████████| 56/56 [00:00<00:00, 290.49it/s]


Epoch 10, Loss: 0.002401245292276144, Validation loss: 0.12953398237004876


100%|██████████| 56/56 [00:00<00:00, 280.89it/s]


Epoch 11, Loss: 0.0025387152563780546, Validation loss: 0.12920495122671127


100%|██████████| 56/56 [00:00<00:00, 289.28it/s]


Epoch 12, Loss: 0.002641338622197509, Validation loss: 0.1301825800910592


100%|██████████| 56/56 [00:00<00:00, 294.59it/s]


Epoch 13, Loss: 0.0020651048980653286, Validation loss: 0.13041356671601534


100%|██████████| 56/56 [00:00<00:00, 276.19it/s]


Epoch 14, Loss: 0.002787448698654771, Validation loss: 0.13021531491540372


100%|██████████| 56/56 [00:00<00:00, 291.76it/s]


Epoch 15, Loss: 0.002320406725630164, Validation loss: 0.12905360385775566


100%|██████████| 56/56 [00:00<00:00, 286.15it/s]


Epoch 16, Loss: 0.0025643613189458847, Validation loss: 0.12978668930009007


100%|██████████| 56/56 [00:00<00:00, 282.55it/s]


Epoch 17, Loss: 0.002540426794439554, Validation loss: 0.14037838159129024


100%|██████████| 56/56 [00:00<00:00, 279.67it/s]


Epoch 18, Loss: 0.0023353341966867447, Validation loss: 0.12782872538082302


100%|██████████| 56/56 [00:00<00:00, 224.36it/s]


Epoch 19, Loss: 0.0024974115658551455, Validation loss: 0.1299252058379352


100%|██████████| 56/56 [00:00<00:00, 282.59it/s]


Epoch 20, Loss: 0.002946168417111039, Validation loss: 0.12867978354915977


100%|██████████| 56/56 [00:00<00:00, 282.01it/s]


Epoch 21, Loss: 0.0023617560509592295, Validation loss: 0.12991167302243412


100%|██████████| 56/56 [00:00<00:00, 281.91it/s]


Epoch 22, Loss: 0.002249847399070859, Validation loss: 0.12724073510617018


100%|██████████| 56/56 [00:00<00:00, 285.72it/s]


Epoch 23, Loss: 0.002895925659686327, Validation loss: 0.12647008849307895


100%|██████████| 56/56 [00:00<00:00, 288.94it/s]


Epoch 24, Loss: 0.0023072545882314444, Validation loss: 0.12670217431150377


100%|██████████| 56/56 [00:00<00:00, 281.79it/s]


Epoch 25, Loss: 0.0026137668173760176, Validation loss: 0.12702876282855868


100%|██████████| 56/56 [00:00<00:00, 283.57it/s]


Epoch 26, Loss: 0.00310868164524436, Validation loss: 0.12721495470032096


100%|██████████| 56/56 [00:00<00:00, 285.03it/s]


Epoch 27, Loss: 0.0023413265589624643, Validation loss: 0.1264455569908023


100%|██████████| 56/56 [00:00<00:00, 280.96it/s]


Epoch 28, Loss: 0.002401032717898488, Validation loss: 0.12720447522588074


100%|██████████| 56/56 [00:00<00:00, 285.79it/s]


Epoch 29, Loss: 0.002377642784267664, Validation loss: 0.12715851934626698


100%|██████████| 56/56 [00:00<00:00, 283.83it/s]


Epoch 30, Loss: 0.002345010871067643, Validation loss: 0.1290011906530708


100%|██████████| 56/56 [00:00<00:00, 285.97it/s]


Epoch 31, Loss: 0.0027189995162189007, Validation loss: 0.12626598915085196


100%|██████████| 56/56 [00:00<00:00, 287.87it/s]


Epoch 32, Loss: 0.002662560436874628, Validation loss: 0.12450843979604542


100%|██████████| 56/56 [00:00<00:00, 245.31it/s]


Epoch 33, Loss: 0.0024279479403048754, Validation loss: 0.12450976110994816


100%|██████████| 56/56 [00:00<00:00, 275.37it/s]


Epoch 34, Loss: 0.0024849760811775923, Validation loss: 0.1266826873179525


100%|██████████| 56/56 [00:00<00:00, 277.22it/s]


Epoch 35, Loss: 0.0025734035298228264, Validation loss: 0.1236596463713795


100%|██████████| 56/56 [00:00<00:00, 279.50it/s]


Epoch 36, Loss: 0.0023942471016198397, Validation loss: 0.12574064754880965


100%|██████████| 56/56 [00:00<00:00, 284.25it/s]


Epoch 37, Loss: 0.0030283539090305567, Validation loss: 0.12255191872827709


100%|██████████| 56/56 [00:00<00:00, 289.93it/s]


Epoch 38, Loss: 0.002540476620197296, Validation loss: 0.12241687357891351


100%|██████████| 56/56 [00:00<00:00, 280.40it/s]


Epoch 39, Loss: 0.0021379089448601007, Validation loss: 0.12453726679086685


100%|██████████| 56/56 [00:00<00:00, 281.18it/s]


Epoch 40, Loss: 0.0021852210629731417, Validation loss: 0.12983582727611065


100%|██████████| 56/56 [00:00<00:00, 286.65it/s]


Epoch 41, Loss: 0.002673860639333725, Validation loss: 0.12292178184725344


100%|██████████| 56/56 [00:00<00:00, 280.79it/s]


Epoch 42, Loss: 0.0024802573025226593, Validation loss: 0.12654209695756435


100%|██████████| 56/56 [00:00<00:00, 277.12it/s]


Epoch 43, Loss: 0.002456067129969597, Validation loss: 0.12133944011293352


100%|██████████| 56/56 [00:00<00:00, 282.19it/s]


Epoch 44, Loss: 0.0023634659592062235, Validation loss: 0.12176828400697559


100%|██████████| 56/56 [00:00<00:00, 283.43it/s]


Epoch 45, Loss: 0.00230425875633955, Validation loss: 0.1206097254762426


100%|██████████| 56/56 [00:00<00:00, 261.65it/s]


Epoch 46, Loss: 0.002457602182403207, Validation loss: 0.12172748299781233


100%|██████████| 56/56 [00:00<00:00, 277.09it/s]


Epoch 47, Loss: 0.0026225238107144833, Validation loss: 0.12085074139758945


100%|██████████| 56/56 [00:00<00:00, 286.71it/s]


Epoch 48, Loss: 0.002500637201592326, Validation loss: 0.12276882235892117


100%|██████████| 56/56 [00:00<00:00, 279.88it/s]


Epoch 49, Loss: 0.002102787373587489, Validation loss: 0.12285074684768915


100%|██████████| 56/56 [00:00<00:00, 286.21it/s]


Epoch 50, Loss: 0.002120097167789936, Validation loss: 0.12453441321849823


100%|██████████| 56/56 [00:00<00:00, 285.50it/s]


Epoch 51, Loss: 0.001976009691134095, Validation loss: 0.11932769860140979


100%|██████████| 56/56 [00:00<00:00, 286.79it/s]


Epoch 52, Loss: 0.0017002035165205598, Validation loss: 0.12550639547407627


100%|██████████| 56/56 [00:00<00:00, 254.85it/s]


Epoch 53, Loss: 0.002700225682929158, Validation loss: 0.12655793502926826


100%|██████████| 56/56 [00:00<00:00, 285.44it/s]


Epoch 54, Loss: 0.0018619224429130554, Validation loss: 0.1330595479812473


100%|██████████| 56/56 [00:00<00:00, 284.04it/s]


Epoch 55, Loss: 0.0021075033582746983, Validation loss: 0.11914737056940794


100%|██████████| 56/56 [00:00<00:00, 281.33it/s]


Epoch 56, Loss: 0.0023128490429371595, Validation loss: 0.11838676873594522


100%|██████████| 56/56 [00:00<00:00, 287.34it/s]


Epoch 57, Loss: 0.0024131140671670437, Validation loss: 0.11832827643956989


100%|██████████| 56/56 [00:00<00:00, 288.90it/s]


Epoch 58, Loss: 0.0025337168481200933, Validation loss: 0.12138060654979199


100%|██████████| 56/56 [00:00<00:00, 268.00it/s]


Epoch 59, Loss: 0.002414393937215209, Validation loss: 0.11920840945094824


100%|██████████| 56/56 [00:00<00:00, 235.67it/s]


Epoch 60, Loss: 0.0024007363244891167, Validation loss: 0.1196624698350206


100%|██████████| 56/56 [00:00<00:00, 285.19it/s]


Epoch 61, Loss: 0.0021133648697286844, Validation loss: 0.12057050585281104


100%|██████████| 56/56 [00:00<00:00, 276.83it/s]


Epoch 62, Loss: 0.0023718690499663353, Validation loss: 0.11667096940800548


100%|██████████| 56/56 [00:00<00:00, 279.28it/s]


Epoch 63, Loss: 0.002548248041421175, Validation loss: 0.11696140561252832


100%|██████████| 56/56 [00:00<00:00, 282.60it/s]


Epoch 64, Loss: 0.002100640209391713, Validation loss: 0.11845160799566656


100%|██████████| 56/56 [00:00<00:00, 283.42it/s]


Epoch 65, Loss: 0.0026340060867369175, Validation loss: 0.1239576784428209


100%|██████████| 56/56 [00:00<00:00, 258.23it/s]


Epoch 66, Loss: 0.0021010716445744038, Validation loss: 0.11596759257372469


100%|██████████| 56/56 [00:00<00:00, 280.66it/s]


Epoch 67, Loss: 0.0019950559362769127, Validation loss: 0.11875514674466103


100%|██████████| 56/56 [00:00<00:00, 283.79it/s]


Epoch 68, Loss: 0.002067903522402048, Validation loss: 0.11926870851311833


100%|██████████| 56/56 [00:00<00:00, 283.02it/s]


Epoch 69, Loss: 0.002151913708075881, Validation loss: 0.12083277641795576


100%|██████████| 56/56 [00:00<00:00, 279.64it/s]


Epoch 70, Loss: 0.002434749621897936, Validation loss: 0.11640919791534543


100%|██████████| 56/56 [00:00<00:00, 288.17it/s]


Epoch 71, Loss: 0.0025099138729274273, Validation loss: 0.11768730042967945


100%|██████████| 56/56 [00:00<00:00, 285.17it/s]


Epoch 72, Loss: 0.0018090546363964677, Validation loss: 0.11512696195859462


100%|██████████| 56/56 [00:00<00:00, 248.03it/s]


Epoch 73, Loss: 0.002254330553114414, Validation loss: 0.11452804424334317


100%|██████████| 56/56 [00:00<00:00, 284.86it/s]


Epoch 74, Loss: 0.002530826022848487, Validation loss: 0.11831432278268039


100%|██████████| 56/56 [00:00<00:00, 283.17it/s]


Epoch 75, Loss: 0.002265904564410448, Validation loss: 0.11535267811268568


100%|██████████| 56/56 [00:00<00:00, 278.78it/s]


Epoch 76, Loss: 0.0018601827323436737, Validation loss: 0.1155896025011316


100%|██████████| 56/56 [00:00<00:00, 284.93it/s]


Epoch 77, Loss: 0.0018274378962814808, Validation loss: 0.11848691187333316


100%|██████████| 56/56 [00:00<00:00, 265.77it/s]


Epoch 78, Loss: 0.0022748522460460663, Validation loss: 0.11401301226578653


100%|██████████| 56/56 [00:00<00:00, 281.25it/s]


Epoch 79, Loss: 0.002186279045417905, Validation loss: 0.11366287677083164


100%|██████████| 56/56 [00:00<00:00, 264.55it/s]


Epoch 80, Loss: 0.0020795082673430443, Validation loss: 0.11459545651450753


100%|██████████| 56/56 [00:00<00:00, 288.93it/s]


Epoch 81, Loss: 0.0021765129640698433, Validation loss: 0.11453242332208902


100%|██████████| 56/56 [00:00<00:00, 283.38it/s]


Epoch 82, Loss: 0.002153156092390418, Validation loss: 0.11815170536283404


100%|██████████| 56/56 [00:00<00:00, 279.61it/s]


Epoch 83, Loss: 0.002115547889843583, Validation loss: 0.11347096448298544


100%|██████████| 56/56 [00:00<00:00, 285.18it/s]


Epoch 84, Loss: 0.0020220945589244366, Validation loss: 0.1134787619812414


100%|██████████| 56/56 [00:00<00:00, 281.38it/s]


Epoch 85, Loss: 0.0018593784188851714, Validation loss: 0.11336413933895528


100%|██████████| 56/56 [00:00<00:00, 273.76it/s]


Epoch 86, Loss: 0.0025816280394792557, Validation loss: 0.11755299591459334


100%|██████████| 56/56 [00:00<00:00, 287.84it/s]


Epoch 87, Loss: 0.0020760302431881428, Validation loss: 0.1129884417168796


100%|██████████| 56/56 [00:00<00:00, 286.96it/s]


Epoch 88, Loss: 0.0019816041458398104, Validation loss: 0.11529214505571872


100%|██████████| 56/56 [00:00<00:00, 281.70it/s]


Epoch 89, Loss: 0.0020434739999473095, Validation loss: 0.11638345313258469


100%|██████████| 56/56 [00:00<00:00, 287.70it/s]


Epoch 90, Loss: 0.0019075676100328565, Validation loss: 0.11601393122691661


100%|██████████| 56/56 [00:00<00:00, 290.83it/s]


Epoch 91, Loss: 0.0020694537088274956, Validation loss: 0.11226545448880643


100%|██████████| 56/56 [00:00<00:00, 266.96it/s]


Epoch 92, Loss: 0.0025587202981114388, Validation loss: 0.11198939883615822


100%|██████████| 56/56 [00:00<00:00, 253.93it/s]


Epoch 93, Loss: 0.0022761973086744547, Validation loss: 0.12202195194549859


100%|██████████| 56/56 [00:00<00:00, 283.11it/s]


Epoch 94, Loss: 0.0022714214865118265, Validation loss: 0.11129449913278222


100%|██████████| 56/56 [00:00<00:00, 278.15it/s]


Epoch 95, Loss: 0.0021404719445854425, Validation loss: 0.11145158379804343


100%|██████████| 56/56 [00:00<00:00, 274.05it/s]


Epoch 96, Loss: 0.0020843802485615015, Validation loss: 0.1171906212111935


100%|██████████| 56/56 [00:00<00:00, 279.13it/s]


Epoch 97, Loss: 0.002122072735801339, Validation loss: 0.115965063450858


100%|██████████| 56/56 [00:00<00:00, 260.30it/s]


Epoch 98, Loss: 0.0022745865862816572, Validation loss: 0.1126572418725118


100%|██████████| 56/56 [00:00<00:00, 229.29it/s]


Epoch 99, Loss: 0.00217641843482852, Validation loss: 0.1152457541320473


We can then use the same strategy as before to pick the best summarizing sentence.

In [17]:

nrows = validation_df.shape[0]
best_sentences: list[str] = []

# Embedding span pointers
start = 0

with torch.no_grad():
    for _, (text, *_) in tqdm(validation_df.iterrows(), total=nrows):

        # Extract sentences
        sentences = text_to_sentences(text)

        # Get embeddings for the current sentences
        sent_embeddings = valid_embed_t[start: start + len(sentences)]

        # Predict the best sentence
        best_sentence_index = np.argmax(
            model(sent_embeddings).squeeze().cpu().numpy())

        best_sentences.append(sentences[best_sentence_index])

        start += len(sentences)  # Move the pointer to the next span

print(len(best_sentences))

100%|██████████| 1500/1500 [00:00<00:00, 2050.24it/s]

1500





In [18]:
avg_score(best_sentences, validation_df["titles"].tolist())

0.12151541233359918

Yet again, the score is not that great. It could be improved by using different embeddings and a larger model.