In [None]:
import csv
import os
import sys
import requests
from lxml import html
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from datasets import Dataset, enable_progress_bar
from tqdm import tqdm
import torch

# Determine if output is a TTY to enable progress bars
use_tqdm = sys.stdout.isatty()

# Attempt DirectML import for AMD on Windows
torch_directml = None
try:
    import torch_directml  # type: ignore
    torch_directml = torch_directml
    print("DirectML support enabled for AMD GPU.")
except ImportError:
    pass

# ----------- CONFIG -----------
output_grouped = 'output_grouped.csv'
t5_triplets_output = 't5_triplets.csv'
unique_topics_file = 'unique_topics.txt'
wiki_csv = 'wiki_output.csv'
answer_triplets_output = 'answer_triplets.csv'

# ----------- DEVICE DETECTION -----------
def get_best_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    if torch_directml:
        return torch_directml.device()
    return torch.device('cpu')

# --------------- UTILITY FUNCTIONS ---------------
def clear_output_files(files):
    for file in files:
        with open(file, 'w', encoding='utf-8'):
            pass


def load_articles(article_csv):
    articles = {}
    try:
        with open(article_csv, 'r', newline='', encoding='utf-8') as f:
            reader = csv.reader(f)
            next(reader, None)
            for row in reader:
                if len(row) >= 2:
                    articles[row[0].strip()] = row[1].strip()
    except FileNotFoundError:
        print(f"Warning: {article_csv} not found. Using empty dictionary.")
    except StopIteration:
        print(f"Warning: {article_csv} is empty.")
    return articles


