#Text summarizer using seq2seq model in pytorch

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [2]:
class Encoder(nn.Module):
    def __init__(self,input_size,emb_dim,enc_hid_dim , dec_hid_dim,dropout):
        super().__init__()
        self.embedding=nn.Embedding(input_size,emb_dim)
        self.rnn=nn.GRU(emb_dim,enc_hid_dim,bidirectional=True)
        self.fc=nn.Linear(enc_hid_dim*2 , dec_hid_dim)
        self.dropout=nn.Dropout(dropout)

    def forward(self,src):
        embedded=self.dropout(self.embedding(src))

        outputs,hidden=self.rnn(embedded)
        hidden=torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]),
                                          dim=1)))
        return outputs,hidden

In [3]:
class Attention(nn.Module):
    def __init__(self, enc_hid_dim,dec_hid_dim):
        super().__init__()
        self.attn=nn.Linear((enc_hid_dim*2)+dec_hid_dim,dec_hid_dim)
        self.v=nn.Linear(dec_hid_dim,1,bias=False)

    def forward (self, hidden,encoder_outputs):
        batch_size=encoder_outputs.shape[1]
        src_len=encoder_outputs.shape[0]
        # Repeat decoder hidden state src_len times
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)

        encoder_outputs=encoder_outputs.transpose(0, 1)

        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        # energy: [batch_size, src_len, dec_hid_dim]

        attention = self.v(energy).squeeze(2)
        # attention: [batch_size, src_len]

        return F.softmax(attention, dim=1)

In [52]:
class Decoder(nn.Module):

    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super().__init__()

        self.output_dim = output_dim
        self.attention = attention

        self.embedding = nn.Embedding(output_dim, emb_dim)

        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)

        self.fc_out = nn.Linear((enc_hid_dim * 2) + emb_dim + dec_hid_dim, output_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, encoder_outputs):

        #input = [batch size]
        #hidden = [batch size, dec_hid_dim]
        #encoder_outputs = [src_len, batch size, enc_hid_dim * 2]

        batch_size = input.shape[0] # Get batch size

        input = input.unsqueeze(0)

        #input = [1, batch size]

        embedded = self.dropout(self.embedding(input))

        #embedded = [1, batch size, emb_dim]

        a = self.attention(hidden, encoder_outputs)

        #a = [batch size, src_len]

        encoder_outputs = encoder_outputs.permute(1, 0, 2)

        #encoder_outputs = [batch size, src_len, enc_hid_dim * 2]

        weighted = torch.bmm(a.unsqueeze(1), encoder_outputs).squeeze(1)

        #weighted = [batch size, enc_hid_dim * 2]

        rnn_input = torch.cat((embedded, weighted.unsqueeze(0)), dim = 2)

        #rnn_input = [1, batch size, (enc_hid_dim * 2) + emb_dim]

        output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))

        #output = [seq_len, batch size, dec_hid_dim]
        #hidden = [seq_len, batch size, dec_hid_dim]

        #seq_len is 1 for the decoder

        #output = [1, batch size, dec_hid_dim]
        #hidden = [1, batch size, dec_hid_dim]

        # Ensure tensors have 2 dimensions [batch_size, feature_dim] before concatenation for fc_out
        embedded = embedded.squeeze(0).view(batch_size, -1) # Shape: [batch size, emb_dim]
        output = output.squeeze(0).view(batch_size, -1)     # Shape: [batch size, dec_hid_dim]
        weighted = weighted.squeeze(0).view(batch_size, -1) # Shape: [batch size, enc_hid_dim * 2]


        # Print shapes for debugging
        print("Shape of output before concat:", output.shape)
        print("Shape of weighted before concat:", weighted.shape)
        print("Shape of embedded before concat:", embedded.shape)

        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))

        #prediction = [batch size, output_dim]

        return prediction, hidden.squeeze(0), a # Removed .squeeze(1)

