<a href="https://colab.research.google.com/github/TejasGupta-27/TextSummary-TransformerScratch-/blob/main/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
!unzip bbc-fulltext.zip

Archive:  bbc-fulltext.zip
   creating: bbc/
   creating: bbc/entertainment/
  inflating: bbc/entertainment/289.txt  
  inflating: bbc/entertainment/262.txt  
  inflating: bbc/entertainment/276.txt  
  inflating: bbc/entertainment/060.txt  
  inflating: bbc/entertainment/074.txt  
  inflating: bbc/entertainment/048.txt  
  inflating: bbc/entertainment/114.txt  
  inflating: bbc/entertainment/100.txt  
  inflating: bbc/entertainment/128.txt  
  inflating: bbc/entertainment/316.txt  
  inflating: bbc/entertainment/302.txt  
  inflating: bbc/entertainment/303.txt  
  inflating: bbc/entertainment/317.txt  
  inflating: bbc/entertainment/129.txt  
  inflating: bbc/entertainment/101.txt  
  inflating: bbc/entertainment/115.txt  
  inflating: bbc/entertainment/049.txt  
  inflating: bbc/entertainment/075.txt  
  inflating: bbc/entertainment/061.txt  
  inflating: bbc/entertainment/277.txt  
  inflating: bbc/entertainment/263.txt  
  inflating: bbc/entertainment/288.txt  
  inflating: bbc/ente

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import math
import re
from collections import Counter

class Tokenizer:
    def __init__(self, vocab_size=50000):
        self.vocab_size = vocab_size
        self.word_to_idx = {}
        self.idx_to_word = {}
        self.vocab_freq = Counter()

        # Special tokens
        self.pad_token = '<pad>'
        self.unk_token = '<unk>'
        self.bos_token = '<bos>'
        self.eos_token = '<eos>'

        # Add special tokens to vocabulary
        for token in [self.pad_token, self.unk_token, self.bos_token, self.eos_token]:
            self.add_token(token)

    def preprocess_text(self, text):
        """Basic text preprocessing"""
        text = text.lower()
        text = re.sub(r'[^\w\s]', ' ', text)
        return text.split()

    def add_token(self, token):
        if token not in self.word_to_idx:
            idx = len(self.word_to_idx)
            self.word_to_idx[token] = idx
            self.idx_to_word[idx] = token

    def build_vocab(self, texts):
        """Build vocabulary from list of texts"""
        for text in texts:
            words = self.preprocess_text(text)
            self.vocab_freq.update(words)

        # Sort by frequency and take top vocab_size words
        most_common = self.vocab_freq.most_common(self.vocab_size - 4)  # -4 for special tokens
        for word, _ in most_common:
            self.add_token(word)

    def encode(self, text, max_length=None, padding=True):
        words = self.preprocess_text(text)
        if max_length is not None:
            words = words[:max_length-2]  # -2 for BOS and EOS tokens

        # Convert words to indices
        indices = [self.word_to_idx.get(word, self.word_to_idx[self.unk_token]) for word in words]
        indices = [self.word_to_idx[self.bos_token]] + indices + [self.word_to_idx[self.eos_token]]

        if padding and max_length is not None:
            indices += [self.word_to_idx[self.pad_token]] * (max_length - len(indices))

        return torch.tensor(indices)

    def decode(self, indices, skip_special_tokens=True):
        words = []
        for idx in indices:
            word = self.idx_to_word[idx.item()]
            if skip_special_tokens and word in [self.pad_token, self.bos_token, self.eos_token]:
                continue
            words.append(word)
        return ' '.join(words)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000):
        super().__init__()
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def attention(self, Q, K, V, mask=None):
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention = F.softmax(scores, dim=-1)
        return torch.matmul(attention, V)

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)

        # Linear projections and reshape
        Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Apply attention
        output = self.attention(Q, K, V, mask)

        # Reshape and apply final linear projection
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return self.W_o(output)

