<a href="https://colab.research.google.com/github/LeoLionel/komma/blob/main/komma.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup & dependencies

In [None]:
# Download the training data to the colab local memory
!gdown --id 1IAxYMM2dIdx3_HcwyQkfBSboBOzDrOKa
!gdown --id 1TtcC9X6NBly4JS26E-1pAt9rkOHoXz0R

In [None]:
data_folder = '/content/'
save_folder = '/content/'

In [None]:
import time
import pickle
import json
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
torch.manual_seed(123)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using ', device)

Using  cuda


# Dataset class

When instantiated, the `SentenceDataV2` class loads the data from two provided files. One contains the initial word vectors (fastText), index-to-word and word-to-index dictionaries. The other, a numpy array with the input sequences and the output labels for the sequences.

In [None]:
from torch.utils.data import Dataset

class SentenceDataV2(Dataset):
    def __init__(self, pickle_path, sentence_data_path, part = 'train', 
                 split = (4_500_000, 160_000)):

        with open(pickle_path, 'rb') as f:
            wordvecs, ix_to_word, word_to_ix = pickle.load(f)       
            # wordvecs: numpy array of shape (num of vectors, dim of vectors)
            # ix_to_word, word_to_ix: dictionaries 

        xy_pairs = np.load(sentence_data_path) 
        # numpy array of shape (n, 2, m)
        # n: number of xy pairs, x: input sequence, y: sequence labels
        # m: maximum sequence length (x & y are padded)

        if part == 'train':
            self.data = xy_pairs[:split[0]] 
        elif part == 'validation':
            self.data = xy_pairs[-split[1]:] 
        else:
            raise ValueError('Choose "train" or "validation" as "part" for the dataset')
        
        self.wordvecs = wordvecs 
        self.ix_to_word = ix_to_word  
        self.word_to_ix = word_to_ix

        self.output_class_weights = get_output_class_weights(xy_pairs)
        self.output_class_ix = {'<pad>':      0,
                                '<eos>':      1,
                                '<comma>':    2,
                                '<no_comma>': 3,
                                }

    def __getitem__(self, index):     
        x = self.data[index][0]
        y = self.data[index][1]
        l = np.count_nonzero(y != self.word_to_ix['<pad>'])
        return x[:l], y[:l], l

    def __len__(self):
        return len(self.data)

def get_output_class_weights(xy_pairs):
    """Calculate the inverse frequency of the three, non-pad output labels, 
    needed for the weighted cross-entropy loss. For a description of xy_pairs 
    check the class SentenceDataV2 above"""
    ys = xy_pairs[:, 1, :]  
    n = [0,0,0,0]
    for i in range(4):
       n[i] = np.count_nonzero(ys == i)
    r = [1/x for x in n[1:]]
    weights = [0] + [x/sum(r) for x in r] 
    return weights 

def num_commas (data_set: Dataset) -> int:
    "Count the number of sentences with a comma in data_set"
    comma_ix = data_set.output_class_ix['<comma>']
    return sum([np.any(y == comma_ix) for _, y, _ in data_set])

In [None]:
class SentenceData(Dataset):
    """Older and smaller data sets were saved to disk in a different format, 
    thus a second Dataset class for compatibility"""
    def __init__(self, load_path, part = 'train', split = (500_000, 30_000)):

        with open(load_path, 'rb') as f:
            xy_pairs, wordvecs, ix_to_word, word_to_ix = pickle.load(f)

        if part == 'train':
            self.data = xy_pairs[:split[0]] # list of pairs of numpy arrays
        elif part == 'validation':
            self.data = xy_pairs[-split[1]:] # list of pairs of numpy arrays
        elif part == 'just comma':
            has_comma = lambda y: np.any(y == word_to_ix['<comma>'])
            jc = [(x,y) for x,y in xy_pairs if has_comma(y)]
            self.data = jc
        else:
            raise ValueError('Choose "train", "validation" or "just comma" as "part" for the dataset')
        
        self.wordvecs = wordvecs.astype('float32') # numpy array
        self.ix_to_word = ix_to_word  # dict
        self.word_to_ix = word_to_ix  # dict

    def __getitem__(self, index):     
        x, y = self.data[index]
        return x, y, len(y)

    def __len__(self):
        return len(self.data)


