In [1]:
# mount drive https://datascience.stackexchange.com/questions/29480/uploading-images-folder-from-my-system-into-google-colab
# login with your google account and type authorization code to mount on your googlbie drive.
from google.colab import drive
drive._mount('/gdrive')
root = '/gdrive/My Drive/CS492I/project'

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


#import

In [2]:
!pip install pytorch_lightning



In [3]:
from torch import nn
from torch.optim import Adagrad
import torch.nn.functional as F
import torch
import pytorch_lightning as pl
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from io import BufferedRWPair
from torch.utils.data import Dataset
import pandas as pd
import json

from pathlib import Path
from collections import Counter
from torch.utils.data import DataLoader

Config.py

In [4]:
from easydict import EasyDict

args = EasyDict()

args.vocab_size = 50000
args.embed_dim = 128
args.hidden_dim = 256
args.batch_size = 8
args.trg_max_len = 50
args.learning_rate = 0.15
args.accum_init = 0.15
args.pad_id = 0
args.seed = 123
args.epochs = 10
args.max_grad_norm = 2.0

# vocab.py

In [17]:
'''
Reference https://github.com/jiminsun/pointer-generator/blob/master/data/vocab.py
'''
pad_token = '<pad>'
unk_token = '<unk>'
start_decode = '<start>'
stop_decode = '<stop>'

class Vocab(object):
  def __init__(self):
    self._word_to_id = {}
    self._id_to_word = []
    self._count = 0

  @classmethod
  def from_file(cls, filename):
    vocab = cls()
    with open(filename, 'r') as f:
      vocab._word_to_id = json.load(f)
    vocab._id_to_word = [w for w, id_ in sorted(vocab._word_to_id, key=vocab._word_to_id.get, reverse=True)]
    vocab._count = len(vocab._id_to_word)
    return vocab

  @classmethod
  def from_counter(cls, counter, vocab_size, min_freq=1, specials=[pad_token, unk_token, start_decode, stop_decode]):
    vocab = cls()
    word_and_freq = sorted(counter.items(), key=lambda tup: tup[0])
    word_and_freq.sort(key=lambda tup: tup[1], reverse=True)

    for w in specials:
      vocab._word_to_id[w] = vocab._count
      vocab._id_to_word.append(w)
      vocab._count += 1

    for word, freq in word_and_freq:
      if freq < min_freq or vocab._count == vocab_size:
        break
      vocab._word_to_id[word] = vocab._count
      vocab._id_to_word.append(word)
      vocab._count += 1
    
    return vocab
  
  def save(self, filename):
    with open(filename, 'w') as f:
      json.dump(self._word_to_id)
  
  def __len__(self):
    return self._count
  
  def unk(self):
    return self._word_to_id.get(unk_token)
  
  def pad(self):
    return self._word_to_id.get(pad_token)
  
  def start(self):
    return self._word_to_id.get(start_decode)
  
  def stop(self):
    return self._word_to_id.get(stop_decode)

  def word2id(self, word):
    unk_id = self._word_to_id.get(word, self.unk())
    if word in self._word_to_id:
      return self._word_to_id[word]
    else:
      return unk_id
  
  def id2word(self, word_id):
    if word_id >= self.__len__():
      raise ValueError(f"Id not found in vocab: {word_id}")
    return self.id_to_word[word_id]
  
  def extend(self, oovs):
    return self._id_to_word + list(oovs)
  
  def tokens2ids(self, tokens):
    return [self.word2id(t) for t in tokens]
  
  def tokens2ids_ext(self, tokens):
    ids = []
    oovs = []
    unk_id = self.unk()
    for t in tokens:
      t_id = self.word2id(t)
      if t_id == unk_id:
        if t not in oovs:
          oovs.append(t)
        ids.append(len(self) + oovs.index(t))
      else:
        ids.append(t_id)
    return ids, oovs

#model.py

## Encoder

In [18]:
"""
B : batch size
E : embedding size
H : encoder hidden state dimension
L : sequence length
T : target sequence length
"""

class Encoder(nn.Module):

    def __init__(self, input_dim=args.embed_dim, hidden_dim=args.hidden_dim):
        """
        Args:
            input_dim: source embedding dimension
        """
        super().__init__()
        self.lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=1, bidirectional=True, batch_first=True)
        self.reduce_h = nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
        self.reduce_c = nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
    
    def forward(self, src, src_lens):
        """
        Args:
            src: source token embeddings    [B x L x E]
            src_lens: source text length    [B]
        Returns:
            enc_hidden: sequence of encoder hidden states                  [B x L x 2H]
            (final_h, final_c): Tuple for decoder state initialization     [B x L x H]
        """

        x = pack_padded_sequence(src, src_lens.cpu(), batch_first=True, enforce_sorted=False)
        output, (h, c) = self.lstm(x) # [B x L x 2H], [2 x B x H], [2 x B x H]
        enc_hidden, _ = pad_packed_sequence(output, batch_first=True)

        # Concatenate bidirectional lstm states
        h = torch.cat((h[0], h[1]), dim=-1)  # [B x 2H]
        c = torch.cat((c[0], c[1]), dim=-1)  # [B x 2H]

        # Project to decoder hidden state size
        final_hidden = torch.relu(self.reduce_h(h))  # [B x H]
        final_cell = torch.relu(self.reduce_c(c))  # [B x H]

        return enc_hidden, (final_hidden, final_cell)