class NewsDataset(Dataset):
    def __init__(self, data_dir, tokenizer, max_content_length=512, max_summary_length=128, max_headline_length=64):
        self.data_dir = Path(data_dir)
        self.tokenizer = tokenizer
        self.max_content_length = max_content_length
        self.max_summary_length = max_summary_length
        self.max_headline_length = max_headline_length

        # Collect all files and build vocabulary
        self.samples = []
        self.categories = set()
        all_texts = []

        for category_dir in self.data_dir.iterdir():
            if category_dir.is_dir():
                category = category_dir.name
                self.categories.add(category)

                for file_path in category_dir.glob('*.txt'):
                    with open(file_path, 'r', encoding='utf-8') as f:
                        text = f.read()
                    all_texts.append(text)
                    self.samples.append((file_path, category))

        # Build vocabulary
        self.tokenizer.build_vocab(all_texts)

        self.category_to_idx = {cat: idx for idx, cat in enumerate(sorted(self.categories))}
        self.num_categories = len(self.categories)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        file_path, category = self.samples[idx]

        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()

        # Split text into headline and content
        lines = text.strip().split('\n')
        headline = lines[0].strip()
        content = '\n'.join(lines[1:]).strip()

        # Create summary (first 3 sentences)
        sentences = content.split('.')[:3]
        summary = '. '.join(sentences) + '.'

        # Tokenize
        content_tokens = self.tokenizer.encode(content, max_length=self.max_content_length)
        summary_tokens = self.tokenizer.encode(summary, max_length=self.max_summary_length)
        headline_tokens = self.tokenizer.encode(headline, max_length=self.max_headline_length)

        return {
            'content': content_tokens,
            'summary': summary_tokens,
            'headline': headline_tokens,
            'category': torch.tensor(self.category_to_idx[category]),
            'raw_content': content,
            'raw_headline': headline,
            'raw_summary': summary
        }

In [7]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self attention
        attn_output = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))

        # Feed forward
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.cross_attention = MultiHeadAttention(d_model, num_heads)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # Self attention
        self_attn_output = self.self_attention(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(self_attn_output))

        # Cross attention
        cross_attn_output = self.cross_attention(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(cross_attn_output))

        # Feed forward
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

class MultiTaskNewsProcessor(nn.Module):
    def __init__(self, vocab_size, num_categories, d_model=512, num_heads=8,
                 num_encoder_layers=6, num_decoder_layers=6, d_ff=2048, dropout=0.1):
        super().__init__()

        # Token embedding and positional encoding
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)

        # Encoder
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_encoder_layers)
        ])

        # Decoders for summary and headline
        self.summary_decoder_layers = nn.ModuleList([
            TransformerDecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_decoder_layers)
        ])

        self.headline_decoder_layers = nn.ModuleList([
            TransformerDecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_decoder_layers)
        ])

        # Output projections
        self.summary_projection = nn.Linear(d_model, vocab_size)
        self.headline_projection = nn.Linear(d_model, vocab_size)

        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 2, num_categories)
        )

        self.dropout = nn.Dropout(dropout)

    def encode(self, src, src_mask=None):
        # Embed and encode input
        x = self.dropout(self.positional_encoding(self.token_embedding(src)))

        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x, src_mask)

        return x

    def decode_summary(self, tgt, memory, src_mask=None, tgt_mask=None):
        x = self.dropout(self.positional_encoding(self.token_embedding(tgt)))

        for decoder_layer in self.summary_decoder_layers:
            x = decoder_layer(x, memory, src_mask, tgt_mask)

        return self.summary_projection(x)

    def decode_headline(self, tgt, memory, src_mask=None, tgt_mask=None):
        x = self.dropout(self.positional_encoding(self.token_embedding(tgt)))

        for decoder_layer in self.headline_decoder_layers:
            x = decoder_layer(x, memory, src_mask, tgt_mask)

        return self.headline_projection(x)

    def forward(self, src, summary_tgt, headline_tgt, src_mask=None, summary_mask=None, headline_mask=None):
        # Encode input
        encoder_output = self.encode(src, src_mask)

        # Classification
        pooled = encoder_output.mean(dim=1)
        classification_logits = self.classifier(pooled)

        # Generate summary and headline
        summary_output = self.decode_summary(summary_tgt, encoder_output, src_mask, summary_mask)
        headline_output = self.decode_headline(headline_tgt, encoder_output, src_mask, headline_mask)

        return classification_logits, summary_output, headline_output

