In [6]:
from transformers import MarianMTModel, MarianTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
import torch
from tqdm import tqdm

# Load your training data
dataset = load_dataset("text", data_files={
                       "train": "./prep/train.de-en.en", 
                       "validation": "./prep/train.de-en.de",
                       "test":"./prep/"})

# Load the MarianMT model and tokenizer
model_name = "Helsinki-NLP/opus-mt-en-de"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)

# Tokenize and prepare the data


def tokenize_batch(batch):
    return tokenizer(batch["text"], return_tensors="pt", truncation=True, padding=True)

print(dataset["train"])

train_data = dataset["train"].map(tokenize_batch, batched=True)
validation_data = dataset["validation"].map(tokenize_batch, batched=True)

# Define DataLoader for training and validation
train_dataloader = DataLoader(
    train_data, batch_size=4, shuffle=True, collate_fn=lambda x: x)
validation_dataloader = DataLoader(
    validation_data, batch_size=4, shuffle=False)

# Training parameters
num_epochs = 3
learning_rate = 5e-5

# Define optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
loss_fn = CrossEntropyLoss()

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}"):
        inputs = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids=inputs, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch + 1}, Average Training Loss: {average_loss:.4f}")

    # Validation
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for batch in tqdm(validation_dataloader, desc="Validation"):
            inputs = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=inputs, labels=labels)
            loss = outputs.loss
            val_loss += loss.item()

    average_val_loss = val_loss / len(validation_dataloader)
    print(f"Epoch {epoch + 1}, Average Validation Loss: {average_val_loss:.4f}")

# Save the trained model
model.save_pretrained("your_trained_model_path")
tokenizer.save_pretrained("your_trained_model_path")

Dataset({
    features: ['text'],
    num_rows: 153348
})


KeyboardInterrupt: 