In [5]:
!pip install torch
!pip install pytorch-crf



Input: Sequence of characters

Output: Sequence of characters including white spaces to indicating the subword tokenization


In [2]:
def get_words(filepath):
    # given the source file path, return all input words
    words = []
    chars = set()
    with open(filepath, encoding='utf-8') as file:
        lines = file.readlines()
        for line in lines:
            words.append(line.strip())
            chars.update(set(line.strip()))
    return words, chars

def tag_original_word(original_word, tokenized_word):
    # Initialize pointers for the original word and tags
    org_idx = tok_idx = 0
    tags = []
    tag = "B"
    # Iterate through the tokenized word
    while tok_idx < len(tokenized_word) and org_idx < len(original_word):
        # Skip characters not in the original word
        if tokenized_word[tok_idx] == original_word[org_idx]:
            tags.append(tag)
            if tag == "B":
                tag = "I"
            org_idx += 1
        elif tokenized_word[tok_idx] == " ":
            tag = "B"
            tok_idx += 1
        else:
            tok_idx += 1

    # Return the final tags as a string
    return "".join(tags)

# preparing labels for the dataset
# Using "B", "I" labeling, B for beginning of the word, I for inside the word
def get_labels(origional_words, tokenized_words):
    labels = []
    for o, t in zip(origional_words, tokenized_words):
        labels.append(tag_original_word(o, t))
    return labels

In [7]:
import string
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from TorchCRF import CRF

In [10]:
def encode(word, char_to_idx, is_input):
    if is_input:
        return [char_to_idx[ch] if ch in char_to_idx else char_to_idx['<unk>'] for ch in word]
    return [char_to_idx['<start>']] + [char_to_idx[ch] for ch in word] + [char_to_idx['<end>']]

def decode(encoded_word, idx_to_char):
    return ''.join(idx_to_char[idx] for idx in encoded_word if idx_to_char[idx] not in ['<pad>', '<unk>', '<start>', '<end>'])

def encode_whole(words, char_to_idx, is_input):
    words_encoded = []
    for word in words:
        encoded_word = encode(word, char_to_idx, is_input)
        words_encoded.append(encoded_word)
    
    return words_encoded

def decode_whole(encoded_words, idx_to_char):   
    decoded_words = []
    for encoded_word in encoded_words:
        decoded_word = decode(encoded_word, idx_to_char)
        decoded_words.append(decoded_word)
    
    return decoded_words

In [11]:
from collections import Counter
# labels = get_labels('dataset/shp.train.tgt')
input_words, input_chars = get_words('dataset/shp.train.src')
output_words, output_chars = get_words('dataset/shp.train.tgt')
print(input_words)
train_labels = get_labels(input_words, output_words)
print(train_labels)
print(Counter(''.join(train_labels)))

val_input_words, val_input_chars = get_words('dataset/shp.dev.src')
val_output_words, val_output_chars = get_words('dataset/shp.dev.tgt')

val_labels = get_labels(val_input_words, val_output_words)

test_words, test_chars = get_words('dataset/shp.test.src')
print(test_words)

# get char to index mapping and index to char mapping
chars = set(input_chars)
chars.update(set(output_chars))
chars.update(set(['<pad>', '<unk>', ' ']))  # include padding, unknown, and space

vocab_size = len(input_chars)

char_to_idx = {ch: idx for idx, ch in enumerate(chars)}
idx_to_char = {idx: ch for ch, idx in char_to_idx.items()}

label_to_idx = {'B': 0, 'I': 1, '<pad>': 2, '<start>': 3, '<end>': 4}
idx_to_label = {idx: label for label, idx in label_to_idx.items()}

encoded_origin = encode_whole(input_words, char_to_idx, is_input=True)
encoded_tokenized = encode_whole(train_labels, label_to_idx, is_input=False)
encoded_val = encode_whole(val_input_words, char_to_idx, is_input=True)
encoded_val_tokenized = encode_whole(val_labels, label_to_idx, is_input=False)

