<a href="https://colab.research.google.com/github/Gummadirajulavamshi/NLP/blob/main/project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import math
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader


class Attention(nn.Module):
    """General attention mechanism described for HAN.
    Given encoder outputs H (batch, time, hidden), compute attention weighted sum.
    """

    def __init__(self, hidden_size: int):
        super().__init__()

        self.proj = nn.Linear(hidden_size, hidden_size, bias=True)
        self.context_vector = nn.Linear(hidden_size, 1, bias=False)

    def forward(self, h: torch.Tensor, mask: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:

        u = torch.tanh(self.proj(h))
        scores = self.context_vector(u).squeeze(-1)
        if mask is not None:

            scores = scores.masked_fill(~mask, -1e9)

        alpha = F.softmax(scores, dim=1)  # (batch, time)

        alpha_unsq = alpha.unsqueeze(-1)  # (batch, time, 1)
        attended = torch.sum(h * alpha_unsq, dim=1)  # (batch, hidden)
        return attended, alpha


class WordEncoder(nn.Module):
    """Encodes words within a sentence using Embedding + BiLSTM + Attention

    Inputs expected per batch: sentences padded to max_words with word indices
    """

    def __init__(self, vocab_size: int, embed_dim: int, word_hidden_size: int, pretrained_embeddings: torch.Tensor = None):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        if pretrained_embeddings is not None:
            self.embedding.weight.data.copy_(pretrained_embeddings)
        self.bilstm = nn.LSTM(embed_dim, word_hidden_size, num_layers=1, bidirectional=True, batch_first=True)
        self.attention = Attention(word_hidden_size * 2)

    def forward(self, sentences: torch.Tensor, word_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

        embeds = self.embedding(sentences)


        packed = pack_padded_sequence(embeds, word_lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, _ = self.bilstm(packed)
        out, _ = pad_packed_sequence(packed_out, batch_first=True)  # (bs, max_time, hidden*2)


        max_time = out.size(1)
        mask = torch.arange(max_time, device=word_lengths.device).expand(len(word_lengths), max_time) < word_lengths.unsqueeze(1)


        sentence_rep, word_alphas = self.attention(out, mask)
        return sentence_rep, word_alphas


class SentenceEncoder(nn.Module):
    """Encodes sentences within a document using BiLSTM + Attention
    Input: sentence representations (batch, max_sents, sent_hidden)
    Also takes sentence lengths (num valid sentences per document)
    """

    def __init__(self, sent_hidden_size: int, doc_hidden_size: int):
        super().__init__()
        self.bilstm = nn.LSTM(sent_hidden_size, doc_hidden_size, num_layers=1, bidirectional=True, batch_first=True)
        self.attention = Attention(doc_hidden_size * 2)

    def forward(self, sentences_rep: torch.Tensor, sent_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

        packed = pack_padded_sequence(sentences_rep, sent_lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, _ = self.bilstm(packed)
        out, _ = pad_packed_sequence(packed_out, batch_first=True)  # (batch, max_sents, hidden*2)

        max_sents = out.size(1)
        mask = torch.arange(max_sents, device=sent_lengths.device).expand(len(sent_lengths), max_sents) < sent_lengths.unsqueeze(1)

        doc_rep, sent_alphas = self.attention(out, mask)
        return doc_rep, sent_alphas


class HAN(nn.Module):
    """Full Hierarchical Attention Network
    Combines WordEncoder and SentenceEncoder and a final classifier.
    """

    def __init__(self, vocab_size: int, embed_dim: int, word_hidden_size: int, sent_hidden_size: int, num_classes: int, pretrained_embeddings: torch.Tensor = None):
        super().__init__()
        self.word_encoder = WordEncoder(vocab_size, embed_dim, word_hidden_size, pretrained_embeddings)
        self.sent_encoder = SentenceEncoder(word_hidden_size * 2, sent_hidden_size)
        self.classifier = nn.Linear(sent_hidden_size * 2, num_classes)

    def forward(self, docs: torch.Tensor, word_lengths: torch.Tensor, sent_lengths: torch.Tensor) -> Tuple[torch.Tensor, dict]:

        batch, max_sents, max_words = docs.size()


        flat_sentences = docs.view(batch * max_sents, max_words)
        flat_word_lengths = word_lengths.view(batch * max_sents)


        nonzero_mask = flat_word_lengths > 0
        filtered_sentences = flat_sentences[nonzero_mask]
        filtered_word_lengths = flat_word_lengths[nonzero_mask]


        word_encoder_output_dim = self.word_encoder.bilstm.hidden_size * 2
        sentence_reps = torch.zeros(batch * max_sents, word_encoder_output_dim, device=docs.device)
        word_alphas = torch.zeros(batch * max_sents, max_words, device=docs.device)

        if filtered_sentences.size(0) > 0:
            processed_sentence_reps, processed_word_alphas_raw = self.word_encoder(filtered_sentences, filtered_word_lengths)
            sentence_reps[nonzero_mask] = processed_sentence_reps


            current_alpha_len = processed_word_alphas_raw.size(1)
            if current_alpha_len < max_words:

                pad = torch.zeros(processed_word_alphas_raw.size(0), max_words - current_alpha_len, device=docs.device)
                processed_word_alphas = torch.cat([processed_word_alphas_raw, pad], dim=1)
            elif current_alpha_len > max_words:

                processed_word_alphas = processed_word_alphas_raw[:, :max_words]
            else:
                processed_word_alphas = processed_word_alphas_raw

            word_alphas[nonzero_mask] = processed_word_alphas



        sentence_reps = sentence_reps.view(batch, max_sents, -1)  # (batch, max_sents, sent_hidden)
        word_alphas = word_alphas.view(batch, max_sents, -1) # reshape word_alphas for output


        doc_rep, sent_alphas = self.sent_encoder(sentence_reps, sent_lengths)

        logits = self.classifier(doc_rep)
        return logits, {"word_alphas": word_alphas, "sent_alphas": sent_alphas}




class SyntheticDataset(Dataset):
    """Generates random "documents" for demonstration. Replace with real dataset.

    Each document is represented as a tensor of shape (max_sents, max_words) containing token ids.
    """

    def __init__(self, num_docs: int, max_sents: int, max_words: int, vocab_size: int, num_classes: int):
        super().__init__()
        self.num_docs = num_docs
        self.max_sents = max_sents
        self.max_words = max_words
        self.vocab_size = vocab_size
        self.num_classes = num_classes

        import random
        self.docs = []
        self.labels = []
        for _ in range(num_docs):
            sents = random.randint(1, max_sents)
            doc = []
            word_counts = []
            for _s in range(sents):
                w = random.randint(1, max_words)
                word_counts.append(w)
                sent = [random.randint(1, vocab_size - 1) for _ in range(w)]

                sent += [0] * (max_words - w)
                doc.append(sent)

            for _ in range(max_sents - sents):
                doc.append([0] * max_words)
                word_counts.append(0)
            self.docs.append(doc)
            self.labels.append(random.randint(0, num_classes - 1))

    def __len__(self):
        return self.num_docs

    def __getitem__(self, idx):

        doc = torch.tensor(self.docs[idx], dtype=torch.long)
        word_lengths = (doc != 0).sum(dim=1).long()
        sent_length = (word_lengths > 0).sum().long()
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return doc, word_lengths, sent_length, label


def collate_fn(batch):

    docs = torch.stack([item[0] for item in batch], dim=0)
    word_lengths = torch.stack([item[1] for item in batch], dim=0)
    sent_lengths = torch.stack([item[2] for item in batch], dim=0)
    labels = torch.stack([item[3] for item in batch], dim=0)
    return docs, word_lengths, sent_lengths, labels



def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    total = 0
    correct = 0
    for docs, word_lengths, sent_lengths, labels in dataloader:
        docs = docs.to(device)
        word_lengths = word_lengths.to(device)
        sent_lengths = sent_lengths.to(device)
        labels = labels.to(device)

        batch, max_sents, max_words = docs.size()
        flat_word_lengths = word_lengths.view(batch * max_sents)

        optimizer.zero_grad()
        logits, alphas = model(docs, flat_word_lengths, sent_lengths)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * labels.size(0)
        total += labels.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()

    return total_loss / total, correct / total


def eval_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    total = 0
    correct = 0
    with torch.no_grad():
        for docs, word_lengths, sent_lengths, labels in dataloader:
            docs = docs.to(device)
            word_lengths = word_lengths.to(device)
            sent_lengths = sent_lengths.to(device)
            labels = labels.to(device)

            batch, max_sents, max_words = docs.size()
            flat_word_lengths = word_lengths.view(batch * max_sents)

            logits, alphas = model(docs, flat_word_lengths, sent_lengths)
            loss = criterion(logits, labels)

            total_loss += loss.item() * labels.size(0)
            total += labels.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()

    return total_loss / total, correct / total


if __name__ == "__main__":

    VOCAB_SIZE = 5000
    EMBED_DIM = 100
    WORD_HIDDEN = 50
    SENT_HIDDEN = 50
    NUM_CLASSES = 4

    MAX_SENTS = 10
    MAX_WORDS = 20

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = HAN(VOCAB_SIZE, EMBED_DIM, WORD_HIDDEN, SENT_HIDDEN, NUM_CLASSES).to(device)

    dataset = SyntheticDataset(num_docs=1000, max_sents=MAX_SENTS, max_words=MAX_WORDS, vocab_size=VOCAB_SIZE, num_classes=NUM_CLASSES)
    train_loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(1, 6):
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_acc = eval_model(model, val_loader, criterion, device)
        print(f"Epoch {epoch}: Train loss {train_loss:.4f} acc {train_acc:.4f} | Val loss {val_loss:.4f} acc {val_acc:.4f}")

    docs, word_lengths, sent_lengths, labels = next(iter(val_loader))
    docs = docs.to(device)
    flat_word_lengths = word_lengths.view(docs.size(0) * docs.size(1)).to(device)
    logits, alphas = model(docs, flat_word_lengths, sent_lengths.to(device))
    preds = logits.argmax(dim=1)
    print("Preds:", preds[:8].tolist())
    print("True:", labels[:8].tolist())

Epoch 1: Train loss 1.3876 acc 0.2430 | Val loss 1.3794 acc 0.2990
Epoch 2: Train loss 1.3787 acc 0.3370 | Val loss 1.3664 acc 0.4410
Epoch 3: Train loss 1.3498 acc 0.4030 | Val loss 1.2682 acc 0.4700
Epoch 4: Train loss 1.1736 acc 0.4890 | Val loss 0.9561 acc 0.6870
Epoch 5: Train loss 0.8078 acc 0.7370 | Val loss 0.4924 acc 0.8550
Preds: [1, 2, 2, 1, 1, 1, 0, 3]
True: [1, 2, 2, 1, 1, 1, 3, 3]
