In [1]:
import json
import pandas as pd
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader

with open('Synth_data/normalised_intent_validation_slotfixed_set.json', 'r') as file:
    data = json.load(file)

# Extract utterances and their associated domain (hotel or train)
utterances = []
labels = []

for scenario in data:
    for dialogue in scenario['scenarios']:
        dialogue_service = None
        for turn in dialogue['turns']:
            if turn['frames']:  # Check if there are any frames to define service
                dialogue_service = turn['frames'][0]['service']  # Take the service from the first frame
            if turn["speaker"] == "USER":  # Assuming user turns are from "participant1"
                utterances.append(turn['utterance'])
                # label the turn based on the service
                if dialogue_service in ["hotel", "train"]:
                    labels.append(dialogue_service)
                else:
                    # 'other' if the service is not one of the target domains
                    labels.append("other")

# Filter out 'other' labels 
filtered_utterances_validation = [utterance for utterance, label in zip(utterances, labels) if label in ["hotel", "train"]]
filtered_labels_validation = [label for label in labels if label in ["hotel", "train"]]

# Create a DataFrame
df_validated = pd.DataFrame({'utterance': filtered_utterances_validation, 'domain': filtered_labels_validation})


In [2]:
# Load the dataset
with open('Synth_data/normalised_intent_train_slotfixed_set.json', 'r') as file:
    data = json.load(file)

# Extract utterances and their associated domain (hotel or train)
utterances = []
labels = []

for scenario in data:
    for dialogue in scenario['scenarios']:
        dialogue_service = None
        for turn in dialogue['turns']:
            if turn['frames']:  # Check if there are any frames to define service
                dialogue_service = turn['frames'][0]['service']  # Take the service from the first frame
            if turn["speaker"] == "USER":  # Assuming user turns are from "participant1"
                utterances.append(turn['utterance'])
                # label the turn based on the service
                if dialogue_service in ["hotel", "train"]:
                    labels.append(dialogue_service)
                else:
                    # 'other' if the service is not one of the target domains
                    labels.append("other")

# Filter out 'other' labels 
filtered_utterances_train = [utterance for utterance, label in zip(utterances, labels) if label in ["hotel", "train"]]
filtered_labels_train = [label for label in labels if label in ["hotel", "train"]]

# Create a DataFrame
df_train = pd.DataFrame({'utterance': filtered_utterances_train, 'domain': filtered_labels_train})

In [3]:
# Load the dataset
with open('Synth_data/normalised_intent_test_slotfixed_set.json', 'r') as file:
    data = json.load(file)

# Extract utterances and their associated domain (hotel or train)
utterances = []
labels = []

for scenario in data:
    for dialogue in scenario['scenarios']:
        dialogue_service = None
        for turn in dialogue['turns']:
            if turn['frames']:  # Check if there are any frames to define service
                dialogue_service = turn['frames'][0]['service']  # Take the service from the first frame
            if turn["speaker"] == "USER":  # Assuming user turns are from "participant1"
                utterances.append(turn['utterance'])
                # label the turn based on the service
                if dialogue_service in ["hotel", "train"]:
                    labels.append(dialogue_service)
                else:
                    # 'other' if the service is not one of the target domains
                    labels.append("other")

# Filter out 'other' labels 
filtered_utterances_test = [utterance for utterance, label in zip(utterances, labels) if label in ["hotel", "train"]]
filtered_labels_test = [label for label in labels if label in ["hotel", "train"]]

# Create a DataFrame
df_test = pd.DataFrame({'utterance': filtered_utterances_test, 'domain': filtered_labels_test})

In [4]:
print(df_train)

                                              utterance domain
0     Can you tell me the train schedule and how muc...  train
1     Yes, I would like to purchase a ticket for the...  train
2     No, that's all I needed. Thank you for your help!  train
3     Actually, I just remembered, can I add a retur...  train
4     Can you tell me the train schedule and how to ...  train
...                                                 ...    ...
2937  No, that's all for now. Thank you for the info...  hotel
2938  Thank you! I have a reservation and would like...  hotel
2939  Yes, please. Could you give me more details ab...  hotel
2940  Yes, that would be great. Could you arrange a ...  hotel
2941  9 AM would be perfect. Thank you for arranging...  hotel