In [8]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
from pathlib import Path
import time
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm
import random

class TrainingConfig:
    def __init__(self):
        # Model parameters
        self.vocab_size = 50000
        self.d_model = 512
        self.num_heads = 8
        self.num_encoder_layers = 6
        self.num_decoder_layers = 6
        self.d_ff = 2048
        self.dropout = 0.1

        # Training parameters
        self.batch_size = 16
        self.num_epochs = 20
        self.learning_rate = 0.0001
        self.warmup_steps = 4000
        self.max_grad_norm = 1.0
        self.label_smoothing = 0.1

        # Data parameters
        self.max_content_length = 512
        self.max_summary_length = 128
        self.max_headline_length = 64

        # Paths
        self.save_dir = Path("checkpoints")
        self.save_dir.mkdir(exist_ok=True)

def create_masks(src, tgt):
    # Source padding mask
    src_padding_mask = (src != 0).unsqueeze(1).unsqueeze(2)

    # Target padding mask
    tgt_padding_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)

    # Target future mask (for causal attention)
    seq_length = tgt.size(1)
    future_mask = torch.triu(torch.ones((seq_length, seq_length)), diagonal=1).bool()
    future_mask = future_mask.unsqueeze(0)

    tgt_mask = tgt_padding_mask & ~future_mask.to(tgt.device)

    return src_padding_mask, tgt_mask

class NoamScheduler:
    def __init__(self, optimizer, d_model, warmup_steps):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.current_step = 0

    def step(self):
        self.current_step += 1
        rate = self.get_rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate

    def get_rate(self):
        return self.d_model ** (-0.5) * min(self.current_step ** (-0.5),
                                          self.current_step * self.warmup_steps ** (-1.5))

