In [5]:
import codecs
from datetime import datetime
import json
from pathlib import Path
import os
import glob
import numpy as np
import torch 
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, DistributedSampler, random_split
from torch.nn.utils import clip_grad_norm_
!pip install torch-summary
from torchsummary import summary
import torch.multiprocessing as mp
from transformers import AutoTokenizer
from tqdm import tqdm
from multiprocessing import Pool, cpu_count

from data_loading_utils import read_lines_from_file_as_data_chunks
import time  # Import the time module
import threading
from concurrent.futures import ThreadPoolExecutor



In [6]:
class WPDataset(Dataset):
    """
    A class loading clean text from txt files to be used as an input 
    to PyTorch DataLoader.

    Datapoints are sequences of words (tokenized) + label (next token). If the 
    words have not been seen before (i.e, they are not found in the
    'word_to_id' dict), they will be mapped to the unknown word '<UNK>'.
    chunk_size: how much we read from the file at the time - we could play around with it. 
    """
    def __init__(self, filenames, tokenizer, samples_length=5, chunk_size=1000000, artificial_padding=True):
        self.sequences = [] # X
        self.labels = [] # Y 
        self.tokenizer = tokenizer
        self.samples_length = samples_length
        self.artificial_padding = artificial_padding
        self.pad_token_id = tokenizer.pad_token_id  # Get the PAD token ID = 0 
        
        with ThreadPoolExecutor(max_workers=4) as executor:
            futures = [executor.submit(self.read_file, filename, chunk_size) for filename in filenames]
            for future in futures:
                future.result()  # Ensure all files are processed
        # Convert lists to numpy arrays for faster access and better memory management
        self.sequences = np.array(self.sequences)
        self.labels = np.array(self.labels)

    def read_file(self, filename, chunk_size):
        print("Read in ", filename)
        start_time = time.time()
        try:
            read_lines_from_file_as_data_chunks(filename, chunk_size, self.process_lines)
        except FileNotFoundError:
            print(f"File not found: {filename}")
        except Exception as e:
            print(f"An error occurred: {e}")
        end_time = time.time()  # End the timer
        print(f"Time taken to read {filename}: {end_time - start_time:.2f} seconds")

    def process_lines(self, data, eof, file_name):
        """
        eof: end of file 
        Callback function to process lines read from file.
        """
        if not eof:
            text = data.strip()  # Remove leading/trailing whitespace
            # split sentence into sub-sentences so that it can be passed to tokenizer, which has a max capacity of 512 
            line_chunks = self.split_into_chunks(text) 
            for chunk in line_chunks:
                line_tokens = self.tokenizer.tokenize(chunk) # data is already lower case 
                line_tokens_ids = self.tokenizer.convert_tokens_to_ids(line_tokens)
                self.create_sequences(line_tokens_ids)
        else:
            print(f"Finished reading file: {file_name}")

    def split_into_chunks(self, line, max_length=512):
        """Splits a long line into chunks of max_length tokens."""
        return [line[i:i + max_length] for i in range(0, len(line), max_length)]

    def create_sequences(self, token_ids):
        """
        Create sequences and labels from tokenized text.
        """
        n = self.samples_length
        if self.artificial_padding:
            k = 0 
            while k < len(token_ids) - n:
                for i in range(1, n + 1):
                    seq = token_ids[k:i+k] + [self.pad_token_id] * (n - i)
                    label = token_ids[i + k]
                    self.sequences.append(seq)
                    self.labels.append(label)
                k += n
            remaining_tokens = len(token_ids) - k
            if remaining_tokens > 1:
                for i in range(1, remaining_tokens):
                    seq = token_ids[k:i+k] + [self.pad_token_id] * (n - i)
                    label = token_ids[i + k]
                    self.sequences.append(seq)
                    self.labels.append(label)     
        else: 
            # Ensure all sequences are of length samples_length
            for i in range(self.samples_length, len(token_ids)): # sliding window 
                seq = token_ids[i-self.samples_length:i]
                label = token_ids[i]
                self.sequences.append(seq)
                self.labels.append(label)

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return torch.tensor(self.sequences[idx]), torch.tensor(self.labels[idx])

