In [24]:
!pip install transformers datasets torch




[notice] A new release of pip is available: 24.3.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


## Model Training

In [59]:
import os
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2TokenizerFast, GPT2Model, GPT2Config
from torch.optim import AdamW
import torch.nn.functional as F

class GPT2WordClassifier:
    def __init__(self, file_names, max_word_length=8, device=None, lr=1e-3, num_labels=5):
        """
        file_names : list of CSV filenames (without .csv extension)
        Each CSV must contain columns: 'word' and 'boundary' (0..num_labels-1)
        """
        if isinstance(file_names, str):
            file_names = [file_names]

        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.max_word_length = max_word_length
        self.num_labels = num_labels

        # Load all data from the CSV files
        self.texts, self.labels = self.get_data(file_names)

        # Tokenizer
        self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", add_prefix_space=True)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        # Dataset & Dataloader
        self.dataset = self.WordDataset(self.texts, self.labels, self.tokenizer, self.max_word_length)
        self.dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)

        # Model
        self.model = self.GPT2ForTokenClassification(self.num_labels).to(self.device)

        # Optimizer
        self.optimizer = AdamW(self.model.parameters(), lr=lr)

    # =========================
    # DATA LOADING
    # =========================
    def get_data(self, file_names):
        texts = []
        labels = []
        for name in file_names:
            file_path = os.path.join(os.getcwd(), f"{name}.csv")
            try:
                df = pd.read_csv(file_path, encoding='utf-8-sig')
                words = df['word'].to_list()
                boundaries = df['boundary'].to_list()
                # convert to int and clip to valid range
                boundaries = [int(x) if str(x).isdigit() else 0 for x in boundaries]
                boundaries = [min(max(0, x), self.num_labels - 1) for x in boundaries]
          
                texts.extend(words)
                labels.extend(boundaries)
            except FileNotFoundError:
                print(f'File not found. Check if "{name}.csv" exists in {os.getcwd()}')
                raise SystemExit
        return texts, labels

    # =========================
    # DATASET
    # =========================
    class WordDataset(Dataset):
        def __init__(self, texts, labels, tokenizer, max_length):
            self.texts = texts
            self.labels = labels
            self.tokenizer = tokenizer
            self.max_length = max_length

        def __len__(self):
            return len(self.texts)

        def __getitem__(self, idx):
            word = self.texts[idx]
            label = self.labels[idx]

            encoding = self.tokenizer(
                word,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=self.max_length,
                add_special_tokens=False  # avoid extra tokens
            )

            input_ids = encoding["input_ids"].squeeze(0)
            attention_mask = encoding["attention_mask"].squeeze(0)

            # label only the first token, rest = -100
            token_labels = torch.full_like(input_ids, -100)
            token_labels[0] = label

            return {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": token_labels
            }

    # =========================
    # MODEL
    # =========================
    class GPT2ForTokenClassification(nn.Module):
        def __init__(self, num_labels=5):
            super().__init__()
            config = GPT2Config.from_pretrained("gpt2")
            self.gpt2 = GPT2Model.from_pretrained("gpt2", config=config)

            # Freeze GPT2
            for param in self.gpt2.parameters():
                param.requires_grad = False

            self.classifier = nn.Linear(config.hidden_size, num_labels)

        def forward(self, input_ids, attention_mask=None, labels=None):
            outputs = self.gpt2(input_ids=input_ids, attention_mask=attention_mask)
            logits = self.classifier(outputs.last_hidden_state)

            loss = None
            if labels is not None:
                # logits: [batch, seq_len, num_labels] â†’ CrossEntropy expects [batch, num_labels, seq_len]
                loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
                loss = loss_fn(logits.permute(0, 2, 1), labels)
            return {"loss": loss, "logits": logits}

    # =========================
    # TRAIN
    # =========================
    def train(self, epochs=50, early_stop_loss=0.1):
        self.model.train()
        for epoch in range(epochs):
            for batch in self.dataloader:
                input_ids = batch["input_ids"].to(self.device)
                attention_mask = batch["attention_mask"].to(self.device)
                labels = batch["labels"].to(self.device)

                outputs = self.model(input_ids, attention_mask, labels)
                loss = outputs["loss"]

                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

            print(f"Epoch {epoch+1} Loss: {loss.item():.4f}")
            if loss.item() < early_stop_loss:
                print(f"Early stop at epoch {epoch+1}, loss={loss.item():.4f}")
                break

    # =========================
    # PREDICT
    # =========================
    def predict(self, sentence):
        if isinstance(sentence, str):
            words = sentence.split()
        elif isinstance(sentence, list):
            words = []
            for item in sentence:
                if isinstance(item, str) and " " in item:
                    words.extend(item.split())
                else:
                    words.append(item)
        else:
            raise TypeError("sentence must be a string or a list")

        encoding = self.tokenizer(
            words,
            is_split_into_words=True,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=len(words)
        )

        input_ids = encoding["input_ids"].to(self.device)
        attention_mask = encoding["attention_mask"].to(self.device)

        self.model.eval()
        with torch.no_grad():
            logits = self.model(input_ids, attention_mask)["logits"]
            probs = F.softmax(logits, dim=-1)

        word_ids = encoding.word_ids(0)
        predictions = {}

        for i, w_id in enumerate(word_ids):
            if w_id is None:
                continue
            # store all probabilities
            predictions[f"{i} {words[i]}"] = {f"b{c}": float(probs[0, i, c]) for c in range(self.num_labels)}

        return predictions



