In [2]:
!pip install transformers
!pip install datasets
!pip install pymorphy2

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.22.2-py3-none-any.whl (4.9 MB)
[K     |████████████████████████████████| 4.9 MB 4.4 MB/s 
[?25hCollecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 32.3 MB/s 
Collecting huggingface-hub<1.0,>=0.9.0
  Downloading huggingface_hub-0.10.0-py3-none-any.whl (163 kB)
[K     |████████████████████████████████| 163 kB 42.2 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.10.0 tokenizers-0.12.1 transformers-4.22.2
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.5.1-py3-none-any.whl (431 kB)
[K     |████████████████████████████████| 431 kB 5.

In [16]:
import transformers
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, DistilBertTokenizer
import pandas as pd
from torch.utils.data import Dataset
from transformers import TrainingArguments, Trainer
import pandas as pd
import numpy as np
from gensim.models import Word2Vec 
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, classification_report, confusion_matrix
from sklearn.metrics import precision_recall_fscore_support
import torch


import warnings
warnings.filterwarnings("ignore")

data = pd.read_csv('/content/news.csv')
data=data.sample(frac=1.0, random_state=42)
#data = data.iloc[:500]


labels=data.target.unique()
NUM_LABELS= len(labels)
id2label={i:l for i,l in enumerate(labels)}
label2id={l:i for i,l in enumerate(labels)}
data["labels"]=data.target.map(lambda x: label2id[x])

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased", max_length=512)
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased",  num_labels=NUM_LABELS, id2label=id2label, label2id=label2id)
#model.to('cpu')

SIZE= data.shape[0]

df_train, df_test, y_train, y_test = train_test_split(data.text, data.labels, test_size=0.3)
df_test, df_val, y_test, y_val = train_test_split(df_test, y_test, test_size=0.5)
df_val, df_inference, y_val, y_inference = train_test_split(df_val, y_val, test_size=0.5)

train_texts= list(df_train)
val_texts=   list(df_val)
test_texts=  list(df_test)
inference_texts=  list(df_inference)

train_labels= list(y_train)
val_labels=   list(y_val)
test_labels=  list(y_test)
inference_labels=  list(y_inference)

train_encodings = tokenizer(train_texts, truncation=True, padding=True, return_tensors="pt")
val_encodings  = tokenizer(val_texts, truncation=True, padding=True, return_tensors="pt")
test_encodings = tokenizer(test_texts, truncation=True, padding=True, return_tensors="pt")
inference_encodings = tokenizer(inference_texts, truncation=True, padding=True, return_tensors="pt")

class MyDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item
    def __len__(self):
        return len(self.labels)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    acc = accuracy_score(labels, preds)
    return {
        'Accuracy': acc,
        'F1': f1,
        'Precision': precision,
        'Recall': recall
    }

train_dataset = MyDataset(train_encodings, train_labels)
val_dataset = MyDataset(val_encodings, val_labels)
test_dataset = MyDataset(test_encodings, test_labels)
inference_dataset = MyDataset(inference_encodings, inference_labels)

training_args = TrainingArguments(
    # The output directory where the model predictions and checkpoints will be written
    output_dir='/content/outputs/',
    #  The number of epochs, defaults to 3.0
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=4,
    # Number of steps used for a linear warmup
    #warmup_steps=10,
    weight_decay=0.01
    #no_cuda=True
   # TensorBoard log directory
    #fp16=True
)

trainer = Trainer(
    # the pre-trained model that will be fine-tuned
    model=model,
     # training arguments that we defined above
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics= compute_metrics
)

result = trainer.train()

q=[trainer.evaluate(eval_dataset=data) for data in [train_dataset, val_dataset, test_dataset, inference_dataset]]
print(pd.DataFrame(q, index=["train","val","test","inference"]).iloc[:,:5])


"""
def predict(text):
    inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors="pt")
    outputs = model(**inputs)
    probs = outputs[0].softmax(1)
    return probs


inference_texts = [
('Inflation in Turkey has climbed above 83% - a 24-year-high. The transport, food and housing sectors have seen the biggest rise in prices. Independent experts the Inflation Research Group estimate the annual rate is actually 186.27%.',label2id['business']), #business
('The pound has climbed after the chancellor reversed his controversial decision to scrap the top rate of tax.Sterling gained more than 1% to $1.1284 before falling back slightly while government borrowing costs edged lower.Tory MPs had threatened to vote against Kwasi Kwarteng\'s plan, saying it was unfair when living costs were so high.',label2id['business']),    #business
('Australia coach Mal Meninga has named 13 uncapped players in his squad as they chase a third men\'s Rugby League World Cup title in a row.The Kangaroos beat England in the 2017 final but have only played four Tests since with their last match a shock loss to Tonga three years ago.Sydney Roosters full-back James Tedesco, who represented Italy at the last two World Cups, captains the side.',label2id['sport']),  #sport
('England will play India in the group stage of the 2023 Women\'s T20 World Cup in South Africa.The two sides have been placed into Group B alongside West Indies, Pakistan and Ireland.Defending champions Australia, New Zealand, hosts South Africa, Sri Lanka and Bangladesh make up Group A with the top two in each group progressing to the semi-finals.The tournament takes place between 10 and 26 February.',label2id['sport']), #sport
('During World War II, Spitfire pilots described their plane as so responsive it felt like an extension of their limbs.Fighter pilots of the 2030s, however, will have an even closer relationship with their fighter jet.It will read their minds.',label2id['tech']), #tech
('In deep, astonishingly clear, blue-lit ponds some 40m (130ft) beneath the Swedish countryside, lies decades worth of high-level nuclear waste.It is an oddly beautiful and rather disturbing sight. Row upon row of long metal containers, filled with used nuclear fuel from the country\'s reactors, lie below the surface near Oskarshamn, on Sweden\'s Baltic coast.It is both highly lethal and entirely safe.',label2id['tech']) #tech
]

df = pd.DataFrame(inference_texts, columns=['texts', 'labels'])

inference_acc = 0
n_iter = 0
for text, label in zip(df.texts, df.labels):
    probs = predict(text)
    #print(probs.argmax(), label, n_iter)
    inference_acc += (probs.argmax() == label).sum().item()
    n_iter += 1

print('Inference_acc: {}'.format(inference_acc / n_iter ))
"""

loading file vocab.txt from cache at /root/.cache/huggingface/hub/models--distilbert-base-uncased/snapshots/043235d6088ecd3dd5fb5ca3592b6913fd516027/vocab.txt
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at /root/.cache/huggingface/hub/models--distilbert-base-uncased/snapshots/043235d6088ecd3dd5fb5ca3592b6913fd516027/tokenizer_config.json
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--distilbert-base-uncased/snapshots/043235d6088ecd3dd5fb5ca3592b6913fd516027/config.json
Model config DistilBertConfig {
  "_name_or_path": "distilbert-base-uncased",
  "activation": "gelu",
  "architectures": [
    "DistilBertForMaskedLM"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "pad_tok

Step,Training Loss




Training completed. Do not forget to share your model on huggingface.co/models =)


***** Running Evaluation *****
  Num examples = 350
  Batch size = 4


***** Running Evaluation *****
  Num examples = 37
  Batch size = 4
***** Running Evaluation *****
  Num examples = 75
  Batch size = 4
***** Running Evaluation *****
  Num examples = 38
  Batch size = 4


           eval_loss  eval_Accuracy   eval_F1  eval_Precision  eval_Recall
train       0.646716       0.974286  0.971938        0.973671     0.970857
val         0.683049       0.945946  0.944279        0.942857     0.949206
test        0.687655       0.960000  0.951905        0.968421     0.945455
inference   0.663768       0.947368  0.931696        0.940000     0.935000


'\ndef predict(text):\n    inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors="pt")\n    outputs = model(**inputs)\n    probs = outputs[0].softmax(1)\n    return probs\n\n\ninference_texts = [\n(\'Inflation in Turkey has climbed above 83% - a 24-year-high. The transport, food and housing sectors have seen the biggest rise in prices. Independent experts the Inflation Research Group estimate the annual rate is actually 186.27%.\',label2id[\'business\']), #business\n(\'The pound has climbed after the chancellor reversed his controversial decision to scrap the top rate of tax.Sterling gained more than 1% to $1.1284 before falling back slightly while government borrowing costs edged lower.Tory MPs had threatened to vote against Kwasi Kwarteng\'s plan, saying it was unfair when living costs were so high.\',label2id[\'business\']),    #business\n(\'Australia coach Mal Meninga has named 13 uncapped players in his squad as they chase a third men\'s Rugby Leag