In [None]:
import torch.nn as nn
import torch
import pandas as pd
from transformers import PreTrainedModel
from transformers.modeling_outputs import TokenClassifierOutput
from transformers import RobertaModel, RobertaTokenizerFast

from google.colab import drive

from datasets import Dataset
from datasets import load_from_disk
from IPython.display import display


tokenizer = RobertaTokenizerFast.from_pretrained(
    "roberta-base",
    add_prefix_space=True
)

encoder = RobertaModel.from_pretrained("roberta-base")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# drive.mount('/content/drive')
# # #connect to drive, one person should run this as it can be connected to one drive, etc.

Mounted at /content/drive


# Pre-processing

In [None]:
import pandas as pd
from datasets import Dataset

# ------------------------------------------------------------
# 0. Load CSV safely
# ------------------------------------------------------------
df = pd.read_csv("/content/tell.csv")
# df = pd.read_csv("/content/test.csv")
# Ensure text + label columns are strings
df["text"] = df["text"].fillna("").astype(str)
# df["label"] = df["label"].fillna("O").astype(str)

# ------------------------------------------------------------
# 1. Build sentences (word-level lists)
# ------------------------------------------------------------
sentences = []
current_words = []
current_labels = []

for _, row in df.iterrows():
    word = row["text"]
    # label = row["label"]

    current_words.append(word)
    # current_labels.append(label)

    if (
        word.endswith(".") or word.endswith("?") or word.endswith("!") or
        word.endswith(";") or word.endswith(":")
    ):
        sentences.append({
            "words": current_words,
            # "labels": current_labels
        })
        current_words = []
        current_labels = []

# Add any leftover words
if current_words:
    sentences.append({
        "words": current_words,
        # "labels": current_labels
    })

# ------------------------------------------------------------
# 2. Group sentences into CHUNKS of 15 sentences each
# ------------------------------------------------------------
CHUNK_SIZE = 15
chunks = []

for i in range(0, len(sentences), CHUNK_SIZE):
    group = sentences[i : i + CHUNK_SIZE]

    merged_words = []
    merged_labels = []

    for s in group:
        merged_words.extend(s["words"])
        # merged_labels.extend(s["labels"])

    chunks.append({
        "words": merged_words,
        # "labels": merged_labels
    })

dataset = Dataset.from_list(chunks)

# ------------------------------------------------------------
# 3. Tokenization + punctuation-mask building
# ------------------------------------------------------------
label2id = {"O": 0, "B": 1, "L": 1, "U": 1, "I":1}

def tokenize_example(example):
    words = example["words"]
    # labels = example["labels"]

    # label_ids = [label2id[l] for l in labels]

    encoding = tokenizer(
        words,
        is_split_into_words=True,
        add_special_tokens=True,
    )

    # Map word labels → token labels
    # token_labels = []
    # word_ids = encoding.word_ids()

    # for w_id in word_ids:
    #     if w_id is None:
    #         token_labels.append(0)
    #     else:
    #         token_labels.append(label_ids[w_id])

    # Build punctuation mask
    tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"])
    punct_tokens = {".", "?","!",";",":", "Ġ.", "Ġ?","Ġ!","Ġ;","Ġ:"}


    punctuation_mask = [1 if tok in punct_tokens else 0 for tok in tokens]

    # # Extract labels only at punctuation positions
    # punctuation_labels = [
    #     token_labels[i]
    #     for i in range(len(token_labels))
    #     if punctuation_mask[i] == 1
    # ]

    return {
        "input_ids": encoding["input_ids"],
        "attention_mask": encoding["attention_mask"],
        "punctuation_mask": punctuation_mask
        # "punctuation_labels": punctuation_labels,
    }

dataset = dataset.map(tokenize_example, remove_columns=["words"]) #["words", "labels"])

# ------------------------------------------------------------
# 4. Print Example Output
# ------------------------------------------------------------
print([dataset[i] for i in range(1)])


# ------------------------------------------------------------
# 5. Save dataset and download locally
# ------------------------------------------------------------
# save the dataset
dataset.save_to_disk("/content/test_processed")

# zip the folder
!zip -r /content/tell.zip /content/test_processed

# download the zip
from google.colab import files
files.download("tell.zip")


Map:   0%|          | 0/1 [00:00<?, ? examples/s]