In [7]:
class RNN(nn.Module):
    """
    There are two possible ways to write this class; either it tries to predict 
    a whole word that consists of several tokens or it only predicts the next token
    after a fixed (or variable) amount of input tokens; 
    Another choice is whether to use a hidden state or not as an input to the forward pass
    Or do a encoder - decoder structure?

    I read somewhere that it is good to ... 
    """
    def __init__(self, embedding_size, hidden_size, no_of_output_symbols, device, num_layers, use_GRU, dropout):
        super().__init__()
        self.no_of_output_symbols = no_of_output_symbols
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.use_GRU = use_GRU
        self.dropout = dropout

        # initialize layers
        self.embedding = nn.Embedding(no_of_output_symbols, embedding_size)
        if use_GRU == True:
            self.rnn = nn.GRU(embedding_size, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout)
        else:
            self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout)
        self.output = nn.Linear( hidden_size, no_of_output_symbols )
        self.device = device
        self.to(device)

    def forward(self, x, hidden):
        """
        x is a list of lists of size (batch_size, max_seq_length)
        Each inner list contains word IDs and represents one datapoint (n words).
       
        Returns:
        the output from the RNN: logits for the predicted next word, hidden state
        """
        x_emb = self.embedding(x) # x_emb shape: (batch_size, max_seq_length, emb_dim)
        if self.use_GRU:
            output, hidden = self.rnn(x_emb, hidden) # output shape: (batch_size, max_seq_length, hidden)
        else:
            output, (h_n, c_n) = self.rnn(x_emb, hidden)  # LSTM expects a tuple (hidden state, cell state)
            hidden = (h_n, c_n)
            
        return self.output(output[:, -1, :]), hidden # logit shape: (batch_size, 1, vocab_size)
    
 

In [8]:
   
def pad_sequence(batch, pad_symbol): #=tokenizer.pad_token):
    """
    Applies padding if the number of tokens in sequences differs within one batch.
    Only applies padding to the sequence, not the label.
    """
    seq, label = zip(*batch)
    max_seq_len = max(map(len, seq))
    max_label_len = max(map(len, label))
    padded_seq = [[b[i] if i < len(b) else pad_symbol for i in range(max_seq_len)] for b in seq]
    padded_label = [[l[i] if i < len(l) else pad_symbol for i in range(max_label_len)] for l in label]
    return padded_seq, padded_label

In [9]:
def evaluate(dataloader, rnn_model, device):
    correct, incorrect = 0,0
    hidden = None
    for seq, label in dataloader:
        sequence, label = seq.to(device), label.to(device)
        prediction, _ = rnn_model(sequence, hidden)
        _, predicted_tensor = prediction.topk(1)

        
        assert (label.shape == predicted_tensor.squeeze(1).shape)
        comparison = torch.eq(label, predicted_tensor.squeeze(1))
        count_same_entries = torch.sum(comparison).item()
        count_same_entries = (label == predicted_tensor.squeeze(1)).sum().item()
        
        correct += count_same_entries
        incorrect += label.shape[0] - count_same_entries

    print( "Correctly predicted words    : ", correct )
    print( "Incorrectly predicted words  : ", incorrect )
    print( "Accuracy                     : ", correct / (correct + incorrect))
    
    return correct / (correct + incorrect)

In [10]:


# ================ Hyper-parameters ================ #

batch_size = 64
embedding_size = 50 #16
hidden_size = 64 #25
num_layers = 2
seq_length = 5      # number of tokens used as a datapoint
learning_rate = 0.001
epochs = 2
num_processes = 4
use_GRU = True
dropout = 0.5


# ====================== Data ===================== #

# select files with text for training (will also be used for test and validation dataset)
txt_files = ["articles.txt", "news_summarization.txt"]
filenames = ['data/clean_data/news_summarization.txt', 'data/clean_data/twitter.txt'] #'data/clean_data/mobile_text.txt',  'data/clean_data/articles.txt' 'data/clean_data/news_summarization.txt', , 'data/clean_data/twitter.txt' 

# choose tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')


# ==================== Training ==================== #
# Reproducibility
np.random.seed(5719)

device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
print( "Running on", device )




# set up dataloaders
dataset = WPDataset(filenames=filenames, tokenizer=tokenizer, samples_length=seq_length)


generator = torch.Generator().manual_seed(42)
training_data, validation_data, test_data = random_split(dataset, [0.8, 0.05, 0.15], generator=generator)

print( "There are", len(training_data), " training datapoints and ", tokenizer.vocab_size, "unique tokens in the dataset" ) 
val_dataloader = DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4, shuffle=True)
train_dataloader = DataLoader(training_data, batch_size=batch_size, drop_last=True, num_workers=4, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, drop_last=True, num_workers=4, shuffle=True)