In [5]:
class Seq2Seq(nn.Module):
    def __init__(self,encoder,decoder,device):
        super().__init__()
        self.encoder=encoder
        self.decoder=decoder
        self.device=device

    def forward(self,src,trg,teacher_forcing_ratio=0.5):
        # src shape: [src_len, batch_size]
        # trg shape: [trg_len, batch_size]

        batch_size=src.shape[1] # Get batch size from input tensor
        trg_len=trg.shape[0]
        trg_vocab_size=self.decoder.output_dim

        outputs=torch.zeros(trg_len,batch_size,trg_vocab_size).to(self.device)

        encoder_outputs,hidden=self.encoder(src)
        # encoder_outputs shape: [src_len, batch_size, enc_hid_dim * 2]
        # hidden shape: [batch_size, dec_hid_dim]

        # First input to decoder is <sos> token
        input = trg[0,:] # shape: [batch_size]

        for t in range(1, trg_len):
            output, hidden, _ = self.decoder(input, hidden, encoder_outputs)
            # output shape: [batch_size, output_dim]
            # hidden shape: [batch_size, dec_hid_dim]

            outputs[t] = output # Assign output to the correct slice

            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = trg[t] if teacher_force else top1

        return outputs

In [6]:
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
import re
from collections import Counter
import pickle

In [7]:
class vocabullary:
    def __init__(self,freq_threshold=5):
        self.itos = {0:"<PAD>", 1:"<SOS>", 2:"<eos>", 3:"<unk>"}
        self.stoi = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    def build_vocabullary(self,sentence_list):
        frequencies = Counter()
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer(sentence):
                frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

In [8]:
class vocabullary:
    def __init__(self,freq_threshold=5):
        self.itos = {0:"<PAD>", 1:"<SOS>", 2:"<eos>", 3:"<unk>"}
        self.stoi = {"<pad>": 0, "<SOS>": 1, "<eos>": 2, "<unk>": 3}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    def build_vocabullary(self,sentence_list):
        frequencies = Counter()
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer(sentence):
                frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def tokenizer(self, text):
        text = text.lower()
        text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
        return text.split()

    def numericalize(self, text):
        tokenized_text = self.tokenizer(text)
        return [self.stoi[token] if token in self.stoi else self.stoi["<unk>"]
                for token in tokenized_text]

    def save(self, filepath):
        with open(filepath, 'wb') as f:
            pickle.dump({'stoi': self.stoi, 'itos': self.itos, 'freq_threshold': self.freq_threshold}, f)

    @staticmethod
    def load(filepath):
        with open(filepath, 'rb') as f:
            state = pickle.load(f)
            vocab = vocabullary(state['freq_threshold'])
            vocab.stoi = state['stoi']
            vocab.itos = state['itos']
            return vocab

In [9]:
class SummarizationDataset(Dataset):
    def __init__(self, csv_file, text_column, summary_column,
                 text_vocab=None, summary_vocab=None,
                 max_text_len=400, max_summary_len=100,
                 freq_threshold=5):


        # Use header=None because the file doesn't have headers
        self.df=pd.read_csv(csv_file, sep='\t', header=None)
        # The first column is the index, the second is column 0

        self.df = self.df.rename(columns={0: 'summary'})
        self.text_column=text_column
        self.summary_column=summary_column
        self.max_text_len=max_text_len
        self.max_summary_len=max_summary_len


        # Build or use provided vocabularies
        if text_vocab is None:
            self.text_vocab = vocabullary(freq_threshold)
            # Accessing the index for text and converting to string
            self.text_vocab.build_vocabullary(self.df.index.astype(str).tolist())
        else:
            self.text_vocab = text_vocab

        if summary_vocab is None:
            self.summary_vocab = vocabullary(freq_threshold)
            # Accessing the renamed column 'summary' for summary and converting to string
            self.summary_vocab.build_vocabullary(self.df[self.summary_column].astype(str).tolist())
            # Explicitly add <SOS> and <eos> to summary_vocab if not present
            if "<SOS>" not in self.summary_vocab.stoi:
                sos_index = 1
                self.summary_vocab.stoi["<SOS>"] = sos_index
                self.summary_vocab.itos[sos_index] = "<SOS>"
            if "<eos>" not in self.summary_vocab.stoi:
                eos_index = 2
                self.summary_vocab.stoi["<eos>"] = eos_index
                self.summary_vocab.itos[eos_index] = "<eos>"
        else:
            self.summary_vocab = summary_vocab


    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):

        if "<SOS>" not in self.summary_vocab.stoi:
            sos_index = 1
            self.summary_vocab.stoi["<SOS>"] = sos_index
            self.summary_vocab.itos[sos_index] = "<SOS>"
        if "<eos>" not in self.summary_vocab.stoi:
            eos_index = 2
            self.summary_vocab.stoi["<eos>"] = eos_index
            self.summary_vocab.itos[eos_index] = "<eos>"


        # Accessing text from the index using iloc and converting to string
        text=str(self.df.index[idx])

        summary=str(self.df.iloc[idx][self.summary_column])


        text_numericalized = self.text_vocab.numericalize(text)[:self.max_text_len]
        summary_numericalized = self.summary_vocab.numericalize(summary)[:self.max_summary_len]


        # Add <sos> and <eos> tokens to summary
        summary_numericalized = [self.summary_vocab.stoi["<SOS>"]] + \
                                 summary_numericalized + \
                                 [self.summary_vocab.stoi["<eos>"]]

        return torch.tensor(text_numericalized), torch.tensor(summary_numericalized)