# shape: seq_num, max_len
padded_origin = pad_sequence([torch.tensor(word) for word in encoded_origin], batch_first=True, padding_value=char_to_idx['<pad>'])
padded_tokenized = pad_sequence([torch.tensor(tok) for tok in encoded_tokenized], batch_first=True, padding_value=label_to_idx['<pad>'])

padded_val = pad_sequence([torch.tensor(word) for word in encoded_val], batch_first=True, padding_value=char_to_idx['<pad>'])
padded_val_tokenized = pad_sequence([torch.tensor(tok) for tok in encoded_val_tokenized], batch_first=True, padding_value=label_to_idx['<pad>'])

['yoyoaibata', 'kotsatax', 'bokasai', 'ikainiki', 'kari', 'ibon', 'yoina', 'oi', 'rekena', 'orkai', 'xeki', 'janxbata', 'axea', 'bokanni', 'oinmayamai', 'jaskáribi', 'borosiko', 'maibo', 'netai', 'rayosanki', 'en', 'anibaon', 'beá', 'jakonshoko', 'yantampaketaitian', 'joá', 'jabekona', 'jainoaki', 'jotikoma', 'cepillonin', 'jaskáaki', 'quinua', 'jatíbi', 'jainxonribi', 'jeme', 'Chile', 'ichaya', 'jawekeskatinki', 'presidenteki', 'kini', 'noxa', 'potókinshaman', 'mana', 'nexaa', 'to', 'pia', 'moa', 'koinonin', 'jawékiboribi', 'jawetio', 'kirika', 'pikóribia', 'boi', 'escuelankoniax', 'ninkataxxonki', 'Martín', 'jonibaon', 'awakanki', 'xete', 'ikárin', 'sereya', 'jawen', 'yoiai', 'jawékioma', 'jawetii', 'cuadernonin', 'biá', 'raketi', 'bixon', 'merameax', 'banai', 'joniake', 'joninki', 'teekana', 'nato', 'joniani', 'manaman', 'aninkobo', 'xawi', 'xoike', 'sikita', 'yami', 'joaxxa', 'besoaxki', 'chankákaini', 'xenin', 'akaibo', 'paro', 'jawekeskataxki', 'jaskatai', 'keotaitianki', 'mapo',

In [12]:
class CharDataset(Dataset):
    def __init__(self, words, tokenized):
        self.words = words
        self.tokenized = tokenized

    def __len__(self):
        return len(self.words)

    def __getitem__(self, idx):
        word = self.words[idx]
        tokenized = self.tokenized[idx]
        # print(word, label)
        return word, tokenized

dataset = CharDataset(padded_origin, padded_tokenized)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

val_dataset = CharDataset(padded_val, padded_val_tokenized)
val_dataloader = DataLoader(val_dataset, batch_size=20, shuffle=True)



In [None]:
class TokenizationSeq2Seq(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, output_size, num_layers=1):
        super(TokenizationSeq2Seq, self).__init__()
        
        # Embedding layer to convert characters into embeddings
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # Encoder (e.g., LSTM or GRU)
        self.encoder = nn.LSTM(embed_dim, hidden_size, num_layers, batch_first=True)
        
        # Decoder (e.g., LSTM or GRU)
        self.decoder = nn.LSTM(embed_dim, hidden_size, num_layers, batch_first=True)
        
        # Fully connected layer to predict token labels (B, I, <pad>, <start>, <end>)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, input_seq, target_seq=None, teacher_forcing_ratio=0.5):
        # Step 1: Embed the input sequence (character indices)
        embedded = self.embedding(input_seq)  # Shape: [batch_size, seq_len, embed_dim]
        
        # Step 2: Pass the embedded sequence through the encoder (LSTM)
        encoder_outputs, (hidden, cell) = self.encoder(embedded)  # Shape: [batch_size, seq_len, hidden_size]
        
        # Step 3: Initialize the decoder output
        batch_size = input_seq.size(0)
        target_len = target_seq.size(1) if target_seq is not None else input_seq.size(1)
        outputs = torch.zeros(batch_size, target_len, self.fc.out_features).to(input_seq.device)
        
        # First input to the decoder is the <start> token (or use teacher forcing)
        decoder_input = target_seq[:, 0] if target_seq is not None else torch.ones(batch_size).long().to(input_seq.device)

        for t in range(1, target_len):
            # Step 4: Embed the decoder input (which is either target token or predicted token)
            decoder_embedded = self.embedding(decoder_input).unsqueeze(1)
            
            # Step 5: Pass the embedded decoder input through the decoder (LSTM)
            decoder_output, (hidden, cell) = self.decoder(decoder_embedded, (hidden, cell))
            
            # Step 6: Apply the fully connected layer to get label predictions
            output = self.fc(decoder_output.squeeze(1))  # Shape: [batch_size, output_size]
            outputs[:, t, :] = output
            
            # Step 7: Decide whether to use teacher forcing
            # if target_seq is not None and torch.rand(1).item() < teacher_forcing_ratio:
            #     decoder_input = target_seq[:, t]  # Teacher forcing
            # else:
            decoder_input = output.argmax(1)  # Predicted token (most likely label)

        return outputs


def train(model, dataloader, optimizer, criterion, num_epochs):
    best_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0

        for input_seq, target_seq in dataloader:
            
            optimizer.zero_grad()
            output = model(input_seq, target_seq)
            # Reshape to [batch_size * seq_len, vocab_size]
            output = output.view(-1, output.size(-1))
            target_seq = target_seq.view(-1)
            
            loss = criterion(output, target_seq)
            epoch_loss += loss.item()

            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")

        # Evaluate on validation set
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for input_seq, target_seq in val_dataloader:
                input_seq = input_seq.to(torch.device('cpu'))
                target_seq = target_seq.to(torch.device('cpu'))

                output = model(input_seq, target_seq)
                output = output.view(-1, output.size(-1))
                target_seq = target_seq.view(-1)

                loss = criterion(output, target_seq)
                val_loss += loss.item()

        print(f"Validation Loss: {val_loss:.4f}")

        # Save the best model
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), "best_seq2seq_model.pth")
            print(f"Best model saved with validation loss: {best_loss:.4f}")

