In [None]:
import time
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertModel, AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import json
import os

# CONFIG
MODEL_NAME = 'bert-base-uncased'
MAX_LEN = 128
BATCH_SIZE = 2
EPOCHS = 9
PAD_TOKEN = -100

id_run = time.strftime("%d%m%y_%H%M%S")
writer = SummaryWriter(log_dir='runs/CustomBert_' + id_run)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

def load_data(path):
    with open(path) as f:
        return json.load(f)

train_dev_data = load_data(os.path.join('dataset', 'ATIS', 'train.json'))
test_data = load_data(os.path.join('dataset', 'ATIS', 'test.json'))
data_raw = test_data + train_dev_data

intents = sorted({item['intent'] for item in data_raw})
intent2id = {label: i for i, label in enumerate(intents)}
id2intent = {i: label for label, i in intent2id.items()}

slots = sorted({slot for item in data_raw for slot in item['slots']})
slot2id = {label: i for i, label in enumerate(slots)}
id2slot = {i: label for label, i in slot2id.items()}

class ATISDataset(Dataset):
    def __init__(self, data, tokenizer, intent2id, slot2id, max_len=128):
        self.tokenizer = tokenizer
        self.data = data
        self.intent2id = intent2id
        self.slot2id = slot2id
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Ensure tokens are split correctly
        tokens = item['utterance']
        if isinstance(tokens, str):
            tokens = tokens.split()  # Split string into list of tokens

        intent = item['intent']
        slots = item['slots']

        encoding = self.tokenizer(tokens,
                                is_split_into_words=True,
                                return_offsets_mapping=True,
                                padding='max_length',
                                truncation=True,
                                max_length=self.max_len)
        
        word_ids = encoding.word_ids()
        slot_labels = []
        previous_word_idx = None
        for word_idx in word_ids:
            if word_idx is None:
                slot_labels.append(PAD_TOKEN)
            elif word_idx != previous_word_idx:
                slot_labels.append(self.slot2id[slots[word_idx]])
            else:
                slot_labels.append(PAD_TOKEN)
            previous_word_idx = word_idx

        return {
            'input_ids': torch.tensor(encoding['input_ids']),
            'attention_mask': torch.tensor(encoding['attention_mask']),
            'token_type_ids': torch.tensor(encoding['token_type_ids']),
            'intent_label': torch.tensor(self.intent2id[intent]),
            'slot_labels': torch.tensor(slot_labels)
        }


class JointBERT(nn.Module):
    def __init__(self, model_name, num_intents, num_slots):
        super(JointBERT, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.intent_classifier = nn.Linear(self.bert.config.hidden_size, num_intents)
        self.slot_classifier = nn.Linear(self.bert.config.hidden_size, num_slots)

    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids)
        pooled_output = outputs.pooler_output
        sequence_output = outputs.last_hidden_state
        return self.intent_classifier(pooled_output), self.slot_classifier(sequence_output)

def compute_slot_f1(preds, trues):
    correct, total_pred, total_true = 0, 0, 0
    for p_seq, t_seq in zip(preds, trues):
        for p, t in zip(p_seq, t_seq):
            total_pred += 1
            total_true += 1
            if p == t:
                correct += 1
    precision = correct / total_pred if total_pred > 0 else 0
    recall = correct / total_true if total_true > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0
    return f1

# Prepare data
tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)
train_data, dev_data = train_test_split(train_dev_data, test_size=0.1, random_state=42)
train_dataset = ATISDataset(train_data, tokenizer, intent2id, slot2id, MAX_LEN)
dev_dataset = ATISDataset(dev_data, tokenizer, intent2id, slot2id, MAX_LEN)
test_dataset = ATISDataset(test_data, tokenizer, intent2id, slot2id, MAX_LEN)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

# Initialize model and loss
model = JointBERT(MODEL_NAME, len(intent2id), len(slot2id)).to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)
intent_loss_fn = nn.CrossEntropyLoss()
slot_loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)

def train_loop(model, train_loader, optimizer, intent_loss_fn, slot_loss_fn, epoch):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} - Training", ncols=100):  # Add tqdm here
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        intent_labels = batch['intent_label'].to(device)
        slot_labels = batch['slot_labels'].to(device)

        optimizer.zero_grad()
        intent_logits, slot_logits = model(input_ids, attention_mask, token_type_ids)

        intent_loss = intent_loss_fn(intent_logits, intent_labels)
        slot_loss = slot_loss_fn(slot_logits.view(-1, len(slot2id)), slot_labels.view(-1))
        loss = intent_loss + slot_loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1} - Loss: {total_loss:.4f}")
    writer.add_scalar('Loss/train', total_loss, epoch)


