In [1]:
import math
import torch
from torch.utils.data import DataLoader
from transformers import AdamW
import argparse
import os
import tqdm
import inspect
import logging

import datasets

# from models.teacher import Teacher
# from models.configuration_teacher import TeacherConfig
# from data import CoTDataset, CoTDataCollator, extract_answer

from transformers import AutoTokenizer

import sys
sys.path.append('..')
from modeling_rmt.language_modeling import MemoryCell, RecurrentWrapper
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModelForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from torch.nn import CrossEntropyLoss


In [2]:
model = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = AutoTokenizer.from_pretrained('gpt2')

In [3]:
class RecurrentWrapperNoSegmentation(RecurrentWrapper):
    def forward(self, segments, labels, output_attentions=None, output_hidden_states=None):
        # segments = segments['segments']
        # for seg in segments:
        #     print(tokenizer.batch_decode(seg['input_ids'], skip_special_tokens=True))
        memory_state = None

        cell_outputs = []
        for seg_num, segment in enumerate(segments):
            cell_out, memory_state = self.memory_cell(input_ids=segment['input_ids'],
                                                      attention_mask=segment['attention_mask'],
                                                      memory_state=memory_state, output_hidden_states=True)
            cell_outputs.append(cell_out)
            self.manage_gradients(memory_state, seg_num)

        out = self.process_outputs(cell_outputs, segments,
                                   output_attentions=output_attentions,
                                   output_hidden_states=output_hidden_states)
        return out

    def generate(self, segments, **generate_kwargs):
        raise NotImplementedError("Generation not implemented for this wrapper.")
        memory_state = None
        for seg_num, segment in enumerate(segments[:-1]):
            cell_out, memory_state = self.memory_cell(**segment, memory_state=memory_state, output_hidden_states=True)

        final_segment = segments[-1]
        out = self.memory_cell.generate(**final_segment, memory_state=memory_state, **generate_kwargs)

        return out

    def process_outputs(self, cell_outputs, segments, **kwargs):
        out = CausalLMOutputWithCrossAttentions()
        proxy_out = {}
        for seg_num, segment in enumerate(segments):
            cell_out = cell_outputs[seg_num]

            full_logits = cell_out.logits

            labels = segment.get('labels')
            if labels is not None:
                shift_labels = labels[..., 1:].contiguous()
                shift_logits = full_logits[..., :-1, :].contiguous()
                flat_labels = shift_labels.view(-1)
                flat_logits = shift_logits.view(-1, shift_logits.size(-1))

                loss_fct = CrossEntropyLoss()
                labels_mask = segment.get('labels_mask')
                if labels_mask is not None:
                    shift_mask = labels_mask[..., :-1].contiguous()

                    flat_labels = flat_labels[shift_mask.view(-1)]
                    flat_logits = flat_logits[shift_mask.view(-1)]

                    if labels_mask.sum() == 0:
                        loss_value = 0
                    else:
                        loss_value = loss_fct(flat_logits, flat_labels)

                proxy_out[f'loss_{seg_num}'] = loss_value
            else:
                proxy_out[f'loss_{seg_num}'] = 0

            segment_keys = ['loss']
            if kwargs.get('output_attentions'):
                segment_keys.append('attentions')
            if kwargs.get('output_hidden_states'):
                segment_keys.append('hidden_states')

            for key, value in cell_out.items():
                if any([sk in key for sk in segment_keys]):
                    proxy_out[f'{key}_{seg_num}'] = value

        num_segments = len(segments)
        out['loss'] = sum([proxy_out[f'loss_{seg_num}'] for seg_num in range(num_segments)]) / num_segments
        out['logits'] = torch.cat([cell_out.logits for cell_out in cell_outputs], dim=1)
        # print(out.keys(), out.loss)

        return out

    def gradient_checkpointing_enable(self, *args, **kwargs):
        if hasattr(self.memory_cell.model, "gradient_checkpointing_enable"):
            return self.memory_cell.model.gradient_checkpointing_enable(*args, **kwargs)


In [4]:
memory_cell = MemoryCell(model, num_mem_tokens=16)
# recurrent_wrapper = RecurrentWrapper(memory_cell, segment_size=512)
rmt = RecurrentWrapperNoSegmentation(memory_cell, segment_size=512)

In [87]:
out = rmt(**collated)


In [90]:
out.loss

