In [None]:
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer
from bertopic import BERTopic
from bertopic.representation import VisualRepresentation

In [None]:
dataset_path = 'WildFireCan-MMD.csv'
dataset = pd.read_csv(dataset_path)

In [None]:
images = dataset['image'].apply(lambda x: x[20:])
docs = dataset['text']

In [None]:
import matplotlib.pyplot as plt
# Step 1: Count words for each line
word_counts = docs.apply(lambda x: len(x.split()))
# Step 2: Plot the histogram of word counts
plt.figure(figsize=(8, 6))
plt.hist(word_counts, bins=30, color='skyblue', edgecolor='black')
# Add labels and title
plt.xlabel('Number of Words')
plt.ylabel('Frequency')
plt.title('Histogram of Word Counts per Line in Docs')
# Show plot
plt.tight_layout()
plt.show()

# test

In [None]:
from bertopic.backend import MultiModalBackend
model = MultiModalBackend('clip-ViT-B-32', batch_size=32)
# Embed both images and documents, then average them
doc_image_embeddings = model.embed(docs, images)

In [None]:
# representation_model = {
#     "Visual_Aspect": VisualRepresentation(image_to_text_model="nlpconnect/vit-gpt2-image-captioning")
# }
visual_model = VisualRepresentation()
representation_model = {
   "Visual_Aspect":  visual_model,
}

In [None]:
topic_model = BERTopic(embedding_model=model, representation_model=representation_model, min_topic_size=30)
topic_model.fit(documents=docs, images=images, embeddings=doc_image_embeddings)

In [None]:
#topic_model.visualize_barchart()
#topic_model.visualize_heatmap()
#topic_model.visualize_topics()
topic_model.visualize_hierarchy()
# topic_model.get_topic_info()

# old way

In [None]:
visual_model = VisualRepresentation()
representation_model = {
   "Visual_Aspect":  visual_model,
}

# remove stopwords after embedding
vectorizer_model = CountVectorizer(ngram_range=(1, 2), stop_words="english")

topic_model = BERTopic(
    representation_model = representation_model, 
    n_gram_range = (1, 2),
    #verbose=True,
    vectorizer_model = vectorizer_model,
    language = 'english', 
    calculate_probabilities = True,
    nr_topics = 'auto'
    )

In [None]:
topic_model.fit(documents=docs, images=images)

In [None]:
# topic_model.save("bertopic_models/bcab&jasper_auto", serialization="safetensors")

In [None]:
import base64
from io import BytesIO
from IPython.display import HTML
from PIL import Image

def get_thumbnail(image_path, size=(100, 100)):
    try:
        im = Image.open(image_path)
        im.thumbnail(size)
        return im
    except Exception as e:
        print(f"Error generating thumbnail: {str(e)}")
        return None

def image_base64(im):
    if isinstance(im, str):
        im = get_thumbnail(im)
    with BytesIO() as buffer:
        im.save(buffer, 'jpeg')
        return base64.b64encode(buffer.getvalue()).decode()

def image_formatter(im):
    return f'<img src="data:image/jpeg;base64,{image_base64(im)}">'

# Extract dataframe
df = topic_model.get_topic_info()#.drop("Representative_Docs", 1).drop("Name", 1)

# Visualize the images
HTML(df.to_html(formatters={'Visual_Aspect': image_formatter}, escape=False))

In [None]:
topic_model.get_topics()
# topic_model.visualize_barchart()
# topic_model.visualize_heatmap()
# topic_model.visualize_topics()
# topic_model.visualize_hierarchy()
# topic_model.get_topic_info()

# 15 topics

In [None]:
topic_model = BERTopic.load("bertopic_models/bcab&jasper_15")

In [None]:
topic_model.visualize_barchart()

In [None]:
topic_model.visualize_heatmap()

In [None]:
topic_model.visualize_topics()

In [None]:
topic_model.visualize_hierarchy()

In [None]:
topic_model.get_topic_info()

# 20 topics

In [None]:
topic_model = BERTopic.load("bertopic_models/bcab&jasper_20")

In [None]:
topic_model.visualize_barchart()

In [None]:
topic_model.visualize_heatmap()

In [None]:
topic_model.visualize_topics()

In [None]:
topic_model.visualize_hierarchy()

In [None]:
topic_model.get_topic_info()

# 25 topics

In [None]:
topic_model = BERTopic.load("bertopic_models/bcab&jasper_25")

In [None]:
topic_model.visualize_barchart()

In [None]:
topic_model.visualize_heatmap()

In [None]:
topic_model.visualize_topics()

In [None]:
topic_model.visualize_hierarchy()

In [None]:
df = topic_model.get_topic_info()

In [None]:
df['Representation'][12]

# 30 topics

In [None]:
topic_model = BERTopic.load("bertopic_models/bcab&jasper_30")

In [None]:
topic_model.visualize_barchart()

In [None]:
topic_model.visualize_heatmap()

In [None]:
topic_model.visualize_topics()

In [None]:
topic_model.visualize_hierarchy()

In [None]:
topic_model.get_topic_info()

# 35 topics

In [None]:
topic_model = BERTopic.load("bertopic_models/bcab&jasper_35")

In [None]:
topic_model.visualize_barchart()

In [None]:
topic_model.visualize_heatmap()

In [None]:
topic_model.visualize_topics()

In [None]:
topic_model.visualize_hierarchy()

In [None]:
topic_model.get_topic_info()

# auto topics

In [None]:
topic_model = BERTopic.load("bertopic_models/bcab&jasper_auto")

In [None]:
topic_model.visualize_barchart()

In [None]:
topic_model.visualize_heatmap()

In [None]:
topic_model.visualize_topics()

In [None]:
topic_model.visualize_hierarchy()

In [None]:
topic_model.get_topic_info()