In [1]:
import logging

import torch
import torch.nn.functional as F

from lightning.pytorch import Trainer
from lightning.pytorch.tuner import Tuner

from model import FasterRCNN
from data import JokesDataModule
from helper import show_image, show_image_and_bounding_box, show_worst_image_predictions, show_confusion_matrix, get_batch, MyProgressBar
from helper import get_sample, convert_predictions

log = logging.getLogger(__name__)

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
import requests
import json
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer

class ActorsWikipediaDataset(Dataset):
    def __init__(self, max_articles=100):
        self.max_articles = max_articles
        self.articles = self._fetch_articles()
        
    def _fetch_articles(self):
        url = f"https://en.wikipedia.org/w/api.php?action=query&list=categorymembers&cmtitle=Category:Actors&cmlimit={self.max_articles}&format=json"
        response = requests.get(url)
        data = response.json()
        article_titles = [item['title'] for item in data['query']['categorymembers']]
        
        articles = []
        for title in tqdm(article_titles, desc="Fetching articles"):
            article_text = self._fetch_article_text(title)
            if article_text:
                articles.append(article_text)
        
        return articles
    
    def _fetch_article_text(self, title):
        url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{title.replace(' ', '_')}"
        response = requests.get(url)
        if response.status_code == 200:
            data = response.json()
            return data.get('extract', '')
        return ''
    
    def __len__(self):
        return len(self.articles)
    
    def __getitem__(self, idx):
        return self.articles[idx]

class TextCollate:
    def __init__(self, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __call__(self, batch):
        encoding = self.tokenizer(batch, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length)
        input_ids = encoding['input_ids']
        attention_mask = encoding['attention_mask']
        return input_ids, attention_mask

# Example usage
if __name__ == "__main__":
    dataset = WikipediaDataset(max_articles=100)
    
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token  # Add padding token
    
    collate_fn = TextCollate(tokenizer)
    
    dataloader = DataLoader(dataset, batch_size=8, collate_fn=collate_fn)
    
    for input_ids, attention_mask in dataloader:
        print(input_ids.shape, attention_mask.shape)
        break


Fetching articles: 100%|██████████████████████████████████████████████████| 40/40 [00:06<00:00,  6.46it/s]


In [21]:
data_module = JokesDataModule(data_dir='data', batch_size=8, max_length=512)

for input_ids, attention_mask in data_module.train_dataloader():
    decoded_texts = [data_module.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
    for text in decoded_texts:
        print(text)
    break

Don't drink and drive, also don't call frozen yogurt "fro yo."
Who the hell decided "have a happy period" was an okay thing to write on maxi pads? "NOT WORTH THE JAIL TIME" would have been more relevant.
Just tell me when and where and I'll be there 20 minutes late.
I hate it when people can't make a good sausage its the wurst
What does a dog get at the vet? [FIXED]
Whats worse than being adopted Being adopted twice.
David Beckham says he will retire at the end of this season, mainly because he ran out of ideas on how to do his next haircut.
My friend asked me if I was ready to go to the nudist colony. I was born ready.