In [None]:
from torch.utils.data import DataLoader

# The custom collate function to be used with the DataLoader class
def pad_and_collate(triples):
    xs, ys, ls = zip(*triples)
    max_len = max(ls)
    xs = [pad(x, max_len) for x in xs ]
    ys = [pad(y, max_len) for y in ys ]
    ls = list(ls)
    return torch.tensor(xs), torch.tensor(ys), ls

def pad(xs: np.ndarray, n: int) -> np.ndarray:
    "Pad a numpy array with zeros up to length n"
    m = len(xs)
    if m < n:
        return np.append(xs, [0]*(n-m)) 
    else:
        return xs

# Comma Position Model
A simple sequence tagging model using an bi-directional RNN. It takes in a batch of sequences (padded) and its lengths and returns the activatons after the last linear layer. For prediction or loss calculation, softmax still needs to be applied to the outputs.

In [None]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class CommaPositionModel(nn.Module): 
    def __init__(self, wordvecs, shrink_emb_size, rnn_layers, rnn_hidden_size):
  
        super(CommaPositionModel, self).__init__()
        self.out_classes = 4 
        # the four prediction classes are:
        # 0: dummy padding class (ignored in loss calculation)
        # 1: eos after this sequence token
        # 2: comma after this sequence token
        # 3: no comma after this sequence token

        # The weights for the classes depend on the training data, and are used
        # for the cross entropy loss. Initialize accordingly!
        self.class_weights = None  

        self.embedding = nn.Embedding.from_pretrained(torch.tensor(wordvecs),
                                                      freeze = False,
                                                      padding_idx = 0)
        wordvec_size = wordvecs.shape[1]
        if shrink_emb_size:
            self.shrink_emb = nn.Linear(in_features = wordvec_size,
                                        out_features = shrink_emb_size,
                                        bias = True)
        else: 
            self.shrink_emb = None
        
        rnn_in_size = shrink_emb_size if shrink_emb_size else wordvec_size
        self.rnn = nn.GRU(num_layers = rnn_layers,
                          input_size = rnn_in_size,
                          hidden_size = rnn_hidden_size,
                          bidirectional = True,
                          batch_first = True)
        
        self.out_layer = nn.Linear(in_features = 2 * rnn_hidden_size,
                                   out_features = self.out_classes,
                                   bias = True) 
        
        self.loss_fn = nn.CrossEntropyLoss(weight = self.class_weights,
                                           ignore_index = 0)
                                                    
    def forward(self, xs, ls):
        # xs is a batch of sequences, ls the lengths of the unpadded sequences
        xs = self.embedding(xs)
        if self.shrink_emb:
            xs = self.shrink_emb(xs)
        xs = pack_padded_sequence(xs, ls, batch_first = True,
                                  enforce_sorted = False)
            
        # Calculate the output of the RNN. If h_0 is not provided,
        # it defaults to zero. (Or h_0 and c_0 in case of an LSTM.)
        xs, _ = self.rnn(xs)

        # Unpack and drop the sequence lengths
        xs, _ = pad_packed_sequence(xs, batch_first = True)

        # Output layer
        xs = self.out_layer(xs)
        
        return xs

In [None]:
def batch_loss (xs, ys, ls, model):
    """Calculate the cross entropy loss for a batch of sequences `xs` 
    with output labels `ys` and lengths `ls`"""
    pred = model(xs, ls)
    pred = pred.contiguous().view(-1, model.out_classes)
    ys = ys.contiguous().view(-1)
    return model.loss_fn(pred, ys)

