In [1]:

import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_df = pd.read_csv("train.csv")
test_df = pd.read_csv("test.csv")


train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42, shuffle=True)


all_tokens = []
for sentence in train_df['text']:
    all_tokens.extend(sentence.split())
vocab = sorted(set(all_tokens))
word2idx = {"<pad>":0, "<unk>":1}
for i, w in enumerate(vocab, start=2):
    word2idx[w] = i

with open("word2idx.json", "w") as f:
    json.dump(word2idx, f)


all_slots = set()
for slots in train_df['slots']:
    all_slots.update(slots.split())
slot2id = {"O":0}
i = 1
for s in all_slots:
    if s != "O":
        slot2id[s] = i
        i +=1
with open("slot2id.json", "w") as f:
    json.dump(slot2id, f)

all_intents = train_df['intent'].unique()
intent2id = {intent:i for i,intent in enumerate(all_intents)}
with open("intent2id.json", "w") as f:
    json.dump(intent2id, f)


id2slot = {v:k for k,v in slot2id.items()}
id2intent = {v:k for k,v in intent2id.items()}

class ATISDataset(Dataset):
    def __init__(self, df, word2idx, slot2id, intent2id, max_len=50):
        self.df = df
        self.word2idx = word2idx
        self.slot2id = slot2id
        self.intent2id = intent2id
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        tokens = self.df.iloc[idx]['text'].split()
        slots = self.df.iloc[idx]['slots'].split()
        intent = self.df.iloc[idx]['intent']

        length = len(tokens)
        input_ids = [self.word2idx.get(t, self.word2idx["<unk>"]) for t in tokens]
        slot_ids = [self.slot2id.get(s, 0) for s in slots]  # default O=0
        intent_id = self.intent2id[intent]

        pad_len = self.max_len - length
        input_ids += [self.word2idx["<pad>"]] * pad_len
        slot_ids += [0] * pad_len  

        return torch.tensor(input_ids), torch.tensor(length), torch.tensor(slot_ids), torch.tensor(intent_id)


def collate_fn(batch):
    input_ids = torch.stack([item[0] for item in batch])
    lengths = torch.tensor([item[1] for item in batch])
    slot_labels = torch.stack([item[2] for item in batch])
    intent_labels = torch.stack([item[3] for item in batch])
    return input_ids, lengths, slot_labels, intent_labels


batch_size = 32
train_dataset = ATISDataset(train_df, word2idx, slot2id, intent2id)
val_dataset = ATISDataset(val_df, word2idx, slot2id, intent2id)
test_dataset = ATISDataset(test_df, word2idx, slot2id, intent2id)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)


class JointLSTMModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, slot_label_size, intent_label_size, dropout=0.3):
        super(JointLSTMModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.encoder = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(dropout)
        self.slot_classifier = nn.Linear(hidden_dim*2, slot_label_size)
        self.intent_classifier = nn.Linear(hidden_dim*2, intent_label_size)

    def forward(self, input_ids, lengths):
        embeddings = self.embedding(input_ids)
        packed = nn.utils.rnn.pack_padded_sequence(embeddings, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, (hidden, cell) = self.encoder(packed)
        sequence_output, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True, total_length=input_ids.size(1))
        sequence_output = self.dropout(sequence_output)
        slot_logits = self.slot_classifier(sequence_output)
        hidden_cat = torch.cat((hidden[-2], hidden[-1]), dim=1)
        intent_logits = self.intent_classifier(hidden_cat)
        return slot_logits, intent_logits


vocab_size = len(word2idx)
embedding_dim = 100
hidden_dim = 128
slot_label_size = len(slot2id)
intent_label_size = len(intent2id)

model = JointLSTMModel(vocab_size, embedding_dim, hidden_dim, slot_label_size, intent_label_size).to(device)
slot_loss_fn = nn.CrossEntropyLoss(ignore_index=0)
intent_loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


def train(model, train_loader, val_loader, optimizer, epochs=10, lambda_intent=0.5):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for input_ids, lengths, slot_labels, intent_labels in train_loader:
            input_ids, lengths = input_ids.to(device), lengths.to(device)
            slot_labels, intent_labels = slot_labels.to(device), intent_labels.to(device)

            optimizer.zero_grad()
            slot_logits, intent_logits = model(input_ids, lengths)

            slot_loss = slot_loss_fn(slot_logits.view(-1, slot_logits.shape[-1]), slot_labels.view(-1))
            intent_loss = intent_loss_fn(intent_logits, intent_labels)
            loss = slot_loss + lambda_intent * intent_loss

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f}")
        evaluate(model, val_loader)

