In [1]:
import numpy as np
import pandas as pd

import os

import transformers
from imblearn.under_sampling import RandomUnderSampler
from sklearn.preprocessing import LabelEncoder

from transformers import AutoTokenizer, BertModel, DistilBertModel
from transformers import AutoModel, BertForSequenceClassification, BertTokenizer

from datasets import Dataset, ClassLabel

import torch
from torch.utils.data import Dataset as TorchDataset, DataLoader
from torch.nn import TripletMarginLoss
from torch.optim import Adam
from tqdm import tqdm

import neptune.new as neptune

In [2]:
### PARAMS
MAX_SAMPLES = 10000
BATCH_SIZE = 4
LR = 1e-3
EPOCHS = 20

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#device = torch.device('cpu')

In [3]:
dataset_df = pd.read_csv('dataset/tweet_dataset.csv')
dataset_df.dropna(inplace=True)
dataset_df

Unnamed: 0,short_description,category
0,Health experts said it is too early to predict...,U.S. NEWS
1,He was subdued by passengers and crew when he ...,U.S. NEWS
2,"""Until you have a dog you don't understand wha...",COMEDY
3,"""Accidentally put grown-up toothpaste on my to...",PARENTING
4,Amy Cooper accused investment firm Franklin Te...,U.S. NEWS
...,...,...
209522,Verizon Wireless and AT&T are already promotin...,TECH
209523,"Afterward, Azarenka, more effusive with the pr...",SPORTS
209524,"Leading up to Super Bowl XLVI, the most talked...",SPORTS
209525,CORRECTION: An earlier version of this story i...,SPORTS


In [4]:
X, y = dataset_df[['short_description']], dataset_df[['category']]

### Undersampling

In [5]:
undersampler = RandomUnderSampler(random_state=42)
X_res, y_res = undersampler.fit_resample(X, y)

### One-hot encoding

In [6]:
oh_encoder = LabelEncoder()
y_enc = oh_encoder.fit_transform(y_res)

  y = column_or_1d(y, warn=True)


### Dataset creation

In [7]:
data_df = {"text": X_res["short_description"], "labels": y_enc.tolist()}
data_df = Dataset.from_dict(data_df).shuffle()
data_df

Dataset({
    features: ['text', 'labels'],
    num_rows: 36246
})

In [8]:
"""
data_df.features['labels'] = ClassLabel(num_classes=42)
split_dataset = data_df.train_test_split(test_size=0.1, stratify_by_column="labels")
data_df = split_dataset['test']
"""

### Tokenization

