In [2]:
import csv
import requests
import os

def load_articles(article_csv):
    """Load wikiArticles.csv into a dictionary: {topic: article}"""
    articles = {}
    with open(article_csv, 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        next(reader)  # skip header
        for row in reader:
            if len(row) >= 2:
                topic, article = row[0], row[1]
                articles[topic.strip()] = article.strip()
    return articles

def update_wiki_articles(grouped_csv, wiki_csv):
    """Fetch missing Wikipedia articles from grouped data and update wiki_output.csv"""
    existing_articles = load_articles(wiki_csv)
    updated_articles = existing_articles.copy()
    
    with open(grouped_csv, 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        next(reader)  # skip header
        
        for row in reader:
            if len(row) >= 1:
                topic = row[0].strip()
                
                # If the article for this topic doesn't exist, fetch it
                if topic not in existing_articles:
                    print(f"Fetching article for topic '{topic}'...")
                    article = fetch_wikipedia_summary(topic)
                    if article:
                        updated_articles[topic] = article
                    else:
                        print(f"Warning: No article found for topic '{topic}'")
    
    # Update the wiki_output.csv with all the articles
    with open(wiki_csv, 'w', encoding='utf-8', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["topic", "article"])  # Write the header
        for topic, article in updated_articles.items():
            writer.writerow([topic, article])
    
    print(f"Updated {len(updated_articles)} articles in '{wiki_csv}'.")

def generate_t5_triplets(wiki_csv, grouped_csv, output_file):
    """Generate T5 training triplets and write to a new CSV."""
    articles = load_articles(wiki_csv)
    triplets = []
    
    # Group the items in the grouped CSV by topic
    grouped_data = {}
    with open(grouped_csv, 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        next(reader)  # skip header

        for row in reader:
            if len(row) >= 2:
                topic = row[0].strip()
                summary = row[1].strip()

                # Group the summaries by topic
                if topic not in grouped_data:
                    grouped_data[topic] = []
                grouped_data[topic].append(summary)

    # Now generate the triplets using the grouped summaries
    for topic, summaries in grouped_data.items():
        article = articles.get(topic)
        if article:
            for summary in summaries:
                triplets.append(["summarize", article, summary])
        else:
            print(f"Warning: No article found for topic '{topic}'")

    # Write the triplets to the output CSV
    with open(output_file, 'w', encoding='utf-8', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["task", "input", "target"])
        writer.writerows(triplets)

    print(f"Generated {len(triplets)} T5 training triplets in '{output_file}'.")

def fetch_wikipedia_summary(topic):
    """Fetch the summary of a Wikipedia article for a given topic."""
    url = f"https://en.wikipedia.org/w/api.php?action=query&prop=extracts&exintro&explaintext&format=json&titles={topic}"
    response = requests.get(url)
    data = response.json()
    
    pages = data.get("query", {}).get("pages", {})
    for page_id, page in pages.items():
        return page.get("extract", "")
    
    return ""

def process_unique_topics(input_csv, unique_topics_file):
    """Process unique topics from input CSV and store in a file."""
    unique_topics = set()
    
    with open(input_csv, 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        next(reader)  # Skip header
        for row in reader:
            if len(row) >= 1:
                unique_topics.add(row[0].strip())
    
    # Clear and write the unique topics to the file
    with open(unique_topics_file, 'w', encoding='utf-8') as f:
        for topic in unique_topics:
            f.write(f"{topic}\n")
    
    print(f"Processed {len(unique_topics)} unique topics and saved to '{unique_topics_file}'.")

def clear_output_files(files):
    """Clear content of output files before writing new data."""
    for file in files:
        if os.path.exists(file):
            with open(file, 'w', encoding='utf-8') as f:
                pass  # Just open and close to clear content
            print(f"Cleared content of '{file}'.")

def main():
    # Define file paths
    wiki_csv = "wiki_output.csv"  # Store fetched articles here
    grouped_csv = "output_grouped.csv"  # Do not modify this file
    t5_triplets_output = "t5_triplets.csv"
    unique_topics_file = "unique_topics.txt"
    
    # Clear previous output files (except for output_grouped.csv)
    clear_output_files([t5_triplets_output, unique_topics_file, wiki_csv])
    
    # Update wiki_output.csv with missing articles
    update_wiki_articles(grouped_csv, wiki_csv)
    
    # Generate T5 training triplets
    generate_t5_triplets(wiki_csv, grouped_csv, t5_triplets_output)
    
    # Process unique topics
    process_unique_topics(grouped_csv, unique_topics_file)

if __name__ == "__main__":
    main()


Cleared content of 't5_triplets.csv'.
Cleared content of 'wiki_output.csv'.


StopIteration: 