##Attention

In [19]:
class Attention(nn.Module):
    def __init__(self, hidden_dim=args.hidden_dim):
        super().__init__()
        self.v = nn.Linear(hidden_dim * 2, 1, bias=False)                       # v
        self.enc_proj = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias=False)   # W_h
        self.dec_proj = nn.Linear(hidden_dim, hidden_dim * 2, bias=True)        # W_s, b_attn
  

    def forward(self, dec_input, enc_hidden, enc_pad_mask):
        """
        Args:
            dec_input: decoder hidden state             [B x H]
            enc_hidden: encoder hidden states           [B x L x 2H]
            enc_pad_mask: encoder padding masks         [B x L]
        Returns:
            attn_dist: attention dist'n over src tokens [B x L]
        """
        enc_feature = self.enc_proj(enc_hidden)               # [B X L X 2H]
        dec_feature = self.dec_proj(dec_input).unsqueeze(1)   # [B X 1 X 2H]
        temp = torch.tanh(enc_feature + dec_feature)
        scores = self.v(torch.tanh(enc_feature + dec_feature)).squeeze(-1)  # [B X L]
        scores = scores.float().masked_fill_(
            enc_pad_mask,
            float('-inf')
        ).type_as(scores)  # FP16 support: cast to float and back
        
        attn_dist = F.softmax(scores, dim=-1) # [B X L]

        return attn_dist

##AttentionDecoderLayer

In [20]:
class AttentionDecoderLayer(nn.Module):
  def __init__(self, input_dim, hidden_dim, trg_vocab_size):
    super().__init__()
    self.lstm = nn.LSTMCell(input_size=input_dim, hidden_size=hidden_dim)
    self.attention = Attention(hidden_dim)
    self.l1 = nn.Linear(hidden_dim*3, hidden_dim, bias=True)    # V
    self.l2 = nn.Linear(hidden_dim, trg_vocab_size, bias=True)  # V'
  
  def forward(self, dec_input, dec_hidden, dec_cell, enc_hidden, enc_pad_mask):
    """
    Args:
        dec_input: decoder input embedding at timestep t    [B x E]
        prev_h: decoder hidden state from prev timestep     [B x H]
        prev_c: decoder cell state from prev timestep       [B x H]
        enc_hidden: encoder hidden states                   [B x L x 2H]
        enc_pad_mask: encoder masks for attn computation    [B x L]
    Returns:
        vocab_dist: predicted vocab dist'n at timestep t    [B x V]
        attn_dist: attention dist'n at timestep t           [B x L]
        context_vec: context vector at timestep t           [B x 2H]
        hidden: hidden state at timestep t                  [B x H]
        cell: cell state at timestep t                      [B x H]
    """
    h, c = self.lstm(dec_input, (dec_hidden, dec_cell))  # [B X H], [B X H]
    attn_dist = self.attention(h, enc_hidden, enc_pad_mask)  # [B X 1 X L]
    context_vec = torch.bmm(attn_dist.unsqueeze(1), enc_hidden).squeeze(1)  # [B X 2H] <- [B X 1 X 2H] = [B X 1 X L] @ [B X L X 2H]
    output = self.l1(torch.cat([h, context_vec], dim = -1)) # [B X H]
    vocab_dist = F.softmax(self.l2(output), dim=-1)              # [B X V]
    return vocab_dist, attn_dist, context_vec, h, c

## PointGenerator

