In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import numpy as np
import re
from typing import List, Tuple

class FastTextEmbeddings:
    def __init__(self, min_count: int = 5, ngram_range: Tuple[int, int] = (3, 6),
                 embedding_dim: int = 100, window_size: int = 5):
        self.min_count = min_count
        self.ngram_range = ngram_range
        self.embedding_dim = embedding_dim
        self.window_size = window_size
        self.word2idx = {}
        self.ngram2idx = {}
        self.model = None

    def generate_ngrams(self, word: str) -> List[str]:
        """Generate character n-grams for a word"""
        word = f"<{word}>"
        ngrams = []
        for n in range(self.ngram_range[0], self.ngram_range[1] + 1):
            for i in range(len(word) - n + 1):
                ngrams.append(word[i:i + n])
        return ngrams

    def build_vocab(self, texts: List[str]):
        """Build vocabulary from texts"""
        # Count words and build word vocabulary
        word_counts = Counter()
        for text in texts:
            words = self._preprocess_text(text)
            word_counts.update(words)

        for word, count in word_counts.items():
            if count >= self.min_count:
                self.word2idx[word] = len(self.word2idx)

        # Build ngram vocabulary
        ngram_counts = Counter()
        for word in self.word2idx.keys():
            ngrams = self.generate_ngrams(word)
            ngram_counts.update(ngrams)

        for ngram in ngram_counts:
            self.ngram2idx[ngram] = len(self.ngram2idx)

    def _preprocess_text(self, text: str) -> List[str]:
        """Clean and tokenize text"""
        text = text.lower()
        text = re.sub(r'[^\w\s]', ' ', text)
        return text.split()

class EmbeddingDataset(Dataset):
    def __init__(self, texts: List[str], fasttext: FastTextEmbeddings):
        self.texts = texts
        self.fasttext = fasttext
        self.pairs = self._create_training_pairs()

    def _create_training_pairs(self):
        pairs = []
        for text in self.texts:
            words = self.fasttext._preprocess_text(text)
            for i, word in enumerate(words):
                if word not in self.fasttext.word2idx:
                    continue

                # Get context words within window
                for j in range(max(0, i - self.fasttext.window_size),
                             min(len(words), i + self.fasttext.window_size + 1)):
                    if i != j and words[j] in self.fasttext.word2idx:
                        pairs.append((self.fasttext.word2idx[word],
                                    self.fasttext.word2idx[words[j]]))
        return pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        return torch.tensor(self.pairs[idx][0]), torch.tensor(self.pairs[idx][1])

class FastTextModel(nn.Module):
    def __init__(self, vocab_size: int, ngram_vocab_size: int, embedding_dim: int):
        super().__init__()
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.ngram_embeddings = nn.Embedding(ngram_vocab_size, embedding_dim)
        self.context_embeddings = nn.Embedding(vocab_size, embedding_dim)

    def forward(self, word_idx, context_idx):
        word_embed = self.word_embeddings(word_idx)
        context_embed = self.context_embeddings(context_idx)
        return torch.mul(word_embed, context_embed).sum(dim=1)

    def get_word_embedding(self, word: str, fasttext: FastTextEmbeddings) -> torch.Tensor:
        """Get embedding for a word including its ngrams"""
        if word not in fasttext.word2idx:
            # For unknown words, use only ngram embeddings
            ngrams = fasttext.generate_ngrams(word)
            ngram_vectors = []
            for ngram in ngrams:
                if ngram in fasttext.ngram2idx:
                    idx = fasttext.ngram2idx[ngram]
                    ngram_vectors.append(self.ngram_embeddings(torch.tensor([idx])))
            if not ngram_vectors:
                return torch.zeros(self.word_embeddings.embedding_dim)
            return torch.mean(torch.stack(ngram_vectors), dim=0)

        # For known words, combine word and ngram embeddings
        word_idx = fasttext.word2idx[word]
        word_vector = self.word_embeddings(torch.tensor([word_idx]))

        ngrams = fasttext.generate_ngrams(word)
        ngram_vectors = []
        for ngram in ngrams:
            if ngram in fasttext.ngram2idx:
                idx = fasttext.ngram2idx[ngram]
                ngram_vectors.append(self.ngram_embeddings(torch.tensor([idx])))

        if ngram_vectors:
            ngram_vector = torch.mean(torch.stack(ngram_vectors), dim=0)
            return word_vector + ngram_vector
        return word_vector

def train_fasttext(texts: List[str], embedding_dim: int = 100, epochs: int = 5):
    # Initialize FastText
    fasttext = FastTextEmbeddings(embedding_dim=embedding_dim)
    fasttext.build_vocab(texts)

    # Create dataset and dataloader
    dataset = EmbeddingDataset(texts, fasttext)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

    # Initialize model
    model = FastTextModel(
        vocab_size=len(fasttext.word2idx),
        ngram_vocab_size=len(fasttext.ngram2idx),
        embedding_dim=embedding_dim
    )

    # Training
    optimizer = optim.Adam(model.parameters())
    criterion = nn.BCEWithLogitsLoss()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    for epoch in range(epochs):
        total_loss = 0
        for word_idx, context_idx in dataloader:
            word_idx = word_idx.to(device)
            context_idx = context_idx.to(device)

            optimizer.zero_grad()
            output = model(word_idx, context_idx)
            loss = criterion(output, torch.ones_like(output))

            # Add negative sampling
            neg_context_idx = torch.randint(0, len(fasttext.word2idx),
                                          context_idx.shape).to(device)
            neg_output = model(word_idx, neg_context_idx)
            loss += criterion(neg_output, torch.zeros_like(neg_output))

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}')

    fasttext.model = model
    return fasttext

# Example usage:
if __name__ == "__main__":
    # Sample texts (replace with your cybercrime descriptions)
    texts = [
        "cyber fraud online payment",
        "phishing email bank account",
        # Add more texts...
    ]

    # Train model
    fasttext = train_fasttext(texts, embedding_dim=100, epochs=5)

    # Get embeddings for words
    word = "cyber"
    if word in fasttext.word2idx:
        embedding = fasttext.model.get_word_embedding(word, fasttext)
        print(f"Embedding for '{word}':", embedding.detach().numpy())