Inspired by
 - The Annotated Transformer
 - Ajay Halthor's Youtube videos on    Transformers

In [None]:
!pip install datasets
!pip install -U huggingface_hub

In [None]:
from torchtext.vocab import build_vocab_from_iterator
import spacy
import os.path
import torch
import datasets as ds

START_TOKEN = "<SOS>"
END_TOKEN = "<EOS>"
PADDING_TOKEN = "<PAD>"
UNKNOWN_TOKEN = "<UNK>"
MAX_VOCAB_SIZE = 20000


def load_tokenizers():
    try:
        print("Loading spacy...")
        spacy_en = spacy.load("en_core_web_sm")
    except IOError:
        print("Downloading spacy...")
        # os.system("python -m spacy download en_core_web_sm")
        spacy.cli.download("en_core_web_sm")
        spacy_en = spacy.load("en_core_web_sm")

    return spacy_en

def tokenize(text, nlp):
    return [token.text for token in nlp.tokenizer(text)]

def yield_tokens(data_iterator, nlp, field):
    for data in data_iterator:
        yield tokenize(data[field], nlp)

def build_vocabulary(dataset, spacy_en):

    print("Building Article Vocabulary ...")
    article_vocab = build_vocab_from_iterator(
        yield_tokens(dataset['train'], spacy_en, 'article'),
        min_freq=2,
        specials=[START_TOKEN, END_TOKEN, PADDING_TOKEN, UNKNOWN_TOKEN],
        max_tokens=MAX_VOCAB_SIZE
    )


    print("Building Highlight Vocabulary ...")
    highlight_vocab = build_vocab_from_iterator(
        yield_tokens(dataset['train'], spacy_en, 'highlights'),
        min_freq=2,
        specials=[START_TOKEN, END_TOKEN, PADDING_TOKEN, UNKNOWN_TOKEN],
        max_tokens=MAX_VOCAB_SIZE
    )

    article_vocab.set_default_index(article_vocab[UNKNOWN_TOKEN])
    highlight_vocab.set_default_index(highlight_vocab[UNKNOWN_TOKEN])

    return article_vocab, highlight_vocab


def load_vocab(spacy_en, dataset):
    if not os.path.exists("vocab.pt"):
        print("Vocabs Not Found")
        article_vocab, highlight_vocab = build_vocabulary(dataset, spacy_en)
        torch.save((article_vocab, highlight_vocab), "vocab.pt")
    else:
        print("Loading Vocabs...")
        article_vocab, highlight_vocab = torch.load("vocab.pt")
    print("Finished.\nVocabulary sizes:")
    print(f"Article Vocab Len: {len(article_vocab)}")
    print(f"Highlight Vocab Len: {len(highlight_vocab)}")
    return article_vocab, highlight_vocab

# print("Initialization...")
# spacy_en = load_tokenizers()
# load_vocab(spacy_en)

In [None]:
import numpy as np
import torch
import math
from torch import nn
import torch.nn.functional as F

def get_device():
    return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

def scaled_dot_product(query, key, value, mask=None):
    d_k = query.size(-1) # 64
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # 30 x 8 x 200 x 200
    if mask is not None:
        scores = scores.permute(1, 0, 2, 3) + mask # 8 x 30 x 200 x 200
        scores = scores.permute(1, 0, 2, 3) # 30 x 8 x 200 x 200 (via broadcasting)
    return torch.matmul(scores.softmax(dim=-1), value) # 30 x 8 x 200 x 64

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_sequence_length):
        super().__init__()
        self.max_sequence_length = max_sequence_length
        self.d_model = d_model

    def forward(self):
        even_i = torch.arange(0, self.d_model, 2).float()
        denominator = torch.pow(10000, even_i/self.d_model)
        position = (torch.arange(self.max_sequence_length)
                          .reshape(self.max_sequence_length, 1))
        evens = torch.sin(position / denominator)
        odds = torch.cos(position / denominator)
        stacked = torch.stack([evens, odds], dim=2)
        pos_enc = torch.flatten(stacked, start_dim=1, end_dim=2)
        return pos_enc

