In [None]:
# ==============================================================================
#
#  FINAL SCRIPT WITH YOUR SAMPLING STRATEGY FOR MAXIMUM SPEED
#  - YOUR STRATEGY: Samples 50k for code, 150k for docs, and a balanced 150k for combined.
#  - This is the fastest version, designed to meet your time limits.
#
# ==============================================================================

import re
import collections
import json
import time
import os
from tqdm.auto import tqdm
from typing import List, Dict, Tuple
import numpy as np
import pandas as pd

# ==============================================================================
# Part 1: BPE Tokenizer Class (Unchanged)
# ==============================================================================
class BPE:
    def __init__(self):
        self.vocab: List[str] = []
        self.merges: Dict[Tuple[str, str], int] = {}
        self.token_to_id: Dict[str, int] = {}
        self.id_to_token: Dict[int, str] = {}
        self.vocab_size: int = 0
    
    def load(self, file_prefix: str):
        vocab_path, merges_path = f"{file_prefix}_vocab.json", f"{file_prefix}_merges.json"
        if not os.path.exists(vocab_path): raise FileNotFoundError(f"BPE vocab file not found: {vocab_path}")
        if not os.path.exists(merges_path): raise FileNotFoundError(f"BPE merges file not found: {merges_path}")
        with open(vocab_path, 'r', encoding='utf-8') as f: self.token_to_id = json.load(f)
        self.vocab = [""] * len(self.token_to_id)
        for tok, idx in self.token_to_id.items(): self.vocab[idx] = tok
        self._invert_vocab()
        with open(merges_path, 'r', encoding='utf-8') as f: merges_loaded = json.load(f)
        self.merges = {tuple(k.split(' ')): v for k, v in merges_loaded.items()}
        print(f"Successfully loaded BPE tokenizer from '{file_prefix}'. Vocab size: {self.vocab_size}")

    def _invert_vocab(self):
        self.id_to_token = {i: t for t, i in self.token_to_id.items()}
        self.vocab_size = len(self.vocab)

    def _tokenize_word(self, word: str) -> List[str]:
        word_tuple = tuple(word) + ('</w>',)
        while True:
            pairs = list(zip(word_tuple[:-1], word_tuple[1:]))
            applicable_merges = {p: self.merges[p] for p in pairs if p in self.merges}
            if not applicable_merges: break
            best_pair = min(applicable_merges, key=applicable_merges.get)
            new_token, i, new_word_tuple = "".join(best_pair), 0, []
            while i < len(word_tuple):
                if i < len(word_tuple) - 1 and (word_tuple[i], word_tuple[i+1]) == best_pair:
                    new_word_tuple.append(new_token); i += 2
                else: new_word_tuple.append(word_tuple[i]); i += 1
            word_tuple = tuple(new_word_tuple)
        return list(word_tuple)

    def encode(self, text: str) -> List[int]:
        if not isinstance(text, str): return []
        ids, unknown_id = [], self.token_to_id.get('<UNK>')
        words = re.findall(r"\w+|\S", text)
        for word in words:
            for token in self._tokenize_word(word):
                ids.append(self.token_to_id.get(token, unknown_id))
        return ids