from sklearn.metrics import accuracy_score, precision_recall_fscore_support

def evaluate(model, loader):
    model.eval()
    all_intent_preds, all_intent_labels = [], []
    all_slot_preds, all_slot_labels = [], []

    with torch.no_grad():
        for input_ids, lengths, slot_labels, intent_labels in loader:
            input_ids, lengths = input_ids.to(device), lengths.to(device)
            slot_labels, intent_labels = slot_labels.to(device), intent_labels.to(device)

            slot_logits, intent_logits = model(input_ids, lengths)
            intent_preds = torch.argmax(intent_logits, dim=1)
            all_intent_preds.extend(intent_preds.cpu().tolist())
            all_intent_labels.extend(intent_labels.cpu().tolist())

            slot_preds = torch.argmax(slot_logits, dim=2)
            for i, l in enumerate(lengths):
                all_slot_preds.extend(slot_preds[i][:l].cpu().tolist())
                all_slot_labels.extend(slot_labels[i][:l].cpu().tolist())

    intent_acc = accuracy_score(all_intent_labels, all_intent_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_slot_labels, all_slot_preds, average='micro')
    print(f"Intent Acc: {intent_acc:.4f} | Slot P: {precision:.4f} R: {recall:.4f} F1: {f1:.4f}")
    model.train()
    return intent_acc, precision, recall, f1

train(model, train_loader, val_loader, optimizer, epochs=100)


print("=== Test Set Evaluation ===")
evaluate(model, test_loader)


Epoch 1 | Loss: 2.2237
Intent Acc: 0.9579 | Slot P: 0.3011 R: 0.3011 F1: 0.3011
Epoch 2 | Loss: 0.6364
Intent Acc: 0.9823 | Slot P: 0.3308 R: 0.3308 F1: 0.3308
Epoch 3 | Loss: 0.3131
Intent Acc: 0.9845 | Slot P: 0.3432 R: 0.3432 F1: 0.3432
Epoch 4 | Loss: 0.1896
Intent Acc: 0.9911 | Slot P: 0.3494 R: 0.3494 F1: 0.3494
Epoch 5 | Loss: 0.1183
Intent Acc: 0.9911 | Slot P: 0.3526 R: 0.3526 F1: 0.3526
Epoch 6 | Loss: 0.0756
Intent Acc: 0.9867 | Slot P: 0.3543 R: 0.3543 F1: 0.3543
Epoch 7 | Loss: 0.0527
Intent Acc: 0.9933 | Slot P: 0.3547 R: 0.3547 F1: 0.3547
Epoch 8 | Loss: 0.0370
Intent Acc: 0.9933 | Slot P: 0.3547 R: 0.3547 F1: 0.3547
Epoch 9 | Loss: 0.0262
Intent Acc: 0.9933 | Slot P: 0.3549 R: 0.3549 F1: 0.3549
Epoch 10 | Loss: 0.0202
Intent Acc: 0.9933 | Slot P: 0.3549 R: 0.3549 F1: 0.3549
Epoch 11 | Loss: 0.0166
Intent Acc: 0.9933 | Slot P: 0.3551 R: 0.3551 F1: 0.3551
Epoch 12 | Loss: 0.0131
Intent Acc: 0.9933 | Slot P: 0.3543 R: 0.3543 F1: 0.3543
Epoch 13 | Loss: 0.0126
Intent Acc: 0

(0.9893899204244032,
 0.40022463496817673,
 0.40022463496817673,
 0.40022463496817673)