In [2]:
import os
from torch.utils.data import DataLoader, DistributedSampler
from transformers import T5ForConditionalGeneration, AdamW, T5Tokenizer
from peft import LoraConfig, get_peft_model

model = T5ForConditionalGeneration.from_pretrained(
    "google/flan-t5-small",
    # torch_dtype=torch.float16,
    token=os.getenv('HF_ACCESS_TOKEN')
)

# config = LoraConfig(
#     task_type = "SEQ_2_SEQ_LM",
#     target_modules=["q", "v"],
#     r=8,
#     lora_alpha=32,
#     lora_dropout=0.1
# )
    
# model = get_peft_model(model, config)
# model.print_trainable_parameters()

# Initialize tokenizer
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
    

In [6]:
from dataset_utils import T5Dataset
import torch

with open('/home/tadesa1/research/ADBMO-UNLV/t5_training/gene_species_tagged_articles.json', 'r') as f:
    text = f.readlines()
    text = [line for line in text if line.strip()]

import json
# Load JSON
with open('/home/tadesa1/research/ADBMO-UNLV/t5_training/gene_species_tagged_articles.json', 'r') as f:
    data = json.load(f)
texts = []
annotations = []
# Flatten json structure
for pmid, sections in data.items():
    for section_type, content in sections.items():
        section_text = " ".join(content.get("text", []))
        section_annotations = []
        for ann_group in content.get("annotation", []):
            for ann_id, ann in ann_group.items():
             
                try:
                    start = int(ann["offset"])
                    end = start + int(ann["length"])
                    section_annotations.append((start, end))
                except:
                    continue
        texts.append(section_text)
        annotations.append(section_annotations)

dataset = T5Dataset(
    texts=texts,
    tokenizer=tokenizer,
    annotations=annotations  # Pass list of spans
)
# Clear CUDA cache
torch.cuda.empty_cache()

# Create collator
# my_collator = CollatorWrapper(tokenizer)
    

In [8]:
idx = 10  # Or any random index
text = dataset.texts[idx]
tokens = text.strip().split()
annotation_spans = dataset.annotations[idx]

corrupted_input, target = dataset.corrupt_text(tokens, annotation_spans)

print("="*80)
print(f"Original Text:\n{text}")
print(f"Tokens:\n{tokens}")
print(f"Annotation Spans:\n{annotation_spans}")
print(f"Corrupted Input:\n{corrupted_input}")
print(f"Target (What T5 has to Generate):\n{target}")
print("="*80)


Original Text:
Mild cognitive impairment prediction and cognitive score regression in the elderly using EEG topological data analysis and machine learning with awareness assessed in affective reminiscent paradigm
Tokens:
['Mild', 'cognitive', 'impairment', 'prediction', 'and', 'cognitive', 'score', 'regression', 'in', 'the', 'elderly', 'using', 'EEG', 'topological', 'data', 'analysis', 'and', 'machine', 'learning', 'with', 'awareness', 'assessed', 'in', 'affective', 'reminiscent', 'paradigm']
Annotation Spans:
[]
Corrupted Input:
Mild cognitive impairment prediction <extra_id_0> score <extra_id_1> the elderly using EEG topological data analysis and machine learning with awareness assessed in affective reminiscent paradigm <extra_id_2>
Target (What T5 has to Generate):
<extra_id_0> and cognitive <extra_id_1> regression in <extra_id_2>