[{'input_ids': [0, 1832, 47, 206, 51, 64, 1437, 1137, 4, 370, 216, 6, 14140, 32, 11962, 8, 70, 6, 8, 98, 32, 20993, 6, 53, 47, 64, 460, 1137, 4, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'punctuation_mask': [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]}]


Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

  adding: content/test_processed/ (stored 0%)
  adding: content/test_processed/data-00000-of-00001.arrow (deflated 69%)
  adding: content/test_processed/dataset_info.json (deflated 69%)
  adding: content/test_processed/state.json (deflated 39%)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
import pandas as pd
from datasets import Dataset

# df = pd.read_csv("/content/drive/My Drive/MIT/NLP Project/data/train_data/small.csv")
df = pd.read_csv("/content/drive/My Drive/NLP Project/data/train_data/small.csv")
############################################################
# 1. Build sentences (word-level lists)
############################################################

sentences = []
current_words = []
current_labels = []

for _, row in df.iterrows():
    word = row["text"]
    label = row["label"]

    current_words.append(word)
    current_labels.append(label)

    if word.endswith(".") or word.endswith("?") or word.endswith("!") or word.endswith(";") or word.endswith("...") or word.endswith(":"):
        sentences.append({
            "words": current_words,
            "labels": current_labels
        })
        current_words = []
        current_labels = []

if current_words:
    sentences.append({
        "words": current_words,
        "labels": current_labels
    })

############################################################
# 2. Group sentences into CHUNKS of 15 sentences each
############################################################

CHUNK_SIZE = 15
chunks = []

for i in range(0, len(sentences), CHUNK_SIZE):
    group = sentences[i : i + CHUNK_SIZE]

    # merge words + labels across the 15-sentence window
    merged_words = []
    merged_labels = []
    for s in group:
        merged_words.extend(s["words"])
        merged_labels.extend(s["labels"])

    chunks.append({
        "words": merged_words,
        "labels": merged_labels
    })

dataset = Dataset.from_list(chunks)

############################################################
# 3. Tokenization + mask building
############################################################

label2id = {"O": 0, "B": 1, "L": 1, "U":1}

def tokenize_example(example):
    words = example["words"]
    labels = example["labels"]
    label_ids = [label2id[l] for l in labels]

    encoding = tokenizer(
        words,
        is_split_into_words=True,
        add_special_tokens=True,
    )

    # Map word labels → token labels
    token_labels = []
    word_ids = encoding.word_ids()

    for w_id in word_ids:
        if w_id is None:
            token_labels.append(0)
        else:
            token_labels.append(label_ids[w_id])

    # Build punctuation mask
    tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"])
    punctuation_mask = [
        1 if tok in [".", "?", "Ġ.", "Ġ?"] else 0
        for tok in tokens
    ]

    # Extract labels *only* at punctuation positions
    punctuation_labels = [
        token_labels[i]
        for i in range(len(token_labels))
        if punctuation_mask[i] == 1
    ]

    return {
        "input_ids": encoding["input_ids"],
        "attention_mask": encoding["attention_mask"],
        "punctuation_mask": punctuation_mask,
        "punctuation_labels": punctuation_labels,
    }

dataset = dataset.map(tokenize_example, remove_columns=["words", "labels"])

############################################################
# 4. Print Example Output
############################################################
print([dataset[i] for i in range (2,4)])



Map:   0%|          | 0/5 [00:00<?, ? examples/s]

[{'input_ids': [0, 1863, 33521, 106, 19, 1457, 211, 46772, 4, 252, 33, 2099, 19, 69, 235, 89, 11, 5, 1692, 9, 5, 558, 4, 1801, 2999, 49, 38594, 66, 4, 252, 214, 1104, 604, 11, 730, 4, 2612, 74, 47, 21587, 116, 9561, 25, 47, 214, 626, 519, 2099, 19, 69, 6, 79, 161, 79, 18, 5283, 6, 47, 668, 69, 4, 370, 214, 2277, 6, 32594, 6, 120, 66, 9, 259, 4, 3394, 5, 7105, 174, 47, 7, 120, 6536, 62, 6, 47, 45590, 116, 2381, 11, 110, 558, 6, 3581, 160, 110, 6713, 2137, 6, 3874, 62, 6, 213, 7, 10, 265, 4592, 6, 120, 10789, 456, 6, 8, 120, 10, 6174, 4, 178, 14, 18, 99, 51, 222, 358, 183, 4, 280, 21, 101, 49, 301, 4, 440, 22749, 2485, 4, 178, 172, 101, 390, 554, 2053, 8, 909, 82, 770, 7, 3529, 19072, 220, 7, 106, 8, 51, 218, 75, 236, 7, 15328, 4, 252, 685, 960, 4, 280, 18, 596, 51, 218, 75, 236, 28957, 7, 900, 4, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

# Data loading

In [None]:
#load the preprocessed full training set
!unzip preprocessed_train_full.zip -d /content/extracted/

In [None]:
from datasets import load_from_disk
train_ds = load_from_disk("/content/extracted/content/my_processed_dataset")
print(train_ds[0].keys())
print(len(train_ds))

In [None]:
# load the testing data from shared drive into collab notebook, then unzip
!unzip test_processed.zip -d /content/extracted/

In [None]:
from datasets import load_from_disk
test_ds = load_from_disk("/content/extracted/content/test_processed")
print(test_ds[0].keys())
print(len(test_ds))

In [None]:
# Need to build a data collator so everything is padded
#1. pad b/c of shorter sequences (input_ids, punctuation_mask, attention_mask)
#2. pad b/c variable num of punctuation per sequence (labels)

class PunctuationCollator:
    """
    Collator for:
      - sequences of tokens
      - punctuation_mask (var-length positive positions)
      - labels: list-of-length-num_punctuation for each example

    Output:
      - input_ids:           (B, seq_len)
      - attention_mask:      (B, seq_len)
      - punctuation_mask:    (B, seq_len)
      - punct_pad_mask:      (B, max_punct)
      - labels:              (B, max_punct)
    """
    def __init__(self, device=None):
        self.device = device

    def __call__(self, batch):
        # 1. Pad input_ids, attention, and punctuation_mask to same length
        input_ids = [torch.tensor(x["input_ids"], dtype=torch.long, device=self.device) for x in batch]
        attention = [torch.tensor(x["attention_mask"], dtype=torch.long, device=self.device) for x in batch]
        punct_masks = [torch.tensor(x["punctuation_mask"], dtype=torch.bool, device=self.device) for x in batch]

        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=1  # RoBERTa pad token ID = 1
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            attention, batch_first=True, padding_value=0
        )
        punctuation_mask = torch.nn.utils.rnn.pad_sequence(
            punct_masks, batch_first=True, padding_value=0
        )


        # 2. Pad punctuation labels to max_punct
        label_lists = [x["punctuation_labels"] for x in batch]
        max_punct = max(len(lbls) for lbls in label_lists)

        labels = []
        punct_pad_mask = []  # 1 = real punctuation, 0 = padded fake punct

        for lbls in label_lists:
            padded = lbls + [0] * (max_punct - len(lbls))
            labels.append(torch.tensor(padded, dtype=torch.float, device=self.device))

            real_count = len(lbls)
            pad_mask = [1]*real_count + [0]*(max_punct - real_count)
            punct_pad_mask.append(torch.tensor(pad_mask, dtype=torch.bool, device=self.device))

        labels = torch.stack(labels)                # (B, max_punct)
        punct_pad_mask = torch.stack(punct_pad_mask) # (B, max_punct)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "punctuation_mask": punctuation_mask,
            "punctuation_labels": labels,
            "punct_pad_mask": punct_pad_mask,
        }

In [None]:
# for OOS
class PunctuationCollator_no_label:
    """
    Collator for:
      - sequences of tokens
      - punctuation_mask (var-length positive positions)
      - labels: list-of-length-num_punctuation for each example

    Output:
      - input_ids:           (B, seq_len)
      - attention_mask:      (B, seq_len)
      - punctuation_mask:    (B, seq_len)
    """
    def __init__(self, device=None):
        self.device = device

    def __call__(self, batch):
        # 1. Pad input_ids, attention, and punctuation_mask to same length
        input_ids = [torch.tensor(x["input_ids"], dtype=torch.long, device=self.device) for x in batch]
        attention = [torch.tensor(x["attention_mask"], dtype=torch.long, device=self.device) for x in batch]
        punct_masks = [torch.tensor(x["punctuation_mask"], dtype=torch.bool, device=self.device) for x in batch]

        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=1  # RoBERTa pad token ID = 1
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            attention, batch_first=True, padding_value=0
        )
        punctuation_mask = torch.nn.utils.rnn.pad_sequence(
            punct_masks, batch_first=True, padding_value=0
        )

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "punctuation_mask": punctuation_mask
        }

