## Import

In [1]:
import torch
from datasets import Dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
import pandas as pd

## Data Prep

In [7]:
def prepare_dataset(data_path):
    # Read and clean the data
    df = pd.read_csv(data_path)
    df['icd10_code'] = df['icd10_code'].fillna('').astype(str)
    df['diagnoses'] = df['diagnoses'].fillna('').astype(str)
    
    # Split into train and validation
    train_df = df.sample(frac=0.9, random_state=42)
    val_df = df.drop(train_df.index)
    
    # Convert to HuggingFace datasets
    train_dataset = Dataset.from_pandas(train_df)
    val_dataset = Dataset.from_pandas(val_df)
    
    return train_dataset, val_dataset

# Load and prepare the datasets
train_data, val_data = prepare_dataset("labeled_icd10_flanT5.csv")
print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")

Training samples: 64534
Validation samples: 7170


## Tokenization function and model setup

In [18]:
# Initialize model and tokenizer
model_name = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

def tokenize_data(examples, tokenizer):
    input_texts = [str(text) for text in examples["icd10_code"]]
    target_texts = [str(text) for text in examples["diagnoses"]]
    
    # Tokenize inputs
    model_inputs = tokenizer(
        input_texts,
        padding="max_length",  # Changed to enable padding
        truncation=True,       # Enable truncation
        max_length=128,
        return_tensors=None    # Important: don't return tensors here
    )
    # Tokenize targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            target_texts,
            padding="max_length",  # Enable padding
            truncation=True,       # Enable truncation
            max_length=16,
            return_tensors=None    # Important: don't return tensors here
        )
    
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# model to cuda if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=384, bias=False)
              (k): Linear(in_features=512, out_features=384, bias=False)
              (v): Linear(in_features=512, out_features=384, bias=False)
              (o): Linear(in_features=384, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 6)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=512, out_features=1024, bias=False)
              (wi_1): Linear(in_features=512, out_features=1024, bias=False)
              (wo): 

## Set up training arguments

In [19]:
# 3. Tokenize datasets
tokenized_train = train_data.map(
    lambda x: tokenize_data(x, tokenizer),
    batched=True,              # Enable batching
    remove_columns=train_data.column_names
)

tokenized_val = val_data.map(
    lambda x: tokenize_data(x, tokenizer),
    batched=True,              # Enable batching
    remove_columns=val_data.column_names
)

Map:   0%|          | 0/64534 [00:00<?, ? examples/s]



Map:   0%|          | 0/7170 [00:00<?, ? examples/s]

## Training Arguments

In [20]:
# 4. Set up data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True              # Ensure padding is enabled
)

# Configure training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./predict-icd10/model/01_flan_t5_training",
    eval_strategy="epoch",
    learning_rate=3e-4,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    save_total_limit=3,
)

# Initialize data collator
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

## Initialize trainer and start training

In [21]:
# Set up the trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

# Start training
print("Starting training...")
trainer.train()

output_dir = "./predict-icd10/model/01_flan_t5_model"
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print("Training completed and model saved!")

  trainer = Seq2SeqTrainer(


Starting training...


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

## Test prediction function

In [None]:
def predict(text, model_path=output_dir):
    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Tokenize input
    inputs = tokenizer(text, return_tensors="pt", max_length=128, truncation=True, padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Generate prediction
    outputs = model.generate(**inputs, max_length=16, num_beams=4)
    predicted_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return predicted_code

# Test the model
test_text = "Cholera due to Vibrio cholerae 01, biovar cholerae"
predicted_icd = predict(test_text)
print(f"Input: {test_text}")
print(f"Predicted ICD-10: {predicted_icd}")