In [73]:
import duckdb
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import MultinomialNB
from sklearn.preprocessing import LabelEncoder
import torch
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments
from datasets import Dataset

In [4]:
con = duckdb.connect(database=':memory:', read_only=False)

# Loading and preprocessing

In [58]:
# huffpo
q = """
WITH data AS (
SELECT
    *
FROM
    '../../data/raw/news_categories.parquet'
),
total AS (
SELECT COUNT(*) AS total FROM data
),
counts AS (
SELECT category, COUNT(*) AS articles FROM data GROUP BY 1
),
pcts AS (
SELECT category, articles / total.total AS pct FROM counts, total
)
SELECT d.*, d.headline || ' ' || d.short_description AS text, REPLACE(REPLACE(SPLIT_PART(d.link, '.com/', 2), '/', ' '), '-', ' ') AS x, category AS y FROM data d JOIN pcts USING (category) WHERE pct > 0.02
LIMIT 1000
"""

huffpo = con.execute(q).fetch_df()

huffpo

Unnamed: 0,link,headline,category,short_description,authors,date,text,x,y
0,https://www.huffpost.com/entry/funniest-tweets...,23 Of The Funniest Tweets About Cats And Dogs ...,COMEDY,"""Until you have a dog you don't understand wha...",Elyse Wanshel,2022-09-23,23 Of The Funniest Tweets About Cats And Dogs ...,entry funniest tweets cats dogs september 17 2...,COMEDY
1,https://www.huffpost.com/entry/funniest-parent...,The Funniest Tweets From Parents This Week (Se...,PARENTING,"""Accidentally put grown-up toothpaste on my to...",Caroline Bologna,2022-09-23,The Funniest Tweets From Parents This Week (Se...,entry funniest parenting tweets_l_632d7d15e4b0...,PARENTING
2,https://www.huffpost.com/entry/dodgers-basebal...,"Maury Wills, Base-Stealing Shortstop For Dodge...",SPORTS,"Maury Wills, who helped the Los Angeles Dodger...","Beth Harris, AP",2022-09-20,"Maury Wills, Base-Stealing Shortstop For Dodge...",entry dodgers baseball obit wills_n_6329feb3e4...,SPORTS
3,https://www.huffpost.com/entry/golden-globes-r...,Golden Globes Returning To NBC In January Afte...,ENTERTAINMENT,"For the past 18 months, Hollywood has effectiv...",,2022-09-20,Golden Globes Returning To NBC In January Afte...,entry golden globes return nbc_n_6329f151e4b0e...,ENTERTAINMENT
4,https://www.huffpost.com/entry/biden-us-forces...,Biden Says U.S. Forces Would Defend Taiwan If ...,POLITICS,President issues vow as tensions with China rise.,,2022-09-19,Biden Says U.S. Forces Would Defend Taiwan If ...,entry biden us forces defend taiwan against ch...,POLITICS
...,...,...,...,...,...,...,...,...,...
995,https://www.huffpost.com/entry/nikola-jokic-ch...,Nikola Jokic Unleashes 'Nasty Cheap Shot' At M...,SPORTS,"The Denver Nuggets star, ejected after his bli...",Ron Dicker,2021-11-09,Nikola Jokic Unleashes 'Nasty Cheap Shot' At M...,entry nikola jokic cheap shot markieff morris_...,SPORTS
996,https://www.huffpost.com/entry/joe-biden-jamal...,Joe Biden Promised To Take On Saudi Arabia Ove...,POLITICS,Biden pledged to root his foreign policy in va...,Akbar Shahid Ahmed,2021-11-09,Joe Biden Promised To Take On Saudi Arabia Ove...,entry joe biden jamal khashoggi saudi arabia_n...,POLITICS
997,https://www.huffpost.com/entry/lance-bass-brit...,Lance Bass Learns He And Britney Spears Are Re...,ENTERTAINMENT,The *NSYNC star said he's always thought of Sp...,Curtis M. Wong,2021-11-09,Lance Bass Learns He And Britney Spears Are Re...,entry lance bass britney spears ancestry test_...,ENTERTAINMENT
998,https://www.huffpost.com/entry/adam-kinzinger-...,Rep. Adam Kinzinger Reportedly Considered Usin...,POLITICS,"""I barricaded myself in, thinking, ’If this is...",Mary Papenfuss,2021-11-09,Rep. Adam Kinzinger Reportedly Considered Usin...,entry adam kinzinger jan 6 riot_n_618a058de4b0...,POLITICS