tensor(10.4265, grad_fn=<DivBackward0>)

In [92]:
self = rmt
output_attentions = False
output_hidden_states = True

In [97]:
memory_state = None

cell_outputs = []
for seg_num, segment in enumerate(segments):
    # cell_out, memory_state = self.memory_cell(**segment, memory_state=memory_state, output_hidden_states=True)
    cell_out, memory_state = self.memory_cell(input_ids=segment['input_ids'],
                                                attention_mask=segment['attention_mask'],
                                                memory_state=memory_state, output_hidden_states=True)
    cell_outputs.append(cell_out)
    self.manage_gradients(memory_state, seg_num)
    print(memory_state.mean(), memory_state.std())

tensor(0.4557, grad_fn=<MeanBackward0>) tensor(9.1104, grad_fn=<StdBackward0>)
tensor(0.4604, grad_fn=<MeanBackward0>) tensor(10.1487, grad_fn=<StdBackward0>)
tensor(0.4524, grad_fn=<MeanBackward0>) tensor(10.0450, grad_fn=<StdBackward0>)
tensor(0.3894, grad_fn=<MeanBackward0>) tensor(9.3270, grad_fn=<StdBackward0>)
tensor(0.3663, grad_fn=<MeanBackward0>) tensor(8.9285, grad_fn=<StdBackward0>)
tensor(0.3589, grad_fn=<MeanBackward0>) tensor(8.8887, grad_fn=<StdBackward0>)
tensor(0.3725, grad_fn=<MeanBackward0>) tensor(9.0271, grad_fn=<StdBackward0>)
tensor(0.3775, grad_fn=<MeanBackward0>) tensor(9.1252, grad_fn=<StdBackward0>)


In [103]:
out.loss.backward()

In [112]:
# rmt.memory_cell.memory.grad

In [227]:
seg_num = 1
cell_out = cell_outputs[seg_num]
segment = segments[seg_num]


# out = self.process_outputs(cell_out, labels=segment['labels'],
#                             labels_mask=segment['labels_mask'],
#                             output_attentions=output_attentions,
#                             output_hidden_states=output_hidden_states)

In [230]:
cell_out.logits.shape

torch.Size([10, 11, 50257])

In [245]:
kwargs = {
    'output_attentions': False,
    'output_hidden_states': True
}

In [270]:
import numpy as np

In [5]:
out = CausalLMOutputWithCrossAttentions()
for seg_num, segment in enumerate(segments):
    cell_out = cell_outputs[seg_num]

    full_logits = cell_out.logits
    full_hidden_states = cell_out.hidden_states

    labels = segment.get('labels')
    if labels is not None:
        shift_labels = labels[..., 1:].contiguous()
        shift_logits = full_logits[..., :-1, :].contiguous()
        flat_labels = shift_labels.view(-1)
        flat_logits = shift_logits.view(-1, shift_logits.size(-1))

        loss_fct = CrossEntropyLoss()
        labels_mask = segment.get('labels_mask')
        if labels_mask is not None:
            shift_mask = labels_mask[..., :-1].contiguous()

            flat_labels = flat_labels[shift_mask.view(-1)]
            flat_logits = flat_logits[shift_mask.view(-1)]

            if labels_mask.sum() == 0:
                loss_value = 0
            else:
                loss_value = loss_fct(flat_logits, flat_labels)

        out[f'loss_{seg_num}'] = loss_value
    else:
        out[f'loss_{seg_num}'] = 0

    out[f'logits_{seg_num}'] = full_logits
    segment_keys = ['loss', 'logits']
    if kwargs.get('output_attentions'):
        segment_keys.append('attentions')
    if kwargs.get('output_hidden_states'):
        segment_keys.append('hidden_states')

    for key, value in cell_out.items():
        if any([sk in key for sk in segment_keys]):
            out[f'{key}_{seg_num}'] = value

num_segments = len(segments)
out['loss'] = sum([out[f'loss_{seg_num}'] for seg_num in range(num_segments)]) / num_segments


NameError: name 'segments' is not defined

In [277]:
for k, v in out.items():
    if 'loss' in k:
        print(k, v)