# Validation loss, test accuracy and training loop

Some helper functions to calculate the loss over the validation set and the accuracy of the model during training time.

In [None]:
PRED_BATCH_SIZE = 1_000

def get_validation_loss(validation_set, model):
    "Calculate the loss over the validation set."
    loader = DataLoader(validation_set, batch_size = PRED_BATCH_SIZE,
                        shuffle = True, drop_last = True, 
                        collate_fn = pad_and_collate)       
    with torch.no_grad():
        valid_loss = 0
        for batch_ix, (xs, ys, ls) in enumerate(loader, start = 1):
            xs = xs.to(device)
            ys = ys.to(device)
            valid_loss += batch_loss(xs, ys, ls, model).item()
    return valid_loss / batch_ix


def get_accuracy(data_set, model):
    "Calculate the accuracy of the model on the training or validation set."
    loader = DataLoader(data_set, batch_size = PRED_BATCH_SIZE, 
                        shuffle = False, collate_fn = pad_and_collate) 
    pad_ix = data_set.output_class_ix['<pad>']  
    with torch.no_grad():
        n = 0 # num of sequences with each word classified correct 
        k = 0 # num of sequences with comma in data_set
        m = 0 # num of sequences with comma where each word is classified correct 
        for xs, ys, ls in loader:
            xs = xs.to(device)
            ys = ys.to(device)

            pred = model(xs, ls)
            pred = nn.functional.softmax(pred, dim = 2)
            pred = torch.argmax(pred, dim = 2)
            mask = (ys != pad_ix) 
            pred = pred * mask

            predicted_right = torch.all(ys == pred, dim = 1)
            n += torch.sum(predicted_right).item()

            comma_ix = data_set.output_class_ix['<comma>']
            with_comma = torch.any(ys == comma_ix, dim = 1)
            k += torch.sum(with_comma).item()

            m += torch.sum(predicted_right[with_comma]).item()

        # Fraction of all sequences where each word is classified correct  
        p1 = n / len(data_set) 
        # Fraction of sequences with comma where each word is classified correct 
        p2 = m / k
        # Fraction of sequences without comma where each word is classified correct
        p3 = (n - m ) / (len(data_set) - k + 1e-8)
    return p1 * 100, p2 * 100, p3 * 100

In [None]:
def print_run_summary(model, optimizer, dataset_name, tr_set, val_set, 
                      out_file=None):
    "Print a summary of the model, dataset and relevant global variables"
    if out_file:
      def mprint(*s):
        print(*s, file=out_file)
    else:
      def mprint(*s):
        print(*s)
    t = time.gmtime(time.time() + 2*60*60)
    date_str = time.strftime("%Y-%m-%d %H:%M:%S", t)
    mprint(date_str,'\n')
    mprint('Dataset name:', dataset_name)
    ls = len(tr_set), len(val_set)
    mprint('Training / validation set size: {} / {}'.format(*ls))
    # p = num_commas(training_set) / len(training_set) * 100
    # mprint('Sentences with comma: {0: .2f}%'.format(p))
    mprint('word vector size:', tr_set.wordvecs.shape[1])
    mprint() 

    mprint('SHRINK_EMB_SIZE =', SHRINK_EMB_SIZE)
    mprint('HIDDEN_SIZE =', HIDDEN_SIZE)
    mprint('RNN_LAYERS =', RNN_LAYERS)
    
    mprint(model)
    weights = [round(x, 2) for x in model.class_weights.numpy()]
    mprint('class weights:', weights)
    mprint()

    mprint(optimizer)
    mprint('BATCH_SIZE =', BATCH_SIZE)
    mprint()

