In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
import torchtext
from torchtext.vocab import build_vocab_from_iterator
import nltk
from nltk.corpus import reuters
import spacy
import numpy as np
import random



In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
print(reuters.categories())
print(reuters.fileids()[:5])

['acq', 'alum', 'barley', 'bop', 'carcass', 'castor-oil', 'cocoa', 'coconut', 'coconut-oil', 'coffee', 'copper', 'copra-cake', 'corn', 'cotton', 'cotton-oil', 'cpi', 'cpu', 'crude', 'dfl', 'dlr', 'dmk', 'earn', 'fuel', 'gas', 'gnp', 'gold', 'grain', 'groundnut', 'groundnut-oil', 'heat', 'hog', 'housing', 'income', 'instal-debt', 'interest', 'ipi', 'iron-steel', 'jet', 'jobs', 'l-cattle', 'lead', 'lei', 'lin-oil', 'livestock', 'lumber', 'meal-feed', 'money-fx', 'money-supply', 'naphtha', 'nat-gas', 'nickel', 'nkr', 'nzdlr', 'oat', 'oilseed', 'orange', 'palladium', 'palm-oil', 'palmkernel', 'pet-chem', 'platinum', 'potato', 'propane', 'rand', 'rape-oil', 'rapeseed', 'reserves', 'retail', 'rice', 'rubber', 'rye', 'ship', 'silver', 'sorghum', 'soy-meal', 'soy-oil', 'soybean', 'strategic-metal', 'sugar', 'sun-meal', 'sun-oil', 'sunseed', 'tea', 'tin', 'trade', 'veg-oil', 'wheat', 'wpi', 'yen', 'zinc']
['test/14826', 'test/14828', 'test/14829', 'test/14832', 'test/14833']


In [4]:
print(reuters.raw('test/14826'))

ASIAN EXPORTERS FEAR DAMAGE FROM U.S.-JAPAN RIFT
  Mounting trade friction between the
  U.S. And Japan has raised fears among many of Asia's exporting
  nations that the row could inflict far-reaching economic
  damage, businessmen and officials said.
      They told Reuter correspondents in Asian capitals a U.S.
  Move against Japan might boost protectionist sentiment in the
  U.S. And lead to curbs on American imports of their products.
      But some exporters said that while the conflict would hurt
  them in the long-run, in the short-term Tokyo's loss might be
  their gain.
      The U.S. Has said it will impose 300 mln dlrs of tariffs on
  imports of Japanese electronics goods on April 17, in
  retaliation for Japan's alleged failure to stick to a pact not
  to sell semiconductors on world markets at below cost.
      Unofficial Japanese estimates put the impact of the tariffs
  at 10 billion dlrs and spokesmen for major electronics firms
  said they would virtually halt exports

In [5]:
docs = []
summaries = []

for file_id in reuters.fileids()[:2000]:
    corpus = reuters.raw(file_id)
    sentences = nltk.sent_tokenize(corpus)
    if len(sentences) > 5:
        docs.append(" ".join(sentences)) # full text
        summaries.append(" ".join(sentences[:3])) # first 3 sentences as summary

In [6]:
spacy_en = spacy.load("en_core_web_sm") 

def tokenize(text):
    return [token.text.lower() for token in spacy_en.tokenizer(text)]

In [33]:
def yeild_tokens(data):
	for doc in data:
		yield tokenize(doc)

In [8]:
vocab = build_vocab_from_iterator(yeild_tokens(docs+summaries), specials=["<unk>", "<pad>", "<sos>", "<eos>"])
vocab.set_default_index(vocab["<unk>"])

In [9]:
UNK_IDX, PAD_IDX, SOS_IDX, EOS_IDX = vocab["<unk>"], vocab["<pad>"], vocab["<sos>"], vocab["<eos>"]

In [10]:
len(vocab)

13229

In [11]:
# print vocab
for i in range(20):
	print(vocab.get_itos()[i])


<unk>
<pad>
<sos>
<eos>

  
the
,
.
to
of
in
and
a
said
"
for
-
's
on
it


