In [None]:
import polars as pl

DATA_PATH = "../data/processed/embeddings_batches/batch_*.parquet"

In [None]:
lf = pl.scan_parquet(DATA_PATH)
lf.head().collect()

In [None]:
metadata_df = lf.drop(["text", "state", "embedding"]).collect()
embeddings_df = lf.select("embedding").collect()

In [None]:
import numpy as np
from sklearn.preprocessing import MinMaxScaler

def reshape_metadata(df: pl.DataFrame, new_range: tuple = (-1, 1)) -> np.ndarray:
    """
    Converts metadata DataFrame to numpy array.
    Scales variables as needed.
    """
    scaler = MinMaxScaler(feature_range=new_range)
    return scaler.fit_transform(df.to_numpy())

In [None]:
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SimpleFusionAE(nn.Module):
    def __init__(self, text_dim: int, meta_dim: int, latent_dim=50):
        super(SimpleFusionAE, self).__init__()

        input_dim = text_dim + meta_dim

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim)
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim)
        )

    def forward(self, text, metadata):
        x = torch.cat((text, metadata), dim=1)

        latent_representation = self.encoder(x)

        reconstructed_x = self.decoder(latent_representation)

        return reconstructed_x, latent_representation

In [None]:
metadata_arr = reshape_metadata(metadata_df)
embeddings_arr = embeddings_df["embedding"].to_numpy()

model = SimpleFusionAE(text_dim=embeddings_arr.shape[1], meta_dim=metadata_arr.shape[1]).to(device)
model

In [None]:
from torch.utils.data import DataLoader, TensorDataset

x_embeddings = torch.tensor(embeddings_arr, dtype=torch.float32).cuda()
x_metadata = torch.tensor(metadata_arr, dtype=torch.float32).cuda()

dataset = TensorDataset(x_embeddings, x_metadata)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
import torch.optim
from tqdm import tqdm

LEARNING_RATE = 1e-3
EPOCHS = 20

# Training the model
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss()

for epoch in tqdm(range(EPOCHS), desc="Training model"):
    model.train()
    total_loss = 0

    for batch_text, batch_meta in dataloader:
        optimizer.zero_grad()

        reconstructed_x, latent_representation = model(batch_text, batch_meta)

        target = torch.cat((batch_text, batch_meta), dim=1)
        loss = criterion(reconstructed_x, target)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {total_loss / len(dataloader):.6f}")

In [None]:
model.eval()

In [None]:
# Pass fused embeddings to the CPU to be stored
# Doing this gradually as to avoid OOM

inference_dataset = TensorDataset(x_embeddings, x_metadata)
inference_loader = DataLoader(inference_dataset, batch_size=32, shuffle=False)

fused_embeddings_list = []
with torch.no_grad():
    for batch_text, batch_meta in inference_loader:
        _, batch_latent = model(batch_text, batch_meta)
        batch_latent_cpu = batch_latent.cpu().numpy()
        fused_embeddings_list.append(batch_latent_cpu)

fused_embeddings = np.vstack(fused_embeddings_list)
fused_embeddings.shape

In [None]:
pl.DataFrame({"fused_embeddings": fused_embeddings}).write_parquet("../data/processed/fused_embeddings.parquet")

In [None]:
torch.save(model.state_dict(), "../models/fusion_autoencoder.weights.pth")

In [None]:
from umap import UMAP
from hdbscan import HDBSCAN
from bertopic import BERTopic
from bertopic.vectorizers import ClassTfidfTransformer
from sklearn.feature_extraction.text import CountVectorizer

vanilla_bertopic = BERTopic(
    umap_model=UMAP(n_components=5, min_dist=0.0, metric="cosine"),
    hdbscan_model=HDBSCAN(min_cluster_size=15, prediction_data=True),
    vectorizer_model=CountVectorizer(stop_words="english"),
    ctfidf_model=ClassTfidfTransformer()
)

modded_bertopic = BERTopic(
    umap_model=UMAP(n_components=5, min_dist=0.0, metric="cosine"),
    hdbscan_model=HDBSCAN(min_cluster_size=15, prediction_data=True),
    vectorizer_model=CountVectorizer(stop_words="english"),
    ctfidf_model=ClassTfidfTransformer()
)

In [None]:
docs = lf.select("text").collect().to_series().to_list()
vanilla_embeddings = lf.select("embedding").collect().to_series().to_numpy()
encoder_embeddings = pl.read_parquet("../data/processed/fused_embeddings.parquet").select("fused_embeddings").to_series().to_numpy()

In [None]:
vanilla_topics, vanilla_probs = vanilla_bertopic.fit_transform(
    docs, embeddings=vanilla_embeddings
)

In [None]:
modded_topics, modded_probs = modded_bertopic.fit_transform(
    docs, embeddings=encoder_embeddings
)