In [None]:
def train(model, optimizer, training_set, validation_set, epochs: int):  
    """ Train the model for num of `epochs`. Loss, accuracy and model weights 
    are saved to global variables or to disk. """
    train_loader = DataLoader(training_set, batch_size = BATCH_SIZE, 
                              shuffle = True, collate_fn = pad_and_collate)  
    print('start training')
    tick = time.time()
    loss_sum = 0
    for ep_ix in range(1, epochs + 1, 1):
        for batch_ix, (xs, ys, ls) in enumerate(train_loader, start = 1):
            xs = xs.to(device)
            ys = ys.to(device)
            optimizer.zero_grad()    
            loss = batch_loss(xs, ys, ls, model)
            loss_sum += loss.detach().item()
            # Backpropagation  
            loss.backward()
            optimizer.step()

            if batch_ix % 100 == 0:
                tr_loss = loss_sum / 100
                loss_sum = 0
                val_loss = get_validation_loss(validation_set, model)
                TRAIN_LOSS.append(tr_loss)
                VALID_LOSS.append(val_loss)
                loss_info = tr_loss, val_loss, batch_ix*BATCH_SIZE, len(training_set)
                print("loss: {0:.3f} {1:.3f} [{2}/{3}]".format(*loss_info))

        tock = time.time()
        print('Epoch {0} finished after {1:.1f}s'.format(ep_ix, tock-tick))
        print('Test accuracies (all / with comma / without comma):')  
        ps_t = get_accuracy(training_set, model)
        print( 'train: {0:.2f}% / {1:.2f}% / {2:.2f}%'.format(*ps_t) )   
        ps_v = get_accuracy(validation_set, model)  
        print( 'valid: {0:.2f}% / {1:.2f}% / {2:.2f}%'.format(*ps_v) ) 
        
        if ep_ix == 1:
            torch.save(model.state_dict(), SAVE_WEIGHTS_PATH )
        if ep_ix > 1 and ps_v[1] > VALID_ACCUR[-1][1]:
            torch.save(model.state_dict(), SAVE_WEIGHTS_PATH )
        TRAIN_ACCUR.append(ps_t)
        VALID_ACCUR.append(ps_v)

# Load data, initialize model & train

In [None]:
dataset_name = 'dataset-v2-all-ml35-unk2-4660k'
ds_path1 = data_folder + dataset_name + '-wv_dicts.pickle' 
ds_path2 = data_folder + dataset_name + '-xy_pairs.npy' 

ds_split = (4_500, 2_000) # Choose (4_500_000, 160_000) to use all data!

training_set   = SentenceDataV2(ds_path1, ds_path2, part = 'train', split = ds_split)
validation_set = SentenceDataV2(ds_path1, ds_path2, part = 'validation', split = ds_split)

In [None]:
SHRINK_EMB_SIZE = None
RNN_LAYERS = 3
HIDDEN_SIZE = 1200

LR = 5e-4
BATCH_SIZE = 180

model = CommaPositionModel(wordvecs = training_set.wordvecs, 
                           shrink_emb_size = SHRINK_EMB_SIZE,
                           rnn_layers = RNN_LAYERS,
                           rnn_hidden_size = HIDDEN_SIZE                         
                           ).to(device)

model.class_weights = torch.tensor(training_set.output_class_weights) 

optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=0)

In [None]:
print_run_summary(model, optimizer, dataset_name, training_set, validation_set, out_file=None)

In [None]:
run_name = 'test'

epochs = 3

with open(save_folder + run_name + '-info.txt', 'w') as f:
    print_run_summary(model, optimizer, dataset_name, training_set, 
                      validation_set, out_file=f)

TRAIN_LOSS = []
VALID_LOSS = []
TRAIN_ACCUR = []
VALID_ACCUR = []

SAVE_WEIGHTS_PATH = save_folder + run_name + '-model.weights'

train(model, optimizer, training_set, validation_set, epochs)

with open(save_folder + run_name + '-loss_accuracies.txt', 'w') as f:
    out = TRAIN_LOSS, VALID_LOSS, TRAIN_ACCUR, VALID_ACCUR 
    json.dump(out, f, indent=2)

