In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator
from torch.utils.tensorboard import SummaryWriter

import numpy as np
import spacy
import random

In [2]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

sw_tok = AutoTokenizer.from_pretrained("eolang/SW-v1")
en_tok = AutoTokenizer.from_pretrained("bert-base-uncased")

In [6]:
# tok functions
def tokenize_en(text):
    return en_tok.tokenize(text)

def tokenize_sw(text):  
    return sw_tok.tokenize(text)

In [8]:
swahili = Field(tokenize=tokenize_sw,
                lower=True,
                init_token="<sos>",
                eos_token="<eos>")

english = Field(tokenize=tokenize_en,
                lower=True,
                init_token="<sos>",
                eos_token="<eos>")
                

In [10]:
from datasets import load_dataset, DatasetDict

dataset = load_dataset("opus_paracrawl", "en-sw")

Found cached dataset opus_paracrawl (/Users/olang/.cache/huggingface/datasets/opus_paracrawl/en-sw/9.0.0/96d4b9607c5750673f19a3ddf6d424eb57830a93183fd3929fc66ae2a8ebd52c)


  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'translation'],
        num_rows: 132520
    })
})

In [12]:
# create train, test, and validation splits
train_testval = dataset["train"].train_test_split(test_size=0.2)

# Split the 10% test + valid in half test, half valid
test_valid = train_testval['test'].train_test_split(test_size=0.5)

# gather everyone if you want to have a single DatasetDict
ds = DatasetDict({
    'train': train_testval['train'],
    'test': test_valid['test'],
    'valid': test_valid['train']})

In [13]:
ds

DatasetDict({
    train: Dataset({
        features: ['id', 'translation'],
        num_rows: 106016
    })
    test: Dataset({
        features: ['id', 'translation'],
        num_rows: 13252
    })
    valid: Dataset({
        features: ['id', 'translation'],
        num_rows: 13252
    })
})

In [14]:
train_data, valid_data, test_data = ds['train'], ds['valid'], ds['test']