[2942 rows x 2 columns]


In [5]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir='BERT_cache_folder')

def encode_data(tokenizer, texts, max_length=128):
    return tokenizer(texts, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')

# Encode utterances
encoded_inputs_validate = encode_data(tokenizer, df_validated['utterance'].tolist())


In [6]:
encoded_inputs_test = encode_data(tokenizer, df_test['utterance'].tolist())
encoded_inputs_train = encode_data(tokenizer, df_train['utterance'].tolist())

In [7]:
# Unique domain mapping to integers
domain_labels = {domain: idx for idx, domain in enumerate(df_validated['domain'].unique())}
df_validated['label'] = df_validated['domain'].map(domain_labels)

# Prepare labels
labels = df_validated['label'].values


In [8]:
class UtteranceDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

dataset_validate = UtteranceDataset(encoded_inputs_validate, labels)
dataset_test = UtteranceDataset(encoded_inputs_test, labels)
dataset_train = UtteranceDataset(encoded_inputs_train, labels)


In [9]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("CUDA is not available, using CPU instead.")


CUDA is available. Using GPU: NVIDIA GeForce GTX 1080 Ti


In [10]:
from transformers import TrainerCallback, TrainerState, TrainerControl
from tqdm.auto import tqdm

class PrintCallback(TrainerCallback):
    def __init__(self):
        self.progress_bar = None

    def on_train_begin(self, args, state, control, **kwargs):
        print("Starting training...")
        self.progress_bar = tqdm(total=state.num_train_epochs)

    def on_epoch_begin(self, args, state, control, **kwargs):
        print(f"\nEpoch {state.epoch + 1}/{state.num_train_epochs}")

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            if 'loss' in logs:
                print(f"  Training loss: {logs['loss']:.4f}")
            if 'eval_loss' in logs:
                print(f"  Validation loss: {logs['eval_loss']:.4f}")
            if 'eval_accuracy' in logs:
                print(f"  Validation accuracy: {logs['eval_accuracy']:.4f}")

    def on_epoch_end(self, args, state, control, **kwargs):
        self.progress_bar.update(1)
        if 'eval_loss' in logs:
            print(f"End of epoch {state.epoch + 1}. Evaluation loss: {logs['eval_loss']:.4f}")
        if 'eval_accuracy' in logs:
            print(f"  Validation accuracy: {logs['eval_accuracy']:.4f}")

    def on_train_end(self, args, state, control, **kwargs):
        self.progress_bar.close()
        print("Training completed.")


In [11]:
import torch

print("Is CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)

if torch.cuda.is_available():
    print("Number of CUDA devices:", torch.cuda.device_count())
    for i in range(torch.cuda.device_count()):
        print(f"Device {i}: {torch.cuda.get_device_name(i)}")


Is CUDA available: True
CUDA version: 12.1
Number of CUDA devices: 4
Device 0: NVIDIA GeForce GTX 1080 Ti
Device 1: NVIDIA GeForce GTX 1080 Ti
Device 2: NVIDIA GeForce GTX 1080 Ti
Device 3: NVIDIA GeForce GTX 1080 Ti


In [12]:
if torch.cuda.is_available():
    torch.cuda.set_device(0)  # Replace 0 with the device number you want to use
    print(f"Default CUDA device set: {torch.cuda.current_device()}")


Default CUDA device set: 0


In [13]:
!nvidia-smi

Wed May  1 12:46:56 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.98                 Driver Version: 535.98       CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce GTX 1080 Ti     Off | 00000000:03:00.0 Off |                  N/A |
| 51%   62C    P8              19W / 250W |   3554MiB / 11264MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce GTX 1080 Ti     Off | 00000000:04:00.0 Off |  

In [17]:
from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(domain_labels))

model.to(device) 


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [22]:
import torch
from torch.utils.data import DataLoader
from transformers import get_scheduler, BertForSequenceClassification, BertTokenizer
from torch.optim import AdamW
from tqdm import tqdm
import numpy as np
import math

# Assuming model, train_dataset, validation_dataset, and tokenizer are predefined
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Specify the number of classes explicitly
num_classes = 2  # Since you have 'hotel' and 'train'

# Convert labels to a NumPy array for easier manipulation
labels_np = np.array(filtered_labels_train)

# Calculate class counts
num_hotel = np.sum(labels_np == 'hotel')
num_train = np.sum(labels_np == 'train')

# Calculate class weights
class_counts = [num_hotel, num_train]
total_counts = sum(class_counts)
# Existing weight calculation
class_weights = torch.tensor([total_counts / class_count for class_count in class_counts], dtype=torch.float).to(device)

# Let's say you want to give more weight to 'hotel' which is the first class in your list
# You can multiply the weight for 'hotel' by a factor, e.g., 2, 3, or any factor you deem appropriate
factor = 3  # Experiment with this factor
class_weights[0] *= factor

print("Adjusted Class Weights:", class_weights)

# Convert datasets to DataLoader
train_loader = DataLoader(dataset_train, batch_size=32, shuffle=True)
validation_loader = DataLoader(dataset_validate, batch_size=32)

# Setting up the optimizer
optimizer = AdamW(model.parameters(), lr=0.001)

# Total number of training steps
num_training_steps = len(train_loader) * 10  # Adjust epochs as needed

# Create the learning rate scheduler.
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=num_training_steps
)

