In [2]:
import torch
import evaluate
import torch.nn as nn
import torch.optim as optim
from datasets import load_dataset
from transformers import AutoModel
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, accuracy_score


In [3]:
GOOGLE_BERT = "bert-base-multilingual-uncased"

tokenizer = AutoTokenizer.from_pretrained(GOOGLE_BERT)
model = AutoModel.from_pretrained(GOOGLE_BERT)

dataset = load_dataset("AmazonScience/massive", 'ru-RU')

Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [26]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'locale', 'partition', 'scenario', 'intent', 'utt', 'annot_utt', 'worker_id', 'slot_method', 'judgments'],
        num_rows: 11514
    })
    validation: Dataset({
        features: ['id', 'locale', 'partition', 'scenario', 'intent', 'utt', 'annot_utt', 'worker_id', 'slot_method', 'judgments'],
        num_rows: 2033
    })
    test: Dataset({
        features: ['id', 'locale', 'partition', 'scenario', 'intent', 'utt', 'annot_utt', 'worker_id', 'slot_method', 'judgments'],
        num_rows: 2974
    })
})

In [4]:
def get_intents(intents):
    intent_to_index = {}
    index_to_intent = {}
    for i, value in enumerate(intents):
        index_to_intent[i] = value
        intent_to_index[value] = i
    return index_to_intent, intent_to_index

In [5]:
index_to_intent, intent_to_index = get_intents(dataset['train'].features["intent"].names)

In [6]:
index_to_intent

