In [9]:
!pip install transformers datasets torch




[notice] A new release of pip is available: 24.3.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip





## Model Training

In [237]:
import os
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2TokenizerFast, GPT2Model, GPT2Config
from torch.optim import AdamW
import torch.nn.functional as F

class GPT2WordClassifier:
    def __init__(self, file_names, max_word_length=8, device=None, lr=1e-3):
        """
        file_names : list of CSV filenames (without .csv extension)
        Each CSV must contain columns: 'word' and 'boundary'
        """

        if isinstance(file_names, str):
            file_names = [file_names]

        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.max_word_length = max_word_length

        # Load all data from the CSV files
        self.texts, self.labels = self.get_data(file_names)

        # Tokenizer
        self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", add_prefix_space=True)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        # Dataset & Dataloader
        self.dataset = self.WordDataset(self.texts, self.labels, self.tokenizer, self.max_word_length)
        self.dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)

        # Model
        self.model = self.GPT2ForTokenClassification().to(self.device)

        # Optimizer
        self.optimizer = AdamW(self.model.parameters(), lr=lr)

    # =========================
    # DATA LOADING
    # =========================
    def get_data(self, file_names):

        texts = []
        labels = []
        for name in file_names:
            excel_directory = os.path.join(os.getcwd(), f"{name}.csv")
            try:
                df = pd.read_csv(excel_directory, encoding='utf-8-sig')
                words = df['word'].to_list()
                boundaries = df['boundary'].to_list()
                # keep only 0/1
                boundaries = [x if x in (0, 1) else 0 for x in boundaries]
                texts.extend(words)
                labels.extend(boundaries)
            except FileNotFoundError:
                print(f'File not found. Check if "{name}.csv" exists in {os.getcwd()}')
                raise SystemExit
        return texts, labels

    # =========================
    # DATASET
    # =========================
    class WordDataset(Dataset):
        def __init__(self, texts, labels, tokenizer, max_length):
            self.texts = texts
            self.labels = labels
            self.tokenizer = tokenizer
            self.max_length = max_length

        def __len__(self):
            return len(self.texts)

        def __getitem__(self, idx):
            word = self.texts[idx]
            label = self.labels[idx]

            encoding = self.tokenizer(
                word,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=self.max_length
            )

            input_ids = encoding["input_ids"].squeeze(0)
            attention_mask = encoding["attention_mask"].squeeze(0)

            token_labels = torch.full_like(input_ids, -100)
            token_labels[0] = label  # label only first token

            return {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": token_labels
            }

    # =========================
    # MODEL
    # =========================
    class GPT2ForTokenClassification(nn.Module):
        def __init__(self, num_labels=2):
            super().__init__()
            config = GPT2Config.from_pretrained("gpt2")
            self.gpt2 = GPT2Model.from_pretrained("gpt2", config=config)

            # Freeze GPT2
            for param in self.gpt2.parameters():
                param.requires_grad = False

            self.classifier = nn.Linear(config.hidden_size, num_labels)

        def forward(self, input_ids, attention_mask=None, labels=None):
            outputs = self.gpt2(input_ids=input_ids, attention_mask=attention_mask)
            logits = self.classifier(outputs.last_hidden_state)

            loss = None
            if labels is not None:
                loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
                loss = loss_fn(logits.view(-1, 2), labels.view(-1))
            return {"loss": loss, "logits": logits}

    # =========================
    # TRAIN
    # =========================
    def train(self, epochs=50, early_stop_loss=0.1):
        self.model.train()
        for epoch in range(epochs):
            for batch in self.dataloader:
                input_ids = batch["input_ids"].to(self.device)
                attention_mask = batch["attention_mask"].to(self.device)
                labels = batch["labels"].to(self.device)

                outputs = self.model(input_ids, attention_mask, labels)
                loss = outputs["loss"]

                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

            print(f"Epoch {epoch+1} Loss: {loss.item():.4f}")
            if loss.item() < early_stop_loss:
                print(f"Early stop at epoch {epoch+1}, loss={loss.item():.4f}")
                break

    # =========================
    # PREDICT
    # =========================
    def predict(self, sentence):
        
        if isinstance(sentence, str):
            words = sentence.split()
        elif isinstance(sentence, list):
            words = []
            for item in sentence:
                if isinstance(item, str) and " " in item:
                    words.extend(item.split())
                else:
                    words.append(item)
        else:
            raise TypeError("sentence must be a string or a list")

        encoding = self.tokenizer(
            words,
            is_split_into_words=True,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=len(words)
        )

        input_ids = encoding["input_ids"].to(self.device)
        attention_mask = encoding["attention_mask"].to(self.device)

        self.model.eval()
        with torch.no_grad():
            logits = self.model(input_ids, attention_mask)["logits"]
            probs = F.softmax(logits, dim=-1)

        word_ids = encoding.word_ids(0)
        predictions = {}
        
        for i, w_id in enumerate(word_ids):
            
            if w_id is None:
                continue
            # one key per word, value = probability of being a boundary
            predictions[f'{i} {words[i]}'] = float(probs[0, i, 1])

        return predictions 