### Deal with majority class bias

In [None]:
import collections
positive_counts = {}
for datapoint in train_ds:
  num_pos = datapoint['punctuation_labels'].count(1)
  positive_counts.setdefault(num_pos, 0)
  positive_counts[num_pos]+=1
positive_counts = collections.OrderedDict(sorted(positive_counts.items()))
print(positive_counts)


In [None]:
# weigh sequences with >= 4 positive laugh punctuations more (30 vs 1)
weights = torch.ones(len(train_ds))
oversample_indices = []
for i, datapoint in enumerate(train_ds):
  if datapoint['punctuation_labels'].count(1) >= 4:
    oversample_indices.append(i)
print(f"{len(oversample_indices)}/{len(train_ds)} sequences to oversample")
weights[oversample_indices] = 30.0

In [None]:
# Load in data and resample to reduce majority bias
from torch.utils.data import DataLoader, WeightedRandomSampler

collator = PunctuationCollator()
sampler = WeightedRandomSampler(
    weights=weights,
    num_samples=len(train_ds),
    replacement=True
)

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    #shuffle=True,
    collate_fn=collator,
    sampler=sampler
)

# check number of positive laughs after sampling
positive_counts = {}

for batch in train_loader:
    for labels in batch['punctuation_labels']:
        num_pos = (labels == 1).sum().item()

        positive_counts.setdefault(num_pos, 0)
        positive_counts[num_pos] += 1

# Sort for readability
positive_counts = collections.OrderedDict(sorted(positive_counts.items()))
print(positive_counts)

In [None]:
# load the test dataset (NO RESAMPLING)
collator = PunctuationCollator()
batch_size = 8

test_loader = DataLoader(
    test_ds,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collator,
)

# Transformer-based Model

## Model Structure

In [None]:
class CrossAttentionBlock(nn.Module):
  """
  Allow punctuation tokens to pay attention to
  all other tokens in the sequence
  """
  def __init__(self, hidden_size, num_heads, dropout=0.1):
    super().__init__()
    self.cross_att = nn.MultiheadAttention(
        embed_dim=hidden_size,
        num_heads=num_heads,
        dropout=dropout,
        batch_first=True,
    )
    self.ff = nn.Sequential(
        nn.Linear(hidden_size, hidden_size * 2),
        nn.ReLU(),
        nn.Linear(hidden_size * 2, hidden_size),
    )
    self.norm1 = nn.LayerNorm(hidden_size)
    self.norm2 = nn.LayerNorm(hidden_size)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, sequence):
    """
    `sequence`: full sequence, including punctuation and
    other tokens embedding
      shape: (batch_size, seq_len, hidden_size)
    `x`: subset of sequence, only punctuation embeddings
      shape: (batch_size, punctuation_count, hidden_size)
    """
    attn_out, _ = self.cross_att(
        query=x,
        key=sequence,
        value=sequence,
    )
    x = self.norm1(x + self.dropout(attn_out))
    ff_out = self.ff(x)
    x = self.norm2(x + self.dropout(ff_out))
    return x


