### Imports

In [8]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence

# %run train.ipynb
%run train_arabic.ipynb

['ث', 'م']
[' ', 'َّ']


### Validation

In [9]:
def validate(model, val_dataset, val_labels, batch_size=BATCH_SIZE):
    """
    This function implements the validation logic
    Inputs:
    - model: the trained model
    - val_dataset: the validation set
    - batch_size: integer representing the number of examples per step
    """

    # (1) create the dataloader for the validation set (make shuffle=False)
    tensor_val_dataset = TensorDataset(val_dataset, val_labels)
    val_dataloader = DataLoader(tensor_val_dataset, batch_size=batch_size, shuffle=False)

    # GPU configuration
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    if use_cuda:
        model = model.cuda()

    total_acc_val = 0

    with torch.no_grad():
        for val_input, val_label in tqdm(val_dataloader):

            # Move the validation input to the device
            val_label = val_label.to(device)

            # Move the validation label to the device
            val_input = val_input.to(device)

            # Do the forward pass
            output = model(val_input).float()

            # Calculate the batch accuracy
            correct_predictions = (output.argmax(dim=2) == val_label)
            acc = correct_predictions.sum().item()
            total_acc_val += acc

    # Calculate metrics for the entire validation set
    val_accuracy = total_acc_val / (len(val_dataset) * len(val_dataset[0]))

    print(f'Validation Accuracy: {val_accuracy} | DER: {1 - val_accuracy}\n')


### Prepare the data

In [10]:
valid_corpus = readFile(VAL_PATH)

X_val = []
y_val = []

for sentence in valid_corpus:
	# Clean each sentence in the corpus
	# Get the char list for each word in the sentence and its corresponding diacritics
	char_list, diacritics_list = separate_words_and_diacritics(sentence.strip())

	X_val.append(char_list)
	y_val.append(diacritics_list)

X_val_padded = [torch.tensor([char_to_index[char] for char in word]) for sentence in X_val for word in sentence ]
X_val_padded = pad_sequence(X_val_padded, batch_first=True)

y_val_padded = [torch.tensor([diacritic_to_index[char] for char in word]) for sentence in y_val for word in sentence ]
# print(y_val_padded)
y_val_padded = pad_sequence(y_val_padded, batch_first=True)
# print(y_val_padded)

[tensor([1, 7, 3, 3]), tensor([1, 1, 0]), tensor([3, 7, 1, 3]), tensor([5, 1, 0, 1, 3, 3]), tensor([0, 7, 1, 7, 3]), tensor([0, 0, 9, 0, 5, 3]), tensor([1, 0, 0, 9, 1, 0, 3, 0, 1]), tensor([1, 7, 1]), tensor([1, 0, 5, 1, 5]), tensor([ 1,  1, 11,  5]), tensor([0, 7, 3, 7, 5]), tensor([1, 1, 0]), tensor([1, 1, 5, 5]), tensor([3, 0, 1]), tensor([1, 7, 5, 5]), tensor([1, 7]), tensor([1, 7, 5, 5]), tensor([3, 0, 1]), tensor([1, 1, 5, 5]), tensor([1, 1, 7, 1]), tensor([1, 0, 5, 1, 5]), tensor([ 1,  1, 11,  5,  5]), tensor([1, 1, 0]), tensor([0, 0, 9, 1, 5]), tensor([1, 0, 0, 9, 7, 5]), tensor([1, 5, 0, 2, 0]), tensor([1, 1, 7, 5, 0, 3, 3]), tensor([1, 9]), tensor([0, 7, 3, 7, 1]), tensor([0, 7]), tensor([1, 0, 1]), tensor([1, 3]), tensor([1, 1, 4]), tensor([5, 1, 7, 5]), tensor([1, 7, 6]), tensor([1, 1, 1, 9, 1]), tensor([1, 1, 7, 5]), tensor([1, 0]), tensor([3, 7, 1, 1, 3]), tensor([1, 7]), tensor([1, 0, 1]), tensor([1, 3]), tensor([1, 1, 1, 0, 5]), tensor([1, 7]), tensor([1, 7, 1, 0, 4]), 

### Load the model

In [11]:
# Load the saved RNN model for inference
loaded_rnn_model = RNN(len(unique_characters) + 1, len(unique_diacritics))
loaded_rnn_model.load_state_dict(torch.load(RNN_PATH))
loaded_rnn_model.eval()

print(loaded_rnn_model)

RNN(
  (embedding): Embedding(39, 200)
  (lstm): LSTM(200, 512, batch_first=True, bidirectional=True)
  (linear): Linear(in_features=1024, out_features=15, bias=True)
)


In [12]:
validate(loaded_rnn_model, X_val_padded, y_val_padded)

100%|██████████| 415/415 [00:29<00:00, 13.97it/s]

Validation Accuracy: 0.9563625049926032 | DER: 0.043637495007396776




