In [20]:
# Install required packages (run once)
!pip install chromadb sentence-transformers scikit-learn torch

import chromadb
from chromadb.utils import embedding_functions
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sentence_transformers import SentenceTransformer
import random

# Load dataset
print("Loading dataset...")
dataset = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))
texts = dataset.data
labels = dataset.target
print(f"Dataset loaded with {len(texts)} documents.")

# Select a percentage of documents
percentage = 0.05  # Adjust the percentage here (e.g., 0.3 for 30%)
selected_indices = random.sample(range(len(texts)), int(len(texts) * percentage))
texts = [texts[i] for i in selected_indices]
labels = [labels[i] for i in selected_indices]
print(f"Selected {len(texts)} documents ({percentage * 100}%) from dataset.")

# Preprocess: filter empty texts
print("Filtering empty texts...")
filtered_texts, filtered_labels = zip(*[(text, label) for text, label in zip(texts, labels) if text.strip()])
print(f"Filtered dataset has {len(filtered_texts)} documents.")

# Train/test split
print("Splitting data into training and test sets...")
X_train, X_test, y_train, y_test = train_test_split(
    filtered_texts, filtered_labels, test_size=0.2, random_state=42, stratify=filtered_labels
)
print(f"Train size: {len(X_train)}, Test size: {len(X_test)}")

# Generate embeddings
print("Generating embeddings...")
model = SentenceTransformer('all-MiniLM-L6-v2')
train_embeddings = model.encode(X_train, show_progress_bar=True, batch_size=32)
test_embeddings = model.encode(X_test, show_progress_bar=True, batch_size=32)
print("Embeddings generated.")

# Initialize ChromaDB
print("Initializing ChromaDB...")
client = chromadb.Client()
collection_name = "text_classification"

# Delete existing collection safely
if collection_name in [col.name for col in client.list_collections()]:
    print(f"Deleting existing collection '{collection_name}'...")
    client.delete_collection(name=collection_name)
    print("Existing collection deleted.")

# Create new collection
print(f"Creating new collection '{collection_name}'...")
collection = client.create_collection(
    name=collection_name,
    embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
)
print("Collection created.")

# Insert data in batches
batch_size = 10
print("Inserting documents into ChromaDB...")
for i in range(0, len(X_train), batch_size):
    batch_ids = [f"id_{j}" for j in range(i, min(i + batch_size, len(X_train)))]
    batch_documents = X_train[i:i + batch_size]
    collection.add(ids=batch_ids, documents=batch_documents)
    print(f"Inserted batch {i // batch_size + 1}/{(len(X_train) - 1) // batch_size + 1}")
print("All documents inserted.")

# Train logistic regression
print("Training logistic regression classifier...")
clf = LogisticRegression(max_iter=10, multi_class='multinomial', solver='lbfgs')
clf.fit(train_embeddings, y_train)
print("Classifier training completed.")

# Evaluate classifier
print("Evaluating classifier...")
y_pred = clf.predict(test_embeddings)
report = classification_report(y_test, y_pred, target_names=dataset.target_names)
print(report)
print("Evaluation completed.")

# Query ChromaDB
query_text = "Space exploration and rockets"
print(f"Querying ChromaDB for: '{query_text}'")
results = collection.query(query_texts=[query_text], n_results=5)

print("\nTop 5 similar documents from ChromaDB:")
for idx, doc in enumerate(results['documents'][0]):
    print(f"\nDocument {idx+1}:\n{doc[:300]}...\n")

print("Query completed.")

Loading dataset...
Dataset loaded with 18846 documents.
Selected 942 documents (5.0%) from dataset.
Filtering empty texts...
Filtered dataset has 913 documents.
Splitting data into training and test sets...
Train size: 730, Test size: 183
Generating embeddings...


Batches:   0%|          | 0/23 [00:00<?, ?it/s]

Batches:   0%|          | 0/6 [00:00<?, ?it/s]

Embeddings generated.
Initializing ChromaDB...
Deleting existing collection 'text_classification'...
Existing collection deleted.
Creating new collection 'text_classification'...
Collection created.
Inserting documents into ChromaDB...
Inserted batch 1/73
Inserted batch 2/73
Inserted batch 3/73
Inserted batch 4/73
Inserted batch 5/73
Inserted batch 6/73
Inserted batch 7/73
Inserted batch 8/73
Inserted batch 9/73
Inserted batch 10/73
Inserted batch 11/73
Inserted batch 12/73
Inserted batch 13/73
Inserted batch 14/73
Inserted batch 15/73
Inserted batch 16/73
Inserted batch 17/73
Inserted batch 18/73
Inserted batch 19/73
Inserted batch 20/73
Inserted batch 21/73
Inserted batch 22/73
Inserted batch 23/73
Inserted batch 24/73
Inserted batch 25/73
Inserted batch 26/73
Inserted batch 27/73
Inserted batch 28/73
Inserted batch 29/73
Inserted batch 30/73
Inserted batch 31/73
Inserted batch 32/73
Inserted batch 33/73
Inserted batch 34/73
Inserted batch 35/73
Inserted batch 36/73
Inserted batch 37



Classifier training completed.
Evaluating classifier...
                          precision    recall  f1-score   support

             alt.atheism       0.33      0.12      0.18         8
           comp.graphics       0.55      0.67      0.60         9
 comp.os.ms-windows.misc       0.88      0.88      0.88         8
comp.sys.ibm.pc.hardware       0.71      0.56      0.62         9
   comp.sys.mac.hardware       0.71      0.62      0.67         8
          comp.windows.x       0.67      0.67      0.67         9
            misc.forsale       0.73      0.89      0.80         9
               rec.autos       1.00      0.70      0.82        10
         rec.motorcycles       0.73      1.00      0.85        11
      rec.sport.baseball       0.82      1.00      0.90         9
        rec.sport.hockey       0.83      0.62      0.71         8
               sci.crypt       0.82      0.75      0.78        12
         sci.electronics       0.57      0.73      0.64        11
                 sc

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