rnn_model = RNN(embedding_size, hidden_size, no_of_output_symbols=tokenizer.vocab_size, device=device, num_layers=num_layers, use_GRU=use_GRU, dropout=dropout).to(device)
optimizer = optim.Adam(rnn_model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
summary(rnn_model)

# Check if a checkpoint exists
checkpoint_path = 'model_checkpoint_twitter_first.pth'

if os.path.exists(checkpoint_path):
    # Load the saved model state
    rnn_model.load_state_dict(torch.load(checkpoint_path))

#checkpoint_path = 'model_checkpoint.pth'

rnn_model.train()

prev_accuracy = 0

for epoch in range(epochs):
    total_loss = 0
    hidden = None
    with tqdm(train_dataloader, desc="Epoch {}".format(epoch + 1)) as tepoch:
        for sequence, label in tepoch:
            sequence, label = sequence.to(device), label.to(device)
            optimizer.zero_grad()
            logits, hidden = rnn_model(sequence, hidden)
            if use_GRU:
                hidden = hidden.detach()  # Detach hidden states to avoid backprop through the entire sequence
            else: 
                hidden = tuple([h.detach() for h in hidden])    
            loss = criterion(logits.squeeze(), label)
            loss.backward()
            
            clip_grad_norm_(rnn_model.parameters(), 5)
            optimizer.step()
            total_loss += loss.item()
    print("Epoch", epoch, "loss:", total_loss )
    total_loss = 0
    torch.save(rnn_model, checkpoint_path)
    print("Evaluating on the validation data...")
    accuracy = evaluate(val_dataloader, rnn_model, device)
    if accuracy > prev_accuracy:
        prev_accuracy = accuracy
        torch.save(rnn_model, 'best_model_so_far_twitter_first.pth')


# Save the model state
torch.save(rnn_model.state_dict(), checkpoint_path)
evaluate(val_dataloader, rnn_model, device)

# ==================== Evaluation ==================== #

rnn_model.eval()
print( "Evaluating on the test data..." )

print( "Number of test sentences: ", len(test_dataloader) )
print()

test_accuracy = evaluate(test_dataloader, rnn_model, device)

# ==================== Save the model  ==================== #

dt = str(datetime.now()).replace(' ','_').replace(':','_').replace('.','_')
newdir = 'model_' + dt
os.mkdir( newdir )
#torch.save( rnn_model.state_dict(), os.path.join(newdir, 'rnn.model') )
torch.save( rnn_model, os.path.join(newdir, 'rnn.model') )

settings = {
    'epochs': epochs,
    'learning_rate': learning_rate,
    'batch_size': batch_size,
    'hidden_size': hidden_size,
    'embedding_size': embedding_size,
    'num_layers': num_layers,
    'dropout': dropout,
    'use_GRU': use_GRU,
    'test_accuracy': test_accuracy
}
with open( os.path.join(newdir, 'settings.json'), 'w' ) as f:
    json.dump(settings, f)

s = f"accuracy: {test_accuracy}, epochs: {epochs}, num_layers: {num_layers}, use_GRU: {use_GRU}, dropout: {dropout}, embedding_size: {embedding_size}, hidden_size: {hidden_size}, batch_size: {batch_size}, learning_rate: {learning_rate}"
with open("experiments.txt", 'a') as f:
    f.write(s + '\n')  


Running on cuda
Read in  data/clean_data/news_summarization.txt
Read in  data/clean_data/twitter.txt
Finished reading file: data/clean_data/news_summarization.txt
Time taken to read data/clean_data/news_summarization.txt: 231.04 seconds
Finished reading file: data/clean_data/twitter.txt
Time taken to read data/clean_data/twitter.txt: 481.05 seconds
There are 139495661  training datapoints and  28996 unique tokens in the dataset
Layer (type:depth-idx)                   Param #
├─Embedding: 1-1                         1,449,800
├─GRU: 1-2                               47,232
├─Linear: 1-3                            1,884,740
Total params: 3,381,772
Trainable params: 3,381,772
Non-trainable params: 0


Epoch 1: 100%|██████████| 2179619/2179619 [1:36:45<00:00, 375.42it/s]  

Epoch 0 loss: 13053926.910126686
Evaluating on the validation data...





Correctly predicted words    :  1294340
Incorrectly predicted words  :  7424124
Accuracy                     :  0.14845963692687153


Epoch 2: 100%|██████████| 2179619/2179619 [1:35:04<00:00, 382.07it/s]  

Epoch 1 loss: 12997308.090759754
Evaluating on the validation data...





Correctly predicted words    :  1272151
Incorrectly predicted words  :  7446313
Accuracy                     :  0.14591457853126422
Correctly predicted words    :  1272061
Incorrectly predicted words  :  7446403
Accuracy                     :  0.1459042556119977
Evaluating on the test data...
Number of test sentences:  408678

Correctly predicted words    :  4067826
Incorrectly predicted words  :  22087566
Accuracy                     :  0.15552533106749078


In [11]:
torch.save(validation_data.indices, 'val_indices_twitter_first.pt')
torch.save(training_data.indices, 'train_indices_twitter_first.pt')
torch.save(test_data.indices, 'test_indices_twitter_first.pt')

train_indices = torch.load('train_indices_twitter_first.pt')
val_indices = torch.load('val_indices_twitter_first.pt')
test_indices = torch.load('test_indices_twitter_first.pt')

train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)
test_dataset = torch.utils.data.Subset(dataset, test_indices)