In [None]:
class TransformerRoBERTa(PreTrainedModel):
  def __init__(self, roberta_model, num_layers=2, num_heads=8, dropout=0.1):
    super().__init__(roberta.config)
    self.roberta = roberta_model
    self.roberta_hidden_size = roberta.config.hidden_size
    self.hidden_size = 400
    self.compress = nn.Linear(self.roberta_hidden_size, self.hidden_size)


    # Freeze RoBERTa parameters
    for p in self.roberta.parameters():
        p.requires_grad = False

    # Transformer cross-attention stack
    # Allows punctuation tokens to pay attention
    # to all other tokens
    self.layers = nn.ModuleList([
        CrossAttentionBlock(
            hidden_size=self.hidden_size,
            num_heads=num_heads,
            dropout=dropout
        )
        for _ in range(num_layers)
    ])

    # Classifier head
    self.classifier = nn.Linear(self.hidden_size, 1)

  def forward(self, input_ids, attention_mask,
        punctuation_mask,   # (B, seq_len) bool tensor, which tokens are punct
        *,
        punct_pad_mask = None,     # (B, max_punct) bool tensor, which punct are real (not padded)
        labels=None,
        **kwargs):
    device = input_ids.device
    # 1. get (frozen) RoBERTa embeddings
    roberta_out = self.roberta(
        input_ids=input_ids,
        attention_mask=attention_mask,
        output_hidden_states=False,
        return_dict=True)
    hidden = roberta_out.last_hidden_state  # (batch_size, seq_len, hidden_size)
    hidden = self.compress(hidden)

    #2. dynamic implementation (doesn't depend on labels)
    punct_counts = punctuation_mask.sum(dim=1)        # (B,)
    max_punct = punct_counts.max().item()             # int
    B, seq_len, H = hidden.size()
    punct_emb = torch.zeros(B, max_punct, H, device=device)

    for i in range(B):
        idxs = punctuation_mask[i].nonzero(as_tuple=True)[0]  # shape (n_punct,)
        n_punct = idxs.size(0)
        if n_punct > 0:
            punct_emb[i, :n_punct] = hidden[i, idxs]


    # 2. Gather punctuation token embeddings
    B, seq_len, H = hidden.size()
    max_punct = punctuation_mask.sum(dim=1).max().item()

    punct_emb = torch.zeros(B, max_punct, H, device = device)

    for i in range(B):
        idxs = punctuation_mask[i].nonzero(as_tuple=True)[0]
        n_punct = idxs.size(0)
        if n_punct > 0:
            # select the RoBERTa embeddings for punctuation
            punct_emb[i, :n_punct, :] = hidden[i, idxs, :]

    # 3. Cross-attention layers
    x = punct_emb
    for layer in self.layers:
        x = layer(x, hidden)

    # 4. Final classifier
    logits = self.classifier(x).squeeze(-1)  # (batch_size, punct_count)

    # 5. Calculate loss
    loss = None
    if labels is not None:
        labels = labels.to(device)
        # Binary cross-entropy without reduction
        loss_each = nn.functional.binary_cross_entropy_with_logits(
            logits, labels.float(), reduction="none" #don't use mean/sum reduction
        )
        # Mask out padded punctuation tokens
        loss = (loss_each * punct_pad_mask.float().to(device)).sum() / punct_pad_mask.sum()
    return TokenClassifierOutput(loss=loss, logits=logits)

## Training transformer model

In [None]:
# Build and train model

#Because our attention implementation is custom,
#do not use the default sdpm mode (faster)
#or flash_attention_2 (what LLaMA-3, Mistral, Qwen use)
# Instead, use regular computations ("eager")
roberta = RobertaModel.from_pretrained("roberta-base", attn_implementation="eager")
batch_size = 8
num_epochs = 3
learning_rate = 5e-4
model = TransformerRoBERTa(
    roberta_model =roberta,
    num_layers=2, #can adjust
    num_heads=4, #can adjust, needs to divide 400 (the compressed embedding size we chose)
)

# show number of parameters in model
def count_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total:,}")
    print(f"Trainable parameters: {trainable:,}")
count_parameters(model)

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=learning_rate
)

