In [None]:
!pip install transformers
!pip install sacremoses

In [None]:
import csv
import random
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Example domains and sub-domains
domains = {
    "Medical": ['Orthopedics', 'Oncology', 'Neurology', 'Pathology', 'Radiology',
                'Infectious Diseases', 'Dermatology', 'General Surgery',
                'Nephrology', 'Emergency Medicine', 'Plastic Surgery',
                'Pediatrics', 'Endocrinology', 'Internal Medicine', 'Cardiology',
                'Reproductive Medicine', 'Anesthesiology', 'Hematology']

}

tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")


def generate_description(domain, sub_domain):
    prompt = f"Write a detailed 1000-word description about the {sub_domain} in the {domain} domain."

    try:
        # Tokenize the input prompt
        inputs = tokenizer(prompt, return_tensors="pt")

        # Generate multiple outputs to ensure diversity (using Top-p or Top-k with num_return_sequences)
        outputs = model.generate(
            **inputs,
            max_length=2000,  # Adjust max_length as needed
            num_return_sequences=3,  # Generate 3 different responses
            temperature=0.8,  # Introduce randomness
            top_k=50,  # Use Top-k sampling for diversity
            top_p=0.9,  # Use Top-p (nucleus) sampling for diversity
            repetition_penalty=2.0,  # Avoid repetition
            do_sample=True,  # Enables sampling (necessary for diversity)
        )

        # Decode the outputs to text
        descriptions = [tokenizer.decode(
            output, skip_special_tokens=True) for output in outputs]

        # We can further process to choose the most unique description
        unique_description = descriptions[0]
        return unique_description
    except Exception as e:
        print(f"Error generating description: {e}")
        return None


# Create the dataset
def generate_dataset(file_name, rows=10000):
    with open(file_name, mode="w", newline="", encoding="utf-8") as file:
        writer = csv.writer(file)
        # Write header
        writer.writerow(["description", "domain", "sub_domain"])

        for _ in range(rows):
            # Randomly select domain and sub-domain
            domain = random.choice(list(domains.keys()))
            sub_domain = random.choice(domains[domain])

            # Generate description
            description = generate_description(domain, sub_domain)

            # Write row
            writer.writerow([description, domain, sub_domain])

    print(f"Dataset with {rows} rows created: {file_name}")

# Generate dataset
# generate_dataset("domain_classifier_dataset.csv", rows=10)

In [None]:
generate_dataset("../Dataset/raw data/domain_classifier_augmented_dataset.csv", rows=100)