In [1]:
!pip install transformers datasets torch




[notice] A new release of pip is available: 24.3.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


## Model Training

In [None]:
import os
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2TokenizerFast, GPT2Model, GPT2Config
from torch.optim import AdamW
import torch.nn.functional as F

class GPT2WordClassifier:
    def __init__(self, file_names, max_word_length=8, device=None, lr=1e-3):
        """
        file_names : list of CSV filenames (without .csv extension)
        Each CSV must contain columns: 'word' and 'boundary'
        """

        if isinstance(file_names, str):
            file_names = [file_names]

        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.max_word_length = max_word_length

        # Load all data from the CSV files
        self.texts, self.labels = self.get_data(file_names)

        # Tokenizer
        self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", add_prefix_space=True)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        # Dataset & Dataloader
        self.dataset = self.WordDataset(self.texts, self.labels, self.tokenizer, self.max_word_length)
        self.dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)

        # Model
        self.model = self.GPT2ForTokenClassification().to(self.device)

        # Optimizer
        self.optimizer = AdamW(self.model.parameters(), lr=lr)

    # =========================
    # DATA LOADING
    # =========================
    def get_data(self, file_names):

        texts = []
        labels = []
        for name in file_names:
            excel_directory = os.path.join(os.getcwd(), f"{name}.csv")
            try:
                df = pd.read_csv(excel_directory, encoding='utf-8-sig')
                words = df['word'].to_list()
                boundaries = df['boundary'].to_list()
                # keep only 0/1
                boundaries = [x if x in (0, 1) else 0 for x in boundaries]
                texts.extend(words)
                labels.extend(boundaries)
            except FileNotFoundError:
                print(f'File not found. Check if "{name}.csv" exists in {os.getcwd()}')
                raise SystemExit
        return texts, labels

    # =========================
    # DATASET
    # =========================
    class WordDataset(Dataset):
        def __init__(self, texts, labels, tokenizer, max_length):
            self.texts = texts
            self.labels = labels
            self.tokenizer = tokenizer
            self.max_length = max_length

        def __len__(self):
            return len(self.texts)

        def __getitem__(self, idx):
            word = self.texts[idx]
            label = self.labels[idx]

            encoding = self.tokenizer(
                word,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=self.max_length
            )

            input_ids = encoding["input_ids"].squeeze(0)
            attention_mask = encoding["attention_mask"].squeeze(0)

            token_labels = torch.full_like(input_ids, -100)
            token_labels[0] = label  # label only first token

            return {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": token_labels
            }

    # =========================
    # MODEL
    # =========================
    class GPT2ForTokenClassification(nn.Module):
        def __init__(self, num_labels=2):
            super().__init__()
            config = GPT2Config.from_pretrained("gpt2")
            self.gpt2 = GPT2Model.from_pretrained("gpt2", config=config)

            # Freeze GPT2
            for param in self.gpt2.parameters():
                param.requires_grad = False

            self.classifier = nn.Linear(config.hidden_size, num_labels)

        def forward(self, input_ids, attention_mask=None, labels=None):
            outputs = self.gpt2(input_ids=input_ids, attention_mask=attention_mask)
            logits = self.classifier(outputs.last_hidden_state)

            loss = None
            if labels is not None:
                loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
                loss = loss_fn(logits.view(-1, 2), labels.view(-1))
            return {"loss": loss, "logits": logits}

    # =========================
    # TRAIN
    # =========================
    def train(self, epochs=50, early_stop_loss=0.1):
        self.model.train()
        for epoch in range(epochs):
            for batch in self.dataloader:
                input_ids = batch["input_ids"].to(self.device)
                attention_mask = batch["attention_mask"].to(self.device)
                labels = batch["labels"].to(self.device)

                outputs = self.model(input_ids, attention_mask, labels)
                loss = outputs["loss"]

                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

            print(f"Epoch {epoch+1} Loss: {loss.item():.4f}")
            if loss.item() < early_stop_loss:
                print(f"Early stop at epoch {epoch+1}, loss={loss.item():.4f}")
                break

    # =========================
    # PREDICT
    # =========================
    def predict(self, sentence):
        
        if isinstance(sentence, str):
            words = sentence.split()
        elif isinstance(sentence, list):
            words = []
            for item in sentence:
                if isinstance(item, str) and " " in item:
                    words.extend(item.split())
                else:
                    words.append(item)
        else:
            raise TypeError("sentence must be a string or a list")

        encoding = self.tokenizer(
            words,
            is_split_into_words=True,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=len(words)
        )

        input_ids = encoding["input_ids"].to(self.device)
        attention_mask = encoding["attention_mask"].to(self.device)

        self.model.eval()
        with torch.no_grad():
            logits = self.model(input_ids, attention_mask)["logits"]
            probs = F.softmax(logits, dim=-1)

        word_ids = encoding.word_ids(0)
        predictions = {}
        
        for i, w_id in enumerate(word_ids):
            
            if w_id is None:
                continue
            # one key per word, value = probability of being a boundary
            predictions[f'{i} {words[i]}'] = float(probs[0, i, 1])

        return predictions 



In [4]:
# List of CSV files without the .csv extension
files = ["sherlock_LLM", "merlin_LLM"]


# Initialize classifier
classifier = GPT2WordClassifier(files)

# Train
classifier.train(epochs=3, early_stop_loss=0.01)

