In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
dataset = load_dataset("wmt14", "fr-en")

Downloading readme: 100%|██████████| 10.5k/10.5k [00:00<00:00, 25.7MB/s]
Downloading data: 100%|██████████| 30/30 [03:36<00:00,  7.23s/files]
Downloading data: 100%|██████████| 475k/475k [00:00<00:00, 1.96MB/s]
Downloading data: 100%|██████████| 536k/536k [00:00<00:00, 4.01MB/s]
Generating train split: 100%|██████████| 40836715/40836715 [01:04<00:00, 636356.22 examples/s] 
Generating validation split: 100%|██████████| 3000/3000 [00:00<00:00, 20895.25 examples/s]
Generating test split: 100%|██████████| 3003/3003 [00:00<00:00, 12153.92 examples/s]


In [12]:
class WMT14Dataset(Dataset):
	def __init__(self, split):
		self.data = dataset[split]

	def __len__(self):
		return len(self.data)

	def __getitem__(self, idx):
		item = self.data[idx]
		return item['translation']['en'], item['translation']['fr']

English: ('All chemical products available to the general public must meet the CCCR, 2001 requirements.', 'The Official Document System was re-engineered to operate on an open-standards platform and to provide search capabilities in all six official languages.', 'He questioned why, if candidate Kinden received full marks for that part of her reply, appellant Ballard did not also get full marks.', 'The Commission therefore rejects the amendments relating to paternity leave.', 'The base case maximum annual S04= concentration in 2020 is predicted to be 4.9 µg/m3.', 'The General Affairs and External Relations Council noted these intentions on the part of the Presidency.', 'Mobility of General Service staff', '”delete international associations of local authorities” and', 'For eight of the imported cases, the exact source was unknown because the patient had traveled in more than one country outside the United States during the exposure period.', 'Annex 6,', 'In these cases, forceful action 

In [None]:
batch_size = 32

train_dataset = WMT14Dataset(split='train')
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

validation_dataset = WMT14Dataset(split='validation')
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

test_dataset = WMT14Dataset(split='test')
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [13]:
for batch in train_loader:
	english_sentences, french_sentences = batch
	print("English:", english_sentences)
	print("French:", french_sentences)
	break

English: ('She asked what legal and media-related measures were being adopted to combat traditional perceptions of women.', '• 235-A-2006 — For an exemption from the application of subsections 64(1) and/or 64(1.1) of the Canada Transportation Act - Air Canada also carrying on business as Air Canada Jetz, on behalf of itself and Jazz Air LP, as represented by its general partner, Jazz Air Holding GP Inc., carrying on business as Air Canada Jazz', 'No.  of admissions to homes administered by MCDS by gender and case type', '(e) Inform government officials and other professionals working with children without parental care that adoptions, in particular international adoptions, are an exceptional alternative care option and that the principles of non-discrimination and the best interests of the child must be taken into account when making such decisions.', 'Fifty-fourth session', 'The Producer must pay royalties to the star Artist or any other Artist bound by an exclusivity agreement.', 'To

In [9]:
class Encoder(torch.nn.Module):
	def __init__(self, input_size, hidden_size, num_layers=1):
		super(Encoder, self).__init__()
		
		self.hidden_size = hidden_size
		self.num_layers = num_layers
		self.rnn = torch.nn.RNN(input_size, hidden_size, num_layers, bidirectional=True, batch_first=True)

	def forward(self, x):
		# x = batch size, sequence length, hidden size (x2 for bidirectional)
		# hn = numlayers*D, batch_size, hidden size
		x, hn = self.rnn(x)
		batch_size = x.size()[0]
		
		# need to get one hidden state that is forward with backwards concatenated
		hn = hn.view(self.num_layers, 2, batch_size, self.hidden_size) # num_layers, num_directions, batch, hidden_size
		hn = torch.concat((hn[:, 0, :, :], hn[:, 1, :, :]), dim=2) # num_layers, batch, hidden_size
		return x,hn

In [10]:
class Decoder(torch.nn.Module):
	def __init__(self, input_size, hidden_size, num_layers):
		super(Decoder, self).__init__()

		self.rnn = torch.nn.RNN(input_size, hidden_size, num_layers=1, bidirectional=False, batch_first=True)

	def forward(self, encoder_output, x):
		return