loss_0 0
loss_1 tensor(12.6786, grad_fn=<NllLossBackward0>)
loss_2 tensor(14.0413, grad_fn=<NllLossBackward0>)
loss_3 tensor(13.6104, grad_fn=<NllLossBackward0>)
loss_4 tensor(14.0162, grad_fn=<NllLossBackward0>)
loss_5 tensor(11.6516, grad_fn=<NllLossBackward0>)
loss_6 tensor(16.1613, grad_fn=<NllLossBackward0>)
loss_7 tensor(12.0924, grad_fn=<NllLossBackward0>)
loss tensor(11.7815, grad_fn=<DivBackward0>)


In [248]:
out.loss

tensor(nan, grad_fn=<AddBackward0>)

In [None]:

out = self.process_outputs(cell_outputs, labels=labels,
                            labels_mask=labels_mask,
                            output_attentions=output_attentions,
                            output_hidden_states=output_hidden_states)

In [5]:
class Holder:
    def __init__(self):
        pass
args = Holder()
args.use_cot = True
args.num_mem_tokens = None

In [6]:

# id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
# think = ans = tokenizer.bos_token_id
# eos = tokenizer.eos_token_id

# from torch.nn.utils.rnn import pad_sequence
# def collate_fn(batch):
#     input_ids, input_ids_generate, labels, labels_mask, attention_mask = [], [], [], [], []
#     for sample in batch:
#         task, lab, cot = sample['task'], sample['labels'], sample['cot']
#         task_tokens = tokenizer.encode(task, add_special_tokens=False)
#         labels_tokens = tokenizer.encode(lab, add_special_tokens=False)
#         cot_tokens = tokenizer.encode(cot, add_special_tokens=False)

#         if args.use_cot:
#             full_input = task_tokens + [think] + cot_tokens + [ans] + labels_tokens + [eos]
#             gen_input = task_tokens + [think]
#         else:
#             full_input = task_tokens + [ans] + labels_tokens + [eos]
#             gen_input = task_tokens + [ans]
        
#         inp_ids = torch.tensor(full_input)
#         input_ids.append(inp_ids)
#         input_ids_generate.append(torch.tensor(gen_input))


#         lab = torch.tensor(full_input)
#         lab[:len(task_tokens)] = -100
#         labels.append(lab)

#         lab_mask = torch.ones_like(inp_ids)
#         lab_mask[:len(task_tokens)] = 0
#         labels_mask.append(lab_mask)
#         attention_mask.append(torch.ones_like(inp_ids))

#     input_ids = pad_sequence(input_ids, padding_value=id_pad_value, batch_first=True)
#     # input_ids_generate = pad_sequence(input_ids_generate, padding_value=id_pad_value, batch_first=True)
#     attention_mask = pad_sequence(attention_mask, padding_value=0, batch_first=True)
#     labels = pad_sequence(labels, padding_value=id_pad_value, batch_first=True)
#     labels_mask = pad_sequence(labels_mask, padding_value=0, batch_first=True)

#     collated = {'input_ids': input_ids,
#                 'input_ids_generate': input_ids_generate,
#                 'labels': labels,
#                 'attention_mask': attention_mask,
#                 }
#     if args.num_mem_tokens is not None:
#         # add labels mask only for RMT, ARMT
#         collated['labels_mask'] = labels_mask.bool()
#     return collated

In [7]:
# dataset_path = "/workspace-SR006.nfs2/Bulatov_A/rmt/data/implicit_chain_of_thought/4_by_4_mult"
# train_dataset = datasets.load_from_disk(os.path.join(dataset_path, "train"))
# valid_dataset = datasets.load_from_disk(os.path.join(dataset_path, "valid"))
# if os.path.exists(os.path.join(dataset_path, "test")):
#     test_dataset = datasets.load_from_disk(os.path.join(dataset_path, "test"))
# else:
#     test_dataset = datasets.load_from_disk(os.path.join(dataset_path, "valid"))

In [8]:
dataset_name = 'booydar/gsm8k'
# dataset_name = 'booydar/multiplication_4x4'
valid_dataset = datasets.load_dataset(dataset_name, split='valid')
train_dataset = datasets.load_dataset(dataset_name, split='train')

args.task_name = dataset_name

In [69]:
tokenizer.encode('////')

[9705]

In [15]:
def split_cot(text, by=">> <<"):
    if text.startswith('<<'):
        text = text[2:]
    if text.endswith('>>'):
        text = text[:-2]
    
    return text.split(by)