In [21]:
class PointerGenerator(nn.Module):
    def __init__(self, src_vocab, trg_vocab):
        super().__init__()
        self.src_vocab = src_vocab
        self.trg_vocab = trg_vocab
        embed_dim = args.embed_dim
        self.src_embedding = nn.Embedding(len(src_vocab), embed_dim, padding_idx=src_vocab.pad())
        self.trg_embedding = nn.Embedding(len(trg_vocab), embed_dim, padding_idx=trg_vocab.pad())


        hidden_dim = args.hidden_dim
        self.encoder = Encoder(input_dim=embed_dim, hidden_dim=hidden_dim)
        self.decoder = AttentionDecoderLayer(input_dim=embed_dim, hidden_dim=hidden_dim, trg_vocab_size=len(trg_vocab))

        self.w_h = nn.Linear(hidden_dim * 2, 1, bias=False)
        self.w_s = nn.Linear(hidden_dim, 1, bias=False)
        self.w_x = nn.Linear(embed_dim, 1, bias=True)


    def forward(self, enc_input, enc_input_ext, enc_pad_mask, enc_len, max_oov_len, dec_input=None):
        """
        Predict summary using reference summary as decoder inputs. If dec_input is not provided, then teacher forcing is disabled.
        Args:
            enc_input: source text id sequence                      [B x L]
            enc_input_ext: source text id seq w/ extended vocab     [B x L]
            enc_pad_mask: source text padding mask. [PAD] -> True   [B x L]
            enc_len: source text length                             [B]
            dec_input: target text id sequence                      [B x T]
            max_oov_len: max number of oovs in src                  [1]
        Returns:
            final_dists: predicted dist'n using extended vocab      [B x V_x x T]
            attn_dists: attn dist'n from each t                     [B x L x T]
            coverages: coverage vectors from each t                 [B x L x T]
        """
        batch_size = enc_input.size(0)
        enc_emb = self.src_embedding(enc_input)             # [B X L X E]
        enc_hidden, (h,c) = self.encoder(enc_emb, enc_len)  # [B X L X 2H], [B X L X H], [B X L X H]
        teacher_forcing = False

        if not dec_input is None:
            teacher_forcing = True
            dec_emb = self.trg_embedding(dec_input)             # [B X T X E]
        else:
            dec_prev_emb = [self.trg_embedding(self.trg_vocab().start()) for _ in range(batch_size)]  # [B X E]


        final_dists = []

        for t in range(args.trg_max_len):
            if teacher_forcing:
                input_t = dec_emb[:, t, :]
            else:
                input_t = dec_prev_emb
            vocab_dist, attn_dist, context_vec, h, c = self.decoder(
                dec_input=input_t, # [B x E]
                dec_hidden=h,
                dec_cell=c,
                enc_hidden=enc_hidden,
                enc_pad_mask=enc_pad_mask
            )
            p_gen = torch.sigmoid(self.w_h(context_vec) + self.w_s(h) + self.w_x(input_t))
            weighted_vocab_dist = p_gen * vocab_dist
            weighted_attn_dist = (1.0 - p_gen) * attn_dist
            B = vocab_dist.size(0)
            extended_vocab_dist = torch.cat([weighted_vocab_dist, torch.zeros(B, max_oov_len, device=vocab_dist.device)], dim=-1)
            final_dist = extended_vocab_dist.scatter_add(dim=-1, index=enc_input_ext, src=weighted_attn_dist) # [B X V_]
            final_dists.append(final_dist)
            if (not teacher_forcing):
                highest_prob = torch.argmax(final_dist, dim=1)                              # [B]
                highest_prob[highest_prob >= len(self.trg_vocab)] = self.trg_vocab.unk()
                dec_prev_emb = self.trg_embedding(B)        #[B X E]
        return final_dists

##SummarizationModel

In [22]:
class SummarizationModel(pl.LightningModule):
    def __init__(self, src_vocab, trg_vocab):
        super().__init__()
        self.vocab = trg_vocab
        self.model = PointerGenerator(src_vocab, trg_vocab)
        self.num_step = 0
    
    def training_step(self, batch, batch_idx):
        output = self.model.forward(
            enc_input=batch.enc_input,
            enc_input_ext=batch.enc_input_ext,
            enc_pad_mask=batch.enc_pad_mask,
            enc_len=batch.enc_len,
            dec_input=batch.dec_input,
            max_oov_len=batch.max_oov_len)
        
        dec_target = batch.dec_target
        loss = F.nll_loss(torch.log(output), dec_target, ignore_index=args.pad_id, reduction='mean')
        self.logger.log_metrics({"train_loss": loss}, self.num_step)
        self.num_step += 1
        return loss
    
    def validation_step(self, batch, batch_idx):
        output = self.model.forward(
            enc_input=batch.enc_input,
            enc_input_ext=batch.enc_input_ext,
            enc_pad_mask=batch.enc_pad_mask,
            enc_len=batch.enc_len,
            dec_input=batch.dec_input,
            max_oov_len=batch.max_oov_len)
        
        dec_target = batch.dec_target
        loss = F.nll_loss(
            torch.log(output), dec_target, ignore_index=args.pad_id, reduction='mean')
        self.log('val_loss', loss, on_step=True, on_epoch=False, prog_bar=False, logger=True)
        self.logger.log_metrics({'val_loss': loss}, self.num_step)
        return loss

    def test_step(self, batch, batch_idx):
        output = self.model.forward(
            enc_input=batch.enc_input,
            enc_input_ext=batch.enc_input_ext,
            enc_pad_mask=batch.enc_pad_mask,
            enc_len=batch.enc_len,
            max_oov_len=batch.max_oov_len
        )
        # TODO: FIXHERE
        result = {}
        result['target'] = output
        result['source'] = [' '.join(w) for w in batch.src_text]
        result['real_target'] = [' '.join(w) for w in batch.tgt_text]
        return result
    
    def configure_optimizers(self):
        return Adagrad(self.parameters(), lr=args.learning_rate, initial_accumulator_value=args.accum_init)



