In [1]:
!wget http://nlp.stanford.edu/data/glove.6B.zip

--2024-05-05 16:18:11--  http://nlp.stanford.edu/data/glove.6B.zip
Resolving nlp.stanford.edu (nlp.stanford.edu)... 171.64.67.140
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://nlp.stanford.edu/data/glove.6B.zip [following]
--2024-05-05 16:18:11--  https://nlp.stanford.edu/data/glove.6B.zip
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://downloads.cs.stanford.edu/nlp/data/glove.6B.zip [following]
--2024-05-05 16:18:11--  https://downloads.cs.stanford.edu/nlp/data/glove.6B.zip
Resolving downloads.cs.stanford.edu (downloads.cs.stanford.edu)... 171.64.64.22
Connecting to downloads.cs.stanford.edu (downloads.cs.stanford.edu)|171.64.64.22|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 862182613 (822M) [application/zip]
Saving to: 'glove.6B.zip'


202

In [2]:
!unzip glove*.zip

Archive:  glove.6B.zip
  inflating: glove.6B.50d.txt        
  inflating: glove.6B.100d.txt       
  inflating: glove.6B.200d.txt       
  inflating: glove.6B.300d.txt       


In [3]:
import torch
from torch.utils.data import Dataset
from nltk.tokenize import sent_tokenize, word_tokenize
from sklearn.datasets import fetch_20newsgroups

In [4]:
class News20Dataset(Dataset):
  def __init__(self, word_map_path, max_sent_length=150, max_doc_length=40, is_train=True):
    self.max_sent_length = max_sent_length
    self.max_doc_length = max_doc_length
    self.split = 'train' if is_train else 'test'

    self.data = fetch_20newsgroups(
        subset=self.split,
        categories=['sci.crypt', 'sci.electronics', 'sci.med', 'sci.space'],
        shuffle=False,
        remove=('headers', 'footers', 'quotes')
    )

    # Load vocabulary from word map file
    self.vocab = pd.read_csv(
        filepath_or_buffer=word_map_path,
        header=None,
        sep=' ',
        quoting=csv.QUOTE_NONE,
        usecols=[0]
    ).values[:50000]

    # Create vocabulary list
    self.vocab = ['', ''] + [word[0] for word in self.vocab]

  def transform(self, text):
    doc = [
        [self.vocab.index(word) if word in self.vocab else 1 for word in word_tokenize(text=sent)]
        for sent in sent_tokenize(text=text)
    ]
    doc = [sent[:self.max_sent_length] for sent in doc][:self.max_doc_length]
    num_sents = min(len(doc), self.max_doc_length)
    if num_sents == 0:
      return None, -1, None

    num_words = [min(len(sent), self.max_sent_length) for sent in doc][:self.max_doc_length]

    return doc, num_sents, num_words

  def __getitem__(self, i):
    label = self.data['target'][i]
    text = self.data['data'][i]

    doc, num_sents, num_words = self.transform(text)

    if num_sents == -1:
      return None

    return doc, label, num_sents, num_words

  def __len__(self):
    return len(self.data['data'])

  @property
  def vocab_size(self):
    return len(self.vocab)

  @property
  def num_classes(self):
    return 4


In [5]:
def collate_fn(batch):
  batch = filter(lambda x: x is not None, batch)
  docs, labels, doc_lengths, sent_lengths = list(zip(*batch))

  bsz = len(labels)
  batch_max_doc_length = max(doc_lengths)
  batch_max_sent_length = max([max(sl) if sl else 0 for sl in sent_lengths])

  # Initialize tensors for documents and sentence lengths
  docs_tensor = torch.zeros([bsz, batch_max_doc_length, batch_max_sent_length]).long()
  sent_lengths_tensor = torch.zeros([bsz, batch_max_doc_length]).long()

  # Fill in tensors with data from batch
  for doc_idx, doc in enumerate(docs):
    doc_length = doc_lengths[doc_idx]
    sent_lengths_tensor[doc_idx, :doc_length] = torch.LongTensor(sent_lengths[doc_idx])
    for sent_idx, sent in enumerate(doc):
      sent_length = sent_lengths[doc_idx][sent_idx]
      docs_tensor[doc_idx, sent_idx, :sent_length] = torch.LongTensor(sent)

  return docs_tensor, torch.LongTensor(labels), torch.LongTensor(doc_lengths), sent_lengths_tensor