class Trainer:
    def __init__(self, model, config, device):
        self.model = model
        self.config = config
        self.device = device

        # Initialize optimizer and scheduler
        self.optimizer = optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
        self.scheduler = NoamScheduler(self.optimizer, config.d_model, config.warmup_steps)

        # Loss functions
        self.classification_criterion = nn.CrossEntropyLoss()
        self.generation_criterion = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=config.label_smoothing)

        # Metrics tracking
        self.best_val_loss = float('inf')
        self.patience = 0
        self.max_patience = 3

    def save_checkpoint(self, epoch, val_loss):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'val_loss': val_loss,
            'config': self.config
        }
        path = self.config.save_dir / f'checkpoint_epoch_{epoch}.pt'
        torch.save(checkpoint, path)

    def load_checkpoint(self, path):
        checkpoint = torch.load(path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        return checkpoint['epoch']

    def train_epoch(self, train_loader):
        self.model.train()
        total_loss = 0
        total_class_acc = 0
        num_batches = 0

        progress_bar = tqdm(train_loader, desc="Training")

        for batch in progress_bar:
            # Move batch to device
            src = batch['content'].to(self.device)
            summary_tgt = batch['summary'].to(self.device)
            headline_tgt = batch['headline'].to(self.device)
            category = batch['category'].to(self.device)

            # Create masks
            src_mask, summary_mask = create_masks(src, summary_tgt[:, :-1])
            _, headline_mask = create_masks(src, headline_tgt[:, :-1])

            # Forward pass
            class_logits, summary_logits, headline_logits = self.model(
                src,
                summary_tgt[:, :-1],
                headline_tgt[:, :-1],
                src_mask,
                summary_mask,
                headline_mask
            )

            # Calculate losses
            classification_loss = self.classification_criterion(class_logits, category)

            summary_loss = self.generation_criterion(
                summary_logits.view(-1, summary_logits.size(-1)),
                summary_tgt[:, 1:].contiguous().view(-1)
            )

            headline_loss = self.generation_criterion(
                headline_logits.view(-1, headline_logits.size(-1)),
                headline_tgt[:, 1:].contiguous().view(-1)
            )

            # Combined loss
            loss = classification_loss + summary_loss + headline_loss

            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
            self.optimizer.step()
            self.scheduler.step()

            # Calculate metrics
            class_preds = torch.argmax(class_logits, dim=1)
            class_acc = (class_preds == category).float().mean()

            # Update progress bar
            total_loss += loss.item()
            total_class_acc += class_acc.item()
            num_batches += 1

            progress_bar.set_postfix({
                'loss': total_loss / num_batches,
                'class_acc': total_class_acc / num_batches
            })

        return total_loss / num_batches, total_class_acc / num_batches

    @torch.no_grad()
    def evaluate(self, val_loader):
        self.model.eval()
        total_loss = 0
        total_class_acc = 0
        num_batches = 0

        progress_bar = tqdm(val_loader, desc="Evaluating")

        for batch in progress_bar:
            # Move batch to device
            src = batch['content'].to(self.device)
            summary_tgt = batch['summary'].to(self.device)
            headline_tgt = batch['headline'].to(self.device)
            category = batch['category'].to(self.device)

            # Create masks
            src_mask, summary_mask = create_masks(src, summary_tgt[:, :-1])
            _, headline_mask = create_masks(src, headline_tgt[:, :-1])

            # Forward pass
            class_logits, summary_logits, headline_logits = self.model(
                src,
                summary_tgt[:, :-1],
                headline_tgt[:, :-1],
                src_mask,
                summary_mask,
                headline_mask
            )

            # Calculate losses
            classification_loss = self.classification_criterion(class_logits, category)

            summary_loss = self.generation_criterion(
                summary_logits.view(-1, summary_logits.size(-1)),
                summary_tgt[:, 1:].contiguous().view(-1)
            )

            headline_loss = self.generation_criterion(
                headline_logits.view(-1, headline_logits.size(-1)),
                headline_tgt[:, 1:].contiguous().view(-1)
            )

            loss = classification_loss + summary_loss + headline_loss

            # Calculate metrics
            class_preds = torch.argmax(class_logits, dim=1)
            class_acc = (class_preds == category).float().mean()

            total_loss += loss.item()
            total_class_acc += class_acc.item()
            num_batches += 1

            progress_bar.set_postfix({
                'val_loss': total_loss / num_batches,
                'val_class_acc': total_class_acc / num_batches
            })

        return total_loss / num_batches, total_class_acc / num_batches

    def train(self, train_loader, val_loader):
        print("Starting training...")

        for epoch in range(self.config.num_epochs):
            print(f"\nEpoch {epoch + 1}/{self.config.num_epochs}")

            # Training phase
            train_loss, train_acc = self.train_epoch(train_loader)
            print(f"Training Loss: {train_loss:.4f}, Training Accuracy: {train_acc:.4f}")

            # Validation phase
            val_loss, val_acc = self.evaluate(val_loader)
            print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")

            # Save checkpoint if best model
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.save_checkpoint(epoch + 1, val_loss)
                self.patience = 0
            else:
                self.patience += 1

            # Early stopping
            if self.patience >= self.max_patience:
                print("Early stopping triggered")
                break

def main():
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)

    # Initialize config
    config = TrainingConfig()

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Create tokenizer and datasets
    tokenizer = Tokenizer(vocab_size=config.vocab_size)

    # Load and split data
    dataset = NewsDataset(
        "/content/bbc",
        tokenizer,
        max_content_length=config.max_content_length,
        max_summary_length=config.max_summary_length,
        max_headline_length=config.max_headline_length
    )

    # Split dataset
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=4
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=4
    )

    # Initialize model
    model = MultiTaskNewsProcessor(
        vocab_size=config.vocab_size,
        num_categories=dataset.num_categories,
        d_model=config.d_model,
        num_heads=config.num_heads,
        num_encoder_layers=config.num_encoder_layers,
        num_decoder_layers=config.num_decoder_layers,
        d_ff=config.d_ff,
        dropout=config.dropout
    ).to(device)

    # Initialize trainer
    trainer = Trainer(model, config, device)

    # Start training
    trainer.train(train_loader, val_loader)

