# LLM-Approach: `B2B`-Classification

## Loading packages and set work-environment

In [None]:
import os
import pandas as pd

from datasets import load_dataset
from datasets import DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding

In [None]:
os.chdir('/Users/janlinzner/Projects/thesis-spatial-seed-syndication') 

## Pre-process data

Load two datasets: one for `training` and one for `prediction`

In [None]:
data_files = {
    "train":   "data/business_orientation/companies_business_focus_save.csv",   # has Description & B2B Binary
    "predict": "data/business_focus/df_missing_b2b.csv"        # has Description only
}

ds = load_dataset("csv", data_files=data_files)

ds["train"] = ds["train"] \
    .rename_column("Description", "text") \
    .rename_column("B2B Binary", "label")

ds["predict"] = ds["predict"] \
    .rename_column("Description", "text")

Clean the prediction dataset and only keep the description

In [None]:
df_predict = pd.read_csv(data_files["predict"])
df_predict = df_predict.rename(columns={"organization_description": "Description"})
df_predict = df_predict[["Description"]]
df_predict.to_csv("data/business_orientation/df_missing_b2b_minimal.csv", index=False)

data_files["predict"] = "data/business_orientation/df_missing_b2b_minimal.csv"

## Set Model

We use the distilbert-base-uncased model and finetune the weights ([Link to Huggingface](https://huggingface.co/distilbert/distilbert-base-uncased))

In [None]:
model_name = "distilbert-base-uncased" 
tokenizer  = AutoTokenizer.from_pretrained(model_name)
model      = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=2                     
)

## Training Pre-Processing

Tokenization

In [None]:
def preprocess_train(batch):
    toks = tokenizer(
        batch["text"],
        truncation=True,
        padding="max_length",
        max_length=128
    )
    toks["labels"] = batch["label"] 
    return toks

ds["train"] = ds["train"].map(
    preprocess_train,
    batched=True,
    remove_columns=[
        "Organization Name",
        "Organization Name URL",
        "text",
        "label"
    ]
)

train_test_split = ds["train"].train_test_split(test_size=0.2, seed=42)
ds = DatasetDict({
    "train": train_test_split["train"],
    "validation": train_test_split["test"],
    "predict": ds["predict"]
})

In [None]:
def preprocess_predict(batch):
    return tokenizer(
        batch["text"],
        truncation=True,
        padding="max_length",
        max_length=128
    )

ds["predict"] = ds["predict"].map(
    preprocess_predict,
    batched=True,
    remove_columns=["text"]
)

In [None]:
data_collator = DataCollatorWithPadding(tokenizer)

## Definition of training parameters and trainer

In [None]:
training_args = TrainingArguments(
    output_dir="distilbert_finetuned_vc",
    logging_steps = 100 ,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    num_train_epochs=1,
    learning_rate=2e-5,
    weight_decay=0.01,
    eval_strategy="steps",       
    eval_steps=100,                
    save_strategy="steps",
    save_steps=100,               
    save_total_limit=1,          
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,  
    seed=42
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds["train"],
    eval_dataset=ds["validation"],  
    data_collator=data_collator,
    tokenizer=tokenizer
)

## Training process and validation

In [None]:
trainer.train()
validation_metrics = trainer.evaluate()
print("Validation Loss:", validation_metrics["eval_loss"])

## Labelling and result export

In [None]:
preds = trainer.predict(ds["predict"])
pred_labels = preds.predictions.argmax(-1)

In [None]:
df = pd.read_csv("data/business_orientation/df_missing_b2b.csv")
df["pred_label"] = pred_labels
df.to_csv("data/business_orientation/companies_business_focus_llm.csv", index=False)
print("✅ Done — predictions saved to companies_business_focus_llm.csv")