In [12]:
class SummaryDataset(Dataset):
    def __init__(self, docs, summaries, vocab, max_len=100):
        self.docs = docs
        self.summaries = summaries
        self.vocab = vocab
        self.max_len = max_len
    
    def encode(self, sentence):
        tokens = tokenize(sentence)[:self.max_len-2]
        return [SOS_IDX] + [self.vocab[token] for token in tokens] + [EOS_IDX]
    
    def pad(self, seq):
        return seq + [PAD_IDX] * (self.max_len - len(seq))
    
    def __len__(self):
        return len(self.docs)
    
    def __getitem__(self, idx):
        src = self.pad(self.encode(self.docs[idx]))
        trg = self.pad(self.encode(self.summaries[idx]))
        
        return torch.tensor(src), torch.tensor(trg)

In [13]:
dataset = SummaryDataset(docs, summaries, vocab)

data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [14]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.GRU(embedding_dim, hidden_dim, n_layers, batch_first=True, dropout=dropout)
        
    def forward(self, src):
		# src : [batch_size, seq_len]
        
        embedded = self.embedding(src)
		# embedded : [batch_size, seq_len, embedding_dim]

        output, hidden = self.rnn(embedded)
		# output : [batch_size, seq_len, hidden_dim]
		# hidden : [n_layers, batch_size, hidden_dim]
        
        return hidden

In [15]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.GRU(embedding_dim, hidden_dim, n_layers, batch_first=True, dropout=dropout)
        self.fc_out = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, input, hidden):
		# input : [batch_size, seq_len]
		# hidden : [n_layers, batch_size, hidden_dim]
        
        input = input.unsqueeze(1)
		# input : [batch_size, 1, seq_len]    # GRU / LSTM expect a 3-D input: (batch, seq_len, input_size) when batch_first=True
        
        embedded = self.embedding(input)
        # embedded : [batch_size, 1, embedding_dim]
        
        output, hidden = self.rnn(embedded, hidden)
        # output : [batch_size, seq_len == 1, hidden_dim]    # because we processed only one step
        # hidden : [n_layers, batch_size, hidden_dim]
        
        prediction = self.fc_out(output.squeeze(1))
        # output : [batch_size, hidden_dim]
        # prediction : [batch_size, vocab_size]
        
        return prediction, hidden

In [16]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        # src : [batch_size, seq_len]
        # trg : [batch_size, seq_len]
        
        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        vocab_size = self.decoder.fc_out.out_features
        
        outputs = torch.zeros(batch_size, trg_len, vocab_size).to(self.device)
        
        hidden = self.encoder(src)
        
        input = trg[:, 0] # <sos> fed into decoder
        
        for t in range(1, trg_len):
            output, hidden = self.decoder(input, hidden)
            
            outputs[:, t, :] = output
            
            teacher_force = random.random() < teacher_forcing_ratio
            
            input = trg[:, t] if teacher_force else output.argmax(1)
        
        return outputs

In [17]:
INPUT_DIM = len(vocab)
OUTPUT_DIM = len(vocab)
ENC_EMBEDDING_DIM, DEC_EMBEDDING_DIM = 256, 256
HIDDEN_DIM = 512
ENC_LAYERS, DEC_LAYERS = 3, 3
DROPOUT = 0.5

enc = Encoder(INPUT_DIM, ENC_EMBEDDING_DIM, HIDDEN_DIM, ENC_LAYERS, DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMBEDDING_DIM, HIDDEN_DIM, DEC_LAYERS, DROPOUT)

model = Seq2Seq(enc, dec, device).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [18]:
EPOCHS = 10