class SentenceEmbedding(nn.Module):
    def __init__(self, max_sequence_length, d_model, language_to_index, tokenizer, START_TOKEN, END_TOKEN, PADDING_TOKEN):
        super().__init__()
        self.vocab_size = len(language_to_index)
        self.max_sequence_length = max_sequence_length
        self.embedding = nn.Embedding(self.vocab_size, d_model)
        self.language_to_index = language_to_index
        self.tokenizer = tokenizer
        self.position_encoder = PositionalEncoding(d_model, max_sequence_length)
        self.dropout = nn.Dropout(p=0.1)
        self.START_TOKEN = START_TOKEN
        self.END_TOKEN = END_TOKEN
        self.PADDING_TOKEN = PADDING_TOKEN

    def batch_tokenize(self, batch, start_token, end_token):

        def convert_tokens_to_index(sentence, start_token, end_token):
            # sentence_word_indicies = [self.language_to_index[token] for token in list(sentence)]
            tokens = tokenize(sentence, self.tokenizer)
            max_tokens = self.max_sequence_length
            if start_token:
                max_tokens -= 1
            if end_token:
                max_tokens -= 1
            if len(tokens) > max_tokens: # Truncation
                tokens = tokens[:max_tokens]

            sentence_word_indicies = [self.language_to_index[tok] for tok in tokens]
            if start_token:
                sentence_word_indicies.insert(0, self.language_to_index[self.START_TOKEN])
            if end_token:
                sentence_word_indicies.append(self.language_to_index[self.END_TOKEN])
            while len(sentence_word_indicies) < self.max_sequence_length:
                sentence_word_indicies.append(self.language_to_index[self.PADDING_TOKEN])
            return torch.tensor(sentence_word_indicies)

        indexed_tokens = []
        for sentence_num in range(len(batch)):
           indexed_tokens.append(convert_tokens_to_index(batch[sentence_num], start_token, end_token))
        indexed_tokens = torch.stack(indexed_tokens)
        # print(indexed_tokens)
        return indexed_tokens.to(get_device())

    def forward(self, x, start_token, end_token): # sentence
        x = self.batch_tokenize(x, start_token, end_token)
        x = self.embedding(x)
        pos = self.position_encoder().to(get_device())
        x = self.dropout(x + pos)
        return x


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        # self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads # 512 / 8 = 64
        self.w_query = nn.Linear(d_model, d_model) # 512 x 512
        self.w_key = nn.Linear(d_model, d_model) # 512 x 512
        self.w_value = nn.Linear(d_model, d_model) # 512 x 512
        self.linear_layer = nn.Linear(d_model, d_model) # 512 x 512

    def forward(self, query, key, value, mask = None): # 30 x 200 x 512
        batch_size, max_sequence_length, d_model = query.size()

        # Split q, k, and v between attention heads
        query = self.split_weights(self.w_query(query)) # 30 x 8 x 200 x 64
        key = self.split_weights(self.w_key(key)) # 30 x 8 x 200 x 64
        value = self.split_weights(self.w_value(value)) # 30 x 8 x 200 x 64

        values = scaled_dot_product(query, key, value, mask) # 30 x 8 x 200 x 64

        # Concat heads
        values = values.permute(0, 2, 1, 3) # 30 x 200 x 8 x 64
        values = values.reshape(batch_size, max_sequence_length, self.num_heads * self.d_k) # 30 x 200 x 512

        return self.linear_layer(values) # 30 x 200 x 512

    def split_weights(self, x):
        batch_size, max_sequence_length, d_model = x.size() # 30 x 200 x 512
        x = x.reshape(batch_size, max_sequence_length, self.num_heads, self.d_k) # 30 x 200 x 8 x 64
        return x.permute(0, 2, 1, 3) # 30 x 8 x 200 x 64


class LayerNormalization(nn.Module):
    def __init__(self, parameters_shape, eps=1e-5):
        super().__init__()
        self.parameters_shape=parameters_shape
        self.eps=eps
        # Learnable parameters
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta = nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, inputs):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        y = (inputs - mean) / std
        out = self.gamma * y + self.beta
        return out


class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, hidden, drop_prob=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, hidden)
        self.linear2 = nn.Linear(hidden, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=drop_prob)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x


class EncoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.norm1 = LayerNormalization(parameters_shape=[d_model])
        self.dropout1 = nn.Dropout(p=drop_prob)
        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.norm2 = LayerNormalization(parameters_shape=[d_model])
        self.dropout2 = nn.Dropout(p=drop_prob)

    def forward(self, x, self_attention_mask):
        residual = x.clone()
        x = self.attention(query=x, key=x, value=x, mask=self_attention_mask)
        x = self.dropout1(x)
        x = self.norm1(x + residual)
        residual = x.clone()
        x = self.ffn(x)
        x = self.dropout2(x)
        x = self.norm2(x + residual)
        return x