# Prediction and Evaluation code


In [None]:
# download pretrained weights to colab local storage
!gdown --id 1-04Nl-w3EjJo_tILb4N9sqJR5Eh2fRVQ

weights_path = save_folder + 'run-21-all-ml35-unk2-4460k-model.weights'
saved_weights = torch.load(weights_path, map_location=device)
model.load_state_dict(saved_weights)

In [None]:
def play(sentence: str, model: CommaPositionModel, word_to_ix) -> str:
    "Let the model predict the commas for a given input sentence"
    word_to_ix.default_factory = lambda: word_to_ix['<unk>']
    words = sentence.split(' ')
    seq = [word_to_ix[w] for w in words]
    x = torch.tensor(seq).reshape(1,-1).to(device)
    l = torch.tensor([len(seq)])
    with torch.no_grad():
        pred = model(x, l)
        pred = nn.functional.softmax(pred, dim=2)
        pred = torch.argmax(pred, dim=2)
        pred = list(pred.reshape(-1).cpu().numpy())

    # The indices of the output classes are: 3: no comma, 2: comma, 1: eos, 0: pad
    # See documentation in SentenceDataV2, or CommaPositionModel class
    d = {3: '', 2: ' ,', 1: ' <eos>'}
    pseq = [d[ix] for ix in pred]   
    out_str = ' '.join([a+b for a,b in zip(words,pseq)])
    return out_str

In [None]:
play('Miriam schaut abends erst die Nachrichten weil sie sich informieren will später ihre Lieblingsserie', model, training_set.word_to_ix)

'Miriam schaut abends erst die Nachrichten , weil sie sich informieren will , später ihre Lieblingsserie <eos>'

In [None]:
# try this:
s = 'Das System erkennt die Sprache schnell und automatisch konvertiert die\
 Wörter in die gewünschte Sprache und versucht die jeweiligen sprachlichen\
 Nuancen und Ausdrücke hinzuzufügen'
play(s, model, training_set.word_to_ix)

In [None]:
def get_detail_error(data_set, model):
    """Calulate the error of the model separately by sentence length 
    and comma count of the sentence"""
    loader = DataLoader(data_set, 
                        batch_size = PRED_BATCH_SIZE, 
                        shuffle = False, 
                        collate_fn = pad_and_collate)
    
    comma_ix = data_set.output_class_ix['<comma>']
    pad_ix   = data_set.output_class_ix['<pad>']
    max_len = 35 + 1 # maximum sequence length

    # n[i,j]: number of sequences with lenght i and j commas in the data set
    n = torch.zeros(max_len, max_len).to(device) 

    # m[i,j]: number of sequences with lenght i and j commas predicted wrong
    m = torch.zeros(max_len, max_len).to(device)

    with torch.no_grad(): 
        for xs, ys, ls in loader:
            xs = xs.to(device)
            ys = ys.to(device)
            
            pred = model(xs, ls)
            pred = nn.functional.softmax(pred, dim = 2)
            pred = torch.argmax(pred, dim = 2)
            mask = (ys != pad_ix)
            pred = pred * mask

            predicted_wrong = torch.any(ys != pred, dim = 1)

            ls = torch.tensor(ls).to(device)
            cs = torch.count_nonzero(ys == comma_ix, dim = 1)
          
            # for i in range(len(pred)):
            #     n[ls[i], cs[i]] += 1
            #     if predicted_wrong[i]  : 
            #         m[ls[i], cs[i]] += 1
            #
            # The rest of the function does the same as the pevious four lines,
            # just in a vectorized way for speed
            # https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507/4
            # batch outer product https://discuss.pytorch.org/t/batch-outer-product/4025

            r = torch.zeros(PRED_BATCH_SIZE, max_len).to(device).scatter_(1,ls.unsqueeze(1),1)
            s = torch.zeros(PRED_BATCH_SIZE, max_len).to(device).scatter_(1,cs.unsqueeze(1),1)    
            rxs = torch.bmm(r.unsqueeze(2), s.unsqueeze(1))
            n += torch.sum(rxs, dim = 0)
            
            ls = ls[predicted_wrong]
            cs = cs[predicted_wrong]
            r = torch.zeros(len(ls), max_len).to(device).scatter_(1,ls.unsqueeze(1),1)
            s = torch.zeros(len(cs), max_len).to(device).scatter_(1,cs.unsqueeze(1),1)
            rxs = torch.bmm(r.unsqueeze(2), s.unsqueeze(1))
            m += torch.sum(rxs, dim = 0)

    p = m / n * 100
    return p, n 