if __name__ == "__main__":
    main()

Using device: cuda




Starting training...

Epoch 1/20


Training: 100%|██████████| 112/112 [01:29<00:00,  1.25it/s, loss=21.6, class_acc=0.254]


Training Loss: 21.5841, Training Accuracy: 0.2539


Evaluating: 100%|██████████| 28/28 [00:07<00:00,  3.55it/s, val_loss=20.1, val_class_acc=0.312]


Validation Loss: 20.1169, Validation Accuracy: 0.3123

Epoch 2/20


Training: 100%|██████████| 112/112 [01:31<00:00,  1.23it/s, loss=18.9, class_acc=0.331]


Training Loss: 18.9095, Training Accuracy: 0.3309


Evaluating: 100%|██████████| 28/28 [00:08<00:00,  3.47it/s, val_loss=17.8, val_class_acc=0.489]


Validation Loss: 17.8146, Validation Accuracy: 0.4887

Epoch 3/20


Training: 100%|██████████| 112/112 [01:33<00:00,  1.20it/s, loss=16.7, class_acc=0.508]


Training Loss: 16.6616, Training Accuracy: 0.5084


Evaluating: 100%|██████████| 28/28 [00:08<00:00,  3.36it/s, val_loss=16.8, val_class_acc=0.507]


Validation Loss: 16.7862, Validation Accuracy: 0.5070

Epoch 4/20


Training: 100%|██████████| 112/112 [01:33<00:00,  1.19it/s, loss=15.6, class_acc=0.669]


Training Loss: 15.6044, Training Accuracy: 0.6685


Evaluating: 100%|██████████| 28/28 [00:08<00:00,  3.36it/s, val_loss=16.2, val_class_acc=0.678]


Validation Loss: 16.2374, Validation Accuracy: 0.6775

Epoch 5/20


Training: 100%|██████████| 112/112 [01:34<00:00,  1.18it/s, loss=14.9, class_acc=0.791]


Training Loss: 14.8662, Training Accuracy: 0.7913


Evaluating: 100%|██████████| 28/28 [00:08<00:00,  3.31it/s, val_loss=15.4, val_class_acc=0.897]


Validation Loss: 15.3652, Validation Accuracy: 0.8968

Epoch 6/20


Training: 100%|██████████| 112/112 [01:34<00:00,  1.19it/s, loss=14.1, class_acc=0.879]


Training Loss: 14.0942, Training Accuracy: 0.8795


Evaluating: 100%|██████████| 28/28 [00:08<00:00,  3.36it/s, val_loss=15.1, val_class_acc=0.916]


Validation Loss: 15.0780, Validation Accuracy: 0.9164

Epoch 7/20


Training: 100%|██████████| 112/112 [01:34<00:00,  1.18it/s, loss=13.3, class_acc=0.934]


Training Loss: 13.2554, Training Accuracy: 0.9336


Evaluating: 100%|██████████| 28/28 [00:08<00:00,  3.32it/s, val_loss=14.7, val_class_acc=0.947]


Validation Loss: 14.7123, Validation Accuracy: 0.9471

Epoch 8/20


Training: 100%|██████████| 112/112 [01:35<00:00,  1.18it/s, loss=12.4, class_acc=0.955]


Training Loss: 12.3914, Training Accuracy: 0.9548


Evaluating: 100%|██████████| 28/28 [00:08<00:00,  3.30it/s, val_loss=14.6, val_class_acc=0.945]


Validation Loss: 14.5816, Validation Accuracy: 0.9454

Epoch 9/20


Training: 100%|██████████| 112/112 [01:34<00:00,  1.18it/s, loss=11.6, class_acc=0.973]


Training Loss: 11.5838, Training Accuracy: 0.9727


Evaluating: 100%|██████████| 28/28 [00:08<00:00,  3.34it/s, val_loss=14.4, val_class_acc=0.894]