# Assign weights inversely proportional to the class frequency
weight_I = 1.0 / 3815   # for shp.train.tgt
weight_B = 1.0 / 1479

# weight_I = 1.0 / 3280  # for tar.train.tgt
# weight_B = 1.0 / 1504

# Normalize weights to ensure that they sum to 1 (optional, but helpful)
total_weight = weight_I + weight_B
weight_I /= total_weight
weight_B /= total_weight
class_weights = torch.tensor([weight_B, weight_I, 0.0, 0.0, 0.0])  # B -> 0, I -> 1

vocab_size = len(char_to_idx)
output_size = len(label_to_idx)
model = TokenizationSeq2Seq(vocab_size, 50, 100, output_size)
criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=char_to_idx['<pad>'])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 15

train(model, dataloader, optimizer, criterion, num_epochs)


In [159]:
def predict(model, input_word, char_to_idx, label_to_idx, max_len, device='cpu'):
    model.eval()  # Set the model to evaluation mode

    # Encode the input word into a sequence of indices
    input_seq = [char_to_idx[char] for char in input_word if char in char_to_idx] # Skip unknown characters
    input_seq = torch.tensor(input_seq, dtype=torch.long).unsqueeze(0).to(device)  # Add batch dimension

    # Initialize the target sequence with <start> token
    target_seq = torch.tensor([label_to_idx['<start>']], dtype=torch.long).unsqueeze(0).to(device)

    with torch.no_grad():
        for _ in range(max_len):
            # Pass input and target sequence to the model
            output = model(input_seq, target_seq)

            # Get the index of the most probable token at the last time step
            next_token = output[:, -1, :].argmax(dim=-1)

            # Append the predicted token to the target sequence
            target_seq = torch.cat([target_seq, next_token.unsqueeze(0)], dim=1)

            # Stop if <end> token is predicted
            if next_token.item() == label_to_idx['<end>']:
                break

    # Decode the target sequence indices back into characters
    decoded_tokens = ''.join([idx_to_label[idx.item()] for idx in target_seq[0][1:]])  # Skip <start> token
    return decoded_tokens.replace('<end>', '').strip()