In [10]:
class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        texts = [item[0] for item in batch]
        summaries = [item[1] for item in batch]

        # Pad sequences
        texts_padded = pad_sequence(texts, batch_first=False, padding_value=self.pad_idx)
        summaries_padded = pad_sequence(summaries, batch_first=False, padding_value=self.pad_idx)

        return texts_padded, summaries_padded

def get_loader(csv_file, text_column, summary_column,
               batch_size=32, num_workers=4,
               text_vocab=None, summary_vocab=None,
               shuffle=True, pin_memory=True):
    pass

In [13]:
#create dataloader for training and validation
dataset=SummarizationDataset(
    csv_file='/content/eng-fra.txt',
    text_column=0,
    summary_column=1 # Use integer index for the second column (French translation as summary)
)

pad_idx=dataset.summary_vocab.stoi["<pad>"]

loader=DataLoader(
    dataset=dataset,
    batch_size=32,
    num_workers=0, # Setting num_workers to 0 to avoid pickling issues
    shuffle=True,
    pin_memory=True,
    collate_fn=MyCollate(pad_idx=pad_idx)
)

In [12]:
with open('/content/eng-fra.txt', 'r', encoding='utf-8') as f:
    for i in range(5): # Displaying the first 5 lines
        print(f.readline())

Go.	Va !

Run!	Cours !

Run!	Courez !

Wow!	Ça alors !

Fire!	Au feu !



In [11]:
import zipfile
import os

zip_file_path = '/content/data.zip'
extracted_file_path = '/content/eng-fra.txt'
file_to_extract = 'data/eng-fra.txt'

with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    # Check if the file to extract exists in the zip archive
    if file_to_extract in zip_ref.namelist():
        zip_ref.extract(file_to_extract, '/content/')

        os.rename(os.path.join('/content/', file_to_extract), extracted_file_path)
        print(f"Successfully extracted '{file_to_extract}' to '{extracted_file_path}'")
    else:
        print(f"File '{file_to_extract}' not found in the zip archive.")

Successfully extracted 'data/eng-fra.txt' to '/content/eng-fra.txt'


#Training

In [14]:
def train_epoch(model, iterator, optimizer, criterion, clip, device):
    model.train()
    epoch_loss = 0

    for i, (src, trg) in enumerate(iterator):
        src, trg = src.to(device), trg.to(device)

        optimizer.zero_grad()
        output = model(src, trg)

        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)

        trg = trg[1:].reshape(-1)

        loss = criterion(output, trg)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        epoch_loss += loss.item()

        if i % 100 == 0:
            # The loss here is the average loss over the batch
            print(f'  Batch {i}/{len(iterator)}, Loss: {loss.item():.4f}')

    return epoch_loss / len(iterator)

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import time
import math


def train_epoch(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0

    for i, (src, trg) in enumerate(iterator):
        src, trg = src.to(device), trg.to(device)

        optimizer.zero_grad()
        output = model(src, trg)


        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1).to(torch.long)

        loss = criterion(output, trg)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        epoch_loss += loss.item()

        if i % 100 == 0:
            print(f'  Batch {i}/{len(iterator)}, Loss: {loss.item():.4f}')

    return epoch_loss / len(iterator)

