In [1]:
!pip install googletrans==3.1.0-alpha
!pip install langdetect
!pip install transformers

Collecting googletrans==3.1.0-alpha
  Downloading googletrans-3.1.0a0.tar.gz (19 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting httpx==0.13.3 (from googletrans==3.1.0-alpha)
  Downloading httpx-0.13.3-py3-none-any.whl (55 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.1/55.1 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
Collecting hstspreload (from httpx==0.13.3->googletrans==3.1.0-alpha)
  Downloading hstspreload-2023.1.1-py3-none-any.whl (1.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m23.8 MB/s[0m eta [36m0:00:00[0m
Collecting chardet==3.* (from httpx==0.13.3->googletrans==3.1.0-alpha)
  Downloading chardet-3.0.4-py2.py3-none-any.whl (133 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m133.4/133.4 kB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting idna==2.* (from httpx==0.13.3->googletrans==3.1.0-alpha)
  Downloading idna-2.10-py2.py3-none-any.whl (58 kB)
[2K

In [2]:
import pandas as pd
import json
import torch
from torch.utils.data import Dataset
import string
from googletrans import Translator
from langdetect import detect


CLASSES_OF_INTEREST = [
    'translate',
    'travel_alert',
    'flight_status',
    'lost_luggage', # /!\ to this class
    'travel_suggestion',
    'carry_on',
    'book_flight',
    'book_hotel',
    'oos', # might be removed because a binary classification problem
]

class CLNIC150(Dataset):
    def __init__(self, path, set = 'train'):
        super().__init__()
        self.set = set
        self.path = path
        self.prompts, self.intents = self._read_clinc()

    def __len__(self):
        return len(self.prompt)

    def __getitem__(self, idx):
        return self.prompt[idx], self.intent[idx]

    def _read_clinc(self):

        #Load CLINC150 dataset from JSON file
        #Json from https://github.com/clinc/oos-eval/tree/master/data

        with open(self.path, 'r') as f:
            data = json.load(f)

        data = data[self.set] + data['oos_' + self.set]

        prompts = []
        intents = []

        for row in data:
            prompts.append(row[0])
            intents.append(row[1])

        return prompts, intents

    def _get_classes_of_interest(self, classes_of_interest: list = CLASSES_OF_INTEREST):
        #corpus.intent is a list
        interest_index = []
        for i, intent in enumerate(self.intents):
            if intent in classes_of_interest:
                interest_index.append(i)

        self.intents = [self.intents[i] for i in interest_index]
        self.prompts = [self.prompts[i] for i in interest_index]
        return self.prompts, self.intents

class BertDataset(Dataset):
    def __init__(self, prompts, intents, tokenizer, max_length):
        super().__init__()
        #Preprocess the inputs prompts
        self.prompts = [self._preprocess(prompt) for prompt in prompts]

        #Convert intents to numeric labels
        self.labels_dict = {label: i for i, label in enumerate(set(intents))}
        self.intents = [self.labels_dict[intent] for intent in intents]

        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        prompt = self.prompts[idx]
        intent = self.intents[idx]

        encoding = self.tokenizer(
            prompt,
            add_special_tokens=True,
            max_length=self.max_length,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )

        return {
            'prompt': prompt,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'intent': torch.tensor(intent, dtype=torch.long)
        }

    def _lower(self, text):
        return text.lower()

    def _remove_punctuation(self, text):
        table = str.maketrans('', '', string.punctuation)
        return text.translate(table)

    def _translate(self, text):
        lang = detect(text)
        if lang != 'en':
            translator = Translator()
            text = translator.translate(text, dest='en', str = 'auto').text
        return text

    def _preprocess(self, text):
        text = self._translate(text)
        text = self._lower(text)
        text = self._remove_punctuation(text)
        return text

In [3]:
from torch import nn
from transformers import BertTokenizer, BertModel

class BERTClassifier(nn.Module):
    def __init__(self, bert_model_name, num_classes, freeze_bert = True):
        super(BERTClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.dropout = nn.Dropout(0.1)
        #Add a fully-connected layer to the bert model for classification
        self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)

        if freeze_bert:
            for p in self.bert.parameters():
                p.requires_grad = False

    def forward(self, input_ids, attention_mask):
        output = self.bert(input_ids, attention_mask=attention_mask)
        output = self.dropout(output.pooler_output)
        logits = self.fc(output)
        return logits

In [5]:
import torch
import torch.nn as nn
from transformers import BertModel
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score, recall_score
import matplotlib.pyplot as plt

def train(model, train_dataset, val_dataset, device, batch_size, epochs, lr, patience):
    # Set up optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    # Set up data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # Initialize lists to store metrics
    train_losses = []
    train_f1s = []
    train_recalls = []
    val_losses = []
    val_f1s = []
    val_recalls = []

    best_val_loss = float('inf')
    counter = 0

    # Train the model
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        train_f1 = 0
        train_recall = 0
        for batch in train_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['intent'].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            preds = outputs.argmax(1).cpu()
            labels = labels.cpu()

            train_loss += loss.item()
            train_f1 += f1_score(labels, preds, average='weighted', zero_division=0)
            train_recall += recall_score(labels, preds, average='weighted', zero_division=0)

        train_loss /= len(train_loader)
        train_f1 /= len(train_loader)
        train_recall /= len(train_loader)

        train_losses.append(train_loss)
        train_f1s.append(train_f1)
        train_recalls.append(train_recall)

        # Evaluate the model on the val set
        model.eval()
        val_loss = 0
        val_f1 = 0
        val_recall = 0
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['intent'].to(device)

                outputs = model(input_ids, attention_mask)
                preds = outputs.argmax(1).cpu()
                labels = labels.cpu()

                val_loss += loss.item()
                val_f1 += f1_score(labels, preds, average='weighted', zero_division=0)
                val_recall += recall_score(labels, preds, average='weighted', zero_division=0)

        val_loss /= len(val_loader)
        val_f1 /= len(val_loader)
        val_recall /= len(val_loader)

        val_losses.append(val_loss)
        val_f1s.append(val_f1)
        val_recalls.append(val_recall)

        print(f'Epoch {epoch+1}/{epochs}: Train Loss: {train_loss:.4f}, Train F1: {train_f1:.4f}, Train Recall: {train_recall:.4f}, Val Loss: {val_loss:.4f}, Val F1: {val_f1:.4f}, Val Recall: {val_recall:.4f}')

        # Check if early stopping conditions are met
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pt')
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print('Early stopping triggered.')
                break

    # Plot training and validation metrics
    plt.figure(figsize=(12, 8))
    plt.subplot(3, 1, 1)
    plt.plot(range(epochs), train_losses, label='Train Loss')
    plt.plot(range(epochs), val_losses, label='Val Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(3, 1, 2)
    plt.plot(range(epochs), train_f1s, label='Train F1 Score')
    plt.plot(range(epochs), val_f1s, label='Val F1 Score')
    plt.xlabel('Epochs')
    plt.ylabel('F1 Score')
    plt.legend()

    plt.subplot(3, 1, 3)
    plt.plot(range(epochs), train_recalls, label='Train Recall')
    plt.plot(range(epochs), val_recalls, label='Val Recall')
    plt.xlabel('Epochs')
    plt.ylabel('Recall')
    plt.legend()

    plt.tight_layout()
    plt.show()

In [6]:
path = '/content/data_full.json'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

clinc_train = CLNIC150(path, set= 'train')
x_train, y_train = clinc_train._get_classes_of_interest(classes_of_interest= CLASSES_OF_INTEREST)

clinc_val = CLNIC150(path, set= 'val')
x_val, y_val = clinc_val._get_classes_of_interest(classes_of_interest= CLASSES_OF_INTEREST)

bert_model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_model_name)

train_dataset = BertDataset(x_train, y_train, tokenizer, max_length= 128)
val_dataset = BertDataset(x_val, y_val, tokenizer, max_length= 128)


Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [8]:
train_dataset.labels_dict

{'travel_alert': 0,
 'travel_suggestion': 1,
 'carry_on': 2,
 'book_hotel': 3,
 'oos': 4,
 'translate': 5,
 'flight_status': 6,
 'lost_luggage': 7,
 'book_flight': 8}

In [None]:
train(BERTClassifier(bert_model_name, num_classes= 9).to(device),
      train_dataset, val_dataset, device,
      batch_size = 32, epochs =1, lr = 2e-4, patience = 10)