for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    
    for i, (src, trg) in enumerate(data_loader):
        
        src = src.to(device)
        trg = trg.to(device)
        
        optimizer.zero_grad()
        
        output = model(src, trg, 0.5) # [batch_size, trg_len, vocab_size]
        
        output_dim = output.shape[-1]
        output = output[:, 1:].reshape(-1, output_dim)
        
        trg = trg[:, 1:].reshape(-1)
        
        loss = criterion(output, trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    print(f"EPOCH {epoch+1}/{EPOCHS} | LOSS {epoch_loss/len(data_loader):.4f}")

EPOCH 1/10 | LOSS 7.4433
EPOCH 2/10 | LOSS 6.6401
EPOCH 3/10 | LOSS 6.5612
EPOCH 4/10 | LOSS 6.5232
EPOCH 5/10 | LOSS 6.4900
EPOCH 6/10 | LOSS 6.4725
EPOCH 7/10 | LOSS 6.4517
EPOCH 8/10 | LOSS 6.4409
EPOCH 9/10 | LOSS 6.4235
EPOCH 10/10 | LOSS 6.4052


In [19]:
def evaluate_model(model, data_loader, criterion, device):
	model.eval()
	epoch_loss = 0
	
	with torch.no_grad():
		for i, (src, trg) in enumerate(data_loader):
			src = src.to(device)
			trg = trg.to(device)
			
			output = model(src, trg, 0)
			
			output_dim = output.shape[-1]
			output = output[:, 1:].reshape(-1, output_dim)
			
			trg = trg[:, 1:].reshape(-1)
			
			loss = criterion(output, trg)
			
			epoch_loss += loss.item()
	
	return epoch_loss / len(data_loader)

In [20]:
evaluate_model(model, data_loader, criterion, device)

6.323334493135151

In [21]:
print(reuters.fileids()[:5])

['test/14826', 'test/14828', 'test/14829', 'test/14832', 'test/14833']


In [37]:
print(reuters.raw('test/14828'))

CHINA DAILY SAYS VERMIN EAT 7-12 PCT GRAIN STOCKS
  A survey of 19 provinces and seven cities
  showed vermin consume between seven and 12 pct of China's grain
  stocks, the China Daily said.
      It also said that each year 1.575 mln tonnes, or 25 pct, of
  China's fruit output are left to rot, and 2.1 mln tonnes, or up
  to 30 pct, of its vegetables. The paper blamed the waste on
  inadequate storage and bad preservation methods.
      It said the government had launched a national programme to
  reduce waste, calling for improved technology in storage and
  preservation, and greater production of additives. The paper
  gave no further details.
  




In [None]:
def summarize_text(model, text, vocab, max_len=50, device="cpu"):
    model.eval()
    
    # 1. Tokenize & numericalize input
    tokens = tokenize(text.lower())[:max_len-2]
    src = [SOS_IDX] + [vocab[tok] for tok in tokens] + [EOS_IDX]
    src = src + [PAD_IDX] * (max_len - len(src))
    src_tensor = torch.tensor(src).unsqueeze(0).to(device)  # (1, seq_len)
    
    # 2. Encode
    with torch.no_grad():
        hidden = model.encoder(src_tensor)
    
    # 3. Start decoding
    trg_indexes = [SOS_IDX]   # start with <sos>
    
    for _ in range(max_len):
        trg_tensor = torch.tensor([trg_indexes[-1]]).to(device)  # (1,)
        
        with torch.no_grad():
            output, hidden = model.decoder(trg_tensor, hidden)
        
        # pred_token = output.argmax(1).item()
        probs = torch.softmax(output, dim=1)    # convert logits to probabilities 
        topk_probs, topk_idx = torch.topk(probs, k=5)   # pick top 5 word tokens 
        pred_token = topk_idx[0][torch.multinomial(topk_probs, 1)]

        trg_indexes.append(pred_token)
        
        if pred_token == EOS_IDX:
            break
    
    # 4. Convert indices back to words
    summary_tokens = [vocab.get_itos()[i] for i in trg_indexes[1:]]  # skip <sos>
    
    return " ".join([tok for tok in summary_tokens if tok not in ["<sos>", "<eos>", "<pad>"]])


In [35]:
test_text = """
India's economy grew by 7.8% in the last quarter, driven by strong consumer spending and government investments.
Experts say this makes India the fastest-growing major economy in the world, despite global headwinds.
"""

summary = summarize_text(model, test_text, vocab, max_len=40, device=device)
print("Generated Summary:", summary)


Generated Summary: u.s. & > 
   > > & 
   the 
   
   
   said 
   
   
   of of 
   
   , 
   
   
   to to 
   
   
   
   , 
   , 
   the 
   of , , 
  


In [38]:
test_text = """
CHINA DAILY SAYS VERMIN EAT 7-12 PCT GRAIN STOCKS
  A survey of 19 provinces and seven cities
  showed vermin consume between seven and 12 pct of China's grain
  stocks, the China Daily said.
      It also said that each year 1.575 mln tonnes, or 25 pct, of
  China's fruit output are left to rot, and 2.1 mln tonnes, or up
  to 30 pct, of its vegetables. The paper blamed the waste on
  inadequate storage and bad preservation methods.
      It said the government had launched a national programme to
  reduce waste, calling for improved technology in storage and
  preservation, and greater production of additives. The paper
  gave no further details.
"""

summary = summarize_text(model, test_text, vocab, max_len=40, device=device)
print("Generated Summary:", summary)

Generated Summary: & & says & 
   
   to 
   to the the 
   of 
   of 
   the , of 
   of of 
   
   of to to 
   , , , the 
   , 
   
   of 
   the 
  