Epoch 1 Loss: 0.0095
Early stop at epoch 1, loss=0.0095


## Boundary prediction

In [32]:
words_sherlock, boundaries_sherlock = classifier.get_data(['sherlock_LLM'])
words_merlin, boundaries_merlin = classifier.get_data(['merlin_LLM'])

preds_sherlock1 = classifier.predict(words_sherlock[0:1000])
preds_sherlock2 = classifier.predict(words_sherlock[1000:2000])
preds_sherlock3 = classifier.predict(words_sherlock[2000:2697])

preds_merlin1 = classifier.predict(words_merlin[0:1000])
preds_merlin2 = classifier.predict(words_merlin[1000:2000])
preds_merlin3 = classifier.predict(words_merlin[2000:2252])


preds = {}
current_index = 0

for d in (preds_sherlock1, preds_sherlock2, preds_sherlock3, preds_merlin1, preds_merlin2, preds_merlin3):
    for key, value in d.items():
        word = key.split(" ", 1)[1]   
        preds[f"{current_index} {word}"] = value
        current_index += 1


print(preds)

boundaries = boundaries_sherlock + boundaries_merlin
print(boundaries)


{'0 okay': 0.05407861992716789, '1 so': 0.005995710380375385, '2 they': 0.00012144700303906575, '3 began': 0.02673395350575447, '4 with': 0.019445333629846573, '5 like': 0.006302922498434782, '6 a': 0.0010435370495542884, '7 dream': 0.025655804201960564, '8 sequence': 0.11041971296072006, '9 of': 0.028049923479557037, '10 um': 0.018856655806303024, '11 a': 0.005316346418112516, '12 shootout': 0.11292082071304321, "13 it's": 0.008787707425653934, '14 during': 0.007454643025994301, '15 the': 0.07794184982776642, '16 day': 0.005116475746035576, '17 it': 0.040664296597242355, '18 looks': 0.012687728740274906, '19 like': 0.009927153587341309, '20 its': 0.041540224105119705, '21 in': 0.014269078150391579, '22 a': 0.022251542657613754, '23 grassy': 0.002816808642819524, '24 field': 0.015042685903608799, '25 with': 0.007402464281767607, '26 some': 0.021394331008195877, '27 sort': 0.010295791551470757, '28 of': 0.01115910429507494, '29 disheveled': 3.7371282815001905e-05, '30 like': 0.006965480

In [47]:

preds_values = list(preds.values())


indexes_of_ones = [i for i, v in enumerate(boundaries) if v == 1]
selected_preds_1 = [preds_values[i] for i in indexes_of_ones]
mean_ones = sum(selected_preds_1) / len(selected_preds_1)


indexes_of_zero = [i for i, v in enumerate(boundaries) if v == 0]
selected_preds_0 = [preds_values[i] for i in indexes_of_zero]
mean_zeros = sum(selected_preds_0) / len(selected_preds_0)


print('mean probability for 1 boundary:', mean_ones)
print('mean probability for 0 boundary:', mean_zeros)

mean probability for 1 boundary: 0.018231799420458564
mean probability for 0 boundary: 0.014611917023661694


# Magnifying probabilities

In [48]:
import math


new_word_probs = {word: 1 - math.exp(-value*100) for word, value in preds.items()}

print(new_word_probs)


{'0 okay': 0.9955187891619227, '1 so': 0.45095289408934636, '2 they': 0.012071251071490319, '3 began': 0.930982511541247, '4 with': 0.8569460379606133, '5 like': 0.4675638261214752, '6 a': 0.09909341406421235, '7 dream': 0.9231254510504855, '8 sequence': 0.9999839847841351, '9 of': 0.939492765825155, '10 um': 0.84827196142915, '11 a': 0.4123564026516183, '12 shootout': 0.9999875287189289, "13 it's": 0.5847068998282334, '14 during': 0.5254860692666317, '15 the': 0.9995878754641366, '16 day': 0.4004929678771062, '17 it': 0.982861530429317, '18 looks': 0.718823550914633, '19 like': 0.6294309043807339, '20 its': 0.9842988671081082, '21 in': 0.7599499450366982, '22 a': 0.8919492511819622, '23 grassy': 0.24548556137842203, '24 field': 0.7778202611796172, '25 with': 0.5230036443056528, '26 some': 0.8822784396349234, '27 sort': 0.6428427632697824, '28 of': 0.6723831289859955, '29 disheveled': 0.003730153908351408, '30 like': 0.5016975355132088, '31 half': 0.9483308007333878, '32 decomposed': 0

In [50]:

preds_values = list(new_word_probs.values())


indexes_of_ones = [i for i, v in enumerate(boundaries) if v == 1]
selected_preds_1 = [preds_values[i] for i in indexes_of_ones]
mean_ones = sum(selected_preds_1) / len(selected_preds_1)


indexes_of_zero = [i for i, v in enumerate(boundaries) if v == 0]
selected_preds_0 = [preds_values[i] for i in indexes_of_zero]
mean_zeros = sum(selected_preds_0) / len(selected_preds_0)


print('mean probability for 1 boundary:', mean_ones)
print('mean probability for 0 boundary:', mean_zeros)
print(mean_ones + mean_zeros)

mean probability for 1 boundary: 0.5349009728593934
mean probability for 0 boundary: 0.47111128639204386
1.0060122592514373
