In [None]:
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification, AutoConfig, logging
from seqeval.metrics import classification_report
import torch
import json
import os

In [None]:
fine_tuned_version="./astrobert-ner-finetuned_2"

In [None]:
samples = []
circular_list = sub_cirs = range(21916, 21916+10)
def read_bio_files(filepaths):
    for filepath in filepaths:
        with open(filepath, "r", encoding="utf-8") as f:
            tokens, labels = [], []
            for line in f:
                line = line.strip()
                if not line:
                    continue
                splits = line.split()
                if len(splits) >= 2:
                    tokens.append(splits[0])
                    labels.append(splits[1])
            samples.append({"tokens": tokens, "ner_tags": labels})
    return samples

In [None]:
import os
print(os.getcwd())
foldername = "manual_annotation_1"
filepaths = [f"./{foldername}/{cir}.bio" for cir in circular_list]
# print(filepaths)
bio_file_data = read_bio_files(filepaths)
train_samples = bio_file_data.copy()
print(len(train_samples), train_samples[0])


In [None]:
unique_labels = sorted(set(l for s in train_samples for l in s["ner_tags"]))
label2id = {l: i for i, l in enumerate(unique_labels)}
id2label = {i: l for l, i in label2id.items()}
num_labels = len(label2id)
print(label2id)


In [None]:
for sample in train_samples:
    sample["labels"] = [label2id[l] for l in sample["ner_tags"]]

dataset = Dataset.from_list(train_samples)
print(len(dataset))
print(dataset[1])

In [None]:
# Load tokenizer and model
remote_model_path = 'kusha7/astrobert-gcn-tokenizer'
hf_token=os.enviro["hf_token"]
config = AutoConfig.from_pretrained(remote_model_path, num_labels=num_labels, id2label=id2label, label2id=label2id,)

tokenizer = AutoTokenizer.from_pretrained(remote_model_path, token=hf_token)
model = AutoModelForTokenClassification.from_pretrained(
    remote_model_path, token=hf_token,  config=config,
    ignore_mismatched_sizes=True
)
# max_len = max(len(tokenizer.tokenize(" ".join(ex["tokens"]))) for ex in train_samples)
# print(max_len)


In [None]:
def tokenize_and_align_labels(example):
    tokenized = tokenizer(example["tokens"], truncation=True, is_split_into_words=True, max_length=512, padding="max_length")
    word_ids = tokenized.word_ids()
    labels = []
    previous_word_idx = None
    for word_idx in word_ids:
        if word_idx is None:
            labels.append(-100)
        elif word_idx != previous_word_idx:
            labels.append(example["labels"][word_idx])
        else:
            labels.append(example["labels"][word_idx])
        previous_word_idx = word_idx
    tokenized["labels"] = labels
    return tokenized
logging.set_verbosity_info()
dataset = dataset.map(tokenize_and_align_labels)
print(dataset[0])

In [None]:

# Training config
args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=3,
    num_train_epochs=1,
    logging_dir="./logs",
    logging_steps=10,
    eval_strategy="no",  # optional
    save_strategy="epoch",
    # fp16=True,
)

data_collator = DataCollatorForTokenClassification(tokenizer)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset,
    tokenizer=tokenizer,
    data_collator=data_collator
)

trainer.train()

In [None]:
model.save_pretrained(fine_tuned_version)
tokenizer.save_pretrained(fine_tuned_version)
import spacy
nlp = spacy.load("en_core_web_sm")
label_list = model.config.id2label
def predict(text, max_length=512, stride=256):
    tokens = text.split()

    # Tokenize long input with overlapping chunks
    tokenized = tokenizer(
        tokens,
        return_tensors="pt",
        is_split_into_words=True,
        truncation=True,
        padding="max_length",
        max_length=max_length,
        stride=stride,
        return_overflowing_tokens=True,
        return_special_tokens_mask=True
    )

    all_predictions = []

    for i in range(len(tokenized["input_ids"])):
        inputs = {k: v[i].unsqueeze(0).to(model.device) for k, v in tokenized.items() if k in ["input_ids", "attention_mask", "token_type_ids"]}
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs[0]  # first element of tuple is logits

        preds = torch.argmax(logits, dim=-1)[0].cpu().numpy()
        word_ids = tokenized.word_ids(batch_index=i)

        for idx, word_id in enumerate(word_ids):
            if word_id is None or word_id == word_ids[idx - 1]:
                continue
            token = tokens[word_id]
            label = id2label[preds[idx]]
            all_predictions.append((token, label))

    return all_predictions

In [None]:
def get_circulars(c_list):
    json_circulars = []
    for sub_cir in c_list:
        with open('./archive.json/{}.json'.format(sub_cir), 'r') as f:
            data = json.load(f)
            json_circulars.append(data)
    return json_circulars

In [None]:
to_test = [21916]
device = torch.device("cpu")
model.to(device)
print("Labels in model:", model.config.id2label)
circulars = get_circulars(to_test)
print(circulars)
predictions = predict(circulars[0]["body"])

print(predictions)

In [None]:
base_model = AutoModelForTokenClassification.from_pretrained('kusha7/astrobert-gcn-tokenizer')
base_tokenizer = AutoTokenizer.from_pretrained('kusha7/astrobert-gcn-tokenizer')
print("Labels in model base:", base_model.config.id2label)
base_id2label = base_model.config.id2label
def predict_with_base(text):
    tokens = text.split()
    inputs = base_tokenizer(tokens, return_tensors="pt", is_split_into_words=True)
    with torch.no_grad():
        outputs = base_model(**inputs)
        logits = outputs[0]  # first element of tuple is logits

    preds = torch.argmax(logits, dim=-1)[0].cpu().numpy()
    return list(zip(tokens, [base_id2label[p] for p in preds]))
base_preds = predict_with_base(circulars[0]['body'])
print(base_preds)