In [None]:
# Train (also run on test set per epoch)
from tqdm.auto import tqdm
from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    # ---------------- Training ----------------
    model.train()
    total_loss = 0.0

    all_preds = []
    all_labels = []

    progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Training]")

    for batch in progress:
        batch = {k: v.to(device) for k, v in batch.items()}

        output = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            punctuation_mask=batch["punctuation_mask"],
            punct_pad_mask=batch["punct_pad_mask"],
            labels=batch["punctuation_labels"]
        )

        optimizer.zero_grad()
        loss = output.loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
        progress.set_postfix({"batch_loss": loss.item()})

        # -------- gather predictions for metrics --------
        logits = output.logits
        true_labels = batch["punctuation_labels"]
        punct_pad_mask = batch["punct_pad_mask"]

        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).long()

        preds_flat = preds.reshape(-1)
        true_labels_flat = true_labels.reshape(-1)
        punct_pad_mask_flat = punct_pad_mask.reshape(-1)

        valid = punct_pad_mask_flat == 1
        all_preds.extend(preds_flat[valid].cpu().tolist())
        all_labels.extend(true_labels_flat[valid].cpu().tolist())

    train_loss = total_loss / len(train_loader)
    train_acc = accuracy_score(all_labels, all_preds)
    train_prec = precision_score(all_labels, all_preds, zero_division=0)
    train_rec = recall_score(all_labels, all_preds, zero_division=0)
    train_f1 = f1_score(all_labels, all_preds, zero_division=0)

    print(f"\nEpoch {epoch+1} Training Metrics:")
    print(f"  Loss:      {train_loss:.4f}")
    print(f"  Accuracy:  {train_acc:.4f}")
    print(f"  Precision: {train_prec:.4f}")
    print(f"  Recall:    {train_rec:.4f}")
    print(f"  F1 Score:  {train_f1:.4f}")

    # ---------------- Evaluation ----------------
    model.eval()
    total_eval_loss = 0.0
    all_preds = []
    all_true_labels = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Evaluation]"):
            batch = {k: v.to(device) for k, v in batch.items()}

            outputs = model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                punctuation_mask=batch["punctuation_mask"],
                punct_pad_mask=batch["punct_pad_mask"],
                labels=batch["punctuation_labels"]  # needed for loss
            )

            total_eval_loss += outputs.loss.item()

            logits = outputs.logits
            true_labels = batch["punctuation_labels"]
            punct_pad_mask = batch["punct_pad_mask"]

            probs = torch.sigmoid(logits)
            preds = (probs >= 0.5).long()

            preds_flat = preds.reshape(-1)
            true_labels_flat = true_labels.reshape(-1)
            punct_pad_mask_flat = punct_pad_mask.reshape(-1)

            valid = punct_pad_mask_flat == 1
            all_preds.extend(preds_flat[valid].cpu().tolist())
            all_true_labels.extend(true_labels_flat[valid].cpu().tolist())

    eval_loss = total_eval_loss / len(test_loader)
    eval_acc = accuracy_score(all_true_labels, all_preds)
    eval_prec = precision_score(all_true_labels, all_preds, zero_division=0)
    eval_rec = recall_score(all_true_labels, all_preds, zero_division=0)
    eval_f1 = f1_score(all_true_labels, all_preds, zero_division=0)

    print(f"\nEpoch {epoch+1} Evaluation Metrics:")
    print(f"  Loss:      {eval_loss:.4f}")
    print(f"  Accuracy:  {eval_acc:.4f}")
    print(f"  Precision: {eval_prec:.4f}")
    print(f"  Recall:    {eval_rec:.4f}")
    print(f"  F1 Score:  {eval_f1:.4f}\n")


## Testing transformer model

In [None]:
# evaluates on test set
# saves a csv with each row corresponding to one test point (a sequence)
# columns: original text, model's predicted labels list, true labels list, accuracy/precision/recall/f1 of that one sequence

import torch
from tqdm import tqdm
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

all_preds = []
all_true_labels = []
all_texts = []
seq_metrics = []

model.eval()

with torch.no_grad():
    for batch in tqdm(test_loader):
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}

        # Forward pass
        outputs = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            punctuation_mask=batch["punctuation_mask"],
            punct_pad_mask=batch["punct_pad_mask"],
            labels=batch["punctuation_labels"]  # needed for loss
        )
        logits = outputs.logits  # (B, max_punct)
        probs = torch.sigmoid(logits)
        preds = (probs >= 0.5).long()  # binary predictions

        # Process each sequence individually
        for i in range(batch["input_ids"].size(0)):
            # Slice predictions and labels to number of punctuation tokens in this sequence
            punct_pad_mask_seq = batch["punct_pad_mask"][i].bool()  # convert to boolean
            pred_seq = preds[i][punct_pad_mask_seq].tolist()
            true_label_seq = batch["punctuation_labels"][i][punct_pad_mask_seq].tolist()

            all_preds.append(pred_seq)
            all_true_labels.append(true_label_seq)

            # Decode full sequence text
            input_ids_seq = batch["input_ids"][i]
            text = tokenizer.decode(input_ids_seq, skip_special_tokens=True)
            all_texts.append(text)

            # Compute per-sequence metrics
            acc = accuracy_score(true_label_seq, pred_seq)
            prec = precision_score(true_label_seq, pred_seq, zero_division=0)
            rec = recall_score(true_label_seq, pred_seq, zero_division=0)
            f1 = f1_score(true_label_seq, pred_seq, zero_division=0)
            seq_metrics.append({
                "accuracy": acc,
                "precision": prec,
                "recall": rec,
                "f1": f1
            })
# Compute overall metrics across all punctuation tokens
flat_preds = [p for seq in all_preds for p in seq]
flat_labels = [l for seq in all_true_labels for l in seq]

overall_acc = accuracy_score(flat_labels, flat_preds)
overall_prec = precision_score(flat_labels, flat_preds, zero_division=0)
overall_rec = recall_score(flat_labels, flat_preds, zero_division=0)
overall_f1 = f1_score(flat_labels, flat_preds, zero_division=0)

print(f"Overall - Accuracy: {overall_acc:.4f}, Precision: {overall_prec:.4f}, "
      f"Recall: {overall_rec:.4f}, F1: {overall_f1:.4f}")

df = pd.DataFrame({
    "input_text": all_texts,
    "predicted_labels": all_preds,
    "true_labels": all_true_labels,
    "accuracy": [m["accuracy"] for m in seq_metrics],
    "precision": [m["precision"] for m in seq_metrics],
    "recall": [m["recall"] for m in seq_metrics],
    "f1": [m["f1"] for m in seq_metrics]
})


In [None]:
# save to csv
df.to_csv("transformer_test_results.csv", index=False)

## OOS

In [None]:
from datasets import load_from_disk
!unzip dirty_processed.zip -d /content/extracted/
dirty_ds = load_from_disk("/content/extracted/content/test_processed")

