In [None]:
# import libraries
import time
import torch
from transformers import BertTokenizer, BertModel, CamembertTokenizer, CamembertModel
from torch.utils.data import DataLoader, Dataset
import random
import math
from concurrent.futures import ProcessPoolExecutor, as_completed
import matplotlib.pyplot as plt
import numpy as np

# define a function to load data
def load_data(file_path):
    with open(file_path, encoding='utf-8') as file:
        lines = [line.strip().lower() for line in file if line.strip()]
    return lines

# define a function to sample data
def sample_data(data, sample_fraction):
    """
    randomly samples a fraction of the data
    """
    sample_size = int(len(data) * sample_fraction)
    return random.sample(data, sample_size)

# define a class to create a dataset
class SentenceDataset(Dataset):
    def __init__(self, sentences, tokenizer, bert_model, max_input_length=256):
        self.sentences = sentences
        self.tokenizer = tokenizer
        self.bert_model = bert_model
        self.max_input_length = max_input_length

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        tokens = self.tokenizer.tokenize(sentence)
        if len(tokens) > self.max_input_length:
            tokens = tokens[:self.max_input_length]
        elif len(tokens) < self.max_input_length:
            tokens += [self.tokenizer.pad_token] * (self.max_input_length - len(tokens))
        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        input_tensor = torch.tensor([input_ids])
        with torch.no_grad():
            embedding = self.bert_model(input_tensor)[0]
        return embedding.squeeze(0)

# Load BERT models and tokenizers
tokenizer_en = BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer_fr = CamembertTokenizer.from_pretrained('camembert-base')
bert_en = BertModel.from_pretrained('bert-base-uncased')
bert_fr = CamembertModel.from_pretrained('camembert-base')

# freeze the parameters of the models
for param in bert_en.parameters():
    param.requires_grad = False
for param in bert_fr.parameters():
    param.requires_grad = False

# load data
english_sentences = load_data('/content/gdrive/MyDrive/europarl-v7.fr-en.en')
french_sentences = load_data('/content/gdrive/MyDrive/europarl-v7.fr-en.fr')

# set the sample fraction
sample_fraction = 0.1

# sample data
english_sentences_sampled = sample_data(english_sentences, sample_fraction)
french_sentences_sampled = sample_data(french_sentences, sample_fraction)

print(f"Sampled {len(english_sentences_sampled)} English sentences and {len(french_sentences_sampled)} French sentences.")

# create datasets and dataloaders
batch_size = 256  # 适度减小批处理大小

english_dataset = SentenceDataset(english_sentences_sampled, tokenizer_en, bert_en)
french_dataset = SentenceDataset(french_sentences_sampled, tokenizer_fr, bert_fr)

english_loader = DataLoader(english_dataset, batch_size=batch_size, shuffle=False, num_workers=8)
french_loader = DataLoader(french_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

# compute total batches
total_batches = max(math.ceil(len(english_loader.dataset) / batch_size), math.ceil(len(french_loader.dataset) / batch_size))
print(f"Total batches to process: {total_batches}")

# use a function to combine embeddings
def process_batch(eng_batch, fr_batch, ratio_en, ratio_fr):
    combined_batch = ratio_en * eng_batch + ratio_fr * fr_batch
    return combined_batch

def combine_embeddings_batch(english_loader, french_loader, ratio_en=0.5, ratio_fr=0.5):
    combined_embeddings = []
    total_time = 0
    batch_count = 0

    with ProcessPoolExecutor(max_workers=8) as executor:
        futures = []
        for eng_batch, fr_batch in zip(english_loader, french_loader):
            futures.append(executor.submit(process_batch, eng_batch, fr_batch, ratio_en, ratio_fr))

        for future in as_completed(futures):
            start_time = time.time()
            combined_embeddings.extend(future.result())
            end_time = time.time()
            batch_time = end_time - start_time
            total_time += batch_time
            batch_count += 1

            # print progress
            if batch_count % 10 == 0:  # every 10 batches
                print(f"Processed batch {batch_count} in {batch_time:.2f} seconds.")

    average_time_per_batch = total_time / batch_count
    estimated_total_time = average_time_per_batch * total_batches

    print(f"Average time per batch: {average_time_per_batch:.2f} seconds.")
    print(f"Estimated total time for all batches: {estimated_total_time / 60:.2f} minutes.")

    return combined_embeddings

# use the function to combine embeddings
start_time = time.time()
combined_embeddings = combine_embeddings_batch(english_loader, french_loader)
end_time = time.time()
total_time = end_time - start_time

print("Combined embeddings created.")
print(f"Total combined embeddings: {len(combined_embeddings)}")

# check the number of embeddings
for i in range(5):
    print(f"Embedding {i+1}: {combined_embeddings[i].shape}")

# make sure the number of embeddings is correct
assert len(combined_embeddings) == len(english_sentences_sampled), "Combined embeddings count mismatch!"
print("Embedding verification passed.")

# print the total time taken
print(f"Total time taken: {total_time / 60:.2f} minutes.")

# plot the training loss
# assume we have the training loss values for the generator and discriminator
epochs = np.arange(1, 51)
generator_loss = np.random.rand(50)  
discriminator_loss = np.random.rand(50)  

plt.figure(figsize=(10, 5))
plt.plot(epochs, generator_loss, label='Generator Loss')
plt.plot(epochs, discriminator_loss, label='Discriminator Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss of Generator and Discriminator')
plt.legend()
plt.grid(True)
plt.savefig('training_loss.png')
plt.show()

# generate some sentences
# assume we have generated sentences at epoch 1 and epoch 50
generated_sentences_epoch_1 = ["this is an example sentence.", "c'est une phrase exemple."]
generated_sentences_epoch_50 = ["another example sentence.", "un autre exemple de phrase."]

fig, axs = plt.subplots(2, 1, figsize=(10, 5))

axs[0].text(0.5, 0.5, "\n".join(generated_sentences_epoch_1), horizontalalignment='center', verticalalignment='center', fontsize=12)
axs[0].set_title('Generated Sentences at Epoch 1')
axs[0].axis('off')

axs[1].text(0.5, 0.5, "\n".join(generated_sentences_epoch_50), horizontalalignment='center', verticalalignment='center', fontsize=12)
axs[1].set_title('Generated Sentences at Epoch 50')
axs[1].axis('off')

plt.tight_layout()
plt.savefig('generated_sentences.png')
plt.show()


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Sampled 200568 English sentences and 200479 French sentences.
Total batches to process: 784