{0: 'datetime_query',
 1: 'iot_hue_lightchange',
 2: 'transport_ticket',
 3: 'takeaway_query',
 4: 'qa_stock',
 5: 'general_greet',
 6: 'recommendation_events',
 7: 'music_dislikeness',
 8: 'iot_wemo_off',
 9: 'cooking_recipe',
 10: 'qa_currency',
 11: 'transport_traffic',
 12: 'general_quirky',
 13: 'weather_query',
 14: 'audio_volume_up',
 15: 'email_addcontact',
 16: 'takeaway_order',
 17: 'email_querycontact',
 18: 'iot_hue_lightup',
 19: 'recommendation_locations',
 20: 'play_audiobook',
 21: 'lists_createoradd',
 22: 'news_query',
 23: 'alarm_query',
 24: 'iot_wemo_on',
 25: 'general_joke',
 26: 'qa_definition',
 27: 'social_query',
 28: 'music_settings',
 29: 'audio_volume_other',
 30: 'calendar_remove',
 31: 'iot_hue_lightdim',
 32: 'calendar_query',
 33: 'email_sendemail',
 34: 'iot_cleaning',
 35: 'audio_volume_down',
 36: 'play_radio',
 37: 'cooking_query',
 38: 'datetime_convert',
 39: 'qa_maths',
 40: 'iot_hue_lightoff',
 41: 'iot_hue_lighton',
 42: 'transport_query',
 43:

In [7]:
intent_to_index

{'datetime_query': 0,
 'iot_hue_lightchange': 1,
 'transport_ticket': 2,
 'takeaway_query': 3,
 'qa_stock': 4,
 'general_greet': 5,
 'recommendation_events': 6,
 'music_dislikeness': 7,
 'iot_wemo_off': 8,
 'cooking_recipe': 9,
 'qa_currency': 10,
 'transport_traffic': 11,
 'general_quirky': 12,
 'weather_query': 13,
 'audio_volume_up': 14,
 'email_addcontact': 15,
 'takeaway_order': 16,
 'email_querycontact': 17,
 'iot_hue_lightup': 18,
 'recommendation_locations': 19,
 'play_audiobook': 20,
 'lists_createoradd': 21,
 'news_query': 22,
 'alarm_query': 23,
 'iot_wemo_on': 24,
 'general_joke': 25,
 'qa_definition': 26,
 'social_query': 27,
 'music_settings': 28,
 'audio_volume_other': 29,
 'calendar_remove': 30,
 'iot_hue_lightdim': 31,
 'calendar_query': 32,
 'email_sendemail': 33,
 'iot_cleaning': 34,
 'audio_volume_down': 35,
 'play_radio': 36,
 'cooking_query': 37,
 'datetime_convert': 38,
 'qa_maths': 39,
 'iot_hue_lightoff': 40,
 'iot_hue_lighton': 41,
 'transport_query': 42,
 'mu

In [9]:
def get_slots(dataset):
    index_to_slot = {} 
    slot_to_index = {}
    
    processed_slots = set()

    index_to_slot[0] = "O"
    index_to_slot[-100] = "<PAD>"
    slot_to_index["O"] = 0
    slot_to_index["<PAD>"] = -100

    counter = 1
    for key in ["train", "validation", "test"]:
        for slots in dataset[key]["slot_method"]:
            for slot in slots["slot"]:
                if slot not in processed_slots:  
                    
                    index_to_slot[counter] = f"B-{slot}"
                    slot_to_index[f"B-{slot}"] = counter
                    counter += 1
                    
                    index_to_slot[counter] = f"I-{slot}"
                    slot_to_index[f"I-{slot}"] = counter
                    counter += 1
                    processed_slots.add(slot)  

    return index_to_slot, slot_to_index 


In [10]:
index_to_slot, slot_to_index = get_slots(dataset)

In [11]:
index_to_slot

{0: 'O',
 -100: '<PAD>',
 1: 'B-time',
 2: 'I-time',
 3: 'B-date',
 4: 'I-date',
 5: 'B-color_type',
 6: 'I-color_type',
 7: 'B-house_place',
 8: 'I-house_place',
 9: 'B-change_amount',
 10: 'I-change_amount',
 11: 'B-artist_name',
 12: 'I-artist_name',
 13: 'B-media_type',
 14: 'I-media_type',
 15: 'B-place_name',
 16: 'I-place_name',
 17: 'B-time_zone',
 18: 'I-time_zone',
 19: 'B-order_type',
 20: 'I-order_type',
 21: 'B-food_type',
 22: 'I-food_type',
 23: 'B-news_topic',
 24: 'I-news_topic',
 25: 'B-song_name',
 26: 'I-song_name',
 27: 'B-music_genre',
 28: 'I-music_genre',
 29: 'B-device_type',
 30: 'I-device_type',
 31: 'B-meal_type',
 32: 'I-meal_type',
 33: 'B-business_name',
 34: 'I-business_name',
 35: 'B-general_frequency',
 36: 'I-general_frequency',
 37: 'B-weather_descriptor',
 38: 'I-weather_descriptor',
 39: 'B-player_setting',
 40: 'I-player_setting',
 41: 'B-joke_type',
 42: 'I-joke_type',
 43: 'B-timeofday',
 44: 'I-timeofday',
 45: 'B-event_name',
 46: 'I-event_nam

In [12]:
slot_to_index

{'O': 0,
 '<PAD>': -100,
 'B-time': 1,
 'I-time': 2,
 'B-date': 3,
 'I-date': 4,
 'B-color_type': 5,
 'I-color_type': 6,
 'B-house_place': 7,
 'I-house_place': 8,
 'B-change_amount': 9,
 'I-change_amount': 10,
 'B-artist_name': 11,
 'I-artist_name': 12,
 'B-media_type': 13,
 'I-media_type': 14,
 'B-place_name': 15,
 'I-place_name': 16,
 'B-time_zone': 17,
 'I-time_zone': 18,
 'B-order_type': 19,
 'I-order_type': 20,
 'B-food_type': 21,
 'I-food_type': 22,
 'B-news_topic': 23,
 'I-news_topic': 24,
 'B-song_name': 25,
 'I-song_name': 26,
 'B-music_genre': 27,
 'I-music_genre': 28,
 'B-device_type': 29,
 'I-device_type': 30,
 'B-meal_type': 31,
 'I-meal_type': 32,
 'B-business_name': 33,
 'I-business_name': 34,
 'B-general_frequency': 35,
 'I-general_frequency': 36,
 'B-weather_descriptor': 37,
 'I-weather_descriptor': 38,
 'B-player_setting': 39,
 'I-player_setting': 40,
 'B-joke_type': 41,
 'I-joke_type': 42,
 'B-timeofday': 43,
 'I-timeofday': 44,
 'B-event_name': 45,
 'I-event_name': 

In [13]:
def tokenize_and_align_labels(tokenizer, slots):
    slot_labels = []
    
    L_Paren = "["
    R_Paren = "]"
    
    is_Slot = False
    count = 0
    slot = None
    
    arr = slots.split()
    for value in arr:
        if L_Paren in value:
            slot = value[1:]
            
        elif ':' == value:
            is_Slot = True
            
        else:
            
            if is_Slot:
                
                if R_Paren in value:
                    value = value[:len(value)-1]
                    is_Slot  = False
                
                tmp = tokenizer.tokenize(value)
                for i in range(len(tmp)):
                    if count == 0:
                        slot_labels.append(f"B-{slot}")
                        count = 1
                        
                    else:
                        slot_labels.append(f"I-{slot}")
                        
                        if not is_Slot and i == len(tmp)-1:
                            count = 0
            else:
                tmp = tokenizer.tokenize(value)
                for i in range(len(tmp)):
                    slot_labels.append("O")
                    
    return [slot_to_index.get(label, slot_to_index['O']) for label in slot_labels]

In [14]:
max_len = 100

class BigData(Dataset):
  def __init__(self, dataset):
    self.dataset = dataset

  def __len__(self):
      return len(self.dataset)

  def __getitem__(self, index):

        elem = self.dataset[index]

        encoding = tokenizer.encode_plus(
            elem['utt'],
            add_special_tokens=True,
            max_length=max_len,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True
        )


        intent_labels = elem['intent']
        slot_labels = tokenize_and_align_labels(tokenizer, elem['annot_utt'], slot_to_index)

        
        if len(slot_labels) > max_len:
            slot_labels = slot_labels[:max_len]
        elif len(slot_labels)  < max_len:
            while (len(slot_labels) != max_len):
                slot_labels.append(-100)

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'intent_labels': torch.tensor(intent_labels, dtype=torch.long),
            'slot_labels': torch.tensor(slot_labels, dtype=torch.long)
        }

In [15]:
train_data = BigData(dataset["train"])
test_data = BigData(dataset["test"])

In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

cuda


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(105879, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
 

```cpp
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.
```

In [17]:
intent_classifier = nn.Linear(model.config.hidden_size, len(intent_to_index), device=device)
slot_classifier = nn.Linear(model.config.hidden_size, len(slot_to_index), device=device)

dropout = nn.Dropout(0.1)

In [18]:
LR = 5e-5 #  насколько сильно модель обновляет свои веса на каждом шаге
EPOCHES = 5
BATCH_SIZE = 32

def train(model, dataset):
        dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) # разбиваю данные на батчи
        optimizer = optim.AdamW(model.parameters(), lr=LR) # мой  оптимайзер, который обновляет параметры мождели с учетом lr
        criterion_intent = nn.CrossEntropyLoss() #  моя функция  потери для предсказания намерений
        criterion_slot = nn.CrossEntropyLoss(ignore_index=-100) #  моя функция потери для предсказания слотов

        model.train()  # загружаю  модель на видеокарту

        for epoch in range(EPOCHES): # обучаю эпохами
            total_loss_intent = 0
            total_loss_slot = 0

            for index, batch in enumerate(dataloader):

                optimizer.zero_grad() # обнуляю градиент  перед каждым новым шагом

                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                intent_labels = batch['intent_labels'].to(device)
                slot_labels = batch['slot_labels'].to(device)

                outputs = model(input_ids, attention_mask=attention_mask, return_dict=True) # результаты работы
                sequence_output = outputs.last_hidden_state

                pooled_output = outputs.pooler_output

                pooled_output = dropout(pooled_output) # добавляет регуляризацию, обнуляя случайные элементы.
                intent_logits = intent_classifier(pooled_output)

                sequence_output = dropout(sequence_output)
                slot_logits = slot_classifier(sequence_output)

                loss_intent = criterion_intent(intent_logits, intent_labels)
                loss_slot = criterion_slot(slot_logits.view(-1, slot_logits.shape[-1]), slot_labels.view(-1))

                loss = loss_intent + loss_slot # сумма двух потерь
                loss.backward() # вычисляю градиент
                optimizer.step() # Обновление параметров модели с использованием вычисленных градиентов.

                total_loss_intent += loss_intent.item()
                total_loss_slot += loss_slot.item()

                print(f"Intent Loss: {loss_intent:.4f}, Slot Loss: {loss_slot:.4f}")

            avg_loss_intent = total_loss_intent / len(dataloader)
            avg_loss_slot = total_loss_slot / len(dataloader)

            print(f"Epoch {epoch+1}/{EPOCHES}, Intent Loss: {avg_loss_intent:.4f}, Slot Loss: {avg_loss_slot:.4f}")

In [19]:
train(model, train_data)

Intent Loss: 4.1006, Slot Loss: 4.6864
Intent Loss: 4.1277, Slot Loss: 3.8316
Intent Loss: 4.1402, Slot Loss: 2.9105
Intent Loss: 4.0545, Slot Loss: 2.4802
Intent Loss: 4.0610, Slot Loss: 2.1301
Intent Loss: 4.0331, Slot Loss: 1.9440
Intent Loss: 4.0689, Slot Loss: 1.8878
Intent Loss: 4.0669, Slot Loss: 1.3849
Intent Loss: 4.0576, Slot Loss: 1.6899
Intent Loss: 3.9278, Slot Loss: 1.6859
Intent Loss: 3.9813, Slot Loss: 1.9373
Intent Loss: 3.9691, Slot Loss: 2.1713
Intent Loss: 3.8282, Slot Loss: 1.7469
Intent Loss: 3.9265, Slot Loss: 1.6084
Intent Loss: 3.8478, Slot Loss: 1.8162
Intent Loss: 3.7839, Slot Loss: 2.0674
Intent Loss: 3.7989, Slot Loss: 1.9460
Intent Loss: 3.7381, Slot Loss: 1.6654
Intent Loss: 3.9192, Slot Loss: 1.8691
Intent Loss: 3.7817, Slot Loss: 2.0632
Intent Loss: 3.7928, Slot Loss: 1.7003
Intent Loss: 3.7345, Slot Loss: 1.9824
Intent Loss: 4.0141, Slot Loss: 1.7784
Intent Loss: 4.0571, Slot Loss: 1.9929
Intent Loss: 3.9335, Slot Loss: 1.8981
Intent Loss: 3.9717, Slot

Intent Loss: 1.7911, Slot Loss: 0.9894
Intent Loss: 1.9331, Slot Loss: 0.7858
Intent Loss: 2.0511, Slot Loss: 0.8312
Intent Loss: 1.9444, Slot Loss: 1.3035
Intent Loss: 1.9365, Slot Loss: 0.9773
Intent Loss: 1.7062, Slot Loss: 1.0078
Intent Loss: 2.2473, Slot Loss: 0.6138
Intent Loss: 1.8842, Slot Loss: 1.0807
Intent Loss: 1.9199, Slot Loss: 1.0576
Intent Loss: 2.0614, Slot Loss: 0.8737
Intent Loss: 1.8540, Slot Loss: 1.0456
Intent Loss: 1.9734, Slot Loss: 1.6611
Intent Loss: 2.0986, Slot Loss: 1.2529
Intent Loss: 1.4804, Slot Loss: 0.7803
Intent Loss: 1.9911, Slot Loss: 0.9309
Intent Loss: 1.9649, Slot Loss: 0.9895
Intent Loss: 1.6444, Slot Loss: 0.8195
Intent Loss: 2.4830, Slot Loss: 0.8355
Intent Loss: 2.1450, Slot Loss: 1.1064
Intent Loss: 1.7457, Slot Loss: 0.9593
Intent Loss: 2.1979, Slot Loss: 1.2440
Intent Loss: 2.1063, Slot Loss: 0.7973
Intent Loss: 1.7172, Slot Loss: 0.9959
Intent Loss: 1.4923, Slot Loss: 1.0568
Intent Loss: 1.9487, Slot Loss: 1.0188
Intent Loss: 2.0779, Slot

Intent Loss: 1.1631, Slot Loss: 0.6073
Intent Loss: 1.4467, Slot Loss: 0.9545
Intent Loss: 1.3198, Slot Loss: 0.7056
Intent Loss: 1.2763, Slot Loss: 0.8188
Intent Loss: 1.5021, Slot Loss: 0.6389
Intent Loss: 0.6318, Slot Loss: 0.5435
Intent Loss: 1.1520, Slot Loss: 0.8409
Intent Loss: 1.1776, Slot Loss: 0.7092
Intent Loss: 1.2390, Slot Loss: 0.7078
Intent Loss: 1.2453, Slot Loss: 0.5948
Intent Loss: 0.9075, Slot Loss: 0.5999
Intent Loss: 1.0250, Slot Loss: 0.6435
Intent Loss: 1.0492, Slot Loss: 0.6909
Intent Loss: 1.0862, Slot Loss: 0.6431
Intent Loss: 0.9460, Slot Loss: 0.8346
Intent Loss: 1.0889, Slot Loss: 0.3885
Intent Loss: 0.9610, Slot Loss: 0.6390
Intent Loss: 1.0336, Slot Loss: 0.6416
Intent Loss: 1.0703, Slot Loss: 0.6740
Intent Loss: 0.6185, Slot Loss: 0.4659
Intent Loss: 1.2666, Slot Loss: 0.9753
Intent Loss: 1.2289, Slot Loss: 0.6326
Intent Loss: 0.7673, Slot Loss: 0.7571
Intent Loss: 1.1039, Slot Loss: 0.6499
Intent Loss: 0.8613, Slot Loss: 0.6988
Intent Loss: 1.0760, Slot

Intent Loss: 0.6555, Slot Loss: 0.5978
Intent Loss: 0.7106, Slot Loss: 0.5869
Intent Loss: 0.7382, Slot Loss: 0.4262
Intent Loss: 0.5448, Slot Loss: 0.6775
Intent Loss: 0.9427, Slot Loss: 0.4996
Intent Loss: 0.6685, Slot Loss: 0.5297
Intent Loss: 0.9592, Slot Loss: 0.5566
Intent Loss: 0.5719, Slot Loss: 0.5796
Intent Loss: 0.6804, Slot Loss: 0.9452
Intent Loss: 0.6482, Slot Loss: 0.6151
Intent Loss: 0.6717, Slot Loss: 0.4714
Intent Loss: 0.5361, Slot Loss: 0.7158
Intent Loss: 0.7268, Slot Loss: 0.5097
Intent Loss: 0.6972, Slot Loss: 0.5008
Intent Loss: 0.7306, Slot Loss: 0.5479
Intent Loss: 0.9929, Slot Loss: 0.4477
Intent Loss: 0.7974, Slot Loss: 0.3078
Intent Loss: 0.9781, Slot Loss: 0.6409
Intent Loss: 0.5007, Slot Loss: 0.6080
Intent Loss: 0.7009, Slot Loss: 0.4828
Intent Loss: 0.7505, Slot Loss: 0.5576
Intent Loss: 1.0068, Slot Loss: 0.4930
Intent Loss: 0.5524, Slot Loss: 0.6033
Intent Loss: 0.7364, Slot Loss: 0.5954
Intent Loss: 0.6000, Slot Loss: 0.4753
Intent Loss: 0.8950, Slot

Intent Loss: 0.5455, Slot Loss: 0.4244
Intent Loss: 0.6783, Slot Loss: 0.2777
Intent Loss: 0.6881, Slot Loss: 0.4736
Intent Loss: 0.9596, Slot Loss: 0.4860
Intent Loss: 0.8673, Slot Loss: 0.4987
Intent Loss: 0.6487, Slot Loss: 0.5458
Intent Loss: 0.5096, Slot Loss: 0.2977
Intent Loss: 0.4215, Slot Loss: 0.3413
Intent Loss: 0.3823, Slot Loss: 0.3955
Intent Loss: 0.4693, Slot Loss: 0.3009
Intent Loss: 0.4874, Slot Loss: 0.2761
Intent Loss: 0.4421, Slot Loss: 0.2796
Intent Loss: 0.3577, Slot Loss: 0.4222
Intent Loss: 0.5767, Slot Loss: 0.4684
Intent Loss: 0.5209, Slot Loss: 0.6239
Intent Loss: 0.7044, Slot Loss: 0.3584
Intent Loss: 0.5265, Slot Loss: 0.4023
Intent Loss: 0.3311, Slot Loss: 0.2421
Intent Loss: 0.3168, Slot Loss: 0.3458
Intent Loss: 0.6796, Slot Loss: 0.4373
Intent Loss: 0.5715, Slot Loss: 0.5154
Intent Loss: 0.5924, Slot Loss: 0.5292
Intent Loss: 0.2582, Slot Loss: 0.5115
Intent Loss: 0.6701, Slot Loss: 0.3771
Intent Loss: 0.7499, Slot Loss: 0.5724
Intent Loss: 0.3855, Slot

Intent Loss: 0.5053, Slot Loss: 0.4480
Intent Loss: 0.5611, Slot Loss: 0.2663
Intent Loss: 0.4828, Slot Loss: 0.4331
Intent Loss: 0.3106, Slot Loss: 0.3923
Intent Loss: 0.4628, Slot Loss: 0.4539
Intent Loss: 0.6020, Slot Loss: 0.5037
Intent Loss: 0.5285, Slot Loss: 0.4817
Intent Loss: 0.6713, Slot Loss: 0.5005
Intent Loss: 0.3610, Slot Loss: 0.5506
Intent Loss: 0.2199, Slot Loss: 0.3626
Intent Loss: 0.4678, Slot Loss: 0.5589
Intent Loss: 0.5208, Slot Loss: 0.2627
Intent Loss: 0.5665, Slot Loss: 0.3138
Intent Loss: 0.2942, Slot Loss: 0.3667
Intent Loss: 0.5064, Slot Loss: 0.2175
Intent Loss: 0.3807, Slot Loss: 0.4289
Intent Loss: 0.4151, Slot Loss: 0.2278
Intent Loss: 0.5296, Slot Loss: 0.4478
Intent Loss: 0.5607, Slot Loss: 0.4930
Intent Loss: 0.3409, Slot Loss: 0.4675
Intent Loss: 0.2518, Slot Loss: 0.2055
Intent Loss: 0.4284, Slot Loss: 0.5050
Intent Loss: 0.4424, Slot Loss: 0.6705
Intent Loss: 0.6250, Slot Loss: 0.2299
Intent Loss: 0.4298, Slot Loss: 0.4972
Intent Loss: 0.5624, Slot

Intent Loss: 0.5390, Slot Loss: 0.2849
Intent Loss: 0.6796, Slot Loss: 0.5079
Intent Loss: 0.2294, Slot Loss: 0.4352
Intent Loss: 0.2505, Slot Loss: 0.3347
Intent Loss: 0.3815, Slot Loss: 0.3873
Intent Loss: 0.3889, Slot Loss: 0.1967
Intent Loss: 0.3365, Slot Loss: 0.2137
Intent Loss: 0.4395, Slot Loss: 0.3138
Intent Loss: 0.4110, Slot Loss: 0.3140
Intent Loss: 0.7195, Slot Loss: 0.4640
Intent Loss: 0.4432, Slot Loss: 0.3232
Intent Loss: 0.6382, Slot Loss: 0.3403
Intent Loss: 0.5131, Slot Loss: 0.3339
Intent Loss: 0.3616, Slot Loss: 0.3184
Intent Loss: 0.1928, Slot Loss: 0.1597
Intent Loss: 0.3145, Slot Loss: 0.2512
Intent Loss: 0.3007, Slot Loss: 0.1865
Intent Loss: 0.2816, Slot Loss: 0.3663
Intent Loss: 0.2485, Slot Loss: 0.2715
Intent Loss: 0.2885, Slot Loss: 0.1413
Intent Loss: 0.4154, Slot Loss: 0.3793
Intent Loss: 0.4824, Slot Loss: 0.4553
Intent Loss: 0.5346, Slot Loss: 0.3837
Intent Loss: 0.3946, Slot Loss: 0.2622
Intent Loss: 0.6588, Slot Loss: 0.4455
Intent Loss: 0.1384, Slot

Intent Loss: 0.3042, Slot Loss: 0.4586
Intent Loss: 0.5218, Slot Loss: 0.2624
Intent Loss: 0.3581, Slot Loss: 0.2743
Intent Loss: 0.2769, Slot Loss: 0.1827
Intent Loss: 0.3084, Slot Loss: 0.3068
Intent Loss: 0.1014, Slot Loss: 0.1820
Intent Loss: 0.4322, Slot Loss: 0.2024
Intent Loss: 0.4342, Slot Loss: 0.3147
Intent Loss: 0.1698, Slot Loss: 0.1623
Intent Loss: 0.3201, Slot Loss: 0.1325
Intent Loss: 0.2550, Slot Loss: 0.2967
Intent Loss: 0.3598, Slot Loss: 0.2357
Intent Loss: 0.2679, Slot Loss: 0.3185
Intent Loss: 0.1092, Slot Loss: 0.4898
Intent Loss: 0.5022, Slot Loss: 0.1654
Intent Loss: 0.1511, Slot Loss: 0.2813
Intent Loss: 0.2213, Slot Loss: 0.3240
Intent Loss: 0.2199, Slot Loss: 0.3187
Intent Loss: 0.2622, Slot Loss: 0.3199
Intent Loss: 0.1454, Slot Loss: 0.2870
Intent Loss: 0.2542, Slot Loss: 0.2164
Intent Loss: 0.2764, Slot Loss: 0.2910
Intent Loss: 0.3264, Slot Loss: 0.2621
Intent Loss: 0.3163, Slot Loss: 0.4318
Intent Loss: 0.1695, Slot Loss: 0.2064
Intent Loss: 0.3386, Slot

Intent Loss: 0.2477, Slot Loss: 0.1937
Intent Loss: 0.2107, Slot Loss: 0.3225
Intent Loss: 0.3780, Slot Loss: 0.1740
Intent Loss: 0.1495, Slot Loss: 0.2753
Intent Loss: 0.4210, Slot Loss: 0.1875
Intent Loss: 0.1647, Slot Loss: 0.2185
Intent Loss: 0.3051, Slot Loss: 0.3014
Intent Loss: 0.3602, Slot Loss: 0.1748
Intent Loss: 0.1952, Slot Loss: 0.2469
Intent Loss: 0.4624, Slot Loss: 0.2910
Intent Loss: 0.1064, Slot Loss: 0.4595
Intent Loss: 0.2873, Slot Loss: 0.4065
Intent Loss: 0.0862, Slot Loss: 0.2380
Intent Loss: 0.1724, Slot Loss: 0.2634
Intent Loss: 0.4928, Slot Loss: 0.3034
Intent Loss: 0.1952, Slot Loss: 0.2786
Intent Loss: 0.3367, Slot Loss: 0.1884
Intent Loss: 0.1355, Slot Loss: 0.2397
Intent Loss: 0.3036, Slot Loss: 0.1854
Intent Loss: 0.4054, Slot Loss: 0.2252
Intent Loss: 0.3668, Slot Loss: 0.2012
Intent Loss: 0.6646, Slot Loss: 0.2946
Intent Loss: 0.1808, Slot Loss: 0.2446
Intent Loss: 0.2693, Slot Loss: 0.2825
Intent Loss: 0.2650, Slot Loss: 0.2831
Intent Loss: 0.5299, Slot

In [20]:
metric = evaluate.load("seqeval")

def eval_model(model, dataset, batch_size=BATCH_SIZE):
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        model.eval()

        all_intent_preds = []
        all_intent_labels = []
        all_slot_preds = []
        all_slot_labels = []

        with torch.no_grad():
            for index, batch in enumerate(dataloader):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                intent_labels = batch['intent_labels'].to(device)
                slot_labels = batch['slot_labels'].to(device)

                outputs = model(input_ids, attention_mask=attention_mask, return_dict=True)
                sequence_output = outputs.last_hidden_state

                pooled_output = outputs.pooler_output

                pooled_output = dropout(pooled_output)
                intent_logits = intent_classifier(pooled_output)

                sequence_output = dropout(sequence_output)
                slot_logits = slot_classifier(sequence_output)

                intent_preds = torch.argmax(intent_logits, dim=1)
                all_intent_preds.extend(intent_preds.cpu().numpy())
                all_intent_labels.extend(intent_labels.cpu().numpy())

                slot_preds = torch.argmax(slot_logits, dim=2)

                for i in range(slot_preds.shape[0]):
                    seq_length = slot_labels[i].size(0)

                    single_pred = slot_preds[i].cpu().numpy().tolist()  
                    single_label = slot_labels[i].cpu().numpy().tolist()  

                    filtered_pred = []
                    filtered_label = []

                    for pred, label in zip(single_pred, single_label):
                      if label != -100:
                          if pred == -100:
                            filtered_pred.append(0)
                          else:
                            filtered_pred.append(pred)
                          filtered_label.append(label)

                    all_slot_preds.append(filtered_pred)
                    all_slot_labels.append(filtered_label)

        
        intent_accuracy = accuracy_score(all_intent_labels, all_intent_preds)

       
        intent_cm = confusion_matrix(all_intent_labels, all_intent_preds)

        valid_slot_labels = []
        for slot_labels in all_slot_labels:
            valid_slot_labels.append([index_to_slot[label] for label in slot_labels])

        valid_slot_preds = []
        for slot_preds in all_slot_preds:
            valid_slot_preds.append([index_to_slot[label] for label in slot_preds])


        all_metrics = metric.compute(predictions=valid_slot_preds, references=valid_slot_labels, zero_division=0)

        return {
            "Intent Accuracy": intent_accuracy,
            "Intent Confusion Matrix": intent_cm,
            "Overall Metrics": {
                "Precision": all_metrics["overall_precision"],
                "Recall": all_metrics["overall_recall"],
                "F1 Score": all_metrics["overall_f1"],
                "Accuracy": all_metrics["overall_accuracy"],
            }
        }

> Precision = TP / (TP + FN) // точность

> Recall = TP / (TP + FN) // полнота

> F1 = 2 * (Precision * Recall) / (Precision + Recall) // среднее гармоническое между Precision и Recall

In [21]:
eval_model(model, test_data)

{'Intent Accuracy': 0.8426361802286483,
 'Intent Confusion Matrix': array([[73,  0,  0, ...,  0,  0,  0],
        [ 0, 30,  0, ...,  0,  0,  0],
        [ 0,  0, 34, ...,  0,  0,  0],
        ...,
        [ 0,  0,  0, ..., 27,  0,  0],
        [ 0,  0,  0, ...,  0, 59,  0],
        [ 0,  0,  0, ...,  0,  0, 43]], dtype=int64),
 'Overall Metrics': {'Precision': 0.6486223662884927,
  'Recall': 0.7115931721194879,
  'F1 Score': 0.6786501610988638,
  'Accuracy': 0.901533396335701}}

In [22]:
def parse_slots(slot_preds, text):
    slot_map = {}
    tokens = tokenizer.tokenize(text)
    token_labels = [index_to_slot[idx] for idx in slot_preds][1:len(slot_preds) - 1]
    print(token_labels)
    
    current_slot_value = None
    current_slot_words = []

    for i, label in enumerate(token_labels):
        token = tokens[i]

        if label == "O" or label == "<PAD>":
            if current_slot_value is not None:
                slot_map[current_slot_value] = " ".join(current_slot_words)
                current_slot_value = None
                current_slot_words = []
            continue

        if label.startswith("B-"):
            if current_slot_value is not None:
                slot_map[current_slot_value] = " ".join(current_slot_words)
            current_slot_value = label[2:] 
            current_slot_words = [token]  
        elif label.startswith("I-") and current_slot_value == label[2:]:
            current_slot_words.append(token)
        else:
            if current_slot_value is not None:
                slot_map[current_slot_value] = " ".join(current_slot_words)
            current_slot_value = None
            current_slot_words = []

    if current_slot_value is not None:
        slot_map[current_slot_value] = " ".join(current_slot_words)

    return slot_map

In [23]:
def predict(text):
    model.eval()
    inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
    
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    intent_pred = None
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask, return_dict=True)
        sequence_output = outputs.last_hidden_state

        pooled_output = outputs.pooler_output

        pooled_output = dropout(pooled_output)
        intent_logits = intent_classifier(pooled_output)

        sequence_output = dropout(sequence_output)
        slot_logits = slot_classifier(sequence_output)

        intent_pred = torch.argmax(intent_logits, dim=1).cpu().numpy()[0]
        slot_preds = torch.argmax(slot_logits, dim=2).cpu().numpy()[0].tolist()

        intent_pred = index_to_intent[intent_pred]


    slot_map = parse_slots(slot_preds, text)
    
    return {
        "intent": intent_pred,
        "slots": slot_map
    }

In [28]:
validation = dataset["test"]
print(validation)

TEST_SIZE = 10

for index in range(10):
    data = validation[index]
    print("="*20)
    result = predict(data["utt"])
    print("Expected intent:\t", index_to_intent[data["intent"]])
    print("Got intent:\t\t", result["intent"])
    print("Expected slots:\t\t",  data["annot_utt"])
    print("Got slots:\t\t", result["slots"])

Dataset({
    features: ['id', 'locale', 'partition', 'scenario', 'intent', 'utt', 'annot_utt', 'worker_id', 'slot_method', 'judgments'],
    num_rows: 2974
})
['O', 'O', 'O', 'O', 'B-time', 'I-time', 'I-time', 'O', 'B-date', 'I-date', 'I-date', 'I-date', 'O']
Expected intent:	 alarm_set
Got intent:		 alarm_set
Expected slots:		 разбуди меня в [time : пять утра] на [date : этой неделе]
Got slots:		 {'time': 'в пять у', 'date': 'на этои не ##дел'}
['O', 'O']
Expected intent:	 audio_volume_mute
Got intent:		 audio_volume_down
Expected slots:		 тише
Got slots:		 {}
['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
Expected intent:	 iot_hue_lightchange
Got intent:		 qa_factoid
Expected slots:		 [color_type : розовый] это то что нам надо
Got slots:		 {}
['O', 'O', 'O', 'O', 'O', 'O', 'O']
Expected intent:	 iot_hue_lighton
Got intent:		 iot_hue_lightup
Expected slots:		 и опустилась темнота
Got slots:		 {}
['O', 'O', 'O', 'O', 'O', 'O', 'B-house_place', 'I-house_place', 'I-house_place', 'O']
Exp