In [None]:
# List of CSV files without the .csv extension
# files = ["sherlock_LLM", "merlin_LLM"]

files = "sherlock_LLM"
# Initialize classifier
classifier = GPT2WordClassifier(files)

# Train
classifier.train(epochs=3, early_stop_loss=0.01)

## Boundary prediction

In [239]:
new_sentence, y = classifier.get_data(['sherlock_LLM'])
new_sentence = new_sentence[209:1000]


preds = classifier.predict(new_sentence)
print(preds)
print(len(preds))



{'0 um': 0.09683682769536972, "1 they're": 2.3132106434786692e-05, '2 talking': 5.314208101481199e-05, '3 she': 0.0009286882705055177, '4 says': 9.94331858237274e-05, '5 something': 0.00025086847017519176, '6 about': 0.00013995093468111008, '7 him': 0.001134530990384519, '8 still': 0.00020489239250309765, '9 having': 0.00013040656631346792, '10 trust': 0.0006912064854986966, '11 issues': 0.0008012876496650279, '12 and': 0.00041264898027293384, '13 or': 0.00010776062845252454, '14 no': 0.00030327195418067276, '15 she': 0.0004790060920640826, '16 writes': 1.841974153649062e-05, '17 that': 0.000105347964563407, '18 and': 6.110299727879465e-05, '19 then': 4.457811883185059e-05, '20 he': 4.021573840873316e-05, '21 reads': 6.7618243519973475e-06, '22 it': 0.00023100683756638318, '23 upside': 7.705582538619637e-05, '24 down': 0.0010256353998556733, '25 and': 0.00013813561236020178, '26 asks': 2.1941474187769927e-05, '27 her': 0.0003022622549906373, '28 like': 0.00010324575850972906, '29 about

In [240]:
import math


new_word_probs = {word: 1 - math.exp(-value*100) for word, value in preds.items()}

print(new_word_probs)


{'0 um': 0.9999377083250516, "1 they're": 0.0023105372335222585, '2 talking': 0.005300112677334101, '3 she': 0.08868696792541098, '4 says': 0.0098940472322957, '5 something': 0.024774787050699465, '6 about': 0.013897617406221863, '7 him': 0.10725393392176907, '8 still': 0.020280761068707798, '9 having': 0.01295599567970318, '10 trust': 0.06678591763979191, '11 issues': 0.07700251100628652, '12 and': 0.04042509322216825, '13 or': 0.010718209078454333, '14 no': 0.0298719398428009, '15 she': 0.04677147549106231, '16 writes': 0.001840278760374292, '17 that': 0.010479499838042305, '18 and': 0.006091669810603273, '19 then': 0.004447890587690018, '20 he': 0.004013498142094507, '21 reads': 0.0006759538753758898, '22 it': 0.022835905731615314, '23 upside': 0.0076759706452674825, '24 down': 0.09747919822771889, '25 and': 0.013718591790671364, '26 asks': 0.0021917420369050866, '27 her': 0.02977398114578622, '28 like': 0.01027146037314941, '29 about': 0.012560302920245126, '30 writing': 0.01920738