# Best validation loss tracking
best_validation_loss = float('inf')
model_path = "./model_save_domain_synth_temp"

# Training loop
progress_bar = tqdm(range(num_training_steps))
model.train()

for epoch in range(10):  # loop over the dataset multiple times
    for batch in train_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        logits = outputs.logits
        labels = batch['labels']
        loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights)
        loss = loss_fct(logits.view(-1, num_classes), labels.view(-1))

        # Backpropagation
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

    # Validation step at the end of each epoch
    model.eval()
    total_eval_loss = 0
    with torch.no_grad():
        for batch in validation_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            logits = outputs.logits
            labels = batch['labels']
            loss = loss_fct(logits.view(-1, num_classes), labels.view(-1))
            total_eval_loss += loss.item()

    avg_val_loss = total_eval_loss / len(validation_loader)
    print(f"Validation Loss: {avg_val_loss}")

    # Check if this is the best validation loss encountered and save model if it is
    if avg_val_loss < best_validation_loss:
        best_validation_loss = avg_val_loss
        print(f"New Best Validation Loss: {best_validation_loss}")
        model.save_pretrained(model_path)
        tokenizer.save_pretrained(model_path)

progress_bar.close()

# Optionally save the model and tokenizer again at the end if needed
# model.save_pretrained(model_path)
# tokenizer.save_pretrained(model_path)


Adjusted Class Weights: tensor([6.2024, 1.9368], device='cuda:0')
Exponentially Adjusted Class Weights: tensor([6.2024, 1.9368], device='cuda:0')


  3%|████                                                                                                                    | 2/60 [00:00<00:12,  4.48it/s]

Epoch 1, Loss: 0.6418927311897278


  5%|██████                                                                                                                  | 3/60 [00:00<00:16,  3.49it/s]

Epoch 1, Loss: 0.5069409012794495


  7%|████████                                                                                                                | 4/60 [00:01<00:17,  3.17it/s]

Epoch 1, Loss: 0.46973347663879395


  8%|██████████                                                                                                              | 5/60 [00:01<00:18,  3.01it/s]

Epoch 1, Loss: 0.43192094564437866


 10%|████████████                                                                                                            | 6/60 [00:01<00:18,  2.92it/s]

Epoch 1, Loss: 0.4909329116344452
Epoch 1, Loss: 0.3961496651172638
Validation Loss: 0.7603864471117655
New Best Validation Loss: 0.7603864471117655


 13%|████████████████                                                                                                        | 8/60 [00:07<01:18,  1.51s/it]