def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0

    with torch.no_grad():
        for i, (src, trg) in enumerate(iterator):
            src, trg = src.to(device), trg.to(device)

            # Turn off teacher forcing
            output = model(src, trg, teacher_forcing_ratio=0)

            output_dim = output.shape[-1]
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1).to(torch.long)

            loss = criterion(output, trg)
            epoch_loss += loss.item()

    return epoch_loss / len(iterator)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
INPUT_DIM = 10000
OUTPUT_DIM = 10000
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
BATCH_SIZE = 64
N_EPOCHS = 10
CLIP = 1
LEARNING_RATE = 0.001


print("Loading training data...")

train_loader = loader
train_dataset = dataset
val_loader = loader
print("Loading validation data...")

val_dataset = dataset


# Update dimensions based on actual vocabulary size
INPUT_DIM = len(train_dataset.text_vocab)
OUTPUT_DIM = len(train_dataset.summary_vocab)

print(f"Input vocabulary size: {INPUT_DIM}")
print(f"Output vocabulary size: {OUTPUT_DIM}")

# Save vocabularies
train_dataset.text_vocab.save('text_vocab.pkl')
train_dataset.summary_vocab.save('summary_vocab.pkl')

# Initialize model
attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)
model = Seq2Seq(enc, dec, device).to(device)

# Initialize weights
def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)

model.apply(init_weights)

# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

# Optimizer and loss
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

PAD_IDX = train_dataset.summary_vocab.stoi['<pad>']
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

# Training loop
best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    start_time = time.time()

    train_loss = train_epoch(model, train_loader, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, val_loader, criterion)

    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    # Learning rate scheduling
    scheduler.step(valid_loss)

    # Save best model
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'seq2seqsummarizer_model.pt')
        print('  [Saved Best Model]')

    # Save checkpoint
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'valid_loss': valid_loss,
    }
    torch.save(checkpoint, f'checkpoint_epoch_{epoch}.pt')

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

print("Training complete!")

Loading training data...
Loading validation data...
Input vocabulary size: 4
Output vocabulary size: 8415
The model has 23,676,639 trainable parameters
  Batch 0/4246, Loss: 9.0379
  Batch 100/4246, Loss: 5.8359
  Batch 200/4246, Loss: 5.6561
  Batch 300/4246, Loss: 5.5926
  Batch 400/4246, Loss: 5.9086
  Batch 500/4246, Loss: 5.2205
  Batch 600/4246, Loss: 5.4735
  Batch 700/4246, Loss: 5.6010
  Batch 800/4246, Loss: 5.2774
  Batch 900/4246, Loss: 5.4433
  Batch 1000/4246, Loss: 5.2191
  Batch 1100/4246, Loss: 5.4535
  Batch 1200/4246, Loss: 4.9235
  Batch 1300/4246, Loss: 4.9366
  Batch 1400/4246, Loss: 5.4727
  Batch 1500/4246, Loss: 4.8283
  Batch 1600/4246, Loss: 4.8509
  Batch 1700/4246, Loss: 4.9275
  Batch 1800/4246, Loss: 5.2844
  Batch 1900/4246, Loss: 4.9880
  Batch 2000/4246, Loss: 5.2888
  Batch 2100/4246, Loss: 5.0611
  Batch 2200/4246, Loss: 5.0510
  Batch 2300/4246, Loss: 5.0996
  Batch 2400/4246, Loss: 5.3060
  Batch 2500/4246, Loss: 4.6702
  Batch 2600/4246, Loss: 5.0

#Inference