def clean_quotes_in_csv(csv_file):
    rows = []
    with open(csv_file, 'r', newline='', encoding='utf-8') as f:
        reader = csv.reader(f)
        for row in reader:
            if row:
                row[0] = row[0].replace('"', '').strip()
                rows.append(row)
    with open(csv_file, 'w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerows(rows)

# ----------- WIKIPEDIA API FUNCTIONS -----------
def search_and_fetch_article(topic):
    try:
        search_url = 'https://en.wikipedia.org/w/api.php'
        search_params = {'action': 'query', 'format': 'json', 'list': 'search', 'utf8': 1, 'srsearch': topic}
        resp = requests.get(search_url, params=search_params, timeout=5)
        data = resp.json()
        results = data.get('query', {}).get('search', [])
        if not results:
            return f"No Wikipedia article found for topic: {topic}"
        title = results[0]['title']
        parse_params = {'action': 'parse', 'format': 'json', 'page': title, 'prop': 'text', 'redirects': ''}
        resp2 = requests.get(search_url, params=parse_params, timeout=5)
        page = resp2.json()
        raw_html = page['parse']['text']['*']
        doc = html.fromstring(raw_html)
        paragraphs = doc.xpath('//p')
        texts = [p.text_content().strip() for p in paragraphs if p.text_content().strip()]
        return '\n\n'.join(texts) if texts else f"No extractable article found for {title}."
    except Exception as e:
        return f"Error retrieving article for {topic}: {e}"
fetch_wiki_article = search_and_fetch_article

# --------------- PROCESSING FUNCTIONS ---------------
def process_unique_topics(grouped_csv, unique_topics_file):
    unique = set()
    with open(grouped_csv, 'r', newline='', encoding='utf-8') as f:
        reader = csv.reader(f)
        next(reader, None)
        for row in reader:
            if row and row[0].strip():
                unique.add(row[0].strip())
    with open(unique_topics_file, 'w', newline='', encoding='utf-8') as f:
        for t in sorted(unique):
            f.write(t + '\n')
    print(f"Saved {len(unique)} unique topics to {unique_topics_file}")


def update_wiki_articles(grouped_csv, wiki_csv):
    existing = load_articles(wiki_csv)
    updated = existing.copy()
    with open(grouped_csv, 'r', newline='', encoding='utf-8') as f:
        reader = csv.reader(f)
        next(reader, None)
        rows = list(reader)
    for row in tqdm(rows, desc="Fetching Wikipedia articles", unit="topic", disable=not use_tqdm):
        topic = row[0].strip() if row else None
        if topic and topic not in updated:
            art = fetch_wiki_article(topic)
            updated[topic] = art
    with open(wiki_csv, 'w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerow(['Topic', 'Article'])
        for t, a in updated.items():
            writer.writerow([t, a])
    print(f"Updated {len(updated)} articles in {wiki_csv}")


def generate_t5_triplets(wiki_csv, grouped_csv, output_file):
    articles = load_articles(wiki_csv)
    triplets = []
    with open(grouped_csv, 'r', newline='', encoding='utf-8') as f:
        reader = csv.reader(f)
        next(reader, None)
        rows = list(reader)
    for row in tqdm(rows, desc="Generating T5 triplets", unit="row", disable=not use_tqdm):
        if len(row) >= 2:
            topic, summary = row[0].strip(), row[1].strip()
            art = articles.get(topic)
            if art is not None:
                triplets.append(['summarize', art, summary])
    with open(output_file, 'w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerow(['task', 'input', 'target'])
        writer.writerows(triplets)
    print(f"Generated {len(triplets)} triplets to {output_file}")


def train_t5_model(t5_triplets_csv):
    device = get_best_device()
    print("Training on device:", device)
    ds = Dataset.from_csv(t5_triplets_csv, delimiter=',')
    tokenizer = T5Tokenizer.from_pretrained('t5-small')
    model = T5ForConditionalGeneration.from_pretrained('t5-small').to(device)

    enable_progress_bar()

    def preprocess(examples):
        inp = tokenizer(examples['input'], padding='max_length', truncation=True, max_length=512)
        lbl = tokenizer(examples['target'], padding='max_length', truncation=True, max_length=128)
        return {'input_ids': inp['input_ids'], 'attention_mask': inp['attention_mask'], 'labels': lbl['input_ids']}

    tokenized = ds.map(preprocess, batched=True, remove_columns=ds.column_names)
    args = TrainingArguments(output_dir='./results', per_device_train_batch_size=8, num_train_epochs=3,
                              save_steps=10000, save_total_limit=2)
    trainer = Trainer(model=model, args=args, train_dataset=tokenized, tokenizer=tokenizer)
    trainer.train()
    print("Training complete.")

# --------------- NEW: QUESTION GENERATION & ANSWER TRAINING ---------------

def generate_questions_and_train_answers(output_grouped_csv, model_dir):
    tokenizer = T5Tokenizer.from_pretrained(model_dir)
    model = T5ForConditionalGeneration.from_pretrained(model_dir).to(get_best_device())

    qa_pairs = []
    with open(output_grouped_csv, 'r', newline='', encoding='utf-8') as f:
        reader = csv.reader(f)
        next(reader, None)
        for row in reader:
            if len(row) >= 3:
                topic = row[0].strip()
                actual_answer = row[2].strip()
                input_ids = tokenizer(topic, return_tensors='pt').input_ids.to(get_best_device())
                outputs = model.generate(input_ids, max_length=64)
                generated_question = tokenizer.decode(outputs[0], skip_special_tokens=True)
                qa_pairs.append((generated_question, actual_answer))

    with open(answer_triplets_output, 'w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerow(['task', 'input', 'target'])
        for q, a in qa_pairs:
            writer.writerow(['question', q, a])

    ds_ans = Dataset.from_csv(answer_triplets_output, delimiter=',')
    def preprocess_ans(examples):
        inp = tokenizer(examples['input'], padding='max_length', truncation=True, max_length=64)
        lbl = tokenizer(examples['target'], padding='max_length', truncation=True, max_length=64)
        return {'input_ids': inp['input_ids'], 'attention_mask': inp['attention_mask'], 'labels': lbl['input_ids']}

    tokenized_ans = ds_ans.map(preprocess_ans, batched=True)
    args_ans = TrainingArguments(
        output_dir=os.path.join(model_dir, 'answer_results'),
        per_device_train_batch_size=8,
        num_train_epochs=3,
        save_steps=10000,
        save_total_limit=2,
    )
    trainer_ans = Trainer(model=model, args=args_ans, train_dataset=tokenized_ans, tokenizer=tokenizer)
    trainer_ans.train()
    print("Answer model training complete.")

# --------------- MAIN ---------------

def main():
    clear_output_files([t5_triplets_output, unique_topics_file, wiki_csv, answer_triplets_output])
    clean_quotes_in_csv(output_grouped)
    process_unique_topics(output_grouped, unique_topics_file)
    update_wiki_articles(output_grouped, wiki_csv)
    generate_t5_triplets(wiki_csv, output_grouped, t5_triplets_output)
    train_t5_model(t5_triplets_output)
    generate_questions_and_train_answers(output_grouped, './results')

if __name__ == '__main__':
    main()


Saved 26959 unique topics to unique_topics.txt
Fetching article for: A JIM CARREY FILM FESTIVAL
Fetching article for: !
Fetching article for: -ARES
Fetching article for: -ICIAN EXPEDITION
Fetching article for: ...OD WORDS
Fetching article for: 1, 2, 3
Fetching article for: 20 QUESTIONS
Fetching article for: A & E
Fetching article for: A & M
Fetching article for: A + 4
Fetching article for: A BEFORE E
Fetching article for: A IN COLLEGE
Fetching article for: A IN GEOGRAPHY
Fetching article for: A IN HISTORY
Fetching article for: A IN LITERATURE
Fetching article for: A IN MATH
Fetching article for: A IN SCIENCE
Fetching article for: A IN SEX EDUCATION
Fetching article for: A IN SHAKESPEARE
Fetching article for: A IS FOR AUTHOR
Fetching article for: A MEN
Fetching article for: A OK
Fetching article for: A PLUS
Fetching article for: A SCIENCE CATEGORY
Fetching article for: A TOUGHIE
Fetching article for: A TRAVEL CATEGORY
Fetching article for: A _____
Fetching article for: ABBREVIATIONS
Fet