Epoch 2, Loss: 0.45551225543022156


 15%|██████████████████                                                                                                      | 9/60 [00:08<00:58,  1.15s/it]

Epoch 2, Loss: 0.29979225993156433


 17%|███████████████████▊                                                                                                   | 10/60 [00:08<00:45,  1.10it/s]

Epoch 2, Loss: 0.2589707374572754


 18%|█████████████████████▊                                                                                                 | 11/60 [00:08<00:36,  1.36it/s]

Epoch 2, Loss: 0.20874769985675812


 20%|███████████████████████▊                                                                                               | 12/60 [00:09<00:29,  1.62it/s]

Epoch 2, Loss: 0.2520408630371094
Epoch 2, Loss: 0.2508358359336853


 22%|█████████████████████████▊                                                                                             | 13/60 [00:10<00:35,  1.34it/s]

Validation Loss: 1.1251006772120793


 23%|███████████████████████████▊                                                                                           | 14/60 [00:10<00:28,  1.59it/s]

Epoch 3, Loss: 0.22660265862941742


 25%|█████████████████████████████▊                                                                                         | 15/60 [00:10<00:24,  1.83it/s]

Epoch 3, Loss: 0.20186731219291687


 27%|███████████████████████████████▋                                                                                       | 16/60 [00:11<00:21,  2.05it/s]

Epoch 3, Loss: 0.23862339556217194


 28%|█████████████████████████████████▋                                                                                     | 17/60 [00:11<00:19,  2.23it/s]

Epoch 3, Loss: 0.3169043958187103


 30%|███████████████████████████████████▋                                                                                   | 18/60 [00:12<00:17,  2.37it/s]

Epoch 3, Loss: 0.44444623589515686
Epoch 3, Loss: 0.33397576212882996


 32%|█████████████████████████████████████▋                                                                                 | 19/60 [00:13<00:24,  1.65it/s]

Validation Loss: 1.2364169831077259


 33%|███████████████████████████████████████▋                                                                               | 20/60 [00:13<00:21,  1.88it/s]

Epoch 4, Loss: 0.3506595492362976


 35%|█████████████████████████████████████████▋                                                                             | 21/60 [00:13<00:18,  2.08it/s]

Epoch 4, Loss: 0.21035675704479218


 37%|███████████████████████████████████████████▋                                                                           | 22/60 [00:14<00:16,  2.25it/s]

Epoch 4, Loss: 0.12383987009525299


 38%|█████████████████████████████████████████████▌                                                                         | 23/60 [00:14<00:15,  2.40it/s]

Epoch 4, Loss: 0.22059203684329987


 40%|███████████████████████████████████████████████▌                                                                       | 24/60 [00:14<00:14,  2.51it/s]

Epoch 4, Loss: 0.16164126992225647
Epoch 4, Loss: 0.2490607053041458


 42%|█████████████████████████████████████████████████▌                                                                     | 25/60 [00:15<00:20,  1.69it/s]

Validation Loss: 0.833342120051384


 43%|███████████████████████████████████████████████████▌                                                                   | 26/60 [00:16<00:17,  1.92it/s]

Epoch 5, Loss: 0.181123286485672


 45%|█████████████████████████████████████████████████████▌                                                                 | 27/60 [00:16<00:15,  2.12it/s]

Epoch 5, Loss: 0.16288097202777863


 47%|███████████████████████████████████████████████████████▌                                                               | 28/60 [00:16<00:13,  2.30it/s]

Epoch 5, Loss: 0.1821008324623108


 48%|█████████████████████████████████████████████████████████▌                                                             | 29/60 [00:17<00:12,  2.42it/s]

Epoch 5, Loss: 0.29478228092193604


 50%|███████████████████████████████████████████████████████████▌                                                           | 30/60 [00:17<00:11,  2.53it/s]

Epoch 5, Loss: 0.13749705255031586
Epoch 5, Loss: 0.3411887586116791


 52%|█████████████████████████████████████████████████████████████▍                                                         | 31/60 [00:18<00:17,  1.69it/s]

Validation Loss: 0.7737327267726263


 53%|███████████████████████████████████████████████████████████████▍                                                       | 32/60 [00:19<00:14,  1.91it/s]