class SequentialEncoder(nn.Sequential):
    def forward(self, *inputs):
        x, self_attention_mask  = inputs
        for module in self._modules.values():
            x = module(x, self_attention_mask)
        return x

class Encoder(nn.Module):
    def __init__(self,
                 d_model,
                 ffn_hidden,
                 num_heads,
                 drop_prob,
                 num_layers,
                 max_sequence_length,
                 language_to_index,
                 tokenizer,
                 START_TOKEN,
                 END_TOKEN,
                 PADDING_TOKEN):
        super().__init__()
        self.sentence_embedding = SentenceEmbedding(max_sequence_length, d_model, language_to_index, tokenizer, START_TOKEN, END_TOKEN, PADDING_TOKEN)
        self.layers = SequentialEncoder(*[EncoderLayer(d_model, ffn_hidden, num_heads, drop_prob)
                                      for _ in range(num_layers)])

    def forward(self, x, self_attention_mask, start_token, end_token):
        x = self.sentence_embedding(x, start_token, end_token)
        x = self.layers(x, self_attention_mask)
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.layer_norm1 = LayerNormalization(parameters_shape=[d_model])
        self.dropout1 = nn.Dropout(p=drop_prob)

        self.cross_attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.layer_norm2 = LayerNormalization(parameters_shape=[d_model])
        self.dropout2 = nn.Dropout(p=drop_prob)

        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.layer_norm3 = LayerNormalization(parameters_shape=[d_model])
        self.dropout3 = nn.Dropout(p=drop_prob)

    def forward(self, x, y, self_attention_mask, cross_attention_mask):
        residual = y.clone()
        y = self.self_attention(query=y, key=y, value=y, mask=self_attention_mask)
        y = self.dropout1(y)
        y = self.layer_norm1(y + residual)

        residual = y.clone()
        y = self.cross_attention(query=y, key=x, value=x, mask=cross_attention_mask)
        y = self.dropout2(y)
        y = self.layer_norm2(y + residual)

        residual = y.clone()
        y = self.ffn(y)
        y = self.dropout3(y)
        y = self.layer_norm3(y + residual)
        return y

class SequentialDecoder(nn.Sequential):
    def forward(self, *inputs):
        x, y, self_attention_mask, cross_attention_mask = inputs
        for module in self._modules.values():
            y = module(x, y, self_attention_mask, cross_attention_mask)
        return y

class Decoder(nn.Module):
    def __init__(self,
                 d_model,
                 ffn_hidden,
                 num_heads,
                 drop_prob,
                 num_layers,
                 max_sequence_length,
                 language_to_index,
                 tokenizer,
                 START_TOKEN,
                 END_TOKEN,
                 PADDING_TOKEN):
        super().__init__()
        self.sentence_embedding = SentenceEmbedding(max_sequence_length, d_model, language_to_index, tokenizer, START_TOKEN, END_TOKEN, PADDING_TOKEN)
        self.layers = SequentialDecoder(*[DecoderLayer(d_model, ffn_hidden, num_heads, drop_prob) for _ in range(num_layers)])

    def forward(self, x, y, self_attention_mask, cross_attention_mask, start_token, end_token):
        y = self.sentence_embedding(y, start_token, end_token)
        y = self.layers(x, y, self_attention_mask, cross_attention_mask)
        return y


class Transformer(nn.Module):
    def __init__(self,
                d_model,
                ffn_hidden,
                num_heads,
                drop_prob,
                num_layers,
                max_sequence_length,
                summary_vocab_size,
                article_to_index,
                summary_to_index,
                tokenizer,
                START_TOKEN,
                END_TOKEN,
                PADDING_TOKEN
                ):
        super().__init__()
        self.encoder = Encoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers, max_sequence_length, article_to_index, tokenizer, START_TOKEN, END_TOKEN, PADDING_TOKEN)
        self.decoder = Decoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers, max_sequence_length, summary_to_index, tokenizer, START_TOKEN, END_TOKEN, PADDING_TOKEN)
        self.linear = nn.Linear(d_model, summary_vocab_size)
        self.device = get_device()

    def forward(self,
                x,
                y,
                encoder_self_attention_mask=None,
                decoder_self_attention_mask=None,
                decoder_cross_attention_mask=None,
                enc_start_token=False,
                enc_end_token=False,
                dec_start_token=True,
                dec_end_token=False):
        x = self.encoder(x, encoder_self_attention_mask, start_token=enc_start_token, end_token=enc_end_token)
        out = self.decoder(x, y, decoder_self_attention_mask, decoder_cross_attention_mask, start_token=dec_start_token, end_token=dec_end_token)
        # print("Encoding and decoding complete")
        out = self.linear(out)
        # print("Vocab FFN complete")
        return out

