In [None]:
import ollama
from datasets import load_dataset
from langchain.text_splitter import RecursiveCharacterTextSplitter
from torch.utils.data import Dataset
import torch
import re
import json
import os  # Import os for file checking
from tqdm import tqdm

class WikipediaDataset(Dataset):
    def __init__(self, model_name: str = "llama2", max_length: int = 512, num_examples: int = None, min_length: int = 50):
        self.model_name = model_name
        self.max_length = max_length
        self.min_length = min_length
        self.json_file = 'wikipedia_embeddings.json'

        # Check if the JSON file already exists
        if os.path.exists(self.json_file):
            self._load_embeddings_from_json()
        else:
            # Load and preprocess Wikipedia dataset
            dataset = load_dataset("wikipedia", "20220301.en", split="train[:2%]")
            
            # Text splitter with overlapping chunks
            text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=max_length,
                chunk_overlap=50,  # Overlap to keep coherence
                length_function=len,
                separators=["\n\n", "\n", "."]
            )
            
            self.texts = []
            for item in tqdm(dataset, desc="Processing Wikipedia articles"):
                cleaned_text = self._clean_text(item['text'])  # Preprocess each article
                chunks = text_splitter.split_text(cleaned_text)
                # Filter out very small chunks to avoid noise
                self.texts.extend([chunk for chunk in chunks if len(chunk) >= self.min_length])
            
            # Limit the number of examples if specified
            if num_examples is not None:
                self.texts = self.texts[:num_examples]

            # Generate embeddings for all texts
            self.embeddings = []
            for text in tqdm(self.texts, desc="Generating embeddings"):
                embedding_dict = ollama.embeddings(model=self.model_name, prompt=text)
                embedding = embedding_dict['embedding']  # Adjust the key as necessary
                self.embeddings.append(torch.tensor(embedding))

            # Save the embeddings and texts to a JSON file
            self._save_embeddings_to_json()

    def _clean_text(self, text: str) -> str:
        """Cleans Wikipedia text by removing unwanted parts."""
        text = re.sub(r'\[\d+\]', '', text)  # Remove reference brackets like [1], [2]
        text = re.sub(r'\(.*?\)', '', text)  # Remove text within parentheses (optional)
        text = re.sub(r'\s+', ' ', text).strip()  # Normalize whitespace
        return text

    def _load_embeddings_from_json(self):
        """Loads embeddings and texts from a JSON file."""
        with open(self.json_file, 'r') as json_file:
            data = json.load(json_file)
            self.texts = [item['text'] for item in data]
            self.embeddings = [torch.tensor(item['embedding']) for item in data]

    def _save_embeddings_to_json(self):
        """Saves embeddings and texts to a JSON file."""
        data_to_save = [
            {'text': text, 'embedding': embedding.tolist()}  # Convert tensor to list for JSON serialization
            for text, embedding in zip(self.texts, self.embeddings)
        ]
        with open(self.json_file, 'w') as json_file:
            json.dump(data_to_save, json_file, indent=4)

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        embedding = self.embeddings[idx]
        
        # Pad or truncate embedding to max_length
        if embedding.size(0) < self.max_length:
            padding = torch.zeros(self.max_length - embedding.size(0))
            embedding = torch.cat([embedding, padding])
        else:
            embedding = embedding[:self.max_length]
        
        return {
            'input_ids': embedding,
            'attention_mask': torch.ones(self.max_length),
            'labels': embedding  # For self-supervised learning, labels are the same as inputs
        }


dataset = WikipediaDataset(
        model_name="nomic-embed-text",  
        max_length=512,
        num_examples=None
    )

Processing Wikipedia articles: 100%|██████████| 129173/129173 [00:50<00:00, 2548.69it/s]
Generating embeddings:   1%|          | 37642/3328958 [55:16<78:32:32, 11.64it/s] 