In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

In [2]:
from datasets import load_dataset

dataset = load_dataset("knkarthick/dialogsum")

In [3]:
# Extract dialogues
dialogues = []
for _, split_data in dataset.items():
    split_dialogues = [dialog['dialogue'] for dialog in split_data]
    dialogues.extend(split_dialogues)

print(f"Collected {len(dialogues)} dialogues")

Collected 14460 dialogues


In [4]:
import re
import nltk

from tqdm import tqdm

# Extract dialog turns and split sentences
sentences = set()
for dialogue in tqdm(dialogues, desc="Extracting sentences"):
    turns = dialogue.split('\n')
    for turn in turns:
        turn = re.sub(r"#.*?#:", "", turn).strip()
        sentences.update(nltk.sent_tokenize(turn))

print(f"Collected {len(sentences)} sentences")


Extracting sentences: 100%|██████████| 14460/14460 [00:02<00:00, 4969.60it/s]

Collected 151108 sentences





In [5]:
import string

from tqdm import tqdm
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('stopwords')
# Filter sentences on surface criteria
with open("resources/english_words.txt", "r") as f:
    english_words = set([word.strip() for word in f.read().splitlines()])

filtered_sentences = []
lowercase_sentences = set()

for sentence in tqdm(list(sentences), desc="Filtering sentences"):
    tokens = nltk.word_tokenize(sentence)
    # Remove punctuation
    tokens = [token for token in tokens if token not in string.punctuation]
    if len(tokens) < 3:
        continue
    if len(tokens) > 7:
        continue

    # Filter sentences where at least one word is not in the English word list
    if not any(token in english_words for token in tokens):
        continue

    if sentence.lower() in lowercase_sentences: 
        continue

    filtered_sentences.append(sentence)
    lowercase_sentences.add(sentence.lower())

print(f"Collected {len(filtered_sentences)} filtered sentences")


[nltk_data] Downloading package punkt to
[nltk_data]     /home/lgirrbach15/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /home/lgirrbach15/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /home/lgirrbach15/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
Filtering sentences: 100%|██████████| 151108/151108 [00:06<00:00, 24265.92it/s]

Collected 64627 filtered sentences





In [6]:
from collections import defaultdict

# Build inverse mapping from lemmas to sentences
lemma_to_sentences = defaultdict(set)
sentences = list(filtered_sentences)

lemmatizer = nltk.WordNetLemmatizer()
stopwords = set(nltk.corpus.stopwords.words('english'))

for sentence in tqdm(sentences, desc="Building lemma to sentences mapping"):
    tokens = nltk.word_tokenize(sentence)
    lemmas = set([lemmatizer.lemmatize(token) for token in tokens])
    # Filter out stopwords
    lemmas = [lemma for lemma in lemmas if lemma not in stopwords]

    # Get intersection of sentences with the same lemmas
    intersection = set.intersection(*[lemma_to_sentences[lemma] for lemma in lemmas])
    if len(intersection) > 0:
        continue

    for lemma in lemmas:
        lemma_to_sentences[lemma].add(sentence)

filtered_sentences = list(set.union(*lemma_to_sentences.values()))
print(f"Collected {len(filtered_sentences)} filtered sentences")

Building lemma to sentences mapping: 100%|██████████| 64627/64627 [00:10<00:00, 6226.34it/s] 


Collected 53779 filtered sentences


In [7]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("all-MiniLM-L6-v2")
sentences = list(filtered_sentences)
sentence_embeddings = model.encode(sentences, batch_size=100)

In [8]:
# Normalize embeddings for cosine distance
import numpy as np
sentence_embeddings = sentence_embeddings / np.linalg.norm(sentence_embeddings, axis=1, keepdims=True)

In [9]:
clusters = []
processed_sentence_ids = set()
batch_size = 100

progress_bar = tqdm(total=len(sentences), desc="Building clusters")
start_idx = 0

while start_idx < len(sentences):
    batch_idxs = []
    while len(batch_idxs) < batch_size:
        if start_idx >= len(sentences):
            break
        if start_idx in processed_sentence_ids:
            start_idx += 1
            continue

        batch_idxs.append(start_idx)
        start_idx += 1

    cosine_similarities = np.dot(sentence_embeddings, sentence_embeddings[batch_idxs].T).T
    cluster_mask = cosine_similarities > 0.7
    for sentence_idx, mask in zip(batch_idxs, cluster_mask):
        mask_idxs = np.nonzero(mask)[0].tolist()
        mask_idxs = [idx for idx in mask_idxs if idx not in processed_sentence_ids]
        if len(mask_idxs) == 0:
            continue

        clusters.append(mask_idxs)
        processed_sentence_ids.update(set(mask_idxs))

    progress_bar.update(len(batch_idxs))

print(f"Collected {len(clusters)} clusters")


Building clusters:  62%|██████▏   | 33600/53779 [00:10<00:05, 3738.30it/s]

Collected 33734 clusters


In [10]:
# From each cluster, select the sentence with fewest words
filtered_sentences = []

for cluster in clusters:
    cluster_sentences = [sentences[idx] for idx in cluster]
    cluster_sentences = sorted(cluster_sentences, key=lambda x: len(x.split()))
    filtered_sentences.append(cluster_sentences[0])

print(f"Collected {len(filtered_sentences)} filtered sentences")
sentences = filtered_sentences

Collected 33734 filtered sentences


In [11]:
# Chunk sentences by number of tokens
sentences_by_num_tokens = defaultdict(list)
for sentence in sentences:
    # Remove punctuation
    tokens = [token for token in nltk.word_tokenize(sentence) if token not in string.punctuation]
    sentences_by_num_tokens[len(tokens)].append(sentence)

