In [None]:
!pip install transformers==4.31.0

In [None]:
import csv
import random
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Example domains and sub-domains
domains = {
    "Technology": ["AI", "Cloud Computing", "Cybersecurity", "IoT", "Blockchain"],
    "Finance": ["Investments", "Banking", "Insurance", "Cryptocurrency", "Fintech"],
    "Healthcare": ["Surgery", "Pharmaceuticals", "Mental Health", "Diagnostics", "Telemedicine"],
    "Education": ["E-learning", "Traditional Schooling", "Edtech", "Research", "MOOCs"],
    "Entertainment": ["Movies", "Music", "Gaming", "Theatre", "Streaming"],
}

tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")

def generate_description(domain, sub_domain):
    prompt = f"Write a detailed 1000-word description about the {sub_domain} in the {domain} domain."
    try:
        inputs = tokenizer(prompt, return_tensors="pt")
        outputs = model.generate(**inputs, max_length=2000)  # Adjust max_length as needed
        description = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return description
    except Exception as e:
        print(f"Error generating description: {e}")
        return None


# Create the dataset
def generate_dataset(file_name, rows=10000):
    with open(file_name, mode="w", newline="", encoding="utf-8") as file:
        writer = csv.writer(file)
        writer.writerow(["description", "domain", "sub_domain"])  # Write header

        for _ in range(rows):
            # Randomly select domain and sub-domain
            domain = random.choice(list(domains.keys()))
            sub_domain = random.choice(domains[domain])

            # Generate description
            description = generate_description(domain, sub_domain)

            # Write row
            writer.writerow([description, domain, sub_domain])

    print(f"Dataset with {rows} rows created: {file_name}")

# Generate dataset
# generate_dataset("domain_classifier_dataset.csv", rows=10)


In [None]:
generate_dataset("../Dataset/raw data/domain_classifier_augmented_dataset.csv", rows=100)