In [None]:
from datasets import load_from_disk
!unzip exclamation_processed.zip -d /content/extracted/
exclamation_ds = load_from_disk("/content/extracted/content/test_processed")

In [None]:
from datasets import load_from_disk
!unzip liferacereligion_processed.zip -d /content/extracted/
liferacereligion_ds = load_from_disk("/content/extracted/content/test_processed")

In [None]:
from datasets import load_from_disk
!unzip longtoshort_processed.zip -d /content/extracted/
longtoshort_ds = load_from_disk("/content/extracted/content/test_processed")

In [None]:
collator = PunctuationCollator_no_label()
batch_size = 8

dirty_loader = DataLoader(
    dirty_ds,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collator,
)
exclamation_loader = DataLoader(
    exclamation_ds,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collator,
)
liferacereligion_ds_loader = DataLoader(
    liferacereligion_ds,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collator,
)
longtoshort_loader = DataLoader(
    longtoshort_ds,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collator,
)

In [None]:
# evaluate on each oos set
import torch
from tqdm import tqdm
import pandas as pd
oos_loader = [dirty_loader, exclamation_loader, liferacereligion_ds_loader, longtoshort_loader]
names = ['dirty', 'exclamation', 'lifetoreligion', 'longtoshort']

for model_idx in range(4):

  all_preds = []
  all_texts = []

  model.eval()

  with torch.no_grad():
      for batch in tqdm(oos_loader[model_idx]):
          # Move batch to device
          batch = {k: v.to(device) for k, v in batch.items()}

          # Forward pass
          outputs = model(
              input_ids=batch["input_ids"],
              attention_mask=batch["attention_mask"],
              punctuation_mask=batch["punctuation_mask"],
          )
          logits = outputs.logits  # (B, max_punct)
          probs = torch.sigmoid(logits)
          preds = (probs >= 0.5).long()  # binary predictions

          # Process each sequence individually
          B = batch["input_ids"].size(0)
          punct_counts = batch["punctuation_mask"].sum(dim=1)  # number of puncts per sequence
          for i in range(B):
              num_punct = punct_counts[i].item()
              pred_seq = preds[i, :num_punct].tolist()   # slice only real punctuation logits
              all_preds.append(pred_seq)

              # decode sequence text
              input_ids_seq = batch["input_ids"][i]
              text = tokenizer.decode(input_ids_seq, skip_special_tokens=True)
              all_texts.append(text)

  # Compute overall metrics across all punctuation tokens
  flat_preds = [p for seq in all_preds for p in seq]

  df = pd.DataFrame({
      "input_text": all_texts,
      "predicted_labels": all_preds,
  })
  df.to_csv(f"transformer_{names[model_idx]}_results.csv", index=False)


# Lightweight Model

In [None]:
# renne stuff
class LightWeightRoBERTa(PreTrainedModel):
    def __init__(self, roberta_model, hidden_dim=32, dropout=0.1):
        super().__init__(roberta.config)
        self.roberta = roberta_model
        self.hidden_size = roberta.config.hidden_size

        for p in self.roberta.parameters():
            p.requires_grad = False

        # MLP classifier: [punct_emb ; local_context] → 1 logit
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_size * 11, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
        )
    def forward(
        self,
        input_ids,
        attention_mask,
        punctuation_mask,
        punct_pad_mask,
        labels=None,
        **kwargs
    ):
        # 1. RoBERTa embeddings
        roberta_out = self.roberta(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=False,
            return_dict=True,
        )
        hidden = roberta_out.last_hidden_state  # (B, seq_len, H)
        B, seq_len, H = hidden.size()

        # 2. Prepare tensors for punctuation embeddings + local contexts
        max_punct = punct_pad_mask.size(1)
        device = hidden.device

        local_context = torch.zeros(B, max_punct, 11* H, device=device)

        # 3. For each example in batch, process its punctuation tokens
        for i in range(B):
            # indices of punctuation tokens in this sequence
            idxs = punctuation_mask[i].nonzero(as_tuple=True)[0]  # (num_punct_i,)
            n_punct = idxs.size(0)

            if n_punct == 0:
                continue  # nothing to fill; stays zeros and will be masked out

            # Build local context for each punctuation token
            for k, pos in enumerate(idxs):
                # pos is a scalar tensor; get its Python int value
                p = pos.item()
                # window [p-5, p+5], clipped to valid range [0, seq_len-1]
                start = max(p - 5, 0)
                end = min(p + 5, seq_len - 1)  # inclusive index
                a_len = end-start+1
                # slice hidden and mask
                window_hidden = hidden[i, start:end + 1, :]           # (win_len, H)
                window_mask = attention_mask[i, start:end + 1]        # (win_len,)
                window_hidden = window_hidden * window_mask.unsqueeze(-1)
                window_len=window_hidden.size(0)
                window_vecs = torch.zeros(11, H, device=device)
                window_vecs[:a_len,:]=window_hidden
                window_flat=window_vecs.reshape(-1)
                local_context[i,k,:]=window_flat



        # 4. Concatenate [punct_emb ; local_context] for each punctuation slot
        feats = local_context  # (B, max_punct, 2H)

        # 5. MLP classifier → logits
        logits = self.mlp(feats).squeeze(-1)  # (B, max_punct)

        # 6. Masked BCE loss on punctuation slots
        loss = None
        if labels is not None:
            loss_each = nn.functional.binary_cross_entropy_with_logits(
                logits, labels.float(), reduction="none"
            )  # (B, max_punct)

            # Only count real punctuation positions (not padded ones)
            loss = (loss_each * punct_pad_mask.float()).sum() / punct_pad_mask.sum()

        return TokenClassifierOutput(loss=loss, logits=logits)