model.load_state_dict(torch.load("best_seq2seq_model.pth"))

# Predict tokenization for a test word
prediction = []
for word in test_words:
    predicted = predict(model, word, char_to_idx, label_to_idx, max_len=len(word))
    prediction.append(predicted)

def decode_tag(words, tags):
    decoded = []
    for word, tag in zip(words, tags):
        decoded_word = []
        for w, t in zip(word, tag):
            if t == 'B':
                decoded_word.append(' ')
            decoded_word.append(w)
        decoded.append(''.join(decoded_word))
    return decoded

decoded = decode_tag(test_words, prediction)
print(test_words)
print(prediction)
print(decoded)

['jainoax', 'yakata', 'tekíbo', 'raankana', 'soanon', 'oinberibai', 'miaki', 'potani', 'Ecuador', 'jaskataxki', 'jawékiakin', 'alemanra', 'boo', 'kenai', 'bukeya', 'mecha', 'jakonbires', 'tama', 'ensalada', 'mekayaokea', 'jaweranoki', 'matsi', 'iwana', 'beneki', 'payari', 'mayatai', 'isinkonama', 'keskáakin', 'chama', 'teetai', 'parana', 'kakinki', 'benawe', 'ani', 'ni', 'iká', 'bakeribi', 'Santa', 'titan', 'Mananxawe', 'Huayna', 'iketianki', 'jai', 'España', 'xoxoi', 'paranribia', 'jaskáaxonki', 'Colón', 'apendicitisya', 'joto', 'ati', 'Yoáshiko', 'jaweki', 'yoyo', 'koshibirestani', 'chibanban', 'rinko', 'japaonike', 'kenawe', 'moatianronki', 'iikinki', 'kokana', 'tee', 'jakon', 'karíbaparibanon', 'ichaira', 'axébiribiki', 'ikáki', 'jikiamapainon', 'bikanai', 'akasai', 'maton', 'rikan', 'ewa', 'chibanai', 'inonbires', 'noia', 'toota', 'bueno', 'keská', 'maxó', 'chapatabo', 'yoxanshoko', 'kaxonki', 'ibonra', 'shino', 'iso', 'kanana', 'aa', 'Tomki', 'jatíribi', 'galletabo', 'kairin', 'j

In [160]:
# write the decoded words to a file with name pred_shp.test.tgt
with open('pred_shp.test.tgt', 'w') as file:
    for word in decoded:
        file.write(word + '\n')

In [161]:
def run_eval(golds, preds):
   tp = 0
   fp = 0
   fn = 0
   for g, p in zip(golds, preds):
      g_bag = g.strip().split(" ")
      p_bag = p.strip().split(" ")
      tp += sum([1 for i in p_bag if i in g_bag])
      fp += sum([1 for i in p_bag if not i in g_bag])
      fn += sum([1 for i in g_bag if not i in p_bag])
   precision = tp / (tp + fp)
   recall = tp / (tp + fn)
   if precision == 0 or recall == 0:
      f1 = 0
   else:
      f1 = 2 / ((1/precision) + (1/recall))
   return f1

f1 = run_eval(val_output_words, decoded)
print(f1)

0.0044742729306487695