In [9]:
tokenizer = BertTokenizer.from_pretrained("distilbert-base-cased")
embeding_model = BertModel.from_pretrained("distilbert-base-cased")

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DistilBertTokenizer'. 
The class this function is called from is 'BertTokenizer'.
You are using a model of type distilbert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
Some weights of the model checkpoint at distilbert-base-cased were not used when initializing BertModel: ['distilbert.transformer.layer.0.output_layer_norm.bias', 'distilbert.transformer.layer.3.output_layer_norm.weight', 'distilbert.transformer.layer.2.attention.k_lin.weight', 'distilbert.embeddings.LayerNorm.bias', 'distilbert.transformer.layer.4.sa_layer_norm.bias', 'distilbert.transformer.layer.5.sa_layer_norm.bias', 'distilbert.transformer.layer.3.sa_layer_norm.weight', 'distilbert.transformer.layer.0.sa_layer_norm.bias', 'distilbert.transfo

In [10]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")
dataset_features = data_df.features.copy()

def tokenize_function(examples):
    with torch.no_grad():
        tokens = tokenizer(examples["text"], return_tensors="pt")
        embedings = embeding_model(**tokens).pooler_output
        embedings = embedings.squeeze()

    return {"embedings": embedings}

dataset = data_df.map(tokenize_function)
dataset.features['labels'] = ClassLabel(num_classes=42)

  0%|          | 0/3625 [00:00<?, ?ex/s]

In [11]:
dataset = dataset.remove_columns(["text"])
dataset.set_format("torch")
dataset

Dataset({
    features: ['labels', 'embedings'],
    num_rows: 3625
})

In [21]:
dataset_pd = pd.DataFrame(dataset['embedings'])
dataset_pd["labels"] = pd.DataFrame(dataset['labels'])
dataset_pd.to_csv('dataset/embeddings_smol.csv', index=False)

In [22]:
dataset_pd

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,759,760,761,762,763,764,765,766,767,labels
0,0.286959,-0.072739,0.753478,0.029065,-0.044117,0.796469,-0.080700,-0.116337,0.106141,0.664438,...,-0.081861,-0.296287,0.481009,-0.632536,-0.181246,-0.031159,0.367673,-0.146828,-0.786743,9
1,0.357129,-0.174773,0.603911,0.026667,0.094043,0.766938,-0.142196,-0.122094,0.022330,0.673558,...,-0.077524,-0.412776,0.496480,-0.645392,-0.021252,0.059162,0.468731,0.179680,-0.866986,2
2,0.395993,-0.010106,0.656082,0.130611,-0.051991,0.719536,-0.271542,0.046218,0.007320,0.625045,...,-0.233596,-0.333096,0.483195,-0.732185,-0.191192,0.217195,0.477800,-0.053176,-0.807459,13
3,0.334255,-0.037806,0.703386,0.064638,0.027087,0.798698,-0.184298,-0.088888,0.138465,0.644953,...,-0.089683,-0.295783,0.493706,-0.596528,-0.298668,0.063497,0.299139,-0.043658,-0.803471,33
4,0.320405,0.077909,0.684145,0.079528,0.158576,0.772303,-0.010687,0.017744,0.207142,0.698566,...,-0.090975,-0.369075,0.639004,-0.637897,-0.158942,-0.108213,0.344758,-0.076982,-0.767144,12
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3620,0.421578,-0.019629,0.656779,0.174148,-0.021393,0.840218,-0.100070,-0.011857,0.167257,0.692656,...,-0.110019,-0.271366,0.487649,-0.640063,-0.209200,0.021151,0.340553,-0.006638,-0.781998,26
3621,0.368945,-0.013736,0.691660,0.126277,0.095280,0.758165,-0.146781,-0.064218,0.114663,0.645248,...,-0.107440,-0.471480,0.559714,-0.732553,-0.126775,-0.059502,0.456765,-0.030017,-0.840898,11
3622,0.265164,-0.116977,0.706452,0.182964,0.142396,0.793411,0.043281,-0.093050,0.318244,0.642137,...,-0.205144,-0.313186,0.530374,-0.650075,-0.190040,0.123436,0.518753,-0.017935,-0.817771,19
3623,0.329796,-0.057874,0.734024,0.115979,0.116618,0.739115,-0.073479,-0.093619,0.288657,0.674363,...,-0.155023,-0.314374,0.598662,-0.596325,-0.178441,-0.073083,0.377727,0.009060,-0.792576,0


In [13]:
dataset_pd

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,759,760,761,762,763,764,765,766,767,labels
0,0.286959,-0.072739,0.753478,0.029065,-0.044117,0.796469,-0.080700,-0.116337,0.106141,0.664438,...,-0.081861,-0.296287,0.481009,-0.632536,-0.181246,-0.031159,0.367673,-0.146828,-0.786743,
1,0.357129,-0.174773,0.603911,0.026667,0.094043,0.766938,-0.142196,-0.122094,0.022330,0.673558,...,-0.077524,-0.412776,0.496480,-0.645392,-0.021252,0.059162,0.468731,0.179680,-0.866986,
2,0.395993,-0.010106,0.656082,0.130611,-0.051991,0.719536,-0.271542,0.046218,0.007320,0.625045,...,-0.233596,-0.333096,0.483195,-0.732185,-0.191192,0.217195,0.477800,-0.053176,-0.807459,
3,0.334255,-0.037806,0.703386,0.064638,0.027087,0.798698,-0.184298,-0.088888,0.138465,0.644953,...,-0.089683,-0.295783,0.493706,-0.596528,-0.298668,0.063497,0.299139,-0.043658,-0.803471,
4,0.320405,0.077909,0.684145,0.079528,0.158576,0.772303,-0.010687,0.017744,0.207142,0.698566,...,-0.090975,-0.369075,0.639004,-0.637897,-0.158942,-0.108213,0.344758,-0.076982,-0.767144,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3620,,,,,,,,,,,...,,,,,,,,,,26.0
3621,,,,,,,,,,,...,,,,,,,,,,11.0
3622,,,,,,,,,,,...,,,,,,,,,,19.0
3623,,,,,,,,,,,...,,,,,,,,,,0.0


In [None]:
dataset['embedings'].shape

In [None]:
dataset = dataset.shuffle().select(range(5000))

In [None]:
split_dataset = dataset.train_test_split(test_size=0.1, stratify_by_column="labels")

### Dataset definition

In [None]:
class TweetDataset(TorchDataset):
    def __init__(self, dataset: Dataset):
        self.input_ids = dataset['input_ids']
        self.attention_mask = dataset['attention_mask']
        self.dataset = dataset.remove_columns("labels")
        self.labels = dataset['labels']

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, item):
        #anchor = self.input_ids[item]
        anchor = self.dataset[item]
        anchor_class = self.labels[item]
        #anchor_attention = self.attention_mask[item]


        positive_indices = self.labels == anchor_class
        positive_indices = positive_indices.nonzero()
        positive_idx = positive_indices[torch.randint(high=len(positive_indices), size=(1, ))[0]]
        #positive_example = self.input_ids[positive_idx].flatten()
        #positive_attention = self.attention_mask[positive_idx]
        positive_example = self.dataset[positive_idx]

        negative_indices = self.labels != anchor_class
        negative_indices = negative_indices.nonzero()
        negative_idx = negative_indices[torch.randint(high=len(negative_indices), size=(1, ))[0]]
        #negative_example = self.input_ids[negative_idx].flatten()
        #negative_attention = self.attention_mask[negative_idx]
        negative_example = self.dataset[negative_idx]

        return anchor, positive_example, negative_example