Validation Loss: 14.3791, Validation Accuracy: 0.8941

Epoch 10/20


Training: 100%|██████████| 112/112 [01:34<00:00,  1.18it/s, loss=10.7, class_acc=0.989]


Training Loss: 10.7275, Training Accuracy: 0.9894


Evaluating: 100%|██████████| 28/28 [00:08<00:00,  3.29it/s, val_loss=14, val_class_acc=0.936]


Validation Loss: 14.0372, Validation Accuracy: 0.9365

Epoch 11/20


Training: 100%|██████████| 112/112 [01:34<00:00,  1.19it/s, loss=9.89, class_acc=0.991]


Training Loss: 9.8894, Training Accuracy: 0.9905


Evaluating: 100%|██████████| 28/28 [00:08<00:00,  3.37it/s, val_loss=13.6, val_class_acc=0.944]


Validation Loss: 13.6334, Validation Accuracy: 0.9437

Epoch 12/20


Training: 100%|██████████| 112/112 [01:34<00:00,  1.18it/s, loss=9.02, class_acc=0.993]


Training Loss: 9.0222, Training Accuracy: 0.9927


Evaluating: 100%|██████████| 28/28 [00:08<00:00,  3.33it/s, val_loss=13.4, val_class_acc=0.939]


Validation Loss: 13.3923, Validation Accuracy: 0.9387

Epoch 13/20


Training: 100%|██████████| 112/112 [01:34<00:00,  1.18it/s, loss=8.17, class_acc=0.993]


Training Loss: 8.1707, Training Accuracy: 0.9933


Evaluating: 100%|██████████| 28/28 [00:08<00:00,  3.30it/s, val_loss=13.1, val_class_acc=0.948]


Validation Loss: 13.0873, Validation Accuracy: 0.9481

Epoch 14/20


Training: 100%|██████████| 112/112 [01:34<00:00,  1.18it/s, loss=7.33, class_acc=0.995]


Training Loss: 7.3252, Training Accuracy: 0.9950


Evaluating: 100%|██████████| 28/28 [00:08<00:00,  3.37it/s, val_loss=12.9, val_class_acc=0.923]


Validation Loss: 12.9043, Validation Accuracy: 0.9231

Epoch 15/20


Training: 100%|██████████| 112/112 [01:34<00:00,  1.18it/s, loss=6.45, class_acc=0.999]


Training Loss: 6.4454, Training Accuracy: 0.9989


Evaluating: 100%|██████████| 28/28 [00:08<00:00,  3.34it/s, val_loss=12.4, val_class_acc=0.944]


Validation Loss: 12.4455, Validation Accuracy: 0.9437

Epoch 16/20


Training: 100%|██████████| 112/112 [01:33<00:00,  1.19it/s, loss=5.59, class_acc=1]


Training Loss: 5.5871, Training Accuracy: 1.0000


Evaluating: 100%|██████████| 28/28 [00:08<00:00,  3.39it/s, val_loss=11.9, val_class_acc=0.962]


Validation Loss: 11.9117, Validation Accuracy: 0.9615

Epoch 17/20


Training: 100%|██████████| 112/112 [01:33<00:00,  1.19it/s, loss=4.79, class_acc=0.999]


Training Loss: 4.7944, Training Accuracy: 0.9994


Evaluating: 100%|██████████| 28/28 [00:08<00:00,  3.32it/s, val_loss=11.6, val_class_acc=0.939]


Validation Loss: 11.5875, Validation Accuracy: 0.9387

Epoch 18/20


Training: 100%|██████████| 112/112 [01:34<00:00,  1.19it/s, loss=4.25, class_acc=0.993]


Training Loss: 4.2516, Training Accuracy: 0.9933


Evaluating: 100%|██████████| 28/28 [00:08<00:00,  3.34it/s, val_loss=11.4, val_class_acc=0.924]


Validation Loss: 11.4331, Validation Accuracy: 0.9241

Epoch 19/20