In [12]:
# Load word counts
import pandas as pd
word_counts = pd.read_csv("resources/count_1w.txt", sep="\t", header=None)
word_counts = word_counts.set_index(0)
word_counts = word_counts.to_dict()[1]

def minimum_word_count(sentence):
    sentence = sentence.replace("'", "")
    tokens = nltk.word_tokenize(sentence)
    # remove stopwords
    tokens = [token.lower() for token in tokens if token not in stopwords]
    # remove punctuation
    tokens = [token for token in tokens if token not in string.punctuation]
    if len(tokens) == 0:
        return 0
    
    return min([word_counts.get(token, 0) for token in tokens])

# For each num_tokens_value, keep the top 200 sentences with highest minimum word count
filtered_sentences_by_num_tokens = dict()
for num_tokens_value, sentences_with_num_tokens in sorted(sentences_by_num_tokens.items(), key=lambda x: x[0]):
    if num_tokens_value > 8 or num_tokens_value < 3:
        continue

    sentences_with_num_tokens = sorted(sentences_with_num_tokens, key=minimum_word_count, reverse=True)
    filtered_sentences_by_num_tokens[num_tokens_value] = sentences_with_num_tokens[:200]

print(f"Collected {sum(len(sentences) for sentences in filtered_sentences_by_num_tokens.values())} filtered sentences")


Collected 1000 filtered sentences


In [13]:
import random

# Print random sentences from each num_tokens_value
for num_tokens_value, sentences in filtered_sentences_by_num_tokens.items():
    print(f"Number of tokens: {num_tokens_value}\nNumber of sentences: {len(sentences)}")
    print(random.choice(sentences), "\n", minimum_word_count(random.choice(sentences)))
    print("\n")


Number of tokens: 3
Number of sentences: 200
What new car? 
 273620358


Number of tokens: 4
Number of sentences: 200
Does he help much? 
 210601244


Number of tokens: 5
Number of sentences: 200
When would you have time? 
 360468339


Number of tokens: 6
Number of sentences: 200
She was great in the part. 
 360468339


Number of tokens: 7
Number of sentences: 200
Then you can have the other one. 
 216122487




In [14]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "Qwen/QwQ-32B-Preview"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/17 [00:00<?, ?it/s]

Building clusters:  63%|██████▎   | 33760/53779 [00:26<00:05, 3738.30it/s]

In [35]:
translation_template = "Please translate into easy Chinese: \"{prompt}\""
pinyin_template = "\"{prompt}\" in Pinyin is"

def generate_response(prompt, template):
    messages = [
        {"role": "system", "content": "You are an experienced translator."},
        {"role": "user", "content": template.format(prompt=prompt)}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=512
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response

In [40]:
import re
import pandas as pd
import pinyin_jyutping_sentence

translated_sentence_records = []
pattern = r"Translation\s*=\s*([^\]]+)[;|；|。]\s*Pinyin\s*=\s*([^\]]+)"

failure_count = 0

sentences = list(set.union(*(set(sentences) for sentences in filtered_sentences_by_num_tokens.values())))
for i, sentence in enumerate(sentences):
    translated_sentence = generate_response(sentence, translation_template).replace("\n", " ")
    translated_sentence = translated_sentence.strip('"')
    if len(translated_sentence) > 20:
        failure_count += 1
        continue


    pinyin = pinyin_jyutping_sentence.pinyin(translated_sentence)
    pinyin = pinyin.strip('"')

    #print(f"Mandarin: {translated_sentence}\tPinyin: {pinyin}")

    mandarin = translated_sentence.strip()
    pinyin = pinyin.strip()
    translated_sentence_records.append({
        "sentence": sentence,
        "mandarin": mandarin,
        "pinyin": pinyin
    })
    print(" " * 100, end="\r")
    print(f"[{i+1}/{len(sentences)}/{failure_count}] Mandarin: {mandarin}, Pinyin: {pinyin}", end="\r")
    #print(f"Processed {i} sentences: Translation: {mandarin}, Pinyin: {pinyin}", end="\r")

translated_sentence_df = pd.DataFrame(translated_sentence_records)
translated_sentence_df.to_csv("resources/translated_sentences.csv", index=False)

[1000/1000/75] Mandarin: “不是一个好的。”, Pinyin: “ bùshì yīgè hǎo de 。 ”                                 g dédào nàgè ne ？

In [43]:
import random
import genanki

model_id = random.randrange(1 << 30, 1 << 31)

chinese_deck_model = genanki.Model(
    model_id,
    'Chinese Sentence Model',
    fields=[
        {'name': 'English'},
        {'name': 'Mandarin'},
        {'name': 'Pinyin'},
        {'name': 'index'},
    ],
  templates=[
    {
      'name': 'English -> Mandarin',
      'qfmt': '{{English}}',
      # Show Mandarin and Pinyin side by side
      'afmt': '{{FrontSide}}<hr id="answer">{{Mandarin}}<br>{{Pinyin}}',
    },
    {
      'name': 'Mandarin -> English',
      'qfmt': '{{Mandarin}}',
      'afmt': '{{FrontSide}}<hr id="answer">{{Pinyin}}<br>{{English}}',
    },
  ])

# Sort translated_sentence_df by length of Mandarin
translated_sentence_df = translated_sentence_df.sort_values(by="mandarin", key=lambda x: x.str.len())

# Make notes for all rows in translated_sentence_df
chinese_sentences_deck = genanki.Deck(
    model_id + 1,
    'Chinese Sentences'
)

for index, row in translated_sentence_df.iterrows():
    chinese_sentences_deck.add_note(genanki.Note(
        model=chinese_deck_model,
        fields=[row["sentence"], row["mandarin"], row["pinyin"], str(index)],
        sort_field="index"
    ))

genanki.Package(chinese_sentences_deck).write_to_file('chinese_sentences.apkg')