In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from collections import Counter, defaultdict
import numpy as np
from typing import List, Dict, Set, Tuple
import re
from tqdm import tqdm

class UnigramTokenizer:
    def __init__(self, vocab_size: int = 32000, min_freq: int = 3):
        self.vocab_size = vocab_size
        self.min_freq = min_freq
        self.subword_vocab = {}
        self.char_vocab = set()
        self.piece2idx = {}
        self.idx2piece = {}

    def train(self, texts: List[str]):
        """Train tokenizer on the corpus"""
        print("Training tokenizer...")
        # Count initial subwords (characters and common sequences)
        char_freq = Counter()
        subword_freq = Counter()

        # First pass: count characters and build initial subwords
        for text in tqdm(texts):
            chars = list(text.lower())
            char_freq.update(chars)

            # Build initial subwords (2-6 character sequences)
            for i in range(len(chars)):
                for length in range(2, 7):
                    if i + length <= len(chars):
                        subword = ''.join(chars[i:i + length])
                        subword_freq[subword] += 1

        # Filter by minimum frequency
        subword_freq = Counter({k: v for k, v in subword_freq.items()
                              if v >= self.min_freq})

        # Build final vocabulary
        vocab = set()
        # Add all characters first
        vocab.update(char_freq.keys())

        # Add most frequent subwords until vocab_size is reached
        remaining_size = self.vocab_size - len(vocab)
        for subword, _ in subword_freq.most_common(remaining_size):
            vocab.add(subword)

        # Create piece to index mappings
        self.piece2idx = {piece: idx for idx, piece in enumerate(sorted(vocab))}
        self.idx2piece = {idx: piece for piece, idx in self.piece2idx.items()}
        print(f"Vocabulary size: {len(self.piece2idx)}")

    def tokenize(self, text: str) -> List[int]:
        """Tokenize text into subword indices"""
        text = text.lower()
        tokens = []
        while len(text) > 0:
            max_match = None
            max_length = 0

            # Find longest matching subword
            for length in range(min(6, len(text)), 0, -1):
                subword = text[:length]
                if subword in self.piece2idx:
                    max_match = subword
                    max_length = length
                    break

            if max_match is None:
                # If no match found, take single character
                tokens.append(self.piece2idx.get(text[0], self.piece2idx['<unk>']))
                text = text[1:]
            else:
                tokens.append(self.piece2idx[max_match])
                text = text[max_length:]

        return tokens

class EmbeddingDataset(Dataset):
    def __init__(self, texts: List[str], tokenizer: UnigramTokenizer,
                 context_size: int = 5):
        self.tokenizer = tokenizer
        self.context_size = context_size
        self.examples = []

        print("Creating training examples...")
        for text in tqdm(texts):
            tokens = self.tokenizer.tokenize(text)

            # Create context windows
            for i in range(len(tokens)):
                # Get context indices
                context_start = max(0, i - context_size)
                context_end = min(len(tokens), i + context_size + 1)
                context = tokens[context_start:i] + tokens[i+1:context_end]

                # Add positive example
                self.examples.append((tokens[i], context))

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        target, context = self.examples[idx]
        # Pad context if needed
        context_padded = context + [0] * (2 * self.context_size - len(context))
        return torch.tensor(target), torch.tensor(context_padded)

class MultilingualEmbeddingModel(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int = 300):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim

        # Main embedding layers
        self.target_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.context_embeddings = nn.Embedding(vocab_size, embedding_dim)

        # Character-level CNN
        self.char_cnn = nn.Sequential(
            nn.Conv1d(embedding_dim, embedding_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(embedding_dim, embedding_dim, kernel_size=3, padding=1),
            nn.ReLU(),
        )

        # Initialize embeddings
        self.target_embeddings.weight.data.uniform_(-0.1, 0.1)
        self.context_embeddings.weight.data.uniform_(-0.1, 0.1)

    def forward(self, target_idx, context_idxs):
        # Get target embedding
        target_embed = self.target_embeddings(target_idx)

        # Get context embeddings and apply char CNN
        context_embeds = self.context_embeddings(context_idxs)
        context_embeds = self.char_cnn(context_embeds.transpose(1, 2)).transpose(1, 2)

        # Average context embeddings
        context_embed = context_embeds.mean(dim=1)

        # Compute similarity
        similarity = torch.sum(target_embed * context_embed, dim=1)
        return similarity

    def get_embedding(self, token_ids: List[int]) -> torch.Tensor:
        """Get embedding for a sequence of tokens"""
        with torch.no_grad():
            token_tensor = torch.tensor(token_ids)
            embeddings = self.target_embeddings(token_tensor)
            embeddings = self.char_cnn(embeddings.unsqueeze(0).transpose(1, 2))
            embeddings = embeddings.transpose(1, 2).squeeze(0)
            return embeddings.mean(dim=0)

def train_model(texts: List[str], embedding_dim: int = 300, epochs: int = 5,
                batch_size: int = 64, vocab_size: int = 32000):
    # Initialize tokenizer and train on corpus
    tokenizer = UnigramTokenizer(vocab_size=vocab_size)
    tokenizer.train(texts)

    # Create dataset
    dataset = EmbeddingDataset(texts, tokenizer)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Initialize model
    model = MultilingualEmbeddingModel(vocab_size=len(tokenizer.piece2idx),
                                     embedding_dim=embedding_dim)

    # Training setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = optim.Adam(model.parameters())
    criterion = nn.BCEWithLogitsLoss()

    print(f"Training on {device}...")
    for epoch in range(epochs):
        total_loss = 0
        progress_bar = tqdm(dataloader)
        for batch_idx, (target, context) in enumerate(progress_bar):
            target = target.to(device)
            context = context.to(device)

            optimizer.zero_grad()

            # Positive samples
            pos_similarity = model(target, context)
            pos_loss = criterion(pos_similarity, torch.ones_like(pos_similarity))

            # Negative sampling
            neg_context = torch.randint(0, len(tokenizer.piece2idx),
                                      context.shape).to(device)
            neg_similarity = model(target, neg_context)
            neg_loss = criterion(neg_similarity, torch.zeros_like(neg_similarity))

            # Combined loss
            loss = pos_loss + neg_loss
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_description(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}")

    return model, tokenizer

def get_text_embedding(text: str, model: MultilingualEmbeddingModel,
                      tokenizer: UnigramTokenizer) -> torch.Tensor:
    """Get embedding for a complete text"""
    tokens = tokenizer.tokenize(text)
    return model.get_embedding(tokens)

# Example usage
if __name__ == "__main__":
    # Sample texts (replace with your actual data)
    texts = [
        "maine online loan liya aur payment nahi kar paya",
        "nenu loan teesukuni repayment cheyyalekapoyanu",
        "naan loan eduthu thiruppi katta mudiyala",
        # Add your 2.2M texts here
    ]

    # Train model
    model, tokenizer = train_model(texts, embedding_dim=300, epochs=5)

    # Get embedding for new text
    test_text = "maine payment nahi kiya"
    embedding = get_text_embedding(test_text, model, tokenizer)
    print(f"Embedding shape: {embedding.shape}")