In [60]:
# List of CSV files without the .csv extension
files = ["sherlock_LLM", "merlin_LLM"]


# Initialize classifier
classifier = GPT2WordClassifier(files)

# Train
classifier.train(epochs=3, early_stop_loss=0.01)

Epoch 1 Loss: 2.0901
Epoch 2 Loss: 0.0463
Epoch 3 Loss: 0.6615


## Boundary prediction

In [90]:
words_sherlock, boundaries_sherlock = classifier.get_data(['sherlock_LLM'])
words_merlin, boundaries_merlin = classifier.get_data(['merlin_LLM'])

preds_sherlock1 = classifier.predict(words_sherlock[0:1000])
preds_sherlock2 = classifier.predict(words_sherlock[1000:2000])
preds_sherlock3 = classifier.predict(words_sherlock[2000:2697])

preds_merlin1 = classifier.predict(words_merlin[0:1000])
preds_merlin2 = classifier.predict(words_merlin[1000:2000])
preds_merlin3 = classifier.predict(words_merlin[2000:2252])


preds = {}
current_index = 0

for d in (preds_sherlock1, preds_sherlock2, preds_sherlock3, preds_merlin1, preds_merlin2, preds_merlin3):
    for key, value in d.items():
        word = key.split(" ", 1)[1]   
        preds[f"{current_index} {word}"] = value
        current_index += 1


# print(preds)

boundaries = boundaries_sherlock + boundaries_merlin
# print(boundaries)


In [91]:

preds_values = list(preds.values())

for b in range(5):

    indexes = [i for i, v in enumerate(boundaries) if v == b]
    selected_preds = [preds_values[i] for i in indexes]
    relevent_preds = [d[f"b{b}"] for d in selected_preds]

    mean_preds = sum(relevent_preds) / len(relevent_preds)

    
    print(f'Mean probability for {b} boundary:', mean_preds)


Mean probability for 0 boundary: 0.8379610048527389
Mean probability for 1 boundary: 0.1646124443388905
Mean probability for 2 boundary: 0.014356996173809986
Mean probability for 3 boundary: 0.00016643648040415746
Mean probability for 4 boundary: 0.00019702845987170677


# Magnifying probabilities (ignore here)

In [89]:
import math


new_word_probs = {word: 1 - math.exp(-value*100) for word, value in preds.items()}

print(new_word_probs)


TypeError: bad operand type for unary -: 'dict'

In [None]:

preds_values = list(new_word_probs.values())


indexes_of_ones = [i for i, v in enumerate(boundaries) if v == 1]
selected_preds_1 = [preds_values[i] for i in indexes_of_ones]
mean_ones = sum(selected_preds_1) / len(selected_preds_1)


indexes_of_zero = [i for i, v in enumerate(boundaries) if v == 0]
selected_preds_0 = [preds_values[i] for i in indexes_of_zero]
mean_zeros = sum(selected_preds_0) / len(selected_preds_0)


print('mean probability for 1 boundary after magnification:', mean_ones)
print('mean probability for 0 boundary after magnification:', mean_zeros)
print('total probability after magnification:', mean_ones + mean_zeros)

mean probability for 1 boundary after magnification: 0.5349009728593934
mean probability for 0 boundary after magnification: 0.47111128639204386
total probability after magnification 1.0060122592514373