In [44]:
def translate_sentence(model, sentence, text_vocab, summary_vocab, device, max_length=100):

    model.eval()
    # Tokenize and numericalize
    tokens = text_vocab.numericalize(sentence)
    tokens = torch.LongTensor(tokens).unsqueeze(1).to(device)
    with torch.no_grad():
        encoder_outputs, hidden = model.encoder(tokens)


    # Start with <sos> token
    trg_indexes = [summary_vocab.stoi['<SOS>']]
    attentions = []

    for _ in range(max_length):
        trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device)

        with torch.no_grad():
            output, hidden, attention = model.decoder(trg_tensor, hidden, encoder_outputs)

        attentions.append(attention)
        pred_token = output.argmax(1).item()
        trg_indexes.append(pred_token)

        if pred_token == summary_vocab.stoi['<eos>']:
            break

    # Convert indexes to words
    trg_tokens = [summary_vocab.itos[i] for i in trg_indexes]

    # Remove <sos> and <eos>
    return trg_tokens[1:-1] if trg_tokens[-1] == '<eos>' else trg_tokens[1:], attentions

def summarize_text(text, model, text_vocab, summary_vocab, device):

    summary_tokens, attentions = translate_sentence(
        model, text, text_vocab, summary_vocab, device
    )
    summary = ' '.join(summary_tokens)
    return summary

In [45]:
def load_model(model_path, text_vocab_path, summary_vocab_path, device):
    """
    Load trained model and vocabularies
    """
    # Load vocabularies
    text_vocab = vocabullary.load(text_vocab_path)
    summary_vocab = vocabullary.load(summary_vocab_path)

    # Model configuration (must match training)
    INPUT_DIM = len(text_vocab)
    OUTPUT_DIM = len(summary_vocab)
    ENC_EMB_DIM = 256
    DEC_EMB_DIM = 256
    ENC_HID_DIM = 512
    DEC_HID_DIM = 512
    ENC_DROPOUT = 0.5
    DEC_DROPOUT = 0.5

    # Initialize model
    attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
    enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
    dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)
    model = Seq2Seq(enc, dec, device).to(device)

    # Load weights
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    return model, text_vocab, summary_vocab

In [54]:
# Example usage
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load model
    model, text_vocab, summary_vocab = load_model(
        model_path='seq2seqsummarizer_model.pt',
        text_vocab_path='text_vocab.pkl',
        summary_vocab_path='summary_vocab.pkl',
        device=device
    )

    # Example text
    text = """
Artificial intelligence is transforming the world in numerous ways.
Machine learning algorithms are now being used in healthcare to diagnose diseases,
in finance to detect fraud, and in transportation to power self-driving cars.
The technology continues to advance rapidly with new breakthroughs happening regularly.
"""

    # Generate summary
    summary = summarize_text(text, model, text_vocab, summary_vocab, device)

    print("Original Text:")
    print(text)
    print("\nGenerated Summary:")
    print(summary)

    # Add a new text example
    new_text = """
The quick brown fox jumps over the lazy dog. This sentence is often used to test
typewriters or keyboards because it contains all the letters of the alphabet.
It's a classic pangram.
"""

    print("\n\nNew Text Example:")
    print(new_text)
    print("\nGenerated Summary for New Text:")
    new_summary = summarize_text(new_text, model, text_vocab, summary_vocab, device)
    print(new_summary)

    # Batch inference example
    texts = [
        "Your first document here...",
        "Your second document here...",
        "Your third document here..."
    ]

    print("\n\nBatch Summarization:")
    for i, text in enumerate(texts, 1):
        summary = summarize_text(text, model, text_vocab, summary_vocab, device)
        print(f"\n{i}. Summary: {summary}")

Shape of output before concat: torch.Size([1, 512])
Shape of weighted before concat: torch.Size([1, 1024])
Shape of embedded before concat: torch.Size([1, 256])
Shape of output before concat: torch.Size([1, 512])
Shape of weighted before concat: torch.Size([1, 1024])
Shape of embedded before concat: torch.Size([1, 256])
Shape of output before concat: torch.Size([1, 512])
Shape of weighted before concat: torch.Size([1, 1024])
Shape of embedded before concat: torch.Size([1, 256])
Shape of output before concat: torch.Size([1, 512])
Shape of weighted before concat: torch.Size([1, 1024])
Shape of embedded before concat: torch.Size([1, 256])
Original Text:

Artificial intelligence is transforming the world in numerous ways.
Machine learning algorithms are now being used in healthcare to diagnose diseases,
in finance to detect fraud, and in transportation to power self-driving cars.
The technology continues to advance rapidly with new breakthroughs happening regularly.


Generated Summary:
je