Epoch 6, Loss: 0.1980646550655365


 55%|█████████████████████████████████████████████████████████████████▍                                                     | 33/60 [00:19<00:12,  2.11it/s]

Epoch 6, Loss: 0.12066128104925156


 57%|███████████████████████████████████████████████████████████████████▍                                                   | 34/60 [00:19<00:11,  2.27it/s]

Epoch 6, Loss: 0.19887714087963104


 58%|█████████████████████████████████████████████████████████████████████▍                                                 | 35/60 [00:20<00:10,  2.41it/s]

Epoch 6, Loss: 0.17044785618782043


 60%|███████████████████████████████████████████████████████████████████████▍                                               | 36/60 [00:20<00:09,  2.52it/s]

Epoch 6, Loss: 0.18706299364566803
Epoch 6, Loss: 0.08513155579566956


 62%|█████████████████████████████████████████████████████████████████████████▍                                             | 37/60 [00:21<00:13,  1.69it/s]

Validation Loss: 0.8515963057676951


 63%|███████████████████████████████████████████████████████████████████████████▎                                           | 38/60 [00:21<00:11,  1.92it/s]

Epoch 7, Loss: 0.12166891992092133


 65%|█████████████████████████████████████████████████████████████████████████████▎                                         | 39/60 [00:22<00:09,  2.12it/s]

Epoch 7, Loss: 0.20144356787204742


 67%|███████████████████████████████████████████████████████████████████████████████▎                                       | 40/60 [00:22<00:08,  2.30it/s]

Epoch 7, Loss: 0.0988287702202797


 68%|█████████████████████████████████████████████████████████████████████████████████▎                                     | 41/60 [00:22<00:07,  2.43it/s]

Epoch 7, Loss: 0.13540077209472656


 70%|███████████████████████████████████████████████████████████████████████████████████▎                                   | 42/60 [00:23<00:07,  2.53it/s]

Epoch 7, Loss: 0.22981251776218414
Epoch 7, Loss: 0.15927867591381073


 72%|█████████████████████████████████████████████████████████████████████████████████████▎                                 | 43/60 [00:24<00:10,  1.69it/s]

Validation Loss: 0.8104106386502584


 73%|███████████████████████████████████████████████████████████████████████████████████████▎                               | 44/60 [00:24<00:08,  1.92it/s]

Epoch 8, Loss: 0.07452256977558136


 75%|█████████████████████████████████████████████████████████████████████████████████████████▎                             | 45/60 [00:25<00:07,  2.12it/s]

Epoch 8, Loss: 0.16472123563289642


 77%|███████████████████████████████████████████████████████████████████████████████████████████▏                           | 46/60 [00:25<00:06,  2.28it/s]

Epoch 8, Loss: 0.1623213291168213


 78%|█████████████████████████████████████████████████████████████████████████████████████████████▏                         | 47/60 [00:25<00:05,  2.42it/s]

Epoch 8, Loss: 0.10392965376377106


 80%|███████████████████████████████████████████████████████████████████████████████████████████████▏                       | 48/60 [00:26<00:04,  2.51it/s]

Epoch 8, Loss: 0.2268029898405075
Epoch 8, Loss: 0.052951883524656296


 82%|█████████████████████████████████████████████████████████████████████████████████████████████████▏                     | 49/60 [00:27<00:06,  1.68it/s]

Validation Loss: 0.911537637313207


 83%|███████████████████████████████████████████████████████████████████████████████████████████████████▏                   | 50/60 [00:27<00:05,  1.91it/s]

Epoch 9, Loss: 0.09028434008359909


 85%|█████████████████████████████████████████████████████████████████████████████████████████████████████▏                 | 51/60 [00:27<00:04,  2.11it/s]

Epoch 9, Loss: 0.11273873597383499


 87%|███████████████████████████████████████████████████████████████████████████████████████████████████████▏               | 52/60 [00:28<00:03,  2.29it/s]

