In [1]:
!pip install datasets --quiet
!pip install transformers --quiet
!pip install torch --quiet

In [1]:
from datasets import load_dataset

from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW

torch.set_printoptions(linewidth=1000000)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = load_dataset('csv', data_files={'train': 'data/cleaned_with_context.csv'})

Found cached dataset csv (/home/jovyan/.cache/huggingface/datasets/csv/default-768f459642bab55c/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)
100%|██████████| 1/1 [00:00<00:00, 458.19it/s]


In [3]:
tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-small', truncation_side='left')

In [4]:
def tokenize_data(data):
    tokenized = tokenizer(data['context'], padding="max_length", truncation=True)
    target_toked = tokenizer(data['target'], padding="max_length", truncation=True)
    tokenized['labels'] = target_toked['input_ids']
    
    return tokenized

dataset = dataset.map(tokenize_data, batched=True)

                                                                   

In [5]:
print(dataset['train'])

Dataset({
    features: ['target', 'message_tree_id', 'humor', 'context', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 51802
})


In [7]:
dataset.set_format(type="torch", columns=["input_ids", "labels"])

In [8]:
train_dataset = dataset['train']

In [9]:
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=1)

In [10]:
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small", device_map="auto")

In [11]:
batch = next(iter(train_dataloader))

In [12]:
for (key, item) in batch.items():
    print(key, ":", str(item)[0:100], '...')

input_ids : tensor([[2570,   32,  259,    9,    7,    1,    0,    0,    0,    0,    0,    0,    0,    0,    0,   ...
labels : tensor([[ 4159,    63,   923,    55,     3,     2,  8532,     3,  7195,     3, 28061,  4353,    15,  ...


In [13]:
model.eval()
with torch.no_grad():
    outputs = model(**batch)

In [14]:
print("Loss: ", outputs.loss)
print("Logits: ", outputs.logits)

Loss:  tensor(55.5240)
Logits:  tensor([[[-51.4815,  -4.7137,  -8.6466,  ..., -51.4649, -51.4606, -51.2127],
         [-41.0469,  -3.5044,   0.8753,  ..., -40.9300, -41.0136, -40.8807],
         [-40.6562,  -1.2704,  -3.4167,  ..., -40.5942, -40.7589, -40.3900],
         ...,
         [-59.3815,  -8.1868, -10.9114,  ..., -59.4425, -59.3974, -59.3567],
         [-59.3830,  -8.1863, -10.9109,  ..., -59.4440, -59.3988, -59.3582],
         [-59.3845,  -8.1858, -10.9103,  ..., -59.4455, -59.4003, -59.3596]]])