In [None]:
ds_split = (4_500, 160_000) # Ensure to use all of the validation set now
validation_set = SentenceDataV2(ds_path1, ds_path2, part = 'validation', split = ds_split)

In [None]:
p, n = get_detail_error(validation_set, model)
# p[i, j] gives the error rate on sentences of length i with j commas
# n[i, j] gives the amount of sentences of length i with j commas in the validation set
p[1:, :7]

tensor([[     nan,      nan,      nan,      nan,      nan,      nan,      nan],
        [  0.0000,  50.0000,      nan,      nan,      nan,      nan,      nan],
        [  0.6468,  47.6190,  42.8571,      nan,      nan,      nan,      nan],
        [  1.0688,  38.6667,  75.0000,   0.0000,      nan,      nan,      nan],
        [  1.0198,  37.5000,  57.1429, 100.0000,      nan,      nan,      nan],
        [  1.3619,  25.1429,  65.0000,  50.0000,      nan,      nan,      nan],
        [  1.7312,  21.8009,  57.1429,   0.0000,   0.0000,   0.0000,      nan],
        [  2.3057,  18.9838,  42.2414,  46.1538,      nan, 100.0000,      nan],
        [  2.7925,  16.0595,  38.0711,  50.0000,   0.0000, 100.0000,      nan],
        [  3.4471,  13.9373,  32.6154,  44.0000,  33.3333,   0.0000,   0.0000],
        [  4.3956,  12.1481,  32.2581,  46.5116,  50.0000,      nan,      nan],
        [  5.4476,  11.4348,  26.6212,  50.0000,  71.4286,   0.0000,   0.0000],
        [  5.7821,  11.7992,  26.5278,  

In [None]:
from sklearn.metrics import precision_recall_fscore_support

def get_precision_recall(data_set, model):
    "Calculate precision, recall and f1 for the model on a data set."
    loader = DataLoader(data_set, batch_size = PRED_BATCH_SIZE, 
                        shuffle = False, drop_last = True,
                        collate_fn = pad_and_collate)
    num_classes = 3
    pad_ix = data_set.output_class_ix['<pad>'] 

    precision = np.zeros(num_classes)
    recall = np.zeros(num_classes)
    f1 = np.zeros(num_classes)

    with torch.no_grad(): 
        for i, (xs, ys, ls) in enumerate(loader, start = 1):
            xs = xs.to(device)
            ys = ys.to(device)
            
            pred = model(xs, ls)
            pred = nn.functional.softmax(pred, dim = 2)
            pred = torch.argmax(pred, dim = 2)

            # get rid of padding tokens, this reshapes the tensors to 1D
            mask = (ys != 0)
            pred = pred[mask].cpu().numpy()
            ys = ys[mask].cpu().numpy()

            pr, re, f, _ = precision_recall_fscore_support(ys, pred)
            precision += pr
            recall += re
            f1 += f

    return precision / i, recall / i, f1 / i

In [None]:
# pecision, recall and f1 for the three classes: 'eos', 'comma', 'word'
get_precision_recall(validation_set, model)

(array([1.        , 0.89568644, 0.99417314]),
 array([1.        , 0.88764191, 0.99463348]),
 array([1.        , 0.89155284, 0.99440299]))