In [16]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchtext
from torch.utils.data import DataLoader, random_split
from torchtext.legacy.data import Field, BucketIterator, Dataset, Example, TabularDataset

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import spacy
import numpy as np
import pandas as pd

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [66]:
class TranslationDataset:
  def __init__(self, filepath):
    self.spacy_en = spacy.load('en_core_web_sm')
    self.spacy_de = spacy.load('de_core_news_sm')
    self.SRC = Field(tokenize=self.tokenize_de, 
                init_token='<sos>', 
                eos_token='<eos>', 
                lower=True, 
                batch_first=True)

    self.TRG = Field(tokenize=self.tokenize_en, 
                init_token='<sos>', 
                eos_token='<eos>', 
                lower=True, 
                batch_first=True)
    self.dataset = TabularDataset(
        path=filepath, format='tsv',
        fields=[('de', self.SRC), ('en', self.TRG)]
    )
    
  def __len__(self):
    length = len(self.dataset)
    return length

  def tokenize_de(self, text):
    return [tok.text for tok in self.spacy_de.tokenizer(text)]

  def tokenize_en(self, text):
    return [tok.text for tok in self.spacy_en.tokenizer(text)]
  
  def __getitem__(self, idx):
    en = self.dataset[idx].en
    de = self.dataset[idx].de
    return de, en

  def createDataset(self, batch_size=32):
    train_data, valid_data, test_data = self.dataset.split(split_ratio=[0.8, 0.1, 0.1])
    self.SRC.build_vocab(train_data, min_freq = 3)
    self.TRG.build_vocab(train_data, min_freq = 3)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    self.train_iterator, self.valid_iterator, self.test_iterator = BucketIterator.splits((train_data, valid_data, test_data), batch_size=batch_size, device=device)

    return  self.train_iterator, self.valid_iterator, self.test_iterator


In [67]:
dataset = TranslationDataset('/content/drive/MyDrive/cleaned_dataset.tsv')

In [68]:
dataset.createDataset()

(<torchtext.legacy.data.iterator.BucketIterator at 0x7f08f1ef4a90>,
 <torchtext.legacy.data.iterator.BucketIterator at 0x7f0925167910>,
 <torchtext.legacy.data.iterator.BucketIterator at 0x7f08ebaaffa0>)

In [69]:
dataset[20]

(['the',
  'school',
  'of',
  'the',
  'environment',
  'is',
  'dedicated',
  'to',
  'sustaining',
  'and',
  'restoring',
  'the',
  'long-term',
  'health',
  'of',
  'the',
  'biosphere',
  'and',
  'the',
  'well-being',
  'of',
  'its',
  'people',
  '.'],
 ['die',
  'umweltschule',
  'widmet',
  'sich',
  'der',
  'aufrechterhaltung',
  'und',
  'wiederherstellung',
  'der',
  'langfristigen',
  'gesundheit',
  'der',
  'biosphäre',
  'und',
  'des',
  'wohlbefindens',
  'ihrer',
  'bevölkerung',
  '.'])