def make_segment(input_tokens, loss=False):
    input_ids = torch.tensor(input_tokens)
    attention_mask = torch.ones_like(input_ids)
    labels = torch.tensor(input_tokens) if loss else torch.tensor([-100] * len(input_ids))
    # labels = torch.tensor(input_tokens)
    labels_mask = torch.ones_like(input_ids) if loss else torch.zeros_like(input_ids)

    return {'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
            'labels_mask': labels_mask.bool()
            }

id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
think = tokenizer.encode('????')
bos = tokenizer.encode('////')
ans = tokenizer.encode('!!!!')
eos = [tokenizer.eos_token_id]
if 'gsm8k' in args.task_name:
    delim = ">> <<"
elif 'multiplication' in args.task_name:
    delim = ' + '
else:
    raise NotImplementedError(f"Unknown task name {args.task_name}")

def collate_fn(batch):
    # first, we segment each sample into task, cot steps and labels
    segments_batch = []
    for sample in batch:
        task, lab, cot = sample['task'], sample['labels'], sample['cot']
        task_tokens = tokenizer.encode(task, add_special_tokens=False)
        labels_tokens = tokenizer.encode(lab, add_special_tokens=False)
        if getattr(args, 'use_cot', False):
            cot_segments = split_cot(cot, by=delim)
        else:
            cot_segments = [cot]
        cot_segment_tokens = tokenizer.batch_encode_plus(cot_segments, add_special_tokens=False)['input_ids']

        segments = []
        segments.append(make_segment(bos + task_tokens + think, loss=False))
        for segment in cot_segment_tokens[:-1]:
            segments.append(make_segment(bos + segment + think, loss=True))
        segments.append(make_segment(bos + cot_segment_tokens[-1] + ans, loss=True))

        segments.append(make_segment(bos + labels_tokens + eos, loss=True))
        segments_batch.append(segments)

    # if some samples have less segments than others, we pad them with empty segments
    num_segments = max(len(segments) for segments in segments_batch)
    for segments in segments_batch:
        if len(segments) < num_segments:
            segments.extend([make_segment(eos, loss=False)] * (num_segments - len(segments)))

    # prepare segments for the whole batch
    batch_segments = []
    for i in range(num_segments):
        input_ids = [s[i]['input_ids'] for s in segments_batch]
        attention_mask = [s[i]['attention_mask'] for s in segments_batch]
        labels = [s[i]['labels'] for s in segments_batch]
        labels_mask = [s[i]['labels_mask'] for s in segments_batch]

        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=id_pad_value)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)
        labels_mask = pad_sequence(labels_mask, batch_first=True, padding_value=False)

        batch_segment = {'input_ids': input_ids,
                         'attention_mask': attention_mask,
                         'labels_mask': labels_mask,
                         'labels': labels
                         }
        batch_segments.append(batch_segment)
    full_labels = torch.cat([s['labels'] for s in batch_segments], dim=1)
    return {"segments": batch_segments, 'labels': full_labels}

In [19]:
args.use_cot = False

In [20]:
batch = [valid_dataset[i] for i in range(10)]
collated = collate_fn(batch)
segments = collated['segments']

In [21]:
len(segments)

3

In [22]:
for k, v in segments[0].items():
    print(k, v.shape)

input_ids torch.Size([10, 115])
attention_mask torch.Size([10, 115])
labels_mask torch.Size([10, 115])
labels torch.Size([10, 115])


In [23]:
out = rmt(**collated)
# out.keys()

In [24]:
collated = segments

In [25]:
seg = collated[0]
tokenizer.batch_decode(seg['input_ids'], skip_special_tokens=False)