Epoch 9, Loss: 0.1461624950170517


 88%|█████████████████████████████████████████████████████████████████████████████████████████████████████████              | 53/60 [00:28<00:02,  2.42it/s]

Epoch 9, Loss: 0.10259607434272766


 90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████            | 54/60 [00:29<00:02,  2.53it/s]

Epoch 9, Loss: 0.1588851362466812
Epoch 9, Loss: 0.22461780905723572


 92%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████          | 55/60 [00:30<00:02,  1.69it/s]

Validation Loss: 1.1330530246098836


 93%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████        | 56/60 [00:30<00:02,  1.92it/s]

Epoch 10, Loss: 0.11306069046258926


 95%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████      | 57/60 [00:30<00:01,  2.12it/s]

Epoch 10, Loss: 0.14160233736038208


 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████    | 58/60 [00:31<00:00,  2.28it/s]

Epoch 10, Loss: 0.21719790995121002


 98%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████  | 59/60 [00:31<00:00,  2.41it/s]

Epoch 10, Loss: 0.06744872033596039


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:31<00:00,  2.51it/s]

Epoch 10, Loss: 0.18293148279190063
Epoch 10, Loss: 0.08297879993915558


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:32<00:00,  1.83it/s]

Validation Loss: 1.472532684604327





In [14]:
import torch
from torch.utils.data import DataLoader
from transformers import get_scheduler
from torch.optim import AdamW
from tqdm import tqdm

# Assuming model, train_dataset, validation_dataset, and tokenizer are predefined
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Convert datasets to DataLoader
train_loader = DataLoader(dataset_train, batch_size=32, shuffle=True)
validation_loader = DataLoader(dataset_validate, batch_size=32)

# Setting up the optimizer
optimizer = AdamW(model.parameters(), lr=0.001)

# Total number of training steps
num_training_steps = len(train_loader) * 10  # num_epochs is 15

# Create the learning rate scheduler.
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=num_training_steps
)

# Best validation loss tracking
best_validation_loss = float('inf')
model_path = "./model_save_domain_synth"

# Training loop
progress_bar = tqdm(range(num_training_steps))
model.train()

for epoch in range(10):  # loop over the dataset multiple times
    for batch in train_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss

        # Backpropagation
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

    # Validation step at the end of each epoch
    model.eval()
    total_eval_loss = 0
    with torch.no_grad():
        for batch in validation_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            total_eval_loss += loss.item()

    avg_val_loss = total_eval_loss / len(validation_loader)
    print(f"Validation Loss: {avg_val_loss}")

    # Check if this is the best validation loss encountered and save model if it is
    if avg_val_loss < best_validation_loss:
        best_validation_loss = avg_val_loss
        print(f"New Best Validation Loss: {best_validation_loss}")
        model.save_pretrained(model_path)
        tokenizer.save_pretrained(model_path)

progress_bar.close()




# Save the model
model_path = "./model_save_domain_synth"
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)


  2%|██                                                                                                                      | 1/60 [00:00<00:52,  1.12it/s]

Epoch 1, Loss: 0.6948522925376892


  5%|██████                                                                                                                  | 3/60 [00:01<00:21,  2.61it/s]

Epoch 1, Loss: 0.7671108245849609


  7%|████████                                                                                                                | 4/60 [00:01<00:21,  2.67it/s]

Epoch 1, Loss: 0.71390700340271


  8%|██████████                                                                                                              | 5/60 [00:02<00:20,  2.70it/s]

Epoch 1, Loss: 0.7333320379257202


 10%|████████████                                                                                                            | 6/60 [00:02<00:19,  2.73it/s]

Epoch 1, Loss: 0.6988521218299866
Epoch 1, Loss: 0.7055455446243286
Validation Loss: 0.694530189037323
New Best Validation Loss: 0.694530189037323


 13%|████████████████                                                                                                        | 8/60 [00:08<01:21,  1.57s/it]

Epoch 2, Loss: 0.6940174698829651


 15%|██████████████████                                                                                                      | 9/60 [00:09<01:01,  1.20s/it]

Epoch 2, Loss: 0.6849258542060852


 17%|███████████████████▊                                                                                                   | 10/60 [00:09<00:46,  1.06it/s]