In [6]:
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler

In [7]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [8]:
class MyDataLoader(DataLoader):
  def __init__(self, dataset, batch_size):
    self.n_samples = len(dataset)
    self.sampler = RandomSampler(dataset)

    self.init_kwargs = {
        'dataset': dataset,
        'batch_size': batch_size,
        'pin_memory': True,  # Pin memory for faster GPU transfers if available
        'collate_fn': collate_fn,  # Collate function for batching data
        'shuffle': False  # Disable shuffling to maintain consistency during evaluation
    }

    # Initialize DataLoader using superclass constructor
    super().__init__(sampler=self.sampler, **self.init_kwargs)


In [9]:
import pandas as pd
import csv

In [10]:
dataset = News20Dataset('glove.6B.100d.txt', is_train=True)
data_loader = MyDataLoader(dataset, 64)

In [11]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence

In [12]:
class WordAttention(nn.Module):
  def __init__(self, vocab_size, embed_dim, gru_hidden_dim, gru_num_layers, att_dim, use_layer_norm, dropout):
    super(WordAttention, self).__init__()
    self.embeddings = nn.Embedding(vocab_size, embed_dim)
    self.gru = nn.GRU(
        embed_dim,
        gru_hidden_dim,
        num_layers=gru_num_layers,
        batch_first=True,
        bidirectional=True,
        dropout=dropout
    )
    self.use_layer_norm = use_layer_norm
    if use_layer_norm:
      self.layer_norm = nn.LayerNorm(2 * gru_hidden_dim, elementwise_affine=True)
    self.dropout = nn.Dropout(dropout)

    self.attention = nn.Linear(2 * gru_hidden_dim, att_dim)

    self.context_vector = nn.Linear(att_dim, 1, bias=False)

  def init_embeddings(self, embeddings):
    self.embeddings.weight = nn.Parameter(embeddings)

  def freeze_embeddings(self, freeze=False):
    self.embeddings.weight.requires_grad = not freeze

  def forward(self, sents, sent_lengths):
    # Sort sentences by length for efficient processing
    sent_lengths, sent_perm_idx = sent_lengths.sort(dim=0, descending=True)
    sents = sents[sent_perm_idx]

    # Pass sentences through embedding layer and apply dropout
    sents = self.embeddings(sents)
    sents = self.dropout(sents)

    # Pack sequences for dynamic sequence handling
    packed_words = pack_padded_sequence(sents, lengths=sent_lengths.tolist(), batch_first=True)

    valid_bsz = packed_words.batch_sizes

    # Pass packed sequences through bidirectional GRU layer
    packed_words, _ = self.gru(packed_words)

    # Optionally apply layer normalization
    if self.use_layer_norm:
      normed_words = self.layer_norm(packed_words.data)
    else:
      normed_words = packed_words

    # Compute attention weights
    att = torch.tanh(self.attention(normed_words.data))
    att = self.context_vector(att).squeeze(1)
    val = att.max()
    att = torch.exp(att - val)
    att, _ = pad_packed_sequence(PackedSequence(att, valid_bsz), batch_first=True)
    att_weights = att / torch.sum(att, dim = 1, keepdim=True)

    # Apply attention to GRU outputs and aggregate attended representations
    sents, _ = pad_packed_sequence(packed_words, batch_first=True)
    sents = sents * att_weights.unsqueeze(2)
    sents = sents.sum(dim=1)

    # Restore original order of sentences
    _, sent_unperm_idx = sent_perm_idx.sort(dim=0, descending=False)
    sents = sents[sent_unperm_idx]
    att_weights = att_weights[sent_unperm_idx]

    return sents, att_weights