# ==============================================================================
# Part 2: Word2Vec Class (Stable and Corrected)
# ==============================================================================
class Word2Vec:
    def __init__(self, vector_size=300, window=5, learning_rate=0.025, epochs=2, subsampling_threshold=1e-5):
        self.vector_size, self.window, self.initial_lr, self.epochs = vector_size, window, learning_rate, epochs
        self.word2id, self.id2word, self.vocab_size = None, None, None
        self.center_vectors, self.context_vectors = None, None
        self.neg_sampling_table, self.neg_sampling_probs, self.subsampling_probs = None, None, None
        self.subsampling_threshold = subsampling_threshold

    def build_vocab_from_bpe(self, bpe_tokenizer: BPE):
        self.word2id, self.id2word, self.vocab_size = bpe_tokenizer.token_to_id, bpe_tokenizer.id_to_token, bpe_tokenizer.vocab_size
        self.center_vectors = np.random.uniform(-0.5/self.vector_size, 0.5/self.vector_size, (self.vocab_size, self.vector_size)).astype(np.float32)
        self.context_vectors = np.zeros((self.vocab_size, self.vector_size), dtype=np.float32)
        self.subsampling_probs = np.zeros(self.vocab_size, dtype=np.float32)
        print(f"Vocabulary size: {self.vocab_size}")

    def _prepare_sampling_distributions(self, tokenized_corpus: List[List[int]]):
        print("Preparing distributions for negative sampling and subsampling...")
        token_counts = collections.Counter(token for sentence in tokenized_corpus for token in sentence)
        total_tokens = sum(token_counts.values())
        
        vocab_indices = np.array(list(token_counts.keys()))
        counts = np.array([token_counts[i] for i in vocab_indices])
        
        powers = counts ** 0.75
        self.neg_sampling_probs = powers / np.sum(powers)
        self.neg_sampling_table = vocab_indices
        
        frequencies = counts / total_tokens
        discard_probs = 1 - np.sqrt(self.subsampling_threshold / frequencies)
        
        self.subsampling_probs[vocab_indices] = discard_probs
        self.subsampling_probs[self.subsampling_probs < 0] = 0

    def train(self, tokenized_corpus: List[List[int]], negative_samples: int = 5):
        self._prepare_sampling_distributions(tokenized_corpus)
        sigmoid = lambda x: 1 / (1 + np.exp(-x))
        total_words, words_processed = sum(len(s) for s in tokenized_corpus), 0

        for epoch in range(self.epochs):
            print(f"\n--- Starting Epoch {epoch + 1}/{self.epochs} ---")
            for sentence in tqdm(tokenized_corpus, desc=f"Epoch {epoch + 1}"):
                subsampled_sentence = [token for token in sentence if np.random.random() > self.subsampling_probs[token]]
                
                for i, target_id in enumerate(subsampled_sentence):
                    current_lr = self.initial_lr * (1 - (words_processed / (total_words * self.epochs + 1)))
                    current_lr = max(current_lr, self.initial_lr * 0.0001)
                    
                    start, end = max(0, i - self.window), min(len(subsampled_sentence), i + self.window + 1)
                    
                    for j in range(start, end):
                        if i == j: continue
                        context_id = subsampled_sentence[j]
                        target_vector = self.center_vectors[target_id]
                        
                        ids_to_update = np.array([context_id] + np.random.choice(self.neg_sampling_table, size=negative_samples, p=self.neg_sampling_probs).tolist())
                        labels = np.array([1] + [0] * negative_samples)
                        vectors_to_update = self.context_vectors[ids_to_update]

                        scores = vectors_to_update.dot(target_vector)
                        scores = np.clip(scores, -10, 10)
                        
                        errors = sigmoid(scores) - labels
                        grad_center = errors.reshape(1, -1).dot(vectors_to_update).flatten()
                        grad_context = np.outer(errors, target_vector)
                        
                        self.center_vectors[target_id] -= current_lr * grad_center
                        np.subtract.at(self.context_vectors, ids_to_update, current_lr * grad_context)

                words_processed += len(sentence)
    
    def get_embedding_matrix(self): return self.center_vectors

    def find_similar_tokens(self, token: str, top_n: int = 10):
        if token not in self.word2id: print(f"Token '{token}' not in vocabulary."); return
        target_id, target_vector = self.word2id[token], self.center_vectors[target_id]
        similarities = self.center_vectors.dot(target_vector) / (np.linalg.norm(self.center_vectors, axis=1) * np.linalg.norm(target_vector))
        top_indices = np.argsort(similarities)[::-1][1:top_n+1]
        print(f"\nTokens most similar to '{token}':")
        for i in top_indices: print(f"  - {self.id2word[i]:<20} (Similarity: {similarities[i]:.4f})")

