# Input Example

In [1]:
!pip install datasets --quiet
!pip install transformers --quiet
!pip install torch --quiet
!pip install accelerate --quiet

In [2]:
from datasets import load_dataset

from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW

torch.set_printoptions(linewidth=1000000)

  from .autonotebook import tqdm as notebook_tqdm


### Load Data

In [3]:
dataset = load_dataset('csv', data_files={'train': '../data/cleaned_with_context.csv'})

Found cached dataset csv (/home/jovyan/.cache/huggingface/datasets/csv/default-db9a28ded843dd61/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)
100%|██████████| 1/1 [00:00<00:00, 483.60it/s]


In [4]:
dataset = dataset.filter(lambda example: (example['message_tree_id'] == 'b1673cb9-1e01-44fd-916c-befa0fae9923') & (example['target'] == 'A rainy rain storm!') )

                                                                        

### Prepare Data

In [5]:
tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-small', truncation_side='left', padding_side = 'left')

In [6]:
def tokenize_data(data):
    tokenized = tokenizer(data['context'], padding="max_length", truncation=True)
    target_toked = tokenizer(data['target'], padding="max_length", truncation=True)
    tokenized['labels'] = target_toked['input_ids']
    
    return tokenized

dataset = dataset.map(tokenize_data, batched=True)

                                                 

In [7]:
dataset.set_format(type="torch", columns=["input_ids", "labels"])

In [8]:
train_dataset = dataset['train']

### Data Loader

In [9]:
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=1)

### Setup Model

In [10]:
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small", device_map="auto")

### Example

In [11]:
batch = next(iter(train_dataloader))

In [12]:
print("Example input: ", tokenizer.batch_decode(batch['input_ids'])[0])

Example input:  <pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad

In [13]:
for (key, item) in batch.items():
    print(key, ":", str(item)[0:100], '...')

input_ids : tensor([[    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,     0,     0,     0,  ...
labels : tensor([[   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,   ...


In [14]:
model.eval()
with torch.no_grad():
    outputs = model(**batch)

In [15]:
print("Loss: ", outputs.loss)
print("Logits: ", outputs.logits)

Loss:  tensor(48.0392)
Logits:  tensor([[[-46.3854,  -3.5440,  -8.1632,  ..., -46.3289, -46.3740, -46.1255],
         [-46.3854,  -3.5440,  -8.1632,  ..., -46.3289, -46.3740, -46.1255],
         [-46.3854,  -3.5440,  -8.1632,  ..., -46.3289, -46.3740, -46.1255],
         ...,
         [-50.9342,   0.8869,  -6.5974,  ..., -50.7909, -50.9224, -50.5790],
         [-52.8611,   0.1231,  -9.2390,  ..., -52.8423, -52.9428, -52.7534],
         [-55.9899,   0.7806,  -7.6759,  ..., -55.9907, -56.0710, -55.9349]]])