In [13]:
class SentenceAttention(nn.Module):
  def __init__(self, vocab_size, embed_dim, word_gru_hidden_dim, sent_gru_hidden_dim,
              word_gru_num_layers, sent_gru_num_layers, word_att_dim, sent_att_dim, use_layer_norm, dropout):
      super(SentenceAttention, self).__init__()

      # Word-level attention module
      self.word_attention = WordAttention(vocab_size, embed_dim, word_gru_hidden_dim, word_gru_num_layers,
                                          word_att_dim, use_layer_norm, dropout)

      # Bidirectional sentence-level GRU
      self.gru = nn.GRU(2 * word_gru_hidden_dim, sent_gru_hidden_dim, num_layers=sent_gru_num_layers,
                        batch_first=True, bidirectional=True, dropout=dropout)

      # Optionally apply layer normalization
      self.use_layer_norm = use_layer_norm
      if use_layer_norm:
          self.layer_norm = nn.LayerNorm(2 * sent_gru_hidden_dim, elementwise_affine=True)
      self.dropout = nn.Dropout(dropout)

      # Sentence-level attention
      self.sent_attention = nn.Linear(2 * sent_gru_hidden_dim, sent_att_dim)

      # Sentence context vector u_s to take dot product with
      self.sentence_context_vector = nn.Linear(sent_att_dim, 1, bias=False)

  def forward(self, docs, doc_lengths, sent_lengths):
      # Sort documents by decreasing order in length
      doc_lengths, doc_perm_idx = doc_lengths.sort(dim=0, descending=True)
      docs = docs[doc_perm_idx]
      sent_lengths = sent_lengths[doc_perm_idx]

      # Make a long batch of sentences by removing pad-sentences
      packed_sents = pack_padded_sequence(docs, lengths=doc_lengths.tolist(), batch_first=True)

      # effective batch size at each timestep
      valid_bsz = packed_sents.batch_sizes

      # Make a long batch of sentence lengths by removing pad-sentences
      packed_sent_lengths = pack_padded_sequence(sent_lengths, lengths=doc_lengths.tolist(), batch_first=True)

      # Word attention module
      sents, word_att_weights = self.word_attention(packed_sents.data, packed_sent_lengths.data)

      # Optionally apply dropout
      sents = self.dropout(sents)

      # Sentence-level GRU over sentence embeddings
      packed_sents, _ = self.gru(PackedSequence(sents, valid_bsz))

      # Optionally apply layer normalization
      if self.use_layer_norm:
          normed_sents = self.layer_norm(packed_sents.data)
      else:
          normed_sents = packed_sents

      # Sentence attention
      att = torch.tanh(self.sent_attention(normed_sents))
      att = self.sentence_context_vector(att).squeeze(1)

      # Normalize attention weights
      val = att.max()
      att = torch.exp(att - val)
      att, _ = pad_packed_sequence(PackedSequence(att, valid_bsz), batch_first=True)
      sent_att_weights = att / torch.sum(att, dim=1, keepdim=True)

      # Restore as documents by repadding
      docs, _ = pad_packed_sequence(packed_sents, batch_first=True)

      # Compute document vectors
      docs = docs * sent_att_weights.unsqueeze(2)
      docs = docs.sum(dim=1)

      # Restore as documents by repadding
      word_att_weights, _ = pad_packed_sequence(PackedSequence(word_att_weights, valid_bsz), batch_first=True)

      # Restore the original order of documents (undo the first sorting)
      _, doc_unperm_idx = doc_perm_idx.sort(dim=0, descending=False)
      docs = docs[doc_unperm_idx]
      word_att_weights = word_att_weights[doc_unperm_idx]
      sent_att_weights = sent_att_weights[doc_unperm_idx]

      return docs, word_att_weights, sent_att_weights

