# Pipeline for transforming docs with `Scikit-learn`

## Instantiating transformer class

In [32]:
from bertopic import BERTopic

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.datasets import fetch_20newsgroups

import logging

logging.basicConfig(level=logging.INFO)

"""
Defining initial BERTopic transformer class
"""

class BERTopicTransformer(BaseEstimator, TransformerMixin):
  def __init__(self, nr_topics, verbose):
    self.model = BERTopic(nr_topics=nr_topics, verbose=verbose) # readymade BERTopic model

  def fit(self, X, y=None):
    return self

  def transform(self, X):
    """
    1: Preprocesses text data
    """

    import nltk
    from nltk.tokenize import word_tokenize
    from nltk.corpus import stopwords
    from nltk.stem import WordNetLemmatizer

    # turn text to lowercase tokens
    logging.info("Preprocessing text data")
    text = X.lower()
    tokens = word_tokenize(text)

    # remove stopwords
    logging.info("Removing stopwords")
    stop_words = set(stopwords.words('english'))
    tokens = [word for word in tokens if word not in stop_words]

    # lemmatize
    logging.info("Lemmatizing")
    lemmatizer = WordNetLemmatizer()
    filtered_tokens = [lemmatizer.lemmatize(word) for word in tokens]

    # join tokens back into text
    preprocessed_text = ' '.join(filtered_tokens)

    """
    2. Reshaping data for BERTopic to consume
    """
    doc_list = [preprocessed_text]
    
    """
    3. Applying BERTopic models
    """
    logging.info("Applying BERTopic model")
    topics, probs = self.model.fit_transform(doc_list)
    return topics, probs

"""
Wrapper class for BERTopic - for some reason this works but no idea why
"""

class BERTopicWrapper(BaseEstimator, TransformerMixin):
  def __init__(self, model):
    self.model = model

  def fit(self, X, y=None):
    self.model.fit(X)  # Call fit on the wrapped BERTopic model
    return self

  def transform(self, X):
    topics, probs = self.model.transform(X)
    return topics, probs

"""
Defining pipeline
"""

test_data = fetch_20newsgroups(subset='all',  remove=('headers', 'footers', 'quotes'))['data']

topic_pipeline = Pipeline([
    ('preprocess', BERTopicTransformer(nr_topics=3, verbose=True)),
    ('model', BERTopicWrapper(BERTopic().fit(test_data)))  # Wrap and fit BERTopic
])

transformed_data = topic_pipeline.transform(test_data)


INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: mps
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2