In [None]:
from torch import nn
import datasets as ds
import numpy as np
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import torch
import gc
import numpy as np

START_TOKEN = "<SOS>"
END_TOKEN = "<EOS>"
PADDING_TOKEN = "<PAD>"
UNKNOWN_TOKEN = "<UNK>"

num_epochs = 10

d_model = 512
batch_size = 10
ffn_hidden = 1024
num_heads = 2
drop_prob = 0.1
num_layers = 1
max_sequence_length = 2000

TRAINING_DATASET_SIZE = 20000

full_dataset = ds.load_dataset('cnn_dailymail', '3.0.0')
small_training_set = full_dataset['train'].select(range(TRAINING_DATASET_SIZE))
# small_training_set = full_dataset['train']
dataset = {}
dataset['train'] = small_training_set


spacy_en = load_tokenizers()
article_vocab, summary_vocab = load_vocab(spacy_en, full_dataset)
summary_vocab_size = len(summary_vocab)

PERCENTILE = 97
article_lengths = [len(tokenize(article['article'], spacy_en)) for article in dataset['train']]
percentile = np.percentile(article_lengths, PERCENTILE)
print( f"{PERCENTILE}th percentile length: {percentile}" )

Downloading data:   0%|          | 0.00/313M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/304M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/155M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/34.7M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/30.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

Loading spacy...
Loading Vocabs...
Finished.
Vocabulary sizes:
Article Vocab Len: 20000
Highlight Vocab Len: 20000
97th percentile length: 1562.0


In [None]:
def create_masks(article_batch, summary_batch):
    num_sentences = len(article_batch)
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True)
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
    encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)

    for idx in range(num_sentences):
      article_sentence_length, summary_sentence_length = len(article_batch[idx]), len(summary_batch[idx])
      article_chars_to_padding_mask = np.arange(article_sentence_length + 1, max_sequence_length) # Account for end token
      summary_chars_to_padding_mask = np.arange(summary_sentence_length + 1, max_sequence_length) # Account for end token
      encoder_padding_mask[idx, :, article_chars_to_padding_mask] = True
      encoder_padding_mask[idx, article_chars_to_padding_mask, :] = True
      decoder_padding_mask_self_attention[idx, :, summary_chars_to_padding_mask] = True
      decoder_padding_mask_self_attention[idx, summary_chars_to_padding_mask, :] = True
      decoder_padding_mask_cross_attention[idx, :, article_chars_to_padding_mask] = True
      decoder_padding_mask_cross_attention[idx, summary_chars_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)
    decoder_self_attention_mask =  torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask

class CNNDMDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]['article'], self.dataset[idx]['highlights']

In [None]:
print("Initialization...")
transformer = Transformer(d_model,
                          ffn_hidden,
                          num_heads,
                          drop_prob,
                          num_layers,
                          max_sequence_length,
                          summary_vocab_size,
                          article_vocab,
                          summary_vocab,
                          spacy_en,
                          START_TOKEN,
                          END_TOKEN,
                          PADDING_TOKEN)
transformer.train()
transformer.to(get_device())
batcheable_dataset = CNNDMDataset(dataset['train'])
train_loader = DataLoader(batcheable_dataset, batch_size)
iterator = iter(train_loader)
criterian = nn.CrossEntropyLoss(ignore_index=summary_vocab[PADDING_TOKEN], reduction='none')
for params in transformer.parameters():
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)
optim = torch.optim.Adam(transformer.parameters(), lr=1e-4)