#data.py

##CommitDataset

In [23]:
class CommitDataset(Dataset):
    def __init__(self, src_vocab: Vocab, trg_vocab: Vocab, file_path):
        self.src_vocab = src_vocab
        self.trg_vocab = trg_vocab
        self.df = pd.read_pickle(file_path)
    
    def __getitem__(self, index):
        item = EasyDict()
        row = self.df.iloc[index]
        src = json.loads(row["commit_messsage"])
        trg = json.loads(row["diff"])
        trg[0:0] = [self.trg_vocab.start()]
        trg.append(self.trg_vocab.stop())
        item.src_ids = self.src_vocab.tokens2ids(src)
        item.src_ids_ext, item.oovs = self.src_vocab.tokens2ids_ext(src)
        item.trg_ids = self.trg_vocab.tokens2ids(trg)
        
        return item
    
    def __len__(self):
        return len(self.df)

##CommitCollate

In [24]:
def CommitCollate(batchdata):
    size = len(batchdata)
    max_enc_len, max_dec_len,max_oov_len = 0,0,0
    enc_len_list = [len(batchdata[i]['src_ids']) for i in range(size)]
    for i in range(size):
        max_enc_len = max(len(batchdata[i]['src_ids']),max_enc_len)
        max_dec_len = max(len(batchdata[i]['trg_ids']),max_dec_len)
        max_oov_len = max(len(batchdata[i]['oovs']),max_oov_len)
    
    for i in range(len(batchdata)):
        batchdata[i]['src_ids'] += [0]*(max_enc_len-len(batchdata[i]['src_ids']))
        batchdata[i]['src_ids_ext'] += [0]*(max_enc_len-len(batchdata[i]['src_ids_ext']))
        batchdata[i]['trg_ids'] += [0]*(max_dec_len-len(batchdata[i]['trg_ids']))
    batch = EasyDict()
    batch.enc_input = torch.LongTensor([batchdata[i]['src_ids'] for i in range(size)])
    batch.enc_input_ext = torch.LongTensor([batchdata[i]['src_ids_ext'] for i in range(size)])
    batch.enc_pad_mask = (batch.enc_input == 0)
    batch.enc_len = torch.LongTensor(enc_len_list)
    batch.dec_input = torch.LongTensor([batchdata[i]['trg_ids'] for i in range(size)])
    batch.max_oov_len = max_oov_len
    return batch

# train.py

In [25]:
def train(root):
    pl.seed_everything(args.seed)

    src_counter = Counter()
    trg_counter = Counter()
    train_path = Path(root) / 'train.pkl'
    validation_path = Path(root) / 'validation.pkl'
    test_path = Path(root) / 'test.pkl'
    train_df = pd.read_pickle(train_path)

    for msg in train_df["diff"]:
        m = json.loads(msg)
        src_counter.update(m)

    for msg in train_df["commit_messsage"]:
        m = json.loads(msg)
        trg_counter.update(m)

    src_vocab = Vocab.from_counter(
        counter=src_counter, 
        vocab_size=args.vocab_size
    )

    trg_vocab = Vocab.from_counter(
        counter=trg_counter, 
        vocab_size=args.vocab_size
    )

    model = SummarizationModel(src_vocab, trg_vocab)

    trainer = pl.Trainer(
        gpus=torch.cuda.device_count(),
        max_epochs=args.epochs,
        gradient_clip_val=args.max_grad_norm
    )

    train_loader = DataLoader(
        CommitDataset(src_vocab, trg_vocab, train_path),
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=CommitCollate
    )
    val_loader = DataLoader(
        CommitDataset(src_vocab, trg_vocab, validation_path),
        batch_size=args.batch_size,
        collate_fn=CommitCollate,
        shuffle=False
    )

    trainer.fit(model, train_loader, val_loader)

train(root)

Global seed set to 123
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name  | Type             | Params
-------------------------------------------
0 | model | PointerGenerator | 26.3 M
-------------------------------------------
26.3 M    Trainable params
0         Non-trainable params
26.3 M    Total params
105.225   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

RuntimeError: ignored