## Transformer model Training

In [None]:
#load the preprocessed full training set
# !find "/content/drive/My Drive/NLP Project/data/train_data" -name "preprocessed_train_full.zip"
!unzip "/content/drive/My Drive/NLP Project/data/train_data/preprocessed_train_full.zip" -d /content/extracted/
#!unzip preprocessed_train_full.zip -d /content/extracted/

Archive:  /content/drive/My Drive/NLP Project/data/train_data/preprocessed_train_full.zip
   creating: /content/extracted/content/my_processed_dataset/
  inflating: /content/extracted/content/my_processed_dataset/state.json  
  inflating: /content/extracted/content/my_processed_dataset/dataset_info.json  
  inflating: /content/extracted/content/my_processed_dataset/data-00000-of-00001.arrow  


In [None]:
from datasets import load_from_disk
train_ds = load_from_disk("/content/extracted/content/my_processed_dataset")
print(train_ds[0].keys())

dict_keys(['input_ids', 'attention_mask', 'punctuation_mask', 'punctuation_labels'])


In [None]:
# Load in data and define model

from torch.utils.data import DataLoader
#Because our attention implementation is custom,
#do not use the default sdpm mode (faster)
#or flash_attention_2 (what LLaMA-3, Mistral, Qwen use)
# Instead, use regular computations ("eager")
roberta = RobertaModel.from_pretrained("roberta-base", attn_implementation="eager")

# model = TransformerRoBERTa(
#     roberta_model =roberta,
#     num_layers=3, #can adjust
#     num_heads=8, #can adjust
# )
model = LightWeightRoBERTa(
    roberta_model=roberta,
    hidden_dim=128,  # can adjust
    dropout=0.1,     # can adjust
)
def count_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total:,}")
    print(f"Trainable parameters: {trainable:,}")
count_parameters(model)

collator = PunctuationCollator()

batch_size = 8
num_epochs = 3
learning_rate = 5e-5

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collator,
)

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=learning_rate
)


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Total parameters: 125,727,233
Trainable parameters: 1,081,601
None


In [None]:
# Train the model and record losses
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
for epoch in range(3):
    model.train()
    total_loss = 0.0
    progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for batch in progress:

        # Move to GPU if available
        batch = {k: v.to(device) for k, v in batch.items()}

        # Forward pass
        output = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            punctuation_mask=batch["punctuation_mask"],
            punct_pad_mask=batch["punct_pad_mask"],
            labels=batch["punctuation_labels"]
        )
        loss = output.loss

        # Backprop
        optimizer.zero_grad()
        loss.backward()

        # gradient clipping improves stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        # Logging
        total_loss += loss.item()
        progress.set_postfix({"batch_loss": loss.item()})

    avg_loss = total_loss / len(train_loader)
    print(f"\nEpoch {epoch+1} finished. Average Loss = {avg_loss:.4f}\n")

Epoch 1/3:   0%|          | 0/404 [00:00<?, ?it/s]


Epoch 1 finished. Average Loss = 0.5309



Epoch 2/3:   0%|          | 0/404 [00:00<?, ?it/s]


Epoch 2 finished. Average Loss = 0.5140



Epoch 3/3:   0%|          | 0/404 [00:00<?, ?it/s]


Epoch 3 finished. Average Loss = 0.5052



# Evaluation

In [None]:
# load the testing data from shared drive into collab notebook, then unzip
# !unzip "/content/drive/My Drive/NLP Project/data/test_data/test_processed.zip" -d /content/extracted/
!unzip test_processed.zip -d /content/extracted/

Archive:  /content/drive/My Drive/NLP Project/data/test_data/test_processed.zip
   creating: /content/extracted/content/test_processed/
  inflating: /content/extracted/content/test_processed/dataset_info.json  
  inflating: /content/extracted/content/test_processed/data-00000-of-00001.arrow  
  inflating: /content/extracted/content/test_processed/state.json  


In [None]:
from datasets import load_from_disk
test_ds = load_from_disk("/content/extracted/content/test_processed")
print(test_ds[0].keys())

dict_keys(['input_ids', 'attention_mask', 'punctuation_mask', 'punctuation_labels'])


In [None]:
collator = PunctuationCollator()
batch_size = 8

test_loader = DataLoader(
    test_ds,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collator,
)

In [None]:
# GPT code -- need to read this through
# in the end, saves a csv where each row corresponds to one testing point
# columns are the original text of that sequence, model's predicted labels, and true labels
import torch
from tqdm import tqdm
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score

all_preds = []
all_labels = []
all_texts = []

model.eval()

with torch.no_grad():
    for batch in tqdm(test_loader):
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}

        # Forward pass
        punct_pad_mask = batch["attention_mask"]
        outputs = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            punctuation_mask=batch["punctuation_mask"],
            punct_pad_mask=punct_pad_mask,
        )

        logits = outputs.logits  # (B, max_punct_in_batch)
        probs = torch.sigmoid(logits)
        preds = (probs >= 0.5).long()  # binary predictions

        # Process each sequence individually
        for i in range(batch["input_ids"].size(0)):
            num_punct = batch["punctuation_labels"][i].size(0)

            # Slice predictions and labels to number of punctuation tokens in this sequence
            pred_seq = preds[i, :num_punct].tolist()
            label_seq = batch["punctuation_labels"][i].tolist()

            all_preds.append(pred_seq)
            all_labels.append(label_seq)

            # Decode full sequence text
            input_ids_seq = batch["input_ids"][i]
            text = tokenizer.decode(input_ids_seq, skip_special_tokens=True)
            all_texts.append(text)