# Evaluation loop function
def eval_loop(model, dev_loader, epoch, tag="dev"):
    model.eval()
    intent_preds, intent_true = [], []
    slot_preds, slot_true = [], []
    
    with torch.no_grad():
        for batch in tqdm(dev_loader, desc=f"Epoch {epoch+1} - Evaluating {tag}", ncols=100):  # Add tqdm here
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            intent_labels = batch['intent_label'].to(device)
            slot_labels = batch['slot_labels'].to(device)

            intent_logits, slot_logits = model(input_ids, attention_mask, token_type_ids)

            # Intent predictions
            preds = torch.argmax(intent_logits, dim=1)
            intent_preds.extend(preds.cpu().numpy())
            intent_true.extend(intent_labels.cpu().numpy())

            # Slot predictions
            slot_pred_ids = torch.argmax(slot_logits, dim=2)
            for i in range(slot_labels.size(0)):
                true_labels = []
                pred_labels = []
                for j in range(slot_labels.size(1)):
                    if slot_labels[i][j] != PAD_TOKEN:
                        true_labels.append(id2slot[slot_labels[i][j].item()])
                        pred_labels.append(id2slot[slot_pred_ids[i][j].item()])
                slot_true.append(true_labels)
                slot_preds.append(pred_labels)

    # Calculate and print metrics
    intent_acc = accuracy_score(intent_true, intent_preds)
    slot_f1 = compute_slot_f1(slot_true, slot_preds)  # Use your own F1 calculation function
    print(f"Intent Accuracy: {intent_acc:.4f} | Slot F1: {slot_f1:.4f}")

    # Log metrics to TensorBoard
    writer.add_scalar(f"{tag}/Intent Accuracy", intent_acc, epoch)
    writer.add_scalar(f"{tag}/Slot F1", slot_f1, epoch)

def test_loop(model, test_loader, tag="test"):
    model.eval()
    intent_preds, intent_true = [], []
    slot_preds, slot_true = [], []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"Evaluating {tag}", ncols=100):  # Add tqdm here
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            intent_labels = batch['intent_label'].to(device)
            slot_labels = batch['slot_labels'].to(device)

            intent_logits, slot_logits = model(input_ids, attention_mask, token_type_ids)

            # Intent predictions
            preds = torch.argmax(intent_logits, dim=1)
            intent_preds.extend(preds.cpu().numpy())
            intent_true.extend(intent_labels.cpu().numpy())

            # Slot predictions
            slot_pred_ids = torch.argmax(slot_logits, dim=2)
            for i in range(slot_labels.size(0)):
                true_labels = []
                pred_labels = []
                for j in range(slot_labels.size(1)):
                    if slot_labels[i][j] != PAD_TOKEN:
                        true_labels.append(id2slot[slot_labels[i][j].item()])
                        pred_labels.append(id2slot[slot_pred_ids[i][j].item()])
                slot_true.append(true_labels)
                slot_preds.append(pred_labels)

    # Calculate and print metrics
    intent_acc = accuracy_score(intent_true, intent_preds)
    slot_f1 = compute_slot_f1(slot_true, slot_preds)  # Use your own F1 calculation function
    print(f"Intent Accuracy: {intent_acc:.4f} | Slot F1: {slot_f1:.4f}")

# Train the model
for epoch in range(EPOCHS):
    train_loop(model, train_loader, optimizer, intent_loss_fn, slot_loss_fn, epoch)
    if (epoch + 1) % 3 == 0:
        eval_loop(model, dev_loader, epoch, tag="dev")

# Final test evaluation
test_loop(model, test_loader, tag="test")

# Save model
os.makedirs("model_bin", exist_ok=True)
torch.save(model.state_dict(), f"model_bin/joint_bert_model_{id_run}.pt")
print(f"Model saved: model_bin/joint_bert_model_{id_run}.pt")
writer.close()


Epoch 1 - Training: 100%|███████████████████████████████████████| 2240/2240 [08:18<00:00,  4.50it/s]


Epoch 1 - Loss: 3045.8029


Epoch 2 - Training: 100%|███████████████████████████████████████| 2240/2240 [06:22<00:00,  5.86it/s]


Epoch 2 - Loss: 1289.5969


Epoch 3 - Training: 100%|███████████████████████████████████████| 2240/2240 [05:54<00:00,  6.32it/s]


Epoch 3 - Loss: 908.9875


Epoch 3 - Evaluating dev: 100%|███████████████████████████████████| 249/249 [00:10<00:00, 23.90it/s]


Intent Accuracy: 0.9639 | Slot F1: 0.9172


Epoch 4 - Training: 100%|███████████████████████████████████████| 2240/2240 [05:54<00:00,  6.32it/s]


Epoch 4 - Loss: 707.9372


Epoch 5 - Training: 100%|███████████████████████████████████████| 2240/2240 [05:54<00:00,  6.32it/s]


Epoch 5 - Loss: 621.4571


Epoch 6 - Training: 100%|███████████████████████████████████████| 2240/2240 [08:04<00:00,  4.62it/s]


Epoch 6 - Loss: 536.1164


Epoch 6 - Evaluating dev: 100%|███████████████████████████████████| 249/249 [00:14<00:00, 16.79it/s]


Intent Accuracy: 0.9759 | Slot F1: 0.9461


Epoch 7 - Training: 100%|███████████████████████████████████████| 2240/2240 [06:37<00:00,  5.64it/s]


Epoch 7 - Loss: 514.2330


Epoch 8 - Training: 100%|███████████████████████████████████████| 2240/2240 [05:54<00:00,  6.32it/s]


Epoch 8 - Loss: 502.8183


Epoch 9 - Training: 100%|███████████████████████████████████████| 2240/2240 [05:54<00:00,  6.32it/s]


Epoch 9 - Loss: 344.0302


Epoch 9 - Evaluating dev: 100%|███████████████████████████████████| 249/249 [00:10<00:00, 23.97it/s]


Intent Accuracy: 0.9900 | Slot F1: 0.9548


TypeError: eval_loop() missing 1 required positional argument: 'epoch'