- Dataset? Training data?
- Bert or torch from scratch?
- which training methods exactly?
- paper recommendations?
- keywords fo googling?

## Dataset ideas

- https://huggingface.co/datasets/fancyzhx/yelp_polarity
- https://huggingface.co/datasets/knowledgator/Scientific-text-classification
- https://huggingface.co/datasets/SetFit/mnli

In [None]:
from transformers import BertForSequenceClassification, BertTokenizer
import torch

model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

text = "Sample text for training."
label = 1  # Assuming positive sentiment

inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
outputs = model(**inputs, labels=torch.tensor([label]))

loss_func = outputs.loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
loss_func.backward()
optimizer.step()

  from .autonotebook import tqdm as notebook_tqdm


KeyboardInterrupt: 

In [1]:
# pip install -q transformers datasets accelerate torch==2.* sentencepiece
import os, math, random, torch
import numpy as np
from dataclasses import dataclass
from typing import Dict, List, Optional
from datasets import load_dataset, Dataset
from transformers import BertTokenizer, BertForMaskedLM, PreTrainedTokenizerBase, SequenceFeatureExtractor, DataCollatorWithPadding
from torch.utils.data import default_collate


  from .autonotebook import tqdm as notebook_tqdm


# Dataset

In [2]:
raw_dataset = load_dataset("kaist-ai/CoT-Collection", trust_remote_code=True, split='train[:1000]')
raw_dataset

Dataset({
    features: ['source', 'target', 'rationale', 'task', 'type'],
    num_rows: 1000
})

In [3]:
raw_dataset[0]

{'source': 'Article: Phytochemistry is a branch of plant biochemistry primarily concerned with the chemical substances produced by plants during secondary metabolism. Some of these compounds are toxins such as the alkaloid coniine from hemlock. Others, such as the essential oils peppermint oil and lemon oil are useful for their aroma, as flavourings and spices (e.g., capsaicin), and in medicine as pharmaceuticals as in opium from opium poppies. Many medicinal and recreational drugs, such as tetrahydrocannabinol (active ingredient in cannabis), caffeine, morphine and nicotine come directly from plants. Others are simple derivatives of botanical natural products. For example, the pain killer aspirin is the acetyl ester of salicylic acid, originally isolated from the bark of willow trees, and a wide range of opiate painkillers like heroin are obtained by chemical modification of morphine obtained from the opium poppy. Popular stimulants come from plants, such as caffeine from coffee, tea 

# CoT training

In [4]:
ckpt = "bert-base-uncased"     # swap for domain/multilingual BERT as needed
tok: BertTokenizer = BertTokenizer.from_pretrained(ckpt)
model = BertForMaskedLM.from_pretrained(ckpt)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
# Special tokens
print("cls_token:", tok.cls_token_id)
print("sep_token:", tok.sep_token_id)
print("mask_token:", tok.mask_token_id)

cls_token: 101
sep_token: 102
mask_token: 103


In [50]:
text = "Some sample text for training and for testing"
input = tok(text, return_tensors='pt', add_special_tokens=False)
print(tok.batch_decode(input['input_ids']))
input['attention_mask'][0][-3] = 0
input['input_ids'][0][-3] = tok.mask_token_id
print('input: \t', input['input_ids'], tok.batch_decode(input['input_ids']))
logits = model.forward(**input).logits
output = torch.argmax(logits, dim=2)
print('output: ', output, tok.batch_decode(output))


['some sample text for training and for testing']
input: 	 tensor([[2070, 7099, 3793, 2005, 2731,  103, 2005, 5604]]) ['some sample text for training [MASK] for testing']
output:  tensor([[1998, 2440, 2817, 2005, 5604, 1998, 2005, 5604]]) ['and full study for testing and for testing']


  return forward_call(*args, **kwargs)


In [7]:
@dataclass
class GeneratorTrainer:
    model: BertForMaskedLM
    ds: Dataset
    tokenizer: BertTokenizer
    teacher_forcing_percentage: float = 0.8

    def __post_init__(self):
        self.loss_func = torch.nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
        self.cls_token_tensor = torch.tensor([[self.tokenizer.cls_token_id]])


    def tokenize(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
        questions = self.tokenizer(features["source"], add_special_tokens=False, return_tensors='pt', padding=True)
        answers = self.tokenizer(features["rationale"], add_special_tokens=False, return_tensors='pt', padding=True)
        return questions['input_ids'], answers['input_ids']


    def generate(self, questions: torch.Tensor, answers: torch.Tensor|None=None, max_length=200):
        generated_answers = []
        logits = []
        batch_size = questions.shape[0]
        cls_token_tensor = self.tokenizer.cls_token_id * torch.ones((batch_size, 1))
        mask_token_tensor = self.tokenizer.mask_token_id * torch.ones((batch_size, 1))
        sep_token_tensor = self.tokenizer.sep_token_id * torch.ones((batch_size, 1))
        for i in range(answers.shape[1] if answers is not None else max_length):
            use_teacher_forcing = (answers is not None) and (i < len(answers)) and (random.random() < self.teacher_forcing_percentage)
            prefix = answers[:, :i] if use_teacher_forcing else (torch.stack(generated_answers, axis=1) if len(generated_answers)>0 else torch.zeros((batch_size, 0)))

            #print(*(x.size() for x in (cls_token_tensor, questions, prefix, mask_token_tensor, sep_token_tensor)))
            inp = torch.concat((cls_token_tensor, questions, prefix, mask_token_tensor, sep_token_tensor), dim=1)
            mask_pos = -2
            attention_mask = torch.ones(inp.shape)
            attention_mask[:,mask_pos] = 0
            token_type_ids = torch.concat((torch.zeros((batch_size, questions.shape[1]+1)), torch.ones((batch_size, prefix.shape[1]+2))), dim=1)

            #print(*(x.size() for x in (inp.int(), token_type_ids.int(), attention_mask.int())))
            generated = self.model.forward(input_ids=inp.int(), token_type_ids=token_type_ids.int(), attention_mask=attention_mask.int())

            generated_answers.append(torch.argmax(generated.logits[:,mask_pos], dim=-1))
            logits.append(generated.logits[:,mask_pos])
        #print(generated_answers, logits)
        #print(torch.stack(generated_answers).shape)
        print(len(generated_answers))
        print(torch.stack(generated_answers).size(), torch.stack(logits).size())
        return torch.stack(generated_answers, dim=1), torch.stack(logits, dim=1)

    def train(self, episodes, batch_size=16):
        for episode in range(episodes):
            i_samples = np.random.randint(0, len(self.ds), batch_size)
            samples = self.ds.select(i_samples)
            answers, questions = self.tokenize(samples)
            generated_answers, logits = self.generate(questions, answers)

            loss = self.loss_func(logits, answers)
            loss.backward()
            self.optimizer.step()

            print(loss, end='\r')

In [None]:
generator_trainer = GeneratorTrainer(model, raw_dataset, tok)
generator_trainer.train(episodes=10, batch_size=2)

  return forward_call(*args, **kwargs)


In [1]:
prompt = "Question: What is 7 + 5?\nReasoning:"
print(generate_with_mlm(prompt, max_new_tokens=40, stop_strings=["Answer:"]))

NameError: name 'generate_with_mlm' is not defined