In [64]:
# uci
q = """
SELECT *, TITLE AS text, REPLACE(REPLACE(SPLIT_PART(URL, '.com/', 2), '/', ' '), '-', ' ') AS x, CATEGORY AS y FROM '../../data/raw/uci_categories.parquet'
LIMIT 1000
"""

uci = con.execute(q).fetch_df()

uci

Unnamed: 0,ID,TITLE,URL,PUBLISHER,CATEGORY,STORY,HOSTNAME,TIMESTAMP,text,x,y
0,1,"Fed official says weak data caused by weather,...",http://www.latimes.com/business/money/la-fi-mo...,Los Angeles Times,b,ddUyU0VZz0BRneMioxUPQVP6sIxvM,www.latimes.com,1394470370698,"Fed official says weak data caused by weather,...",business money la fi mo federal reserve plosse...,b
1,2,Fed's Charles Plosser sees high bar for change...,http://www.livemint.com/Politics/H2EvwJSK2VE6O...,Livemint,b,ddUyU0VZz0BRneMioxUPQVP6sIxvM,www.livemint.com,1394470371207,Fed's Charles Plosser sees high bar for change...,Politics H2EvwJSK2VE6OF7iK1g3PP Feds Charles P...,b
2,3,US open: Stocks fall after Fed official hints ...,http://www.ifamagazine.com/news/us-open-stocks...,IFA Magazine,b,ddUyU0VZz0BRneMioxUPQVP6sIxvM,www.ifamagazine.com,1394470371550,US open: Stocks fall after Fed official hints ...,news us open stocks fall after fed official hi...,b
3,4,"Fed risks falling 'behind the curve', Charles ...",http://www.ifamagazine.com/news/fed-risks-fall...,IFA Magazine,b,ddUyU0VZz0BRneMioxUPQVP6sIxvM,www.ifamagazine.com,1394470371793,"Fed risks falling 'behind the curve', Charles ...",news fed risks falling behind the curve charle...,b
4,5,Fed's Plosser: Nasty Weather Has Curbed Job Gr...,http://www.moneynews.com/Economy/federal-reser...,Moneynews,b,ddUyU0VZz0BRneMioxUPQVP6sIxvM,www.moneynews.com,1394470372027,Fed's Plosser: Nasty Weather Has Curbed Job Gr...,Economy federal reserve charles plosser weathe...,b
...,...,...,...,...,...,...,...,...,...,...,...
995,996,Titanfall Xbox One resolution update isn't rea...,http://www.slashgear.com/titanfall-xbox-one-re...,SlashGear,t,dw0Jnfj8vEPaLZM8TQMSCbVaYeyrM,www.slashgear.com,1394493719210,Titanfall Xbox One resolution update isn't rea...,titanfall xbox one resolution update isnt read...,t
996,997,Next-Gen Starts Tomorrow With Titanfall Launch,http://www.newsledge.com/next-gen-starts-tomor...,The News Ledge,t,dw0Jnfj8vEPaLZM8TQMSCbVaYeyrM,www.newsledge.com,1394493719387,Next-Gen Starts Tomorrow With Titanfall Launch,next gen starts tomorrow titanfall launch 5431,t
997,998,Titanfall On Xbox One Likely To Get Improved R...,http://thekoalition.com/2014/03/titanfall-on-x...,The Koalition,t,dw0Jnfj8vEPaLZM8TQMSCbVaYeyrM,thekoalition.com,1394493719614,Titanfall On Xbox One Likely To Get Improved R...,2014 03 titanfall on xbox one likely to get im...,t
998,999,'Titanfall' Release Date Nears: Twitter Predic...,http://www.ibtimes.com/titanfall-release-date-...,International Business Times,t,dw0Jnfj8vEPaLZM8TQMSCbVaYeyrM,www.ibtimes.com,1394493719773,'Titanfall' Release Date Nears: Twitter Predic...,titanfall release date nears twitter predicts ...,t


In [61]:
# recognasumm

q = """
WITH data AS (
SELECT
    *
FROM
    '../../data/raw/recognasumm.parquet'
),
total AS (
SELECT COUNT(*) AS total FROM data
),
counts AS (
SELECT Categoria AS category, COUNT(*) AS articles FROM data GROUP BY 1
),
pcts AS (
SELECT category, articles / total.total AS pct FROM counts, total
)
SELECT d.*, d.Titulo || ' ' || d.Subtitulo AS text, REPLACE(REPLACE(SPLIT_PART(REPLACE(d.URL, '.com.br', '.com'), '.com/', 2), '/', ' '), '-', ' ') AS x, category AS y FROM data d JOIN pcts ON d.Categoria = pcts.category WHERE pct > 0.02 AND d.URL LIKE 'http%'
"""

