In [2]:
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
tweets = load_dataset("m-newhauser/senator-tweets")

In [4]:
tweets

DatasetDict({
    train: Dataset({
        features: ['date', 'id', 'username', 'text', 'party', 'labels', 'embeddings'],
        num_rows: 79754
    })
    test: Dataset({
        features: ['date', 'id', 'username', 'text', 'party', 'labels', 'embeddings'],
        num_rows: 19939
    })
})

In [5]:
tweets['train'][0]

{'date': '2021-10-13 19:47:44',
 'id': 1448374915636383745,
 'username': 'SenatorHassan',
 'text': 'Happy th birthday to the @USNavy! The strength, dedication, and skill of our Sailors including those at Portsmouth Naval Shipyard help keep this country safe, secure, and free. Today we recognize and celebrate their incredible service. #246NavyBirthday https://t.co/GuHEDMApke',
 'party': 'Democrat',
 'labels': 1,
 'embeddings': [-0.026915842667222023,
  0.08723406493663788,
  0.018707331269979477,
  -0.03298894315958023,
  -0.014149527996778488,
  0.0024309002328664064,
  0.02864699810743332,
  -0.05514369532465935,
  -0.15324026346206665,
  0.01460077241063118,
  -0.005979436449706554,
  0.010507513768970966,
  0.06579640507698059,
  0.007917123846709728,
  -0.05829846113920212,
  0.08774501085281372,
  -0.019554272294044495,
  -0.056161101907491684,
  -0.025554964318871498,
  0.03465566039085388,
  -0.11109739542007446,
  0.02460041455924511,
  -0.021291883662343025,
  0.01870913989841

In [19]:
from torch.utils.data import Subset

Subset(tweets['train'], train_idx)

<torch.utils.data.dataset.Subset at 0x2ac11f8e0>

In [18]:
from sklearn.model_selection import train_test_split



In [80]:
import torch
import datasets
import pytorch_lightning as pl

from datasets import load_dataset
from transformers import AutoTokenizer


class DataModule(pl.LightningDataModule):
    def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", batch_size=32):
        super().__init__()

        self.batch_size = batch_size
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

    def prepare_data(self, val_size = 0.2, random_state = 42, stratify = True):
        tweets_dataset = load_dataset("m-newhauser/senator-tweets")
        
        
        if stratify:
            tweets_dataset_split = tweets_dataset['train'].train_test_split(test_size=val_size, seed=random_state, stratify_by_column='labels')
        else:
            tweets_dataset_split = tweets_dataset['train'].train_test_split(test_size=val_size, seed=random_state)

        self.train_data = tweets_dataset_split['train']
        self.val_data = tweets_dataset_split['test']


    def tokenize_data(self, example):
        return self.tokenizer(
            example["text"],
            truncation=True,
            padding="max_length",
            max_length=512,
        )
        
    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_data = self.train_data.map(self.tokenize_data, batched=True)
            self.train_data.set_format(
                type="torch", columns=['date', 'id', 'username', 'text', 'party', 'labels', 'embeddings']
            )

            self.val_data = self.val_data.map(self.tokenize_data, batched=True)
            self.val_data.set_format(
                type="torch", columns=['date', 'id', 'username', 'text', 'party', 'labels', 'embeddings']
            )

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_data, batch_size=self.batch_size, shuffle=True
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_data, batch_size=self.batch_size, shuffle=False
        )

In [81]:
data_model = DataModule()
data_model.prepare_data()
data_model.setup()

Map: 100%|██████████| 63803/63803 [00:10<00:00, 5807.20 examples/s]
Map: 100%|██████████| 15951/15951 [00:02<00:00, 6087.48 examples/s]


In [None]:
data_model.train_data.dataset

DatasetDict({
    train: Dataset({
        features: ['date', 'id', 'username', 'text', 'party', 'labels', 'embeddings'],
        num_rows: 79754
    })
    test: Dataset({
        features: ['date', 'id', 'username', 'text', 'party', 'labels', 'embeddings'],
        num_rows: 19939
    })
})

In [None]:
self.train_data = self.train_dataloader()