In [14]:
class HierarchicalAttentionNetwork(nn.Module):
  def __init__(self,
               num_classes,
               vocab_size,
               embed_dim,
               word_gru_hidden_dim,
               sent_gru_hidden_dim,
               word_gru_num_layers,
               sent_gru_num_layers,
               word_att_dim,
               sent_att_dim,
               use_layer_norm,
               dropout):
    super(HierarchicalAttentionNetwork, self).__init__()
    self.sent_attention = SentenceAttention(
        vocab_size,
        embed_dim,
        word_gru_hidden_dim,
        sent_gru_hidden_dim,
        word_gru_num_layers,
        sent_gru_num_layers,
        word_att_dim,
        sent_att_dim,
        use_layer_norm,
        dropout
    )

    # Fully connected layer for classification
    self.fc = nn.Linear(2 * sent_gru_hidden_dim, num_classes)

    self.use_layer_norm = use_layer_norm
    self.dropout = dropout

  def forward(self, docs, doc_lengths, sent_lengths):
    # Compute document embeddings and attention weights
    doc_embed, word_att_weights, sent_att_weights = self.sent_attention(docs, doc_lengths, sent_lengths)

    # Pass document embeddings through the fully connected layer for classification
    scores = self.fc(doc_embed)

    return scores, word_att_weights, sent_att_weights


In [15]:
class MetricTracker(object):
  def __init__(self):
    self.reset()

  def reset(self):
    self.val = 0
    self.avg = 0
    self.sum = 0
    self.count = 0

  def update(self, summed_val, n=1):
    # Compute average value for the current batch or epoch
    self.val = summed_val / n
    # Update sum of all values encountered so far
    self.sum += summed_val
    # Increment count of total number of samples seen
    self.count += n
    # Compute running average of all values encountered so far
    self.avg = self.sum / self.count


In [16]:
class Evaluation:
  def __init__(self, config, model):
    self.config = config
    self.model = model
    self.device = next(self.model.parameters()).device

    # Initialize dataset and data loader for evaluation
    self.dataset = News20Dataset(config['vocab_path'], is_train=False)
    self.dataloader = torch.utils.data.DataLoader(self.dataset, batch_size=config['batch_size'], shuffle=False, collate_fn=collate_fn)

    # Initialize accuracy tracker
    self.accs = MetricTracker()
    self.best_acc = 0

  def eval(self):
    # Set model to evaluation mode
    self.model.eval()
    # Disable gradient calculation
    with torch.no_grad():
      # Reset accuracy tracker
      self.accs.reset()

      # Iterate over batches of data
      for (docs, labels, doc_lengths, sent_lengths) in self.dataloader:
        batch_size = labels.size(0)

        # Move data to device
        docs = docs.to(self.device)
        sent_lengths = sent_lengths.to(self.device)
        labels = labels.to(self.device)
        doc_lengths = doc_lengths.to(self.device)

        # Forward pass through the model
        scores, word_at_weights, sentence_att_weights = self.model(docs, doc_lengths, sent_lengths)

        # Compute predictions
        predictions = scores.max(dim=1)[1]

        # Calculate accuracy for the batch
        correct_predictions = torch.eq(predictions, labels).sum().item()
        acc = correct_predictions

        # Update accuracy tracker
        self.accs.update(acc, batch_size)

      # Update best accuracy if current average accuracy is higher
      self.best_acc = max(self.best_acc, self.accs.avg)

      # Print test average accuracy
      print('Test Average Accuracy: {acc.avg:.4f}'.format(acc=self.accs))