['////John cuts his grass to 2 inches.  It grows .5 inches per month.  When it gets to 4 inches he cuts it back down to 2 inches.  It cost $100 to get his grass cut.  How much does he pay per year?????<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '////Hannah has

In [26]:
seg = collated[1]
tokenizer.batch_decode(seg['input_ids'], skip_special_tokens=False)

['////<<4-2=2>> <<2/.5=4>> <<12/4=3>> <<100*3=300>>!!!!<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '////<<1.5*2=3>> <<3+2.5=5.5>> <<1.5+3+5.5=10>>!!!!<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '////<<30/100*2000=600>> <<2000-600=1400>>!!!!<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoft

In [27]:
seg = collated[2]
tokenizer.batch_decode(seg['input_ids'], skip_special_tokens=False)

['////300<|endoftext|><|endoftext|>',
 '////10<|endoftext|><|endoftext|>',
 '////1400<|endoftext|>',
 '////15<|endoftext|><|endoftext|>',
 '////240<|endoftext|><|endoftext|>',
 '////20<|endoftext|><|endoftext|>',
 '////10<|endoftext|><|endoftext|>',
 '////2<|endoftext|><|endoftext|>',
 '////25<|endoftext|><|endoftext|>',
 '////25<|endoftext|><|endoftext|>']

In [62]:
seg = collated[3]
tokenizer.batch_decode(seg['input_ids'], skip_special_tokens=False)

['----12/4=3????<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '----1.5+3+5.5=10!!!!',
 '----1400<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '----15<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '----240<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '----50-30=20!!!!<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '----10<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '----3*4=12????<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '----2*2.50=5.00????<|endoftext|><|endoftext|>',
 '----25<|endoftext|><|end

In [63]:
seg = collated[4]
tokenizer.batch_decode(seg['input_ids'], skip_special_tokens=False)

['----100*3=300!!!!<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '----10<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext

In [64]:
seg = collated[5]
tokenizer.batch_decode(seg['input_ids'], skip_special_tokens=False)

['----300<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '----12+4+4=20????',
 '----25<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|>

In [65]:
seg = collated[6]
tokenizer.batch_decode(seg['input_ids'], skip_special_tokens=False)

['<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '----40/20=2!!!!',
 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>']

In [66]:
seg = collated[7]
tokenizer.batch_decode(seg['input_ids'], skip_special_tokens=False)

['<|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|>',
 '----2<|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|>',
 '<|endoftext|><|endoftext|><|endoftext|>']

In [None]:
def split_cot(text, by=">> <<"):
    return text.split(by)

In [150]:
# batch = [train_dataset[i] for i in range(10)]
batch = [valid_dataset[i] for i in range(10)]
collated = collate_fn(batch)

In [159]:
# collated

In [None]:
think = tokenizer.encode('????')
ans = tokenizer.encode('!!!!')
eos = [tokenizer.eos_token_id]

In [100]:
def make_segment(input_tokens, loss=False):
    input_ids = torch.tensor(input_tokens)
    attention_mask = torch.ones_like(input_ids)
    labels = torch.tensor(input_tokens) if loss else torch.tensor([-100] * len(input_ids))
    labels_mask = torch.ones_like(input_ids) if loss else torch.zeros_like(input_ids)

    return {'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
            'labels_mask': labels_mask.bool()
            }

In [140]:
segments_batch = []
for sample in batch:
    task, lab, cot = sample['task'], sample['labels'], sample['cot']
    task_tokens = tokenizer.encode(task, add_special_tokens=False)
    labels_tokens = tokenizer.encode(lab, add_special_tokens=False)
    cot_segments = split_cot(cot, by=' + ')
    cot_segment_tokens = tokenizer.batch_encode_plus(cot_segments, add_special_tokens=False)['input_ids']
    
    segments = []
    segments.append(make_segment(task_tokens + think, loss=False))
    for segment in cot_segment_tokens[:-1]:
        segments.append(make_segment(segment + think, loss=True))
    segments.append(make_segment(cot_segment_tokens[-1] + ans, loss=True))

    segments.append(make_segment(labels_tokens + eos, loss=True))
    segments_batch.append(segments)

In [141]:
num_segments = max(len(segments) for segments in segments_batch)
for segments in segments_batch:
    if len(segments) < num_segments:
        segments.extend([make_segment(eos, loss=False)] * (num_segments - len(segments)))

In [None]:
batch_segments = []
for i in range(num_segments):
    input_ids = [s[i]['input_ids'] for s in segments_batch]
    attention_mask = [s[i]['attention_mask'] for s in segments_batch]
    labels = [s[i]['labels'] for s in segments_batch]
    labels_mask = [s[i]['labels_mask'] for s in segments_batch]

    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=id_pad_value)
    attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
    labels = pad_sequence(labels, batch_first=True, padding_value=-100)
    labels_mask = pad_sequence(labels_mask, batch_first=True, padding_value=False)

    batch_segment = {'input_ids': input_ids,
                     'attention_mask': attention_mask,
                     'labels_mask': labels_mask,
                     'labels': labels
                    }
    batch_segments.append(batch_segment)

In [143]:
seg = batch_segments[0]
tokenizer.batch_decode(seg['input_ids'], skip_special_tokens=False)

['5 6 3 2 * 7 4 3 4????',
 '6 9 1 5 * 6 4 4 7????',
 '6 7 3 9 * 8 9 1 7????',
 '3 0 3 4 * 3 4 6 5????',
 '0 3 3 7 * 8 5 6 5????',
 '3 6 0 6 * 4 3 8 7????',
 '4 7 8 4 * 2 9 1 6????',
 '2 9 6 1 * 0 5 1 9????',
 '1 1 5 4 * 5 6 0 5????',
 '3 9 9 3 * 0 9 3 3????']

In [144]:
seg = batch_segments[1]
tokenizer.batch_decode(seg['input_ids'], skip_special_tokens=False)

['5 5 5 6 1????',
 '6 7 1 1 3????',
 '8 0 0 5 7????',
 '9 0 9 2 1????',
 '0 4 6 8 5????',
 '2 5 2 4 2????',
 '8 4 7 9 0????',
 '0 0 0 0 0????',
 '5 5 5 2 2????',
 '0 0 0 0 0????']

In [145]:
import torch
from torch.utils.data import DataLoader

def segment_collate_fn(batch, segment_size=2, pad_value=0):
    """Segments samples, then groups corresponding segments, handling variable segment counts."""
    
    # Step 1: Segment each sample
    segmented_samples = [
        [sample[i:i + segment_size] for i in range(0, len(sample), segment_size)]
        for sample in batch
    ]
    
    # Step 2: Find the max number of segments
    max_segments = max(len(segments) for segments in segmented_samples)
    
    # Step 3: Pad samples to have the same number of segments
    padded_segments = [
        segments + [[pad_value] * segment_size] * (max_segments - len(segments))
        for segments in segmented_samples
    ]
    
    # Step 4: Transpose to group corresponding segments
    grouped_segments = [[seg[i] for seg in padded_segments] for i in range(max_segments)]
    
    # Convert to tensors
    return [torch.tensor(segment) for segment in grouped_segments]

# Example dataset (list of lists with varying lengths)
dataset = [
    [1, 2, 3, 4],             # 2 segments
    [5, 6, 7, 8, 9, 10],      # 3 segments
    [11, 12, 13]             # 2 segments (last one short)
]

# Create DataLoader
data_loader = DataLoader(dataset, batch_size=3, collate_fn=lambda batch: segment_collate_fn(batch, segment_size=2))

# Fetch a batch
for segmented_batch in data_loader:
    print(segmented_batch)


ValueError: expected sequence of length 2 at dim 1 (got 1)

In [94]:
segments

[{'input_ids': tensor([  20,  718,  513,  362, 1635,  767,  604,  513,  604, 9805]),
  'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
  'labels': tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100]),
  'labels_mask': tensor([False, False, False, False, False, False, False, False, False, False])},
 {'input_ids': tensor([  20,  642,  642,  718,  352, 9805]),
  'attention_mask': tensor([1, 1, 1, 1, 1, 1]),
  'labels': tensor([  20,  642,  642,  718,  352, 9805]),
  'labels_mask': tensor([True, True, True, True, True, True])},
 {'input_ids': tensor([  15,  657,  718,  604,  860,  657,  357,  642,  642,  352,  352,  352,
           352, 1267, 9805]),
  'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
  'labels': tensor([  15,  657,  718,  604,  860,  657,  357,  642,  642,  352,  352,  352,
           352, 1267, 9805]),
  'labels_mask': tensor([True, True, True, True, True, True, True, True, True, True, True, True,
          True, True, Tru

In [None]:
if args.use_cot:
    segments = []
    segments.append({'input_ids': task_tokens + [think]})
    for segment in cot_segment_tokens[:-1]:
        segments.append({'input_ids': segment + [think]})
    segments.append({'input_ids': cot_segment_tokens[-1] + [ans]})
    
    segments.append({'input_ids': labels_tokens + [eos]})

    for segment in segments
    
else:
    full_input = task_tokens + [ans] + labels_tokens + [eos]
    gen_input = task_tokens + [ans]

In [25]:
input_ids, input_ids_generate, labels, labels_mask, attention_mask = [], [], [], [], []
for sample in batch:
    task, lab, cot = sample['task'], sample['labels'], sample['cot']
    task_tokens = tokenizer.encode(task, add_special_tokens=False)
    labels_tokens = tokenizer.encode(lab, add_special_tokens=False)
    cot_tokens = tokenizer.encode(cot, add_special_tokens=False)

    if args.use_cot:
        full_input = task_tokens + [think] + cot_tokens + [ans] + labels_tokens + [eos]
        gen_input = task_tokens + [think]
    else:
        full_input = task_tokens + [ans] + labels_tokens + [eos]
        gen_input = task_tokens + [ans]
    
    inp_ids = torch.tensor(full_input)
    input_ids.append(inp_ids)
    input_ids_generate.append(torch.tensor(gen_input))


    lab = torch.tensor(full_input)
    lab[:len(task_tokens)] = -100
    labels.append(lab)

    lab_mask = torch.ones_like(inp_ids)
    lab_mask[:len(task_tokens)] = 0
    labels_mask.append(lab_mask)
    attention_mask.append(torch.ones_like(inp_ids))

input_ids = pad_sequence(input_ids, padding_value=id_pad_value, batch_first=True)
# input_ids_generate = pad_sequence(input_ids_generate, padding_value=id_pad_value, batch_first=True)
attention_mask = pad_sequence(attention_mask, padding_value=0, batch_first=True)
labels = pad_sequence(labels, padding_value=id_pad_value, batch_first=True)
labels_mask = pad_sequence(labels_mask, padding_value=0, batch_first=True)

collated = {'input_ids': input_ids,
            'input_ids_generate': input_ids_generate,
            'labels': labels,
            'attention_mask': attention_mask,
            }
if args.num_mem_tokens is not None:
    # add labels mask only for RMT, ARMT
    collated['labels_mask'] = labels_mask.bool()

In [26]:
collated

{'input_ids': tensor([[   20,   718,   513,   362,  1635,   767,   604,   513,   604, 50256,
             20,   642,   642,   718,   352,  1343,   657,   657,   718,   604,
            860,   657,   357,   642,   642,   352,   352,   352,   352,  1267,
           1343,   657,   657,   642,   860,   657,   767,   657,   357,   642,
            642,   718,   657,   362,   807,   657,  1267,  1343,   657,   657,
            657,   657,   718,   604,   860,   657, 50256,    20,   642,   718,
            657,   807,   362,   657,   352, 50256],
         [   21,   860,   352,   642,  1635,   718,   604,   604,   767, 50256,
             21,   767,   352,   352,   513,  1343,   657,   604,   807,   767,
            657,   362,   357,   718,   352,   657,   860,   513,   362,  1267,
           1343,   657,   657,   604,   807,   767,   657,   362,   357,   718,
            352,   604,   767,   352,   513,   362,  1267,  1343,   657,   657,
            657,   362,   767,   513,   718,   513, 50

In [None]:
# print out first all features of sample

for k, v in valid_dataset[0].items():
    tokens = [i if i > 0 else 0 for i in v]
    print(k, tokenizer.decode(tokens))
    print()

In [None]:
id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
# id_pad_value = -100


if args.use_cot in (False, None):
    inputs_key = 'examples_nocot'
    labels_key = 'labels_nocot'
else:
    inputs_key = 'examples_all'
    labels_key = 'labels_all'
    
split_cot_by = ">> <<"
# def collate_fn(batch):
input_ids = [torch.tensor(b[inputs_key]) for b in batch]
labels = [torch.tensor(b[labels_key]) for b in batch]
attention_mask = [torch.ones_like(b, dtype=int) for b in input_ids]
# labels_mask defines which input_ids participate in loss calculation
labels_mask = [torch.sign(torch.tensor(b[labels_key])) for b in batch]


input_ids = pad_sequence(input_ids, padding_value=id_pad_value, batch_first=True)
labels = pad_sequence(labels, padding_value=id_pad_value, batch_first=True)
attention_mask = pad_sequence(attention_mask, padding_value=0, batch_first=True)
labels_mask = pad_sequence(labels_mask, padding_value=0, batch_first=True)

collated = {'input_ids': input_ids,
            'labels': labels, 
            'attention_mask': attention_mask,
            }
if args.num_mem_tokens is not None:
    # add labels mask only for RMT, ARMT
    collated['labels_mask'] = labels_mask.bool()
# return collated