In [None]:
import csv
import os
from time import sleep
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from datasets import Dataset
import torch

# Attempt DirectML import for AMD on Windows
torch_directml = None
try:
    import torch_directml # type: ignore
    torch_directml = torch_directml
except ImportError:
    pass

# ----------- CONFIG -----------

# File paths
output_grouped = 'output_grouped.csv'
t5_triplets_output = 't5_triplets.csv'
unique_topics_file = 'unique_topics.txt'
wiki_csv = 'wiki_output.csv'

# ----------- DEVICE DETECTION -----------

def get_best_device():
    """Return the best available device: NVIDIA CUDA, AMD ROCm/CUDA, DirectML, or CPU."""
    # NVIDIA / ROCm (installed as CUDA build)
    if torch.cuda.is_available():
        return torch.device('cuda')
    # AMD on Windows via DirectML
    if torch_directml:
        return torch_directml.device()
    # Fallback to CPU
    return torch.device('cpu')

# --------------- UTILITY FUNCTIONS ---------------

def clear_output_files(files):
    """Clear the contents of the provided files."""
    for file in files:
        with open(file, 'w', encoding='utf-8') as f:
            pass


def load_articles(article_csv):
    """Load wikiArticles.csv into a dictionary: {topic: article}"""
    articles = {}
    try:
        with open(article_csv, 'r', encoding='utf-8') as f:
            reader = csv.reader(f)
            header_skipped = False
            for row in reader:
                if not header_skipped:
                    header_skipped = True
                    continue
                if len(row) >= 2:
                    articles[row[0].strip()] = row[1].strip()
    except FileNotFoundError:
        print(f"Warning: {article_csv} not found. Using empty dictionary.")
    except StopIteration:
        print(f"Warning: {article_csv} is empty.")
    return articles


def clean_quotes_in_csv(csv_file):
    """Remove quotes from the first column of the CSV."""
    rows = []
    with open(csv_file, 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        for row in reader:
            if row:
                row[0] = row[0].replace('"', '').strip()
                rows.append(row)
    with open(csv_file, 'w', encoding='utf-8', newline='') as f:
        writer = csv.writer(f)
        writer.writerows(rows)


def fetch_wiki_article(topic):
    """Fetch the Wikipedia article for a given topic."""
    # Implement API fetch here; placeholder:
    return f"Article for {topic}"


def process_unique_topics(grouped_csv, unique_topics_file):
    """Process unique topics from grouped CSV and save them to a file."""
    unique_topics = set()
    with open(grouped_csv, 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        next(reader, None)
        for row in reader:
            if row and row[0].strip():
                unique_topics.add(row[0].strip())
    with open(unique_topics_file, 'w', encoding='utf-8') as f:
        for topic in sorted(unique_topics):
            f.write(f"{topic}\n")
    print(f"Saved {len(unique_topics)} unique topics to {unique_topics_file}")


def update_wiki_articles(grouped_csv, wiki_csv):
    """Fetch missing Wikipedia articles and update wiki_output.csv."""
    existing = load_articles(wiki_csv)
    updated = existing.copy()
    with open(grouped_csv, 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        next(reader, None)
        for row in reader:
            if row and row[0].strip():
                topic = row[0].strip()
                if topic not in updated:
                    art = fetch_wiki_article(topic)
                    updated[topic] = art
                    sleep(0.5)
    with open(wiki_csv, 'w', encoding='utf-8', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Topic','Article'])
        for t,a in updated.items():
            writer.writerow([t,a])
    print(f"Updated {len(updated)} articles in {wiki_csv}")


def generate_t5_triplets(wiki_csv, grouped_csv, output_file):
    """Generate T5 training triplets and write to CSV."""
    articles = load_articles(wiki_csv)
    triplets = []
    with open(grouped_csv, 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        next(reader, None)
        for row in reader:
            if len(row) >= 2:
                topic, summary = row[0].strip(), row[1].strip()
                art = articles.get(topic)
                if art:
                    triplets.append(['summarize', art, summary])
                else:
                    print(f"No article for {topic}")
    with open(output_file, 'w', encoding='utf-8', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['task','input','target'])
        writer.writerows(triplets)
    print(f"Generated {len(triplets)} triplets to {output_file}")


def train_t5_model(t5_triplets_csv):
    """Train T5 model on triplets"""
    # Device setup
    device = get_best_device()
    print(f"Training on device: {device}")
    # Load dataset, model, tokenizer
    ds = Dataset.from_csv(t5_triplets_csv)
    tokenizer = T5Tokenizer.from_pretrained('t5-small')
    model = T5ForConditionalGeneration.from_pretrained('t5-small').to(device)
    
    def preprocess(examples):
        inputs = tokenizer([str(i) for i in examples['input']], 
                           padding='max_length', truncation=True, max_length=512)
        labels = tokenizer([str(l) for l in examples['target']], 
                           padding='max_length', truncation=True, max_length=128)
        return {'input_ids': inputs['input_ids'],
                'attention_mask': inputs['attention_mask'],
                'labels': labels['input_ids']}

    tokenized = ds.map(preprocess, batched=True)

    args = TrainingArguments(
        output_dir='./results',
        per_device_train_batch_size=8,
        num_train_epochs=3,
        save_steps=10000,
        save_total_limit=2,
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=tokenized,
        tokenizer=tokenizer,
    )
    trainer.train()
    print("Training complete.")

# --------------- MAIN ---------------

def main():
    clear_output_files([t5_triplets_output, unique_topics_file, wiki_csv])
    clean_quotes_in_csv(output_grouped)
    process_unique_topics(output_grouped, unique_topics_file)
    update_wiki_articles(output_grouped, wiki_csv)
    generate_t5_triplets(wiki_csv, output_grouped, t5_triplets_output)
    train_t5_model(t5_triplets_output)

if __name__=='__main__':
    main()


Saved 26959 unique topics to unique_topics.txt