In [None]:
class RNNpredictor:
    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device

    def filter_vocab_by_prefix(self, vocab, prefix):
        if prefix == None:
            return vocab
        return {token: idx for token, idx in vocab.items() if token.startswith(prefix)}

    def mask_logits_by_vocab(self, logits, filtered_vocab):
        mask = torch.full_like(logits, float('-inf'))
        for token, idx in filtered_vocab.items():
            mask[idx] = logits[idx]
        return mask
    
    def remove_last_word(self, input_string):
        last_space_index = input_string.rfind(' ')
        if last_space_index == -1:
            return input_string, None
        else:
            return input_string[:last_space_index], input_string[last_space_index+1:]

    def predict_next_word(self, prompt, number_of_suggestions, max_subwords=5):
        self.model.eval()

        vocab = self.tokenizer.get_vocab()
        hidden = None

        # remove last word from prompt (word that is supposed to be predicted)
        prompt, prefix = self.remove_last_word(prompt)
        tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
        input_ids = torch.tensor(tokens).unsqueeze(0).to(self.device)  # Add batch dimension
        
        next_words = []

        for i in range(number_of_suggestions):
            generated_subwords = []
            for _ in range(max_subwords):
                with torch.no_grad():
                    outputs, hidden = self.model(input_ids, hidden)
                    next_token_logits = outputs.squeeze()  # Get the logits for the last token
                    next_token_id = torch.argmax(next_token_logits, dim=-1).item()  # Get the ID of the highest probability token
                    next_token_ids = next_token_logits.topk(5).indices.tolist()

                if len(generated_subwords) == 0:
                    # filter by prefix
                    filtered_vocab = self.filter_vocab_by_prefix(vocab, prefix)
                    # Mask the logits based on the filtered vocabulary
                    masked_logits = self.mask_logits_by_vocab(next_token_logits, filtered_vocab)
                    # Normalize the masked logits to get probabilities
                    probs = torch.softmax(masked_logits, dim=-1)
                    #next_token_id = torch.argmax(probs, dim=-1).item()
                    next_token_id = probs.topk(5).indices.tolist()[i]
                else: 
                    next_token_id = next_token_logits.topk(5).indices.tolist()[i]


                # Decode the generated subwords so far
                subword_text = self.tokenizer.decode([next_token_id], clean_up_tokenization_spaces=True)
                # Check if the last token completes a word
                if not subword_text.startswith("##") and len(generated_subwords) > 0:  # Check if it's not a continuation of a word
                    break

                generated_subwords.append(next_token_id)
                input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]]).to(self.device)], dim=1).to(self.device)  # Append the predicted token to the input

            # Decode the generated subwords to form the next word
            next_word = self.tokenizer.decode(generated_subwords, clean_up_tokenization_spaces=True).strip()
            next_words.append(next_word)

        return next_words

In [None]:
batch_size = 64
embedding_size = 50 #16
hidden_size = 64 #25
seq_length = 5      # number of tokens used as a datapoint
learning_rate = 0.001
epochs = 2
num_layers = 2
num_processes = 4
GRU = False
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
print( "Running on", device )

#rnn_model = RNN(embedding_size, hidden_size, no_of_output_symbols=tokenizer.vocab_size, device=device, num_layers=num_layers, use_GRU = use_GRU).to(device)
#rnn_model.load_state_dict(torch.load('model_2024-05-28_06_59_20_927544/rnn.model'))    
rnn_model = torch.load('LSTM_large/rnn.model')
rnn = RNNpredictor(rnn_model, tokenizer, device)

rnn.predict_next_word("on top of the world there are name", 5)
rnn.predict_next_word("I am from Ge", 5)

In [None]:
tokenizer.pad_token