rs = con.execute(q).fetch_df()

rs

Unnamed: 0,index,Titulo,Subtitulo,Noticia,Categoria,Autor,Data,URL,Autor_corrigido,Sumario,__index_level_0__,text,x,y
0,121981,Bolsonaro autoriza obra de duplicação da BR-4...,Presidente participou de evento na manhã desta...,O presidente Jair Bolsonaro (sem partido) lanç...,Política,Por G1 PR e RPC Foz do Iguaçu — Foz do Iguaçu,27/08/2020 09h43,https://g1.globo.com/pr/oeste-sudoeste/noticia...,Por G1 e RPC Foz do Iguaçu — Foz do Iguaçu,Bolsonaro autoriza obra de duplicação da BR-46...,27048,Bolsonaro autoriza obra de duplicação da BR-4...,pr oeste sudoeste noticia 2020 08 27 bolsonaro...,Política
1,46728,Levantamento mostra que ataques cibernéticos ...,País é o 2º na América Latina com mais ataques...,Ingrid Oliveirada CNN O Brasil registrou no pr...,Ciência e Tecnologia,CNN,19/08/2022 às 17:42 | Atualizado 22/08/2022 à...,https://www.cnnbrasil.com.br/tecnologia/levant...,CNN,Levantamento mostra que ataques cibernéticos n...,27052,Levantamento mostra que ataques cibernéticos ...,tecnologia levantamento mostra que ataques cib...,Ciência e Tecnologia
2,8053,Fiocruz: casos de síndrome respiratória aguda ...,Dados mostram desaceleração do crescimento ent...,"Dados do último Boletim InfoGripe da Fiocruz, ...",Saúde,Por O GLOBO — São Paulo,10/03/2023 15h16,https://oglobo.globo.com/saude/medicina/notici...,Por O GLOBO — São Paulo,Fiocruz: casos de síndrome respiratória aguda ...,27039,Fiocruz: casos de síndrome respiratória aguda ...,saude medicina noticia 2023 03 fiocruz casos d...,Saúde
3,26352,Cúpula da Otan redefine geopolítica global; U...,Ouça a análise dos principais fatos da semana ...,"Camila Olivoda CNN Neste episódio, Roberto No...",Brasil,CNN,14/07/2023 às 14:46 | Atualizado 14/07/2023 à...,https://www.cnnbrasil.com.br/nacional/cupula-d...,CNN,Cúpula da Otan redefine geopolítica global; Uc...,27050,Cúpula da Otan redefine geopolítica global; U...,nacional cupula da otan redefine geopolitica g...,Brasil
4,133437,"Brasil tem taxas de juros mais altas do mundo,...",As declarações foram dadas horas antes da reun...,O presidente Luiz Inácio Lula da Silva afirmou...,Política,UOL,02/08/2023 15h35,https://noticias.uol.com.br/ultimas-noticias/a...,UOL,"Brasil tem taxas de juros mais altas do mundo,...",27037,"Brasil tem taxas de juros mais altas do mundo,...",ultimas noticias ansa 2023 08 02 brasil tem ta...,Política
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
98018,80370,Aumenta o número de imigrantes brasileiros em ...,"Desemprego, situação irregular, falta de prepa...",O número de brasileiros em Portugal que pedira...,Internacional,"Por Luciana Quaresma, RFI",15/02/2023 04h00,https://g1.globo.com/mundo/noticia/2023/02/15/...,Luciana Quaresma,Aumenta o número de imigrantes brasileiros em ...,27,Aumenta o número de imigrantes brasileiros em ...,mundo noticia 2023 02 15 aumenta o numero de i...,Internacional
98019,37844,Tropas do Grupo Wagner com mais de 100 veícul...,Imagens de satélite verificadas mostram reboqu...,Paul P. Murphyda CNN O primeiro comboio das fo...,Internacional,CNN,18/07/2023 às 04:01 | Atualizado 18/07/2023 à...,https://www.cnnbrasil.com.br/internacional/tro...,CNN,Tropas do Grupo Wagner com mais de 100 veículo...,24,Tropas do Grupo Wagner com mais de 100 veícul...,internacional tropas do grupo wagner com mais ...,Internacional
98020,83970,Presidente da Comissão Europeia visita a Ucrân...,Ursula von der Leyen faz sua terceira visita a...,"A presidente da Comissão Europeia, Ursula von ...",Internacional,Por France Presse,15/09/2022 08h28,https://g1.globo.com/mundo/ucrania-russia/noti...,France Presse,Presidente da Comissão Europeia visita a Ucrân...,16,Presidente da Comissão Europeia visita a Ucrân...,mundo ucrania russia noticia 2022 09 15 presid...,Internacional
98021,44425,Polícia da Nicarágua proíbe procissão católic...,Relações entre entidade e governo têm sido ten...,Ismael Lopezda Reuters A polícia da Nicarágua ...,Internacional,CNN,12/08/2022 às 17:13,https://www.cnnbrasil.com.br/internacional/pol...,CNN,Polícia da Nicarágua proíbe procissão católica...,11,Polícia da Nicarágua proíbe procissão católic...,internacional policia da nicaragua proibe proc...,Internacional