# Compute metrics across all punctuation tokens
flat_preds = [p for seq in all_preds for p in seq]
flat_labels = [l for seq in all_labels for l in seq]

acc = accuracy_score(flat_labels, flat_preds)
f1 = f1_score(flat_labels, flat_preds)
print(f"Accuracy: {acc:.4f}, F1 Score: {f1:.4f}")

# Save CSV: one row per sequence
df = pd.DataFrame({
    "input_text": all_texts,
    "predicted_labels": all_preds,   # list of binary predictions per sequence
    "true_labels": all_labels        # list of true labels per sequence
})

df.to_csv("punctuation_predictions_per_sequence.csv", index=False)
print("Saved CSV: punctuation_predictions_per_sequence.csv")


100%|██████████| 9/9 [00:00<00:00, 14.11it/s]


Accuracy: 0.6875, F1 Score: 0.0115
Saved CSV: punctuation_predictions_per_sequence.csv


## Evaluation

judge with roberta/GPT

In [None]:
model_name = "cardiffnlp/twitter-roberta-base-sentiment-latest" # trained specifically on sentiment
# I didn't use roberta-base because apparently it doesn't have a classification head, so it's not good for classifying
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
if use_gpu and torch.cuda.is_available():
   model = self.model.to('cuda')

In [None]:
data = {"text": [], "label": []}
with open("your_file.csv", "r", encoding="utf-8") as f: # change this once I find how to parse through google sheets
    reader = csv.DictReader(f)
    for row in reader:
        data["text"].append(row["text"])
        data["label"].append(row["label"])


df = pd.dataframe(data)
print(df.head())

In [None]:
def predict_laugh(text):
  """
  Given a text segment, predict if the audience laughs using the sentiment model
  """
  inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=1000)
  if use_gpu and torch.cuda.is_available():
      inputs = {k: v.to('cuda') for k, v in inputs.items()}
  with torch.no_grad():
    outputs = model(**inputs)
    predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
  confidence = predictions[0][1].item()
  label = "L" if confidence > 0.5 else "O"

  return label, confidence

In [None]:
def process_data(df, context_window=3):
  """
  Given a dataframe a context window (how many previous segments to include), split the dataframe into different segments based on punctuation.
  """
  df = df.copy()
  df['predicted'] = ''
  df['confidence'] = 0.0

  for idx, row in df.iterrows():
    start_idx = max(0, idx - context_window)
    context = df.loc[start_idx:idx, 'text'].tolist()
    full_text = ' '.join(context_texts)

    label, confidence = predict_laugh(full_text)

    df.at[idx, 'predicted'] = label
    df.at[idx, 'confidence'] = confidence
  return df

In [None]:
process_df = process_data(df) # adjust context_window here
if 'label' in df.columns:
  accuracy = (process_df['predicted'] == process_df['label']).mean()
  print(f"\nAccuracy: {accuracy:.2%}")

# Figure out a place to save the accuracy


New model:


In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import csv
import pandas as pd

# Use DistilBERT model fine-tuned on SST-2 for sentiment analysis
model_name = "distilbert/distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

use_gpu = True  # Set this flag as needed

if use_gpu and torch.cuda.is_available():
    model = model.to('cuda')

# Load data
data = {"text": [], "label": []}
with open("your_file.csv", "r", encoding="utf-8") as f:
    reader = csv.DictReader(f)
    for row in reader:
        data["text"].append(row["text"])
        data["label"].append(row["label"])

df = pd.DataFrame(data)  # Fixed: DataFrame (capital D)
print(df.head())

In [None]:
def predict_laugh(text):
    """
    Given a text segment, predict if the audience laughs using the sentiment model
    """
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)

    if use_gpu and torch.cuda.is_available():
        inputs = {k: v.to('cuda') for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)

    predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)

    # For DistilBERT SST-2: label 0 = NEGATIVE, label 1 = POSITIVE
    # We'll use POSITIVE sentiment to predict laughter
    confidence = predictions[0][1].item()
    label = "L" if confidence > 0.5 else "O"

    return label, confidence

In [None]:
def process_data(df, context_window=3):
    """
    Given a dataframe and a context window (how many previous segments to include),
    process the data to predict laughter for each segment.
    """
    df = df.copy()
    df['predicted'] = ''
    df['confidence'] = 0.0

    for idx, row in df.iterrows():
        start_idx = max(0, idx - context_window)
        context_texts = df.loc[start_idx:idx, 'text'].tolist()  # Fixed: renamed variable
        full_text = ' '.join(context_texts)

        label, confidence = predict_laugh(full_text)
        df.at[idx, 'predicted'] = label
        df.at[idx, 'confidence'] = confidence

    return df

In [None]:
# Process the dataframe
processed_df = process_data(df)  # Fixed: renamed variable for clarity

# Calculate accuracy if labels exist
if 'label' in df.columns:
    accuracy = (processed_df['predicted'] == processed_df['label']).mean()
    print(f"\nAccuracy: {accuracy:.2%}")