# ==============================================================================
# Part 3: Main Orchestration Function with YOUR Sampling Strategy
# ==============================================================================
def run_all_training_tasks(bpe_files_path: str):
    # --- YOUR STRATEGY IMPLEMENTED ---
    CODE_SAMPLE_SIZE = 50000
    DOCS_SAMPLE_SIZE = 150000
    COMBINED_SAMPLE_PER_CORPUS = 75000 # (75k code + 75k docs = 150k total)
    
    VECTOR_SIZE, WINDOW_SIZE, EPOCHS, NEGATIVE_SAMPLES = 300, 5, 2, 5

    try:
        if not isinstance(df, pd.DataFrame): print("Error: Global DataFrame 'df' not found."); return
    except NameError: print("Error: Global DataFrame 'df' not found. Please run data loading cell first."); return

    train_df_full = df[df['partition'] == 'train'].copy()
    print(f"Full training data has {len(train_df_full)} rows.")

    # --- Task 1: Code-Only (50k sample) ---
    print("\n" + "="*80 + "\n--- Task 1: Training Word2Vec on CODE (Sampled to 50k) ---\n" + "="*80)
    train_df_code_sample = train_df_full.sample(n=CODE_SAMPLE_SIZE, random_state=42)
    bpe_code = BPE(); bpe_code.load(os.path.join(bpe_files_path, 'bpe_code_only'))
    code_corpus = [bpe_code.encode(text) for text in tqdm(train_df_code_sample['code'].dropna(), desc="Tokenizing code")]
    w2v_code = Word2Vec(VECTOR_SIZE, WINDOW_SIZE, epochs=EPOCHS); w2v_code.build_vocab_from_bpe(bpe_code)
    w2v_code.train(code_corpus, negative_samples=NEGATIVE_SAMPLES)
    np.save("word2vec_code_embeddings.npy", w2v_code.get_embedding_matrix())
    print("\nSaved code embeddings"); w2v_code.find_similar_tokens('def')

    # --- Task 2: Docs-Only (150k sample) ---
    print("\n" + "="*80 + "\n--- Task 2: Training Word2Vec on DOCS (Sampled to 150k) ---\n" + "="*80)
    train_df_docs_sample = train_df_full.sample(n=DOCS_SAMPLE_SIZE, random_state=42)
    bpe_docs = BPE(); bpe_docs.load(os.path.join(bpe_files_path, 'bpe_docs_only'))
    docs_corpus = [bpe_docs.encode(text) for text in tqdm(train_df_docs_sample['docstring'].dropna(), desc="Tokenizing docs")]
    w2v_docs = Word2Vec(VECTOR_SIZE, WINDOW_SIZE, epochs=EPOCHS); w2v_docs.build_vocab_from_bpe(bpe_docs)
    w2v_docs.train(docs_corpus, negative_samples=NEGATIVE_SAMPLES)
    np.save("word2vec_docs_embeddings.npy", w2v_docs.get_embedding_matrix())
    print("\nSaved docs embeddings"); w2v_docs.find_similar_tokens('model')

    # --- Task 3: Combined (75k code + 75k docs) ---
    print("\n" + "="*80 + "\n--- Task 3: Training Word2Vec on COMBINED (Sampled to 150k) ---\n" + "="*80)
    # Sample separately to ensure a balanced corpus
    code_sample_for_combined = train_df_full.sample(n=COMBINED_SAMPLE_PER_CORPUS, random_state=42)
    docs_sample_for_combined = train_df_full.sample(n=COMBINED_SAMPLE_PER_CORPUS, random_state=43) # Use a different seed
    combined_text = code_sample_for_combined['code'].dropna().tolist() + docs_sample_for_combined['docstring'].dropna().tolist()
    
    bpe_combined = BPE(); bpe_combined.load(os.path.join(bpe_files_path, 'bpe_combined'))
    combined_corpus = [bpe_combined.encode(text) for text in tqdm(combined_text, desc="Tokenizing combined")]
    w2v_combined = Word2Vec(VECTOR_SIZE, WINDOW_SIZE, epochs=EPOCHS); w2v_combined.build_vocab_from_bpe(bpe_combined)
    w2v_combined.train(combined_corpus, negative_samples=NEGATIVE_SAMPLES)
    np.save("word2vec_combined_embeddings.npy", w2v_combined.get_embedding_matrix())
    print("\nSaved combined embeddings"); w2v_combined.find_similar_tokens('function'); w2v_combined.find_similar_tokens('return')

    print("\n" + "="*80 + "\n--- All Training Complete ---\n" + "="*80)

# ==============================================================================
# Main Execution Block
# ==============================================================================
try:
    # !!! IMPORTANT !!!
    # EDIT THE PATH BELOW to the location of your BPE tokenizer files dataset.
    BPE_KAGGLE_FILES_PATH = "/kaggle/input/temp-bpe"

    run_all_training_tasks(bpe_files_path=BPE_KAGGLE_FILES_PATH)
except NameError:
    print("Execution failed: Please make sure you have run the data loading cell first.")