Training: 100%|██████████| 112/112 [01:34<00:00,  1.18it/s, loss=3.84, class_acc=0.999]


Training Loss: 3.8379, Training Accuracy: 0.9994


Evaluating: 100%|██████████| 28/28 [00:08<00:00,  3.29it/s, val_loss=11, val_class_acc=0.95]


Validation Loss: 10.9938, Validation Accuracy: 0.9504

Epoch 20/20


Training: 100%|██████████| 112/112 [01:34<00:00,  1.18it/s, loss=3.59, class_acc=0.995]


Training Loss: 3.5914, Training Accuracy: 0.9950


Evaluating: 100%|██████████| 28/28 [00:08<00:00,  3.37it/s, val_loss=10.9, val_class_acc=0.939]


Validation Loss: 10.9176, Validation Accuracy: 0.9392


In [12]:
import torch

def test_model(model, tokenizer, device):
    """
    Test the model with a sample input and print the outputs for classification, summarization, and headline generation.
    """
    model.eval()  # Set the model to evaluation mode

    # Example input content
    sample_content = """
    Artificial Intelligence (AI) has become a transformative force across industries, enabling automation, enhancing decision-making,
    and creating opportunities for innovation. This article explores how AI is reshaping the technology landscape and its implications for the future.
    """

    # Tokenize the input content
    tokenized_content = tokenizer.encode(sample_content, max_length=512).tolist() # Convert the tensor to list of integers
    src_tensor = torch.tensor([tokenized_content]).to(device)  # Add batch dimension and move to device

    # Create a dummy target sequence for testing (start token)
    dummy_summary_tgt = torch.tensor([[tokenizer.word_to_idx[tokenizer.bos_token]]]).to(device)  # BOS token for summary

    dummy_headline_tgt = torch.tensor([[tokenizer.word_to_idx[tokenizer.bos_token]]]).to(device)   # BOS token for headline

    # Create masks
    src_mask = (src_tensor != 0).unsqueeze(1).unsqueeze(2)

    # Perform inference
    with torch.no_grad():
        class_logits, summary_logits, headline_logits = model(
            src_tensor,
            dummy_summary_tgt,
            dummy_headline_tgt,
            src_mask,
            None,
            None
        )

    # Decode outputs
    predicted_class = torch.argmax(class_logits, dim=1).item()
    predicted_summary_ids = torch.argmax(summary_logits, dim=-1).squeeze(0).tolist()
    predicted_headline_ids = torch.argmax(headline_logits, dim=-1).squeeze(0).tolist()

    # Convert IDs back to tokens
    predicted_summary = tokenizer.decode(predicted_summary_ids, skip_special_tokens=True)
    predicted_headline = tokenizer.decode(predicted_headline_ids, skip_special_tokens=True)

    # Display results
    print("Input Content:")
    print(sample_content)
    print("\nPredicted Category:")
    print(predicted_class)
    print("\nPredicted Summary:")
    print(predicted_summary)
    print("\nPredicted Headline:")
    print(predicted_headline)

if __name__ == "__main__":
    # Initialize model, tokenizer, and device
    config = TrainingConfig()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tokenizer = Tokenizer(vocab_size=config.vocab_size)  # Replace with actual tokenizer initialization
    model = MultiTaskNewsProcessor(
        vocab_size=config.vocab_size,
        num_categories=5,  # Example: Adjust based on your dataset
        d_model=config.d_model,
        num_heads=config.num_heads,
        num_encoder_layers=config.num_encoder_layers,
        num_decoder_layers=config.num_decoder_layers,
        d_ff=config.d_ff,
        dropout=config.dropout
    ).to(device)

    # Load a pre-trained checkpoint if available
    checkpoint_path = config.save_dir / "/content/checkpoints/checkpoint_epoch_19.pt"  # Replace with your checkpoint path
    if checkpoint_path.exists():
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        print("Checkpoint loaded successfully!")

    # Run the test
    test_model(model, tokenizer, device)


  checkpoint = torch.load(checkpoint_path)


Checkpoint loaded successfully!


AttributeError: 'int' object has no attribute 'item'