In [86]:
def make_splits(df, seed=20240823):
    train, test = train_test_split(df, test_size=0.2, random_state=seed)
    val, test = train_test_split(test, test_size=0.5, random_state=seed)
    return train, val, test

In [87]:
# split each dataset into train (80%), validation (10%) and test (10%) sets
train_huffpo, val_huffpo, test_huffpo = make_splits(huffpo)
train_uci, val_uci, test_uci = make_splits(uci)

# Model training

## BERT classifier

In [57]:
def train_bert_clf(X_train, y_train, X_val, y_val, model_name='distilbert-base-uncased', epochs=3, output_name="huffpo"):
    # initialize resources
    le = LabelEncoder()
    tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
    model = DistilBertForSequenceClassification.from_pretrained(model_name, num_labels=len(y_train.unique()))

    # encode labels
    y_train = le.fit_transform(y_train)
    y_val = le.transform(y_val)

    # tokenize text and create datasets
    def tokenize(batch):
        return tokenizer(batch['text'], padding=True, truncation=True)

    train_dataset = Dataset.from_dict({"text": X_train, "label": y_train})
    val_dataset = Dataset.from_dict({"text": X_val, "label": y_val})

    train_dataset = train_dataset.map(tokenize, batched=True)
    val_dataset = val_dataset.map(tokenize, batched=True)

    # move model to gpu
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps')
    model.to(device)

    # train model
    training_args = TrainingArguments(
        output_dir='./results',
        num_train_epochs=epochs,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=64,
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir='./logs',
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset
    )

    trainer.train()

    # save model, tokenizer and label encoder
    model.save_pretrained(f'models/bert/{output_name}')
    tokenizer.save_pretrained(f'models/bert/{output_name}')
    torch.save(le, f'models/bert/{output_name}/label_encoder.pt')    

    return model, tokenizer, le

In [88]:
model, tokenizer, label_encoder = train_bert_clf(train_huffpo.x, train_huffpo.y, val_huffpo.x, val_huffpo.y)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map: 100%|██████████| 800/800 [00:00<00:00, 33890.28 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 22162.77 examples/s]
  0%|          | 0/150 [1:13:59<?, ?it/s]
100%|██████████| 150/150 [00:29<00:00,  5.17it/s]

{'train_runtime': 28.999, 'train_samples_per_second': 82.761, 'train_steps_per_second': 5.173, 'train_loss': 1.8362221272786459, 'epoch': 3.0}





## Distant labeling

In [93]:
def train_distant_labeler(train, mapping_file="../../data/uci_categories_attributed.csv", output_name="huffpo"):
    if mapping_file is not None:
        data = con.execute(f"""
        WITH mapping AS (
            SELECT * FROM '{mapping_file}'
        )
        SELECT x.text AS x, m.category AS y FROM train x JOIN mapping m ON SPLIT_PART(SPLIT_PART(x.URL, '://', 2), '/', 2) = m.slug
        """).fetch_df()

        data = data[data.y.notna()]

        X_train= data['x']
        y_train = data['y']
    else:
        X_train = train["text"]
        y_train = train["y"]

    
    # count vectorize X's
    vectorizer = CountVectorizer()
    X_train_counts = vectorizer.fit_transform(X_train)

    # train model
    clf = MultinomialNB()
    clf.fit(X_train_counts, y_train)

    # save model and vectorizer
    torch.save(clf, f'models/distant/{output_name}/model.pt')
    torch.save(vectorizer, f'models/distant/{output_name}/vectorizer.pt')

    return clf, vectorizer

    

In [96]:
train_distant_labeler(train_huffpo, mapping_file=None, output_name="huffpo")

(MultinomialNB(), CountVectorizer())