In [17]:
class Trainer:
  def __init__(self, config, model, optimizer, criterion, dataloader):
    self.config = config
    self.model = model
    self.optimizer = optimizer
    self.criterion = criterion
    self.dataloader = dataloader
    self.device = next(self.model.parameters()).device

    # Initialize metric trackers for loss and accuracy
    self.losses = MetricTracker()
    self.accs = MetricTracker()

    # Initialize evaluation instance for testing
    self.tester = Evaluation(self.config, self.model)

  def train(self):
    # Iterate over epochs
    for epoch in range(self.config['num_epochs']):
      # Train for one epoch
      result = self._train_epoch(epoch)
      # Print epoch-level results
      print('Epoch: [{0}]\t Avg Loss {loss:.4f}\t Avg Accuracy {acc:.3f}'.format(epoch, loss=result['loss'], acc=result['acc']))

      # Evaluate the model
      self.tester.eval()
      # Save the model if the current accuracy is the best so far
      if self.tester.best_acc == self.tester.accs.avg:
        print('Saving Model...')
        torch.save({
            'epoch': epoch,
            'model': self.model,
            'optimizer': self.optimizer
        }, 'best_model/model.pth.tar')

  def _train_epoch(self, epoch_idx):
    # Set model to training mode
    self.model.train()

    # Reset loss and accuracy trackers
    self.losses.reset()
    self.accs.reset()

    # Iterate over batches of training data
    for batch_idx, (docs, labels, doc_lengths, sent_lengths) in enumerate(self.dataloader):
        batch_size = labels.size(0)

        # Move data to device
        docs = docs.to(self.device)
        labels = labels.to(self.device)
        sent_lengths = sent_lengths.to(self.device)
        doc_lengths = doc_lengths.to(self.device)

        # Zero the gradients
        self.optimizer.zero_grad()

        # Forward pass through the model
        scores, word_att_weights, sentence_att_weights = self.model(docs, doc_lengths, sent_lengths)

        # Calculate the loss
        loss = self.criterion(scores, labels)
        loss.backward()

        # Clip gradients if specified
        if self.config['max_grad_norm'] is not None:
          torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['max_grad_norm'])

        # Update model parameters
        self.optimizer.step()

        # Compute accuracy for the batch
        predictions = scores.max(dim=1)[1]
        correct_predictions = torch.eq(predictions, labels).sum().item()
        acc = correct_predictions

        # Update loss and accuracy trackers
        self.losses.update(loss.item(), batch_size)
        self.accs.update(acc, batch_size)

        # Print batch-level results
        print('Epoch: [{0}][{1}/{2}]\t Loss {loss.val:.4f} ({loss.avg:.4f})\t Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(epoch_idx, batch_idx, len(self.dataloader), loss=self.losses, acc=self.accs))

    # Return epoch-level results
    log = {
        'loss': self.losses.avg,
        'acc': self.accs.avg
    }
    return log


In [18]:
def train(config, device):
  # Initialize training dataset and dataloader
  dataset = News20Dataset(config['vocab_path'], is_train=True)
  dataloader = MyDataLoader(dataset, batch_size=config['batch_size'])

  # Initialize the model
  model = HierarchicalAttentionNetwork(
        num_classes=dataset.num_classes,
        vocab_size=dataset.vocab_size,
        embed_dim=config['embed_dim'],
        word_gru_hidden_dim=config['word_gru_hidden_dim'],
        sent_gru_hidden_dim=config['sent_gru_hidden_dim'],
        word_gru_num_layers=config['word_gru_num_layers'],
        sent_gru_num_layers=config['sent_gru_num_layers'],
        word_att_dim=config['word_att_dim'],
        sent_att_dim=config['sent_att_dim'],
        use_layer_norm=config['use_layer_norm'],
        dropout=config['dropout']
    ).to(device)

  # Initialize optimizer
  optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()), lr=config['lr'])

  # Initialize loss function
  criterion = nn.CrossEntropyLoss(reduction='sum').to(device)

  # Initialize pretrained embeddings if specified
  if config['pretrain']:
    glove_pretrained = get_pretrained_weights(dataset.vocab, config['embed_dim'], device)
    model.sent_attention.word_attention.init_embeddings(glove_pretrained)

  # Freeze embeddings if specified
  model.sent_attention.word_attention.freeze_embeddings(config['freeze'])

  # Initialize trainer
  trainer = Trainer(config, model, optimizer, criterion, dataloader)

  # Start training
  trainer.train()


In [19]:
import os
import torch
from tqdm import tqdm
from pylab import *
from nltk.tokenize import word_tokenize, sent_tokenize

import matplotlib
import matplotlib.pyplot as plt

In [20]:
os.makedirs('best_model', exist_ok=True)

