In [2]:
! pip install datasets

Collecting datasets
  Downloading datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.1-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.4/491.4 kB[0m [31m23.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m15.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl 

In [53]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Using device: {device}")

✅ Using device: cuda


In [26]:
from datasets import load_dataset, Dataset
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments
import pandas as pd
import torch
from sklearn.metrics import accuracy_score, f1_score


In [None]:
NUM_Labels = 2
Model_Name = "models/SemanticVAD_Cuda.pt"
EPOCHS = 3

#### Load Data

In [None]:
from datasets import Value

df = pd.read_csv("train_1500.csv")
df['text'] = df['text'].str.replace(r'\[.*?\]', '', regex=True).str.strip()
dataset = Dataset.from_pandas(df)

tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)


encoded_dataset = dataset.map(preprocess_function, batched=True)

encoded_dataset = encoded_dataset.train_test_split(test_size=0.2)
encoded_dataset = encoded_dataset.rename_column("labels", "label")
encoded_dataset = encoded_dataset.cast_column("label", Value("float32"))

encoded_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])


Map:   0%|          | 0/1499 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/1199 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/300 [00:00<?, ? examples/s]

#### Construct Model

In [29]:
from transformers import DistilBertModel
import torch.nn as nn

class DistilBERTBackchannelScorer(nn.Module):
    def __init__(self, hidden_dim=768):
        super().__init__()
        self.encoder = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1)
        )

    def forward(self, input_ids, attention_mask=None, labels=None):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        if labels is not None:
            labels = labels.to(device)

        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state[:, 0]
        logits = self.classifier(pooled).squeeze(-1)

        loss = None
        if labels is not None:
            labels = labels.float()
            loss = nn.BCEWithLogitsLoss()(logits, labels)

        return SequenceClassifierOutput(loss=loss, logits=logits)

#### Training Setting

In [30]:
from transformers.modeling_outputs import SequenceClassifierOutput
model =DistilBERTBackchannelScorer()
model.to(device)
training_args = TrainingArguments(
    num_train_epochs=EPOCHS,
    eval_strategy="epoch",
    output_dir="./checkpoints",
    save_strategy="epoch",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    logging_dir="./logs",
    logging_steps=20,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    save_total_limit=2,
    report_to="tensorboard"
)


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    probs = torch.sigmoid(torch.tensor(logits)).numpy()
    preds = (probs > 0.5).astype(int)
    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds)
    return {"accuracy": acc, "f1": f1}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

  trainer = Trainer(


#### Train

In [16]:
trainer.train()
trainer.evaluate()

Epoch,Training Loss,Validation Loss,Accuracy,F1
1,0.2928,0.147871,0.963333,0.978304
2,0.148,0.2051,0.953333,0.972112
3,0.0793,0.180644,0.966667,0.980237


{'eval_loss': 0.18064364790916443,
 'eval_accuracy': 0.9666666666666667,
 'eval_f1': 0.9802371541501976,
 'eval_runtime': 1.3247,
 'eval_samples_per_second': 226.463,
 'eval_steps_per_second': 56.616,
 'epoch': 3.0}

In [17]:
# save model
torch.save(model.state_dict(), Model_Name)


#### Quick Check

In [54]:
from transformers import DistilBertTokenizerFast
import torch

# load model
model = DistilBERTBackchannelScorer()
model.load_state_dict(torch.load(Model_Name))
model.to(device)
model.eval()

# load tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

def predict(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
        logit = outputs.logits.squeeze()
        prob = torch.sigmoid(logit).item()
        label = 1 if prob > 0.5 else 0
    return prob, ("Backchannel" if label == 0 else "Interruption")


In [55]:
print(predict("yeah, right, I see"))
print(predict("Are you sure?"))
print(predict("I don't think so. You should say something different."))
print(predict("There is a long..."))
print(predict("wait, can I jump in here?"))

(0.002612100914120674, 'Backchannel')
(0.18273857235908508, 'Backchannel')
(0.9977163076400757, 'Interruption')
(0.997680127620697, 'Interruption')
(0.9981252551078796, 'Interruption')


In [56]:
import time
a = time.time()
print(predict("wait, can I jump in here?"))
b = time.time()
print(f"Ref time: {b-a}")

(0.9981252551078796, 'Interruption')
Ref time: 0.006989002227783203


In [57]:
import time
a = time.time()
print(predict("Gotcha"))
b = time.time()
print(f"Ref time: {b-a}")

(0.033293478190898895, 'Backchannel')
Ref time: 0.007735252380371094