Epoch 2, Loss: 0.7095075249671936


 18%|█████████████████████▊                                                                                                 | 11/60 [00:09<00:37,  1.31it/s]

Epoch 2, Loss: 0.7054446935653687


 20%|███████████████████████▊                                                                                               | 12/60 [00:10<00:30,  1.56it/s]

Epoch 2, Loss: 0.6521050930023193
Epoch 2, Loss: 0.6379058957099915


 22%|█████████████████████████▊                                                                                             | 13/60 [00:11<00:35,  1.31it/s]

Validation Loss: 0.7201826671759287


 23%|███████████████████████████▊                                                                                           | 14/60 [00:11<00:29,  1.56it/s]

Epoch 3, Loss: 0.7321730852127075


 25%|█████████████████████████████▊                                                                                         | 15/60 [00:11<00:24,  1.81it/s]

Epoch 3, Loss: 0.6482895016670227


 27%|███████████████████████████████▋                                                                                       | 16/60 [00:12<00:21,  2.02it/s]

Epoch 3, Loss: 0.6658648252487183


 28%|█████████████████████████████████▋                                                                                     | 17/60 [00:12<00:19,  2.22it/s]

Epoch 3, Loss: 0.6814121007919312


 30%|███████████████████████████████████▋                                                                                   | 18/60 [00:12<00:17,  2.37it/s]

Epoch 3, Loss: 0.6478613018989563
Epoch 3, Loss: 0.700506329536438


 32%|█████████████████████████████████████▋                                                                                 | 19/60 [00:13<00:24,  1.64it/s]

Validation Loss: 0.7350525061289469


 33%|███████████████████████████████████████▋                                                                               | 20/60 [00:14<00:21,  1.87it/s]

Epoch 4, Loss: 0.6976355910301208


 35%|█████████████████████████████████████████▋                                                                             | 21/60 [00:14<00:18,  2.07it/s]

Epoch 4, Loss: 0.6824394464492798


 37%|███████████████████████████████████████████▋                                                                           | 22/60 [00:15<00:16,  2.25it/s]

Epoch 4, Loss: 0.6555702686309814


 38%|█████████████████████████████████████████████▌                                                                         | 23/60 [00:15<00:15,  2.40it/s]

Epoch 4, Loss: 0.6550520062446594


 40%|███████████████████████████████████████████████▌                                                                       | 24/60 [00:15<00:14,  2.50it/s]

Epoch 4, Loss: 0.6718311309814453
Epoch 4, Loss: 0.6386275291442871


 42%|█████████████████████████████████████████████████▌                                                                     | 25/60 [00:16<00:20,  1.69it/s]

Validation Loss: 0.7529509862263998


 43%|███████████████████████████████████████████████████▌                                                                   | 26/60 [00:17<00:17,  1.92it/s]

Epoch 5, Loss: 0.622857928276062


 45%|█████████████████████████████████████████████████████▌                                                                 | 27/60 [00:17<00:15,  2.12it/s]

Epoch 5, Loss: 0.6458950042724609


 47%|███████████████████████████████████████████████████████▌                                                               | 28/60 [00:17<00:13,  2.29it/s]

Epoch 5, Loss: 0.6435486078262329


KeyboardInterrupt: 

In [21]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np
test_loader = DataLoader(dataset_test, batch_size=32) 

model.eval()  # Set model to evaluation mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

predictions, true_labels = [], []

# Evaluate the model
with torch.no_grad():
    for batch in test_loader:
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
        labels = batch['labels'].to(device)
        outputs = model(**inputs)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=-1)
        predictions.extend(preds.detach().cpu().numpy())
        true_labels.extend(labels.detach().cpu().numpy())

# Calculate metrics
accuracy = accuracy_score(true_labels, predictions)
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='macro')  # adjust 'average' as needed

print(f"Accuracy: {accuracy}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1 Score: {f1}")


Accuracy: 0.4888888888888889
Precision: 0.49375
Recall: 0.49377799900447983
F1 Score: 0.48882578096061247