In [21]:
def get_pretrained_weights(corpus_vocab, embed_dim, device):
  save_dir = os.path.join('glove_pretrained.pt')
  if os.path.exists(save_dir):
    return torch.load(save_dir, map_location=device)

  corpus_set = set(corpus_vocab)
  pretrained_vocab = set()
  glove_pretrained = torch.zeros(len(corpus_vocab), embed_dim)
  with open(os.path.join('glove.6B.100d.txt'), 'rb') as f:
    for l in tqdm(f):
      line = l.decode().split()
      if line[0] in corpus_set:
        pretrained_vocab.add(line[0])
        glove_pretrained[corpus_vocab.index(line[0])] = torch.from_numpy(np.array(line[1:]).astype(float))

    var = float(torch.var(glove_pretrained))
    for oov in corpus_set.difference(pretrained_vocab):
      glove_pretrained[corpus_vocab.index(oov)]  = torch.empty(100).float().uniform_(-var, var)
    print('weight size: ', glove_pretrained.size())
    torch.save(glove_pretrained, save_dir)
  return glove_pretrained

In [22]:
def map_sentence_to_color(words, scores, sent_score):
  sentencemap = matplotlib.cm.get_cmap('binary')
  wordmap = matplotlib.cm.get_cmap('OrRd')
  result = '<p><span style="margin:5px; padding:5px; background-color: {}">'\
    .format(matplotlib.colors.rgb2hex(sentencemap(sent_score)[:3]))
  template = '<span class="barcode"; style="color: black; background-color: {}">{}</span>'
  for word, score in zip(words, scores):
    color = matplotlib.colors.rgb2hex(wordmap(scores)[:3])
    result += template.format(color, '&nbsp' + word + '&nbsp')
  result += '</span><p>'
  return result

In [24]:
config = {
  "batch_size": 64,
    "num_epochs": 25,
    "lr": 3e-3,
    "max_grad_norm": 5.0,
    "embed_dim": 100,
    "word_gru_hidden_dim": 100,
    "sent_gru_hidden_dim": 100,
    "word_gru_num_layers": 1,
    "sent_gru_num_layers": 1,
    "word_att_dim": 200,
    "sent_att_dim": 200,
    "vocab_path": "glove.6B.100d.txt",
    "pretrain": True,
    "freeze": False,
    "use_layer_norm": True,
    "dropout": 0.1
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train(config, device)



Epoch: [0][0/38]	 Loss 1.3888 (1.3888)	 Accuracy 0.186 (0.186)
Epoch: [0][1/38]	 Loss 1.4145 (1.4021)	 Accuracy 0.270 (0.230)
Epoch: [0][2/38]	 Loss 1.3841 (1.3959)	 Accuracy 0.302 (0.254)
Epoch: [0][3/38]	 Loss 1.3328 (1.3803)	 Accuracy 0.377 (0.285)
Epoch: [0][4/38]	 Loss 1.4327 (1.3910)	 Accuracy 0.159 (0.259)
Epoch: [0][5/38]	 Loss 1.4163 (1.3952)	 Accuracy 0.246 (0.257)
Epoch: [0][6/38]	 Loss 1.3443 (1.3879)	 Accuracy 0.306 (0.264)
Epoch: [0][7/38]	 Loss 1.3703 (1.3857)	 Accuracy 0.177 (0.253)
Epoch: [0][8/38]	 Loss 1.3579 (1.3826)	 Accuracy 0.295 (0.258)
Epoch: [0][9/38]	 Loss 1.2962 (1.3739)	 Accuracy 0.516 (0.284)
Epoch: [0][10/38]	 Loss 1.3446 (1.3712)	 Accuracy 0.190 (0.275)
Epoch: [0][11/38]	 Loss 1.2011 (1.3570)	 Accuracy 0.403 (0.286)
Epoch: [0][12/38]	 Loss 1.2740 (1.3506)	 Accuracy 0.355 (0.291)
Epoch: [0][13/38]	 Loss 1.1137 (1.3334)	 Accuracy 0.619 (0.315)
Epoch: [0][14/38]	 Loss 1.0743 (1.3158)	 Accuracy 0.651 (0.338)
Epoch: [0][15/38]	 Loss 1.0156 (1.2968)	 Accuracy 