In [1]:
from datasets import load_dataset
import torch
import torch.nn as nn
from transformers import BertTokenizerFast
from torch.utils.data import DataLoader, TensorDataset

# Check if GPU is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the Train, Validation and Test data

dataset = load_dataset("multi_woz_v22")
train_data = dataset['train']
val_data = dataset['validation']
test_data = dataset['test']

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def filterDomains(data):
    """
    Filters a list of dictionaries by only including entries with services
    either "restaurant" or "hotel" and having only one service.

    Parameters:
    - data: list of dictionaries containing a "services" key, which is a list of services.

    Returns:
    - List of filtered dictionaries.
    """
    return [entry for entry in data if set(entry["services"]).issubset({"restaurant", "hotel"})]

# Only keep dialogues related to Restaurants or Hotels.

train_data_filtered = filterDomains(train_data)
val_data_filtered = filterDomains(val_data)
test_data_filtered = filterDomains(test_data)

In [3]:
print(train_data_filtered[0])

{'dialogue_id': 'PMUL4398.json', 'services': ['restaurant', 'hotel'], 'turns': {'turn_id': ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11'], 'speaker': [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], 'utterance': ['i need a place to dine in the center thats expensive', 'I have several options for you; do you prefer African, Asian, or British food?', 'Any sort of food would be fine, as long as it is a bit expensive. Could I get the phone number for your recommendation?', 'There is an Afrian place named Bedouin in the centre. How does that sound?', 'Sounds good, could I get that phone number? Also, could you recommend me an expensive hotel?', "Bedouin's phone is 01223367660. As far as hotels go, I recommend the University Arms Hotel in the center of town.", 'Yes. Can you book it for me?', 'Sure, when would you like that reservation?', 'i want to book it for 2 people and 2 nights starting from saturday.', 'Your booking was successful. Your reference number is FRGZWQL2 . May I help you

# Identifying the slots

In [4]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')


def label_utterances(dialogue):
    labeled_data = []
    data = dialogue['turns']
    
    # Loop through each turn in the dialogue
    for i, turn_id in enumerate(data['turn_id']):
        utterance = data['utterance'][i]
        # Tokenize the utterance and get the offset mappings
        encoded_input = tokenizer(utterance, add_special_tokens=False, return_offsets_mapping=True)
        tokens = tokenizer.convert_ids_to_tokens(encoded_input['input_ids'])
        labels = ['O'] * len(tokens)  # Initialize labels as 'O' (Outside)
        offset_mapping = encoded_input['offset_mapping']
        # Check if there are slot values for this turn
        if 'dialogue_acts' in data and i < len(data['dialogue_acts']):
            dialogue_act = data['dialogue_acts'][i]
            span_info = dialogue_act.get('span_info', {})
            for act_slot_name, act_slot_value, span_start, span_end in zip(
                    span_info.get('act_slot_name', []),
                    span_info.get('act_slot_value', []),
                    span_info.get('span_start', []),
                    span_info.get('span_end', [])):
                
                # Find the tokens that correspond to the start and end indices
                # start_token_idx = next((idx for idx, offset in enumerate(offset_mapping) if offset[0] == span_start), None)
                # end_token_idx = next((idx for idx, offset in enumerate(offset_mapping) if offset[1] == span_end), None)
                
                # Utilize the offset_mapping to find the token index for the start and end of the span
                start_token_idx = None
                end_token_idx = None
                
                for idx, offset in enumerate(offset_mapping):
                    if start_token_idx is None and offset[0] == span_start:
                        start_token_idx = idx
                    if offset[1] == span_end:
                        end_token_idx = idx
                        break
                
                if start_token_idx is not None and end_token_idx is not None:
                    if start_token_idx < len(tokens) and end_token_idx < len(tokens):
                        # Label tokens using IOB format with the actual ground truth slot value
                        labels[start_token_idx] = f"B-{act_slot_name}"
                        for j in range(start_token_idx + 1, end_token_idx + 1):
                            labels[j] = f"I-{act_slot_name}"
                    else:
                        print(f"Warning: Index out of range for utterance '{utterance}' with span {span_start}-{span_end}")
            
            try:
                # if the prev_dialogue_act is not None, then we need to label the tokens that are part of the previous dialogue act
                prev_dialogue_act = data['dialogue_acts'][i-1]['dialog_act']['act_type'][0] if i > 0 and data['dialogue_acts'][i]['dialog_act']['act_type'][0] else ""
                current_dialogue_act = data['dialogue_acts'][i]['dialog_act']['act_type'][0] if data['dialogue_acts'][i]['dialog_act']['act_type'][0] else ""
            except IndexError as e:
                prev_dialogue_act = ""
                current_dialogue_act = ""
            
            dialogue_act_str = f"{prev_dialogue_act}|{current_dialogue_act}"
            
            act_tokens = tokenizer.tokenize(dialogue_act_str)
            act_labels = ['X'] * len(act_tokens)
            tokens = act_tokens + ['[SEP]'] + tokens  # Add a separator token between acts and the utterance
            labels = act_labels + ['X'] + labels  # Add an 'X' label for the separator token

        # Store the tokenized utterance along with its labels
        labeled_data.append((tokens, labels))
        
    return labeled_data


In [5]:
import pandas as pd
import numpy as np

def toDF(data):
    all_labeled_data = []
    for dialogue in data:
        all_labeled_data.extend(label_utterances(dialogue))
    return pd.DataFrame(all_labeled_data, columns=['Tokens', 'Labels'])
    
# Create DataFrames of labeled utterances
train_df = toDF(train_data_filtered)
test_df = toDF(test_data_filtered)
val_df = toDF(val_data_filtered)

In [6]:
train_df.to_excel("output.xlsx")  

In [7]:
print(train_df.shape)
print(train_df["Tokens"].iloc[9])
print(train_df["Labels"].iloc[9])

(28928, 2)
['hotel', '-', 'inform', '|', 'booking', '-', 'book', '[SEP]', 'your', 'booking', 'was', 'successful', '.', 'your', 'reference', 'number', 'is', 'fr', '##g', '##z', '##w', '##q', '##l', '##2', '.', 'may', 'i', 'help', 'you', 'further', '?']
['X', 'X', 'X', 'X', 'X', 'X', 'X', 'X', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ref', 'I-ref', 'I-ref', 'I-ref', 'I-ref', 'I-ref', 'I-ref', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


In [8]:

all_labels = [label for sublist in train_df['Labels'].tolist() for label in sublist]
all_labels += [label for sublist in val_df['Labels'].tolist() for label in sublist]
all_labels += [label for sublist in test_df['Labels'].tolist() for label in sublist]
unique_labels = sorted(set(all_labels))

unique_labels.__sizeof__()

# We will ignore the 'X' label.
unique_labels.remove('X')

print(unique_labels)

['B-address', 'B-area', 'B-arriveby', 'B-bookday', 'B-bookpeople', 'B-bookstay', 'B-booktime', 'B-choice', 'B-day', 'B-department', 'B-departure', 'B-destination', 'B-entrancefee', 'B-food', 'B-leaveat', 'B-name', 'B-openhours', 'B-phone', 'B-postcode', 'B-price', 'B-pricerange', 'B-ref', 'B-stars', 'B-type', 'I-address', 'I-area', 'I-arriveby', 'I-bookday', 'I-bookpeople', 'I-bookstay', 'I-booktime', 'I-choice', 'I-department', 'I-departure', 'I-destination', 'I-entrancefee', 'I-food', 'I-leaveat', 'I-name', 'I-openhours', 'I-phone', 'I-postcode', 'I-price', 'I-pricerange', 'I-ref', 'I-stars', 'I-type', 'O']


In [9]:
label_map = {label: i for i, label in enumerate(unique_labels)}

In [10]:
def create_dataset(df, tokenizer, label_map):
    # Lists to store the tokenized inputs and labels
    input_ids = []
    attention_masks = []
    label_ids = []

    # Iterate over the DataFrame rows
    for _, row in df.iterrows():
        tokens = row['Tokens']
        labels = row['Labels']
        
        # Convert the IOB labels to their corresponding IDs
        label_ids_for_tokens = [label_map.get(label, -100) for label in labels] # ignore the 'X' label

        encoded_input = tokenizer(
            tokens,
            is_split_into_words=True,
            add_special_tokens=True,
            return_attention_mask=True,
            padding='max_length',
            truncation=True,
            max_length=256,
            return_offsets_mapping=True
        )
        
        # Create an empty array to hold the final label IDs
        aligned_labels = np.ones(len(encoded_input['input_ids']), dtype=int) * -100

        # Set labels using the word_ids to align them with tokens
        for i, word_id in enumerate(encoded_input.word_ids()):
            if word_id is not None and tokens[word_id] not in ["[CLS]", "[SEP]"]:
                aligned_labels[i] = label_ids_for_tokens[word_id]

        # Append the results to the lists
        input_ids.append(encoded_input['input_ids'])
        attention_masks.append(encoded_input['attention_mask'])
        label_ids.append(aligned_labels.tolist())

    # Convert lists to tensors
    input_ids = torch.tensor(input_ids, dtype=torch.long)
    attention_masks = torch.tensor(attention_masks, dtype=torch.long)
    label_ids = torch.tensor(label_ids, dtype=torch.long)

    # Create the TensorDataset
    dataset = TensorDataset(input_ids, attention_masks, label_ids)

    return dataset

In [11]:
train_dataloader = DataLoader(create_dataset(train_df,tokenizer,label_map), batch_size=16, shuffle=True)
val_dataloader = DataLoader(create_dataset(val_df,tokenizer,label_map), batch_size=16, shuffle=True)
test_dataloader = DataLoader(create_dataset(test_df,tokenizer,label_map), batch_size=16, shuffle=True)

In [12]:
from transformers import BertForTokenClassification, BertConfig

# Define the number of labels
num_labels = len(label_map)  # Make sure label_map is defined in your environment

# Create a configuration object with `num_labels` set
config = BertConfig.from_pretrained('bert-base-uncased', num_labels=num_labels)

# Create the model with the standard token classification head
model = BertForTokenClassification(config).to(device)


In [103]:
from tqdm.auto import tqdm

optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
criterion = nn.CrossEntropyLoss(ignore_index=-100)
epochs = 7
patience = 2


# Initialize the early stopping counter
best_val_loss = float('inf')
patience_counter = 0

# Training loop
for epoch in range(epochs):
    model.train()
    train_loss = 0
    train_progress_bar = tqdm(train_dataloader, desc=f'Epoch {epoch+1}/{epochs} Training', leave=False)
    
    # Training phase
    for batch in train_progress_bar:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        optimizer.zero_grad()
        outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        train_progress_bar.set_postfix(train_loss=loss.item())
    
    avg_train_loss = train_loss / len(train_dataloader)
    print(f'Epoch {epoch + 1}/{epochs} | Train Loss: {avg_train_loss}')

    # Validation phase
    model.eval()
    val_loss = 0
    val_progress_bar = tqdm(val_dataloader, desc=f'Epoch {epoch+1}/{epochs} Validation', leave=False)
    for batch in val_progress_bar:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        with torch.no_grad():
            outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
            loss = outputs.loss
            val_loss += loss.item()
            val_progress_bar.set_postfix(val_loss=loss.item())
    
    avg_val_loss = val_loss / len(val_dataloader)
    print(f'Epoch {epoch + 1}/{epochs} | Validation Loss: {avg_val_loss}')

    # Check if the validation loss is lower than the best one seen so far
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(model.state_dict(), f'checkpoint_epoch_{epoch+1}.pt')
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print('Early stopping!')
            break
print('Training complete. Final model saved.')

Epoch 1/7 Training:   0%|          | 0/1808 [00:00<?, ?it/s]

Epoch 1/7 | Train Loss: 0.2906221937814695


Epoch 1/7 Validation:   0%|          | 0/130 [00:00<?, ?it/s]

Epoch 1/7 | Validation Loss: 0.15909734001526465


Epoch 2/7 Training:   0%|          | 0/1808 [00:00<?, ?it/s]

Epoch 2/7 | Train Loss: 0.13695577808822695


Epoch 2/7 Validation:   0%|          | 0/130 [00:00<?, ?it/s]

Epoch 2/7 | Validation Loss: 0.12666630146022026


Epoch 3/7 Training:   0%|          | 0/1808 [00:00<?, ?it/s]

Epoch 3/7 | Train Loss: 0.10949596863979705


Epoch 3/7 Validation:   0%|          | 0/130 [00:00<?, ?it/s]

Epoch 3/7 | Validation Loss: 0.11606963659421755


Epoch 4/7 Training:   0%|          | 0/1808 [00:00<?, ?it/s]

Epoch 4/7 | Train Loss: 0.09407084769171671


Epoch 4/7 Validation:   0%|          | 0/130 [00:00<?, ?it/s]

Epoch 4/7 | Validation Loss: 0.105181605146768


Epoch 5/7 Training:   0%|          | 0/1808 [00:00<?, ?it/s]

Epoch 5/7 | Train Loss: 0.0833234294634199


Epoch 5/7 Validation:   0%|          | 0/130 [00:00<?, ?it/s]

Epoch 5/7 | Validation Loss: 0.09879767420486762


Epoch 6/7 Training:   0%|          | 0/1808 [00:00<?, ?it/s]

Epoch 6/7 | Train Loss: 0.07670422744758497


Epoch 6/7 Validation:   0%|          | 0/130 [00:00<?, ?it/s]

Epoch 6/7 | Validation Loss: 0.11090743742310084


Epoch 7/7 Training:   0%|          | 0/1808 [00:00<?, ?it/s]

Epoch 7/7 | Train Loss: 0.06996452265258866


Epoch 7/7 Validation:   0%|          | 0/130 [00:00<?, ?it/s]

Epoch 7/7 | Validation Loss: 0.11091081942073427
Early stopping!
Training complete. Final model saved.


In [104]:
!pip install seqeval
from seqeval.metrics import classification_report as seqeval_classification_report
import numpy as np
import torch

# Reverse the label map to translate from numeric to string labels
label_map_reverse = {v: k for k, v in label_map.items()}

model.eval()
total_loss = 0
all_predictions = []
all_true_labels = []

with torch.no_grad():
    for batch in test_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_attention_masks, b_labels = batch

        outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_attention_masks, labels=b_labels)
        loss = outputs.loss
        total_loss += loss.item()

        # Move logits and labels to CPU
        logits = outputs.logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()

        # Convert logits to token predictions
        predictions = np.argmax(logits, axis=-1)

        # For each item in the batch...
        for i in range(b_input_ids.size(0)):
            # Skip predictions for tokens with label_id == -100
            pred_label_sequence = []
            true_label_sequence = []
            for j, (pred_id, label_id) in enumerate(zip(predictions[i], label_ids[i])):
                if b_attention_masks[i][j] != 0 and label_id != -100:
                    pred_label_sequence.append(label_map_reverse.get(pred_id, 'O'))  # Default to 'O' if key is not found
                    true_label_sequence.append(label_map_reverse[label_id])

            # Ensure the true and predicted sequences have the same length
            if len(true_label_sequence) != len(pred_label_sequence):
                print(f"Length mismatch in sequence {i}: true labels {len(true_label_sequence)} vs. predicted labels {len(pred_label_sequence)}")
                # Output the actual sequences to help diagnose the issue
                print("True labels:", true_label_sequence)
                print("Pred labels:", pred_label_sequence)
                continue

            # ...extend the true labels and predicted labels lists
            all_true_labels.append(true_label_sequence)
            all_predictions.append(pred_label_sequence)

# Calculate average loss over all the batches
avg_loss = total_loss / len(test_dataloader)
print(f"Test loss: {avg_loss}")

# Use seqeval to compute a classification report
seqeval_report = seqeval_classification_report(all_true_labels, all_predictions)
print(seqeval_report)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Test loss: 0.10447849362305757
              precision    recall  f1-score   support

     address       0.59      0.82      0.68        79
        area       0.71      0.79      0.75       321
     bookday       0.92      0.99      0.95       205
  bookpeople       0.88      0.86      0.87       186
    bookstay       0.77      0.83      0.80       125
    booktime       0.90      0.97      0.93       114
      choice       0.83      0.90      0.86       215
        food       0.91      0.93      0.92       241
        name       0.62      0.77      0.69       427
       phone       0.82      0.94      0.88        53
    postcode       0.76      0.85      0.80        48
  pricerange       0.87      0.93      0.90       306
         ref       0.86      0.94      0.90       147
       stars       0.94      0.96      0.95       141
        type       0.63      0.71      0.66       187

   micro avg       0.78      0.87      0.82      2795
   macro avg       0.80      0.88      0.84      

In [115]:
def query_model(model, tokenizer, label_map, utterance, device):
    model.eval()  
    
    
    # Reverse the label map to translate from numeric IDs to string labels
    label_map_reverse = {v: k for k, v in label_map.items()}

    
    # Tokenize the new utterance directly with the tokenizer
    encoded_input = tokenizer(
        utterance,
        add_special_tokens=True,
        return_attention_mask=True,
        padding='max_length',
        truncation=True,
        max_length=256,
        return_tensors='pt'  # Return PyTorch tensors directly
    )
    
    # Move tensors to the correct device
    input_ids = encoded_input['input_ids'].to(device)
    attention_masks = encoded_input['attention_mask'].to(device)
    
    with torch.no_grad():
        outputs = model(input_ids, token_type_ids=None, attention_mask=attention_masks)
        logits = outputs.logits

    # Convert logits to probabilities and get the most likely label index
    predictions = torch.argmax(logits, dim=2).squeeze().tolist()

    # Map predictions to label strings, ignoring -100 and padding tokens
    predicted_labels = [label_map_reverse.get(label_id) for label_id, mask in zip(predictions, attention_masks.squeeze().tolist()) if mask != 0 and label_id != -100][1:-1]

    return predicted_labels

# Example usage
new_utterance = "Can I book a table for five at a Spanish restaurant, the restaurant must be cheap?"
predicted_labels = query_model(model, tokenizer, label_map, new_utterance, device)
print(predicted_labels)


['O', 'O', 'O', 'O', 'O', 'O', 'B-choice', 'O', 'O', 'B-food', 'O', 'O', 'O', 'O', 'O', 'O', 'B-pricerange', 'O']


# Mapping slots to values

In [14]:
def label_slots(dialogue):
    labeled_data = []
    data = dialogue['turns']
    
    # Loop through each turn in the dialogue
    for i, turn_id in enumerate(data['turn_id']):
        utterance = data['utterance'][i]
        tokens = utterance.split()  

        # Check if there are slot values for this turn
        if 'dialogue_acts' in data and i < len(data['dialogue_acts']):
            dialogue_act = data['dialogue_acts'][i]
            span_info = dialogue_act.get('span_info', {})
            for act_slot_name, act_slot_value, span_start, span_end in zip(
                    span_info.get('act_slot_name', []),
                    span_info.get('act_slot_value', []),
                    span_info.get('span_start', []),
                    span_info.get('span_end', [])):
                
                # Find the tokens that correspond to the start and end indices
                start_token_idx = len(utterance[:span_start].split())
                end_token_idx = len(utterance[:span_end].split()) - 1

                if start_token_idx < len(tokens) and end_token_idx < len(tokens):
                    # Label tokens using IOB format with the actual ground truth slot value
                    slot = f"{utterance[span_start:span_end]}"
                    value = act_slot_value
                    
                    labeled_data.append((slot, value))
                else:
                    print(f"Warning: Index out of range for utterance '{utterance}' with span {span_start}-{span_end}")
                

        # Store the tokenized utterance along with its labels
        
        
    return labeled_data

In [15]:
import pandas as pd
def slotsToDF(data):
    all_labeled_data = []
    for dialogue in data:
        all_labeled_data.extend(label_slots(dialogue))
    return pd.DataFrame(all_labeled_data, columns=['Slots', 'Values'])
    
# Create DataFrames of labeled utterances
train_df = slotsToDF(train_data_filtered)
test_df = slotsToDF(test_data_filtered)
val_df = slotsToDF(val_data_filtered)



In [16]:
mismatched_rows = train_df[train_df['Slots'] != train_df['Values']]

# Display the filtered rows
display(mismatched_rows)

Unnamed: 0,Slots,Values
0,center,centre
25,don't care,dontcare
26,don't care,dontcare
33,Cityroomz,cityroomz
53,Chinese,chinese
...,...,...
35526,Saturday,saturday
35545,Wednesday,wednesday
35552,Indian,indian
35568,doesn't matter,dontcare


In [17]:
def map_indiferece(text):
    
    indiferences = ["don't care", "any", "don't mind", "no preference", "whatever", "doesn't matter"]
    
    for indiference in indiferences:
        if indiference == text:
            return "dontcare"
        else:
            return text
    
    

def text_to_num(text):
    """
    Converts text numbers up to 20 into their integer representations as strings.
    If the provided text is not a number or out of range, it returns None.
    """
    text_to_num_dict = {
        'zero': '0', 'one': '1', 'two': '2', 'three': '3', 'four': '4', 
        'five': '5', 'six': '6', 'seven': '7', 'eight': '8', 'nine': '9', 
        'ten': '10', 'eleven': '11', 'twelve': '12', 'thirteen': '13', 
        'fourteen': '14', 'fifteen': '15', 'sixteen': '16', 'seventeen': '17', 
        'eighteen': '18', 'nineteen': '19', 'twenty': '20'
    }
    # convert to lower case to make the function case-insensitive
    text = text.lower()
    return text_to_num_dict.get(text, False)


# Define the post process function
def post_process_slot_value(slot_value):
    slot_value = slot_value.lower().strip()
    # Check if the slot value is a number
    number = text_to_num(slot_value)
    if number:
        return number
    
    slot_value = map_indiferece(slot_value)
    
    return slot_value

# Assuming train_df is your DataFrame
# Apply the post_process_slot_value function to each value in the 'Slots' column
train_df['Slots'] = train_df['Slots'].apply(post_process_slot_value)

# Now you can filter out the mismatched rows
mismatched_rows = train_df[train_df['Slots'] != train_df['Values'].apply(post_process_slot_value)]

# Assuming you want to display these mismatched rows
display(mismatched_rows)

Unnamed: 0,Slots,Values
0,center,centre
118,african food.,north african
137,same area,centre
141,doesn't matter,dontcare
198,any,dontcare
...,...,...
35065,same area,west
35160,any,dontcare
35204,any,dontcare
35243,same day,tuesday


In [60]:
map_indiferece("any")

0.5868494414004028


In [None]:
from typing import Dict, List

class DialogSlotMemory():
    slot_list_dict: Dict[str, List[str]] = {}

    def __init__(self):
        self.slot_list_dict = {}
    
    def add_slot(self, slot_name: str, slot_value: str):
        if slot_name not in self.slot_list_dict:
            self.slot_list_dict[slot_name] = []
        self.slot_list_dict[slot_name].append(slot_value)
    
    def get_slot_values(self, slot_name: str):
        return self.slot_list_dict[slot_name]

    def get_most_recent_slot_value(self, slot_name: str):
        return self.slot_list_dict[slot_name][-1] if slot_name in self.slot_list_dict else None

    def get_all_slot_values(self):
        return self.slot_list_dict

    def get_all_slot_names(self):
        return self.slot_list_dict.keys()
        
class ConversationDataset:
    id_dialog: str
    memory: DialogSlotMemory
    dataset: TensorDataset

def generate_separate_dialogue_datasets(data) -> List[ConversationDataset]:
    """
    Generates separate datasets for each dialogue in the provided data.
    
    Parameters:
    - data: list of dictionaries containing a "services" key, which is a list of services.
    
    Returns:
    - List of datasets, one for each dialogue.
    """
    datasets = []
    for dialogue in data:
        # Create a dataset for the current dialogue
        dataset = ConversationDataset()
        dataset.memory = DialogSlotMemory()
        dataset.id_dialog = dialogue['dialogue_id']
        dataset.dataset = create_dataset(toDF([dialogue]), tokenizer, label_map)
        datasets.append(dataset)
    return datasets


def remove_tokens_before_sep(ids: torch.Tensor, tokenizer: BertTokenizerFast):
    """
    Removes all the tokens before the SEP token, including the SEP token itself.

    Parameters:
    - ids: list of token IDs.
    
    Returns:
    - List of token IDs with sep and all tokens before it removed.
    """
    sep_token_id = tokenizer.sep_token_id
    sep_token_index = (ids == sep_token_id).nonzero(as_tuple=True)[0][0]
    return ids[sep_token_index+1:]

In [None]:
!pip install seqeval
from seqeval.metrics import classification_report as seqeval_classification_report
import numpy as np
import torch

# Load model from checkpoint
model.load_state_dict(torch.load('checkpoint_epoch_1.pt'))
model.eval()


# Separate the test data into separate datasets for each dialogue
test_datasets = generate_separate_dialogue_datasets(test_data_filtered)

# Reverse the label map to translate from numeric to string labels
label_map_reverse = {v: k for k, v in label_map.items()}

model.eval()
total_loss = 0
all_predictions = []
all_true_labels = []

with torch.no_grad():
    for dataset in test_datasets:
        # Get the input_ids, attention_masks and labels from the dataset
        input_ids = dataset.dataset.tensors[0]
        attention_masks = dataset.dataset.tensors[1]
        labels = dataset.dataset.tensors[2]

        # Move tensors to the correct device
        input_ids = input_ids.to(device)
        attention_masks = attention_masks.to(device)
        labels = labels.to(device)
        

        outputs = model(input_ids, token_type_ids=None, attention_mask=attention_masks, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()

        # Move logits and labels to CPU
        logits = outputs.logits.detach().cpu().numpy()
        label_ids = labels.to('cpu').numpy()

        # Convert logits to token predictions
        predictions = np.argmax(logits, axis=-1)

        # For each item in the batch...
        for i in range(input_ids.size(0)):
            # Skip predictions for tokens with label_id == -100
            pred_label_sequence = []
            true_label_sequence = []
            for j, (pred_id, label_id) in enumerate(zip(predictions[i], label_ids[i])):
                if attention_masks[i][j] != 0 and label_id != -100:
                    pred_label_sequence.append(label_map_reverse.get(pred_id, 'O'))

                    # Get the true label from the dataset
                    true_label_id = label_ids[i][j]
                    true_label_sequence.append(label_map_reverse[true_label_id])

            # Ensure the true and predicted sequences have the same length
            if len(true_label_sequence) != len(pred_label_sequence):
                print(f"Length mismatch in sequence {i}: true labels {len(true_label_sequence)} vs. predicted labels {len(pred_label_sequence)}")
                # Output the actual sequences to help diagnose the issue
                print("True labels:", true_label_sequence)
                print("Pred labels:", pred_label_sequence)
                continue
                
            # ...extend the true labels and predicted labels lists
            all_true_labels.append(true_label_sequence)
            all_predictions.append(pred_label_sequence)
            
            # Map slot values to slot names based on the predicted labels and add them to the memory
            # Skip all the tokens before (and including) the [SEP] token           
            ids = dataset.dataset.tensors[0][i][1:]
            ids = remove_tokens_before_sep(ids, tokenizer)
            for token, pred_label in zip(ids, pred_label_sequence):
                if pred_label != 'O':
                    slot_name = pred_label[2:]
                    
                    # Get the slot value
                    slot_value = tokenizer.convert_tokens_to_string([tokenizer.convert_ids_to_tokens(token.item())])                    
                    dataset.memory.add_slot(slot_name, slot_value)
            
            # Print the memory for the current dialogue
            print(f"Memory for dialogue {dataset.id_dialog}: {dataset.memory.get_all_slot_values()}")

# Calculate average loss over all the batches
avg_loss = total_loss / len(test_datasets)
print(f"Test loss: {avg_loss}")

# Use seqeval to compute a classification report
seqeval_report = seqeval_classification_report(all_true_labels, all_predictions)
print(seqeval_report)