print("Beginning training...")
for epoch in range(num_epochs):
    # print(f"Epoch: {epoch}")
    iterator = iter(train_loader)
    for batch_num, batch in enumerate(iterator):
        print(f"Epoch: {epoch + 1}, Batch Num: {batch_num + 1} / {TRAINING_DATASET_SIZE / batch_size}")
        transformer.train()
        article_batch, summary_batch = batch
        # print("Creating masks")
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(article_batch, summary_batch)
        optim.zero_grad()
        # print("Forward pass")
        summary_predictions = transformer(article_batch,
                                        summary_batch,
                                        encoder_self_attention_mask.to(get_device()),
                                        decoder_self_attention_mask.to(get_device()),
                                        decoder_cross_attention_mask.to(get_device()),
                                        enc_start_token=False,
                                        enc_end_token=False,
                                        dec_start_token=True,
                                        dec_end_token=True)
        # print("Get labels")
        labels = transformer.decoder.sentence_embedding.batch_tokenize(summary_batch, start_token=False, end_token=True)
        # print("Calculate loss")
        loss = criterian(
            summary_predictions.view(-1, summary_vocab_size).to(get_device()),
            labels.view(-1).to(get_device())
        ).to(get_device())
        valid_indicies = torch.where(labels.view(-1) == summary_vocab[PADDING_TOKEN], False, True)
        loss = loss.sum() / valid_indicies.sum()
        print(f"Loss: {loss}")
        # print("Backwards pass")
        loss.backward()
        optim.step()
    gc.collect()
print("Training complete!")
torch.save(transformer.state_dict(), "transformer_model.pt")
print("Model successfully saved")

Initialization...
Beginning training...
Epoch: 1, Batch Num: 1 / 2000.0
Loss: 9.906740188598633
Epoch: 1, Batch Num: 2 / 2000.0
Loss: 9.68807315826416
Epoch: 1, Batch Num: 3 / 2000.0
Loss: 9.504085540771484
Epoch: 1, Batch Num: 4 / 2000.0


KeyboardInterrupt: 

In [None]:
# Loading model
transformer = Transformer(d_model,
                          ffn_hidden,
                          num_heads,
                          drop_prob,
                          num_layers,
                          max_sequence_length,
                          summary_vocab_size,
                          article_vocab,
                          summary_vocab,
                          spacy_en,
                          START_TOKEN,
                          END_TOKEN,
                          PADDING_TOKEN)
transformer.load_state_dict(torch.load("transformer_model.pt"))
transformer.eval()
transformer.to(get_device())

index_to_summary = summary_vocab.get_itos()

# Inference
def summarize(article):
  with torch.no_grad():
    article = (article,)
    summary = (START_TOKEN,)

    for word_counter in range(max_sequence_length):
      encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(article, summary)
      predictions = transformer(article,
                                summary,
                                encoder_self_attention_mask.to(get_device()),
                                decoder_self_attention_mask.to(get_device()),
                                decoder_cross_attention_mask.to(get_device()),
                                enc_start_token=False,
                                enc_end_token=False,
                                dec_start_token=True,
                                dec_end_token=False)
      next_token_prob_distribution = predictions[0][word_counter]
      # print(next_token_prob_distribution)
      next_token_index = torch.argmax(next_token_prob_distribution).item()
      # print(next_token_index)
      next_token = index_to_summary[next_token_index]
      summary = (summary[0] + next_token, )
      if next_token == END_TOKEN:
        break
  return summary[0]

In [None]:
# Prepare test data
test_articles = full_dataset['test'].select(range(5))

summaries = [summarize(article['article']) for article in test_articles]
print("Train summary")
print(summarize(dataset['train'][0]['article']))

print("Test summary")
print(summarize(test_articles[0]['article']))


Train summary
<SOS><UNK><UNK><UNK>istheinininininin.in.in.....................................................................................................................................................................................................................................................................................................................................................................................<UNK><UNK><UNK>........<UNK><UNK><UNK><UNK><UNK><UNK>...ininin<UNK><UNK>....inin<UNK><UNK><UNK><UNK>inininininin.....<UNK><UNK><UNK><UNK><UNK>in.....ininininin...ininininin..........inin......ininin........ininininin..........inininininininin..<UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK>inininininininininin<UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><UNK><

In [1]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
import datasets

# Load the CNN/DailyMail dataset
cnn_dm_dataset = datasets.load_dataset('cnn_dailymail', '3.0.0')

# Initialize the LED tokenizer and model
tokenizer = T5Tokenizer.from_pretrained('t5-base')
model = T5ForConditionalGeneration.from_pretrained('t5-base')

# Function to summarize a batch of articles
def summarize(article):
    inputs = tokenizer.encode("summarize: " + article, return_tensors="pt", max_length=2000, truncation=True)
    outputs = model.generate(inputs,
                            max_length=150,
                            min_length=50)
    print(tokenizer.decode(outputs[0]))

# Summarize batch of articles
num_articles = 5
for x in range(num_articles):
    print("Article")
    print(cnn_dm_dataset['test'][x]['article'])
    print("Summary")
    summarize(cnn_dm_dataset['test'][x]['article'])

ModuleNotFoundError: No module named 'datasets'