In [None]:
train_ds = TweetDataset(split_dataset['train'])
test_ds = TweetDataset(split_dataset['test'])

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=True)

### Model Training

In [None]:
class TweetBERTTail(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.pooler = torch.nn.Sequential(
            torch.nn.Linear(768, 768, bias=True),
            torch.nn.ReLU(),
            torch.nn.Linear(768, 768, bias=True)
        )
        self.tahn = torch.nn.Tanh()

    def forward(self, x):
        x = self.pooler(x)
        return self.tahn(x)

In [None]:
model = TweetBERT()

In [None]:
model = model.to(device=device)

In [None]:
optimizer = Adam(model.parameters(), lr=LR)
loss = TripletMarginLoss()

In [None]:
run = neptune.init(
    project="konradszewczyk/TweetBuble",
    api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiI0MWIyOTA1ZS03ODc3LTQ5MzQtYjk0OS05ZjNjYzdiMDFjMDcifQ==",
)

os.mkdir(os.path.join('models', run['sys/id'].fetch()))

for epoch in range(EPOCHS):
    model.train()
    train_loss_log = []
    for batch_idx, (anchor, positive_ex, negative_ex) in enumerate(tqdm(train_dl)):
        #anchor = anchor.to(device=device)
        anchor = {k: v.to(device) for k, v in anchor.items()}
        archor_output = model(anchor)

        #positive_ex = positive_ex.to(device=device)
        positive_ex = {k: v[0].to(device) for k, v in positive_ex.items()}
        positive_ex_output = model(positive_ex)

        #negative_ex = negative_ex.to(device=device)
        negative_ex = {k: v[0].to(device) for k, v in negative_ex.items()}
        negative_ex_output = model(negative_ex)

        optimizer.zero_grad()
        train_loss = loss(archor_output, positive_ex_output, negative_ex_output)
        train_loss.backward()

        optimizer.step()

        train_loss_log.append(train_loss.detach().cpu())

    train_loss = np.mean(train_loss_log)
    run['train_loss'].log(train_loss)
    print("Epoch {:02d} train: {:.5f}".format(epoch, train_loss))

    file_name = 'epoch-{:02d}.pt'.format(epoch)
    PATH = os.path.join('models', run['sys/id'].fetch(), file_name)
    torch.save(model.state_dict(), PATH)

    model.eval()
    test_loss_log = []
    with torch.no_grad():
        for batch_idx, (anchor, positive_ex, negative_ex) in enumerate(tqdm(test_dl)):
            #anchor = anchor.to(device=device)
            anchor = {k: v.to(device) for k, v in anchor.items()}
            archor_output = model(anchor)

            #positive_ex = positive_ex.to(device=device)
            positive_ex = {k: v[0].to(device) for k, v in positive_ex.items()}
            positive_ex_output = model(positive_ex)

            #negative_ex = negative_ex.to(device=device)
            negative_ex = {k: v[0].to(device) for k, v in negative_ex.items()}
            negative_ex_output = model(negative_ex)

            test_loss = loss(archor_output, positive_ex_output, negative_ex_output)

            test_loss_log.append(test_loss.cpu())

    test_loss = np.mean(test_loss_log)
    run['test_loss'].log(test_loss)
    print("Epoch {:02d} val: {:.5f}".format(epoch, test_loss))

run.stop()

In [None]:
run.stop()

In [None]:
run.stop()