In [None]:
import os
import gc
from datetime import date

import torch
from tqdm.notebook import tqdm
import pandas as pd

from utils import (
    compute_metrics,
    load_all_available_transcripts,
    SummarizationPipeline,
    TextChunker,
    LoggingConfig,
    ModelConfig,
    TextChunker,
    TopicModeler,
    Retriever
)

In [None]:
transcripts = load_all_available_transcripts()

In [None]:
transcripts.head()

In [None]:
checkpoint = 'facebook/bart-large-cnn'


model_config = ModelConfig(
    model_name_or_path=checkpoint,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

pipeline = SummarizationPipeline(
    model_config=model_config,
    logging_config=LoggingConfig()
)

tokenizer = pipeline.get_tokenizer()
chunker = TextChunker(tokenizer)

chunks = chunker.chunk_text(' '.join(transcripts.full_text.tolist()))

print([len(c) for c in chunks])

In [None]:
from langchain.schema import Document
chunks = [Document(page_content=chunk) for chunk in chunks]
chunks[:10]

In [None]:
tm = TopicModeler(chunks=chunks, speed='learn', workers=8)


In [None]:
topic_words, _, topic_nums = tm.get_topics(1)

for words, tid in zip(topic_words, topic_nums):
    print(f'Topic #{tid}: ' + ', '.join(words))