## Reasoning Scaling Law Example Code for Training a Small Model

In [1]:
### import relevant packages

import networkx as nx
import numpy as np
import random
from collections import defaultdict
import os, json
import copy
import torch
import transformers
import matplotlib.pyplot as plt
import itertools
from transformers import Trainer, TrainingArguments
from torch.utils.data import IterableDataset, get_worker_info, Dataset
from typing import Dict, Optional, Sequence
from sklearn.utils import shuffle
from dataclasses import dataclass
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers import LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.cache_utils import Cache
from typing import List, Optional, Tuple, Union

Helper founctions:

In [2]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

def add_edge(G, h, t, r):
    num_edges = 0
    if G.has_edge(h, t):
        if r not in G[h][t]['id']:
            G[h][t]['id'].append(r)
            num_edges += 1
        else:
            print('edge already exists')
    else:
        G.add_edge(h, t, id=[r])
        num_edges += 1
    print('add edge: ', (h, r, t), 'num edges: ', num_edges)
    return num_edges


def generate_rules(relations, num_rules, L_min, L_max, weighted=False, temperature=0.25):
    # Generate K acyclic logic rules with varying lengths
    dependency_graph = defaultdict(set)
    rules = []
    weights = []
    if weighted:
        for l in range(L_min, L_max + 1):
            weights.append(np.exp(-temperature*l))
        probs = np.array([w / sum(weights) for w in weights])
    else:
        weights = [1] * (L_max - L_min + 1)

    def has_cycle(start, visited, stack):
        """Detects if adding a new dependency introduces a cycle."""
        if start not in visited:
            visited.add(start)
            stack.add(start)
            print('visited: ', visited)
            print('stack: ', stack)
            for neighbor in dependency_graph[start]:
                if neighbor in stack:
                    return True
                elif has_cycle(neighbor, visited, stack):
                    return True
        if start in stack:
            stack.remove(start)
        return False

    for _ in range(num_rules):
        while True:
            if weighted:
                length = random.choices(range(L_min, L_max + 1), weights=weights)[0]
            else:
                length = random.randint(L_min, L_max)
            rule_relations = random.choices(relations, k = length + 1) # the first element is the implied relation
            valid_rule = True
            for i in range(1, len(rule_relations)):
                dependency_graph[rule_relations[0]].add(rule_relations[i])

                # Check for cycles
                if has_cycle(rule_relations[i], set(), set()):
                    valid_rule = False
                    for j in range(1, i + 1):
                        dependency_graph[rule_relations[0]].remove(rule_relations[j])
                    break

            if valid_rule:
                rules.append(tuple(rule_relations))
                break

    print('rules: ', rules)
    return rules

def get_node_types(rules, max_num_relations_per_node=3):
    # map node types to out relations
    node_types = {}
    # map out relations to node types
    r2node_types = defaultdict(list)
    for rule in rules:
        for i in range(len(rule)):
            node_type = len(node_types)
            if i == 0:
                node_types[node_type] = [rule[i], rule[1]]
                r2node_types[rule[i]].append(node_type)
                r2node_types[rule[1]].append(node_type)
            elif i == len(rule) - 1:
                node_types[node_type] = ['-' + rule[i], '-' + rule[0]]
                r2node_types['-' + rule[i]].append(node_type)
                r2node_types['-' + rule[0]].append(node_type)
            else:
                node_types[node_type] = ['-' + rule[i], rule[i+1]]
                r2node_types['-' + rule[i]].append(node_type)
                r2node_types[rule[i+1]].append(node_type)

    print(node_types)
    print(r2node_types)

    for num_rs in range(2, max_num_relations_per_node):
        possible_new_node_types = []
        for r in r2node_types:
            alt_rs = []
            for node_type in r2node_types[r]:
                for _r in node_types[node_type]:
                    if _r != r:
                        alt_rs.append(_r)
            alt_rs = list(set(alt_rs))
            for node_type in r2node_types[r]:
                if len(node_types[node_type]) == num_rs:
                    for _r in alt_rs:
                        if _r not in node_types[node_type]:
                            possible_new_node_types.append(tuple(sorted([_r] + list(node_types[node_type]))))
            print(possible_new_node_types)
            possible_new_node_types += list(set(possible_new_node_types))
        possible_new_node_types = list(set(possible_new_node_types))
        print(possible_new_node_types)

        for rs in possible_new_node_types:
            new_node_type = len(node_types)
            node_types[new_node_type] = list(rs)
            for _r in rs:
                r2node_types[_r].append(new_node_type)

    return node_types

def get_adj_out_relations(rules):
    adj = defaultdict(list)
    for rule in rules:
        for i in range(len(rule)):
            if i == 0:
                adj[rule[i]].append(rule[1])
                adj[rule[1]].append(rule[i])
            elif i == len(rule) - 1:
                adj['-' + rule[i]].append('-' + rule[0])
                adj['-' + rule[0]].append('-' + rule[i])
            else:
                adj['-' + rule[i]].append(rule[i+1])
                adj[rule[i+1]].append('-' + rule[i])
    return adj

Synthetic graph generation:

In [3]:
def latent_rule_graph(num_rules=50, L_min=2, L_max=4, n=10000, m=10, n_r=200,
                      num_test=1000, num_train=150000, check_frequency=100,
                      power_law=False, initial_graph=None,
                      length_weighted=False, mcmc=0.2, temperature=0.25,
                      deductible_ratio=0.5):

    relations = ['P' + str(i) for i in range(n_r)]
    all_rules = generate_rules(relations, max(n_r//L_min, num_rules), L_min, L_max)
    r2rules = {}
    for rule in all_rules:
        if rule[0] not in r2rules:
            r2rules[rule[0]] = []
        r2rules[rule[0]].append(rule[1:])
    num_triples = 0
    repeated_entities = defaultdict(list) # map in relation to entities
    child_relations = []
    for rule in all_rules:
        child_relations += rule[1:]
    child_relations = list(set(child_relations))
    child_relations += ['-' + r for r in child_relations]
    deductible_rules = random.sample(all_rules, num_rules)
    if length_weighted:
        weights = [int(100*np.exp(-temperature*len(rule))) for rule in all_rules]
    else:
        weights = [1 for _ in all_rules]
    repeated_rules = []
    for rule, weight in zip(all_rules, weights):
        for _ in range(weight):
            repeated_rules.append(rule)
    random.shuffle(repeated_rules)
    adj = get_adj_out_relations(repeated_rules)
    all_deductibles = {}

    if initial_graph is None:
        # Default initial graph
        G = nx.DiGraph()
        node_id = 0
        min_repeated_entities = 0
        while min_repeated_entities < m:
            for rule in all_rules:
                source = 'Q' + str(node_id)
                node_id += 1
                h = source
                for r in rule[1:]:
                    t = 'Q' + str(node_id)
                    node_id += 1
                    num_triples += add_edge(G, h, t, r)
                    repeated_entities[r].append(t)
                    repeated_entities['-' + r].append(h)
                    h = t
                num_triples += add_edge(G, source, t, rule[0])
                repeated_entities[rule[0]].append(t)
                repeated_entities['-' + rule[0]].append(source)

            min_repeated_entities = min([len(set(repeated_entities[r])) for r in child_relations])
    else:
        if len(initial_graph) < m or len(initial_graph) > n:
            raise nx.NetworkXError(
                f"Initial graph needs between m={m} and n={n} nodes"
            )
        G = initial_graph.copy()
        node_id = len(G)

    if not power_law:
        repeated_entities = {r: list(set(repeated_entities[r])) for r in repeated_entities}

    # adding nodes
    while node_id < n:
        source = 'Q' + str(node_id)
        node_id += 1
        possible_relations = [_r for _r in adj if _r in child_relations]
        if len(possible_relations) == 0:
            print('no adj relations')
            break
        print('add child edge')
        chosen_edges = []
        stop = False
        for _ in range(m):
            it = 0
            while (r, t) in chosen_edges:
                r = random.choice(possible_relations)
                t = random.choice(repeated_entities[r])
                it += 1
                if it > 100:
                    print('failed to find edge')
                    stop = True
                    break
            if stop or len(possible_relations) == 0:
                break

            possible_relations = [_r for _r in adj[r] if _r in child_relations]
            chosen_edges.append((r, t))
            if r[0] == '-':
                num_triples += add_edge(G, t, source, r[1:])
                repeated_entities[r[1:]].append(source)
            else:
                num_triples += add_edge(G, source, t, r)
                repeated_entities['-' + r].append(source)
            repeated_entities[r].append(t)
            if len(possible_relations) == 0:
                print('no adj relations')
                break

        if not power_law:
            repeated_entities = {r: list(set(repeated_entities[r])) for r in repeated_entities}

        if node_id % check_frequency == 0 or node_id == n-1:
            # add deductibles
            all_nodes = list(G.nodes)
            random.shuffle(all_nodes)
            for h in all_nodes:
                for rule in deductible_rules:
                    head_list = [h]
                    r = rule[0]

                    for _r in rule[1:]:
                        next_head_list = []
                        for e_h in head_list:
                            if e_h not in G.nodes:
                                continue
                            for e_t in G[e_h]:
                                if _r in G[e_h][e_t]['id']:
                                    if random.random() < mcmc:
                                        next_head_list.append(e_t)
                        head_list = next_head_list

                    for t in head_list:
                        if (h, r, t) not in all_deductibles:
                            all_deductibles[(h, r, t)] = [rule]
                        elif rule not in all_deductibles[(h, r, t)]:
                            all_deductibles[(h, r, t)].append(rule)
                        if not G.has_edge(h, t) or r not in G[h][t]['id']:
                            print('add deductible edge')
                            add_edge(G, h, t, r)
                            num_triples += 1
                            repeated_entities[r].append(t)
                            repeated_entities['-' + r].append(h)

    atomic_triples = []
    deductible_triples = []
    for h, t in G.edges:
        for r in G[h][t]['id']:
            if (h, r, t) not in all_deductibles:
                atomic_triples.append((h, r, t))
            else:
                deductible_triples.append((h, r, t))
    random.shuffle(atomic_triples)
    random.shuffle(deductible_triples)
    assert len(atomic_triples) >= int(num_train * (1-deductible_ratio))
    assert len(deductible_triples) >= int(num_train * deductible_ratio) + 2 * num_test

    remove_triples = []
    train_atomic_triples = atomic_triples[:int(num_train * (1-deductible_ratio))]
    remove_triples += atomic_triples[int(num_train * (1-deductible_ratio)):]
    train_deductible_triples = deductible_triples[:int(num_train * deductible_ratio)]
    remove_triples += deductible_triples[int(num_train * deductible_ratio):]

    for h, r, t in remove_triples:
        _t = t
        rs = G[h][_t]['id']
        if r in rs:
            if len(rs) == 1:
                G.remove_edge(h, _t)
            else:
                G[h][_t]['id'].remove(r)

    train_triples = train_deductible_triples + train_atomic_triples
    random.shuffle(train_triples)
    print("num train triples: ", len(train_triples))

    r2rule = {}
    for rule in deductible_rules:
        if rule[0] in r2rule:
            r2rule[rule[0]].append(rule[1:])
        else:
            r2rule[rule[0]] = [rule[1:]]

    def check_deductible(triple):
        h, r, t = triple
        alt_ts = []
        for rule in r2rule[r]:
            head_list = [h]
            for _r in rule:
                next_head_list = []
                for e_h in head_list:
                    for e_t in G[e_h]:
                        if _r in G[e_h][e_t]['id']:
                            next_head_list.append(e_t)
                head_list = next_head_list
            alt_ts += head_list
        if t in alt_ts:
            return True
        return False

    id_test_triples = []
    for i in range(int(num_train * deductible_ratio), len(deductible_triples)):
        if check_deductible(deductible_triples[i]):
            id_test_triples.append(deductible_triples[i])
        if len(id_test_triples) == num_test:
            break

    id_test_rules = [all_deductibles[triple] for triple in id_test_triples]
    print("num id test triples: ", len(id_test_triples))

    rule2triples = defaultdict(list)
    for triple in deductible_triples[i+1:]:
        for rule in all_deductibles[triple]:
            rule2triples[rule].append(triple)

    # uniformly sample testing triples from each rule
    uniform_test_triples = []
    for rule in rule2triples:
        triples = []
        for triple in rule2triples[rule]:
            if check_deductible(triple):
                triples.append(triple)

        if len(triples) > num_test//len(rule2triples):
            uniform_test_triples += random.sample(triples, num_test//len(rule2triples))
        else:
            uniform_test_triples += triples

    random.shuffle(uniform_test_triples)
    uniform_test_rules = [all_deductibles[triple] for triple in uniform_test_triples]
    print("num uniform test triples: ", len(uniform_test_triples))

    return G, deductible_rules, train_triples, id_test_triples, id_test_rules, uniform_test_triples, uniform_test_rules

Data class for synthetic graph:

In [4]:
class LatentRuleGraph:
    def __init__(self,
                 n=1000, n_r=40, m=5, n_rules=30, n_triples=10000,
                 num_test=1000, L_min=2, L_max=4, power_law=False,
                 length_weighted=False, mcmc=1.0,
                 temperature=0.25, deductible_ratio=0.5, seed=42):
        self.n = n
        self.n_r = n_r
        self.n_triples = n_triples
        self.n_rules = n_rules
        self.num_test = num_test
        self.L_min = L_min
        self.L_max = L_max
        self.power_law = power_law
        self.m = m
        self.length_weighted = length_weighted
        self.mcmc = mcmc
        self.temperature = temperature
        self.deductible_ratio = deductible_ratio
        self.seed = seed
        random.seed(seed)
        self.G = nx.DiGraph()
        self.load_data()
        self.all_es = list(self.G.nodes)
        self.all_rs = set()
        for h, t, r_dict in self.G.edges(data=True):
            for r in r_dict['id']:
                self.all_rs.add(r)
        self.triple_complet_file = None

    def load_data(self):
        self.triples = []
        self.id_test_triples = []
        self.uniform_test_triples = []
        self.id_alt_ts = []
        self.uniform_alt_ts = []
        self.rules = []
        self.id_test_rules = []
        self.uniform_test_rules = []


        self.G, self.rules, self.triples, \
        self.id_test_triples, self.id_test_rules, \
        self.uniform_test_triples, self.uniform_test_rules = latent_rule_graph(
            num_rules=self.n_rules, L_min=self.L_min, L_max=self.L_max,
            n=self.n, n_r=self.n_r, m=self.m,
            num_test=self.num_test, num_train=self.n_triples,
            power_law=self.power_law,
            length_weighted=self.length_weighted, mcmc=self.mcmc,
            deductible_ratio=self.deductible_ratio, temperature=self.temperature)

        r2rule = {}
        for rule in self.rules:
            if rule[0] in r2rule:
                r2rule[rule[0]].append(rule[1:])
            else:
                r2rule[rule[0]] = [rule[1:]]

        def get_alt_ts(h, r, t):
            alt_ts = []
            for rule in r2rule[r]:
                head_list = [h]
                for _r in rule:
                    next_head_list = []
                    for e_h in head_list:
                        for e_t in self.G[e_h]:
                            if _r in self.G[e_h][e_t]['id']:
                                next_head_list.append(e_t)
                    head_list = next_head_list
                alt_ts += head_list
            return alt_ts

        for h, r, t in self.id_test_triples:
            alt_ts = get_alt_ts(h, r, t)
            self.id_alt_ts.append(alt_ts)

        for h, r, t in self.uniform_test_triples:
            alt_ts = get_alt_ts(h, r, t)
            self.uniform_alt_ts.append(alt_ts)

        self.mem_triples = random.sample(self.triples, k=self.num_test)

Create a new synthetic graph:

In [5]:
graph = LatentRuleGraph(
        n=2000,
        n_r=50,
        n_triples=10000,
        n_rules=20,
        L_min=2,
        L_max=4,
        power_law=True,
        deductible_ratio=0.5,
        length_weighted=False,
        m=6,
        num_test=1000,
        temperature=0.25,
        mcmc=1.0)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
add edge:  ('Q1681', 'P38', 'Q1821') num edges:  1
add deductible edge
add edge:  ('Q1681', 'P38', 'Q1884') num edges:  1
add deductible edge
add edge:  ('Q1681', 'P38', 'Q1885') num edges:  1
add deductible edge
add edge:  ('Q1304', 'P43', 'Q1812') num edges:  1
add deductible edge
add edge:  ('Q1304', 'P43', 'Q1819') num edges:  1
add deductible edge
add edge:  ('Q1209', 'P22', 'Q663') num edges:  1
add deductible edge
add edge:  ('Q1209', 'P22', 'Q521') num edges:  1
add deductible edge
add edge:  ('Q1209', 'P22', 'Q412') num edges:  1
add deductible edge
add edge:  ('Q1209', 'P22', 'Q1159') num edges:  1
add deductible edge
add edge:  ('Q1209', 'P22', 'Q463') num edges:  1
add deductible edge
add edge:  ('Q1209', 'P22', 'Q830') num edges:  1
add deductible edge
add edge:  ('Q1209', 'P22', 'Q1307') num edges:  1
add deductible edge
add edge:  ('Q1209', 'P22', 'Q1572') num edges:  1
add deductible edge
add edge:  ('Q120

Training data class for synthetic graph:

In [6]:
class TrainDataset(IterableDataset):
    """
    Iterable dataset that returns constant length chunks of tokens from stream of text files.
        Args:
            tokenizer (Tokenizer): The processor used for proccessing the data.
            dataset (dataset.Dataset): Dataset with text files.
            infinite (bool): If True the iterator is reset after dataset reaches end else stops.
            seq_length (int): Length of token sequences to return.
            num_of_sequences (int): Number of token sequences to keep in buffer.
            chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
            tokenized (bool): If true we use a pretokenized dataset.
    """

    def __init__(
        self,
        graph, # generated graph
        tokenizer,
        seq_length=256,
        num_of_sequences=1024,
        chars_per_token=3.6,
        seed=42,
    ):
        super(TrainDataset, self).__init__()

        self.tokenizer = tokenizer
        self.seq_length = seq_length
        self.epoch = 0
        self.current_size = 0
        self.num_buffer_sequences = num_of_sequences
        self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
        self.seed = seed
        self.data = graph

        print("max buffer size: ", self.max_buffer_size)

    def set_epoch(self, worker_id):
        set_seed(self.seed + self.epoch + worker_id) # int(time.time())

    def triple2str(self, triple):
        if type(triple[0]) == int or type(triple[1]) == int or type(triple[2]) == int:
            return f'Q{triple[0]} P{triple[1]} Q{triple[2]}'
        else:
            return ' '.join(list(triple))

    def iter_fun(self, worker_id=0):
        num_sents = len(self.data.triples)
        while True:
            i = random.randint(0, num_sents-1)
            triple = self.data.triples[i]
            text = self.triple2str(triple) + '\n'
            if text is None:
                print("cannot translate ", triple, " into text.")
                continue
            yield text

    def __len__(self):
        return len(self.data.triples)

    def __iter__(self):
        more_examples = True
        try:
            worker_info = get_worker_info()
            print(worker_info)
            worker_id = worker_info.id
        except:
            worker_id = 0
        self.set_epoch(worker_id)
        iterator = self.iter_fun(worker_id=worker_id)
        print("worker id: ", )

        while more_examples:
            buffer, buffer_len = [], 0
            while True:
                if buffer_len >= self.max_buffer_size:
                    print("data buffer full")
                    break
                try:
                    buffer.append(next(iterator))
                    buffer_len += len(buffer[-1])
                except StopIteration:
                    self.epoch += 1
                    self.set_epoch(worker_id)
                    iterator = self.iter_fun()
                    print(f"Dataset epoch: {self.epoch}")
            # print(buffer[:3])

            input_lens = []
            random.shuffle(buffer)
            tokenized_inputs = self.tokenizer(buffer,
                                padding=False,
                                max_length=self.seq_length,
                                truncation=True)["input_ids"]
            for tokenized_input in tokenized_inputs:
                input_ids = tokenized_input + [self.tokenizer.eos_token_id]
                input_lens.append(len(input_ids))
                self.current_size += 1
                yield dict(input_ids=torch.tensor(input_ids), labels=torch.tensor(input_ids))
            print("average example length: ", np.mean(input_lens))

In [7]:
train_dataset = TrainDataset(
        graph,
        tokenizer=None,
        seq_length=128,
        num_of_sequences=1024,
        chars_per_token=3.6,
        )

max buffer size:  471859.2


Tokenizer class:

In [8]:
class BaseTokenizer:
    def __init__(self, n=1, vocab=None, padding_side='right', add_special_tokens=False):
        self.n = n
        if vocab is None:
            self.vocab = self.build_vocab()
        else:
            self.vocab = vocab
        self.rev_vocab = {v: k for k, v in self.vocab.items()}
        self.padding_side = padding_side
        self.add_special_tokens = add_special_tokens
        self.bos_token = '<BOS>'
        self.bos_token_id = self.vocab['<BOS>']
        self.eos_token = '<EOS>'
        self.eos_token_id = self.vocab['<EOS>']
        self.pad_token = '<PAD>'
        self.pad_token_id = self.vocab['<PAD>']
        self.unk_token = '<UNK>'
        self.unk_token_id = self.vocab['<UNK>']
        self.all_special_ids = [self.bos_token_id, self.eos_token_id,
                                self.pad_token_id, self.unk_token_id]
        self.all_special_tokens = self.all_special_tokens_extended = [
            self.bos_token, self.eos_token,
            self.pad_token, self.unk_token]

    def build_vocab(self):
        pass

    def tokenize(self, text: str, max_length: int):
        pass

    def encode(self, text, padding=False, max_length=1024, return_tensors=None, truncation=True):
        if type(text) == str:
            ids = [self.tokenize(text, max_length)]
        else:
            ids = []
            lens = []
            for t in text:
                _ids = self.tokenize(t, max_length)
                ids.append(_ids)
                lens.append(len(_ids))

            if padding:
                max_length = max(lens)
                for _ids in ids:
                    if len(_ids) < max_length:
                        if self.padding_side == 'left':
                            _ids = [self.pad_token_id] * (max_length - len(_ids)) + _ids
                        elif self.padding_side == 'right':
                            _ids += [self.pad_token_id] * (max_length - len(_ids))
                        else:
                            raise NotImplementedError

        if return_tensors == 'pt':
            ids = torch.tensor(ids)

        return ids

    def __call__(self, text, padding=False, max_length=1024, return_tensors=None, truncation=True, device='cpu'):
        if type(text) == str:
            ids = [self.tokenize(text, max_length)]
            attns = [[1] * len(ids[0])]
        else:
            ids = []
            attns = []
            lens = []
            for t in text:
                _ids = self.tokenize(t, max_length)
                ids.append(_ids)
                lens.append(len(_ids))
                attns.append([1] * len(_ids))

            if padding:
                max_length = max(lens)
                padded_ids = []
                padded_attns = []
                for _ids, attn in zip(ids, attns):
                    num_pad = max_length - len(_ids)
                    if self.padding_side == 'left':
                        padded_ids.append([self.pad_token_id] * num_pad + _ids)
                        padded_attns.append([0] * num_pad + attn)
                    elif self.padding_side == 'right':
                        padded_ids.append(_ids + [self.pad_token_id] * num_pad)
                        padded_attns.append(attn + [0] * num_pad)
                    else:
                        raise NotImplementedError
                ids = padded_ids
                attns = padded_attns

        if return_tensors == 'pt':
            ids = torch.tensor(ids).to(device)
            attns = torch.tensor(attns).to(device)

        return {"input_ids": ids, 'attention_mask': attns}

    def __len__(self):
        return len(self.vocab)

    def decode(self, token_ids, skip_special_tokens=False):
        if type(token_ids) == int:
            return self.rev_vocab[token_ids]
        else:
            out = ''
            for i in token_ids:
                if i == self.eos_token_id:
                    if not skip_special_tokens:
                        out += self.eos_token
                    break
                if skip_special_tokens and i in self.all_special_ids:
                    continue
                out += self.rev_vocab[i]
            return out

    def batch_decode(self, sequences, skip_special_tokens=False):
        out = []
        for token_ids in sequences:
            out.append(self.decode(token_ids, skip_special_tokens))
        return out

    def save_pretrained(self, output_dir):
        with open(f'{output_dir}/tokenizer.json', 'w') as wf:
            json.dump(self.vocab, wf, indent = 4)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, padding_side='right', trust_remote_code=False, revision=None):
        vocab_path = f"{pretrained_model_name_or_path}/tokenizer.json"
        if os.path.exists(vocab_path):
            vocab = json.load(open(vocab_path))
            n = 1
            for token in vocab:
                if token not in ['<BOS>', '<EOS>', '<PAD>', '<UNK>']:
                    if '_' in token:
                        n = max(n, int(token.split('_')[1]) + 1)
                    else:
                        n = max(n, len(token))

            return cls(n, vocab, padding_side=padding_side)
        else:
            return cls(padding_side=padding_side)

class CharTokenizer(BaseTokenizer):
    def __init__(self, n=1, vocab=None, padding_side='right', add_special_tokens=False):
        super().__init__(n, vocab, padding_side, add_special_tokens)

    def build_vocab(self):
        vocab = {'Q':0, 'P':1}
        for i in range(10):
            vocab[str(i)] = i+2
        vocab_size = 12
        vocab['\n'] = vocab_size
        vocab_size += 1
        vocab[' '] = vocab_size
        vocab_size += 1
        vocab['-'] = vocab_size
        vocab_size += 1
        vocab['?'] = vocab_size
        vocab_size += 1
        vocab['<BOS>'] = vocab_size
        vocab_size += 1
        vocab['<EOS>'] = vocab_size
        vocab_size += 1
        vocab['<PAD>'] = vocab_size
        vocab_size += 1
        vocab['<UNK>'] = vocab_size

        return vocab

    def tokenize(self, text: str, max_length: int):
        ids = []
        for l in text.split('\n'):
            if len(l) == 0:
                continue
            for w in l.split():
                for c in w.strip():
                    if c not in self.vocab:
                        ids.append(self.unk_token_id)
                    else:
                        ids.append(self.vocab[c])
                ids.append(self.vocab[' '])
            ids.append(self.vocab['\n'])

        if self.add_special_tokens:
            ids.append(self.vocab['<EOS>'])
        else:
            ids = ids[:-2]
        # print(ids)
        if max_length < len(ids):
            return ids[:max_length]
        else:
            return ids

Helper functions for training and evaluation:

In [9]:
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "<unk>"

def model_path_map(model_name):
    return '../llms/' + model_name

def count_params(model):
    params: int = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return params

def compute_llama_param(l, h, v):
    d = 64 * h
    embd = d * v
    atten = 4*d*d
    mlp = 2*d*d*3
    ln = d
    return l * (atten + mlp + 2*ln) + ln + 2*embd

@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances]
                                  for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True,
                                                 padding_value=IGNORE_INDEX)

        attn_mask = input_ids.ne(self.tokenizer.pad_token_id)

        # print("input_ids: ", input_ids)
        # print("labels: ", labels)
        # print("atten mask: ", attn_mask)

        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=attn_mask,
        )


def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict:
    """Tokenize a list of strings."""
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=max_length,
            truncation=True,
            # pad_to_multiple_of=8,
        )["input_ids"]
        for text in strings
    ]
    input_ids = labels = [tokenized[0] for tokenized in tokenized_list]
    input_ids_lens = labels_lens = [
        tokenized.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )


def prepare_data(
    sources: Sequence[str],
    targets: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
    max_length: int
) -> Dict:
    """Preprocess the data by tokenizing."""
    examples = [s + t for s, t in zip(sources, targets)]
    examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer, max_length)
                                             for strings in (examples, sources)]
    eos = torch.tensor([tokenizer.eos_token_id])
    input_ids = [torch.cat((ids, eos)) for ids in examples_tokenized["input_ids"]]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
        label[:source_len] = IGNORE_INDEX
    return dict(input_ids=input_ids, labels=labels)

Training function (set bf16 to True if you have GPUs):

In [10]:
def train(train_dataset, model_name_or_path='llama-2-2', random_initialize=True,
          output_dir='.', bf16=False):

    set_seed(42) # make sure use the same model initialization
    l, h, v = None, None, None

    if random_initialize:
        print("Random initializing...")
        model_name, l, h = model_name_or_path.split('-')
        l, h = int(l), int(h)
        d = 64 * h
        if model_name == 'llama':
            config = transformers.LlamaConfig(hidden_size=d,
                                            intermediate_size=2*d,
                                            num_attention_heads=h,
                                            num_hidden_layers=l)
        else:
            raise NotImplemented


        tokenizer = CharTokenizer()
        config.vocab_size = len(tokenizer.vocab)
        config.bos_token_id = tokenizer.bos_token_id
        config.eos_token_id = tokenizer.eos_token_id
        print("vocab size: ", len(tokenizer.vocab))
        print("new config: ", config)

        v = config.vocab_size
        model = transformers.AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)
        print("embedding size: ", model.get_input_embeddings().weight.data.shape)
    else:
        print("Using pre-trained model weights...")

        tokenizer = CharTokenizer.from_pretrained(model_name_or_path)

        model = transformers.AutoModelForCausalLM.from_pretrained(
            model_name_or_path
        )

    if l is not None and h is not None and v is not None:
        print("theoretical # params: ", compute_llama_param(l, h, v))
    print("actual # params: ", count_params(model))

    train_dataset.tokenizer = tokenizer

    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

    train_args = TrainingArguments(bf16=bf16, max_steps=1000,
                      per_device_train_batch_size=32, eval_strategy="no",
                      save_steps=1000, save_total_limit=1, learning_rate=1e-4,
                      weight_decay=0.0, warmup_ratio=0.2, lr_scheduler_type="cosine",
                      logging_steps=1, output_dir=output_dir, report_to="none")

    trainer = Trainer(model=model, tokenizer=tokenizer, args=train_args,
                    train_dataset=train_dataset, data_collator=data_collator,
                    eval_dataset=None)

    if not random_initialize:
        print("resume training from: ", model_name_or_path)
        trainer.train(model_name_or_path)
    else:
        trainer.train()
    trainer.save_state()
    trainer.save_model(output_dir=output_dir)
    return model, tokenizer

Train a 2-layer language model on the generated synthetic graph:

In [None]:
model, tokenizer = train(train_dataset)

Random initializing...
vocab size:  20
new config:  LlamaConfig {
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 16,
  "eos_token_id": 17,
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 128,
  "initializer_range": 0.02,
  "intermediate_size": 256,
  "max_position_embeddings": 2048,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 2,
  "num_hidden_layers": 2,
  "num_key_value_heads": 2,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "transformers_version": "4.53.0",
  "use_cache": true,
  "vocab_size": 20
}

embedding size:  torch.Size([20, 128])
theoretical # params:  333440
actual # params:  333440


  trainer = Trainer(model=model, tokenizer=tokenizer, args=train_args,


None
worker id: 
data buffer full


Step,Training Loss
1,2.8365
2,2.8264
3,2.8297
4,2.8254
5,2.8382
6,2.8201
7,2.8165
8,2.8306
9,2.83
10,2.8406


None
worker id: 
data buffer full


Evaluation data class:

In [None]:
class EvalDataset(Dataset):

    def __init__(self,
                graph,
                tokenizer,
                split="id", # or "uniform", or "mem"
                num_options=10,
                use_rule_length=False,
                seed=42):

        super(EvalDataset, self).__init__()
        self.split = split
        self.tokenizer = tokenizer
        self.eos_token = self.tokenizer.eos_token
        self.num_options = num_options
        self.num_test = graph.num_test
        set_seed(seed)
        self.path_length = []

        self.data = graph
        if split == 'id':
            self.triples = self.data.id_test_triples
            self.alt_ts = self.data.id_alt_ts
        elif split == 'uniform':
            self.triples = self.data.uniform_test_triples
            self.alt_ts = self.data.uniform_alt_ts
        elif split == "mem":
            self.triples = self.mem_triples
        else:
            print("no such split: ", split)
            raise NotImplementedError

        if use_rule_length:
            print("using rule length")
            self.path_length = [min([len(rule) - 1 for rule in rules]) for rules in self.data.test_rules]

        self.get_data()
        if len(self.path_length) == 0:
            self.get_path_length()

    def get_path_length(self):
        for h, r, t in self.input_triples:
            try:
                l = nx.shortest_path_length(self.data.G, source=h, target=t)
            except:
                l = 0
                print(f'cannot find shortest path between {h} and {t}')
            self.path_length.append(l)
        print("avg path length: ", np.mean(self.path_length))

    def get_data(self):
        self.input_text = []
        self.input_triples = []
        self.seen_ts = []
        self.options = []

        for idx, triple in enumerate(self.triples):
            h, r, t = triple

            if self.split == "mem":
                seen_ts = []
                if h in self.data.G:
                    for e in self.data.G[h]:
                        if r in self.data.G[h][e]['id']:
                            seen_ts.append(e)
            else:
                seen_ts = self.alt_ts[idx]
            self.seen_ts.append(seen_ts)

            question = h + ' ' + r + ' '
            ans = t

            options = [ans]
            for i in range(self.num_options-1):
                neg_e = random.choice(self.data.all_es)
                while neg_e == ans or neg_e in seen_ts:
                    neg_e = random.choice(self.data.all_es)
                options.append(neg_e)

            self.input_text.append(question)
            random.shuffle(options)
            self.options.append(options)

            self.input_triples.append(triple)

    def __len__(self):
        return len(self.input_text)

    def __getitem__(self, i):
        example = [self.input_text[i], self.input_triples[i], self.seen_ts[i]]
        example += [self.options[i]]
        example.append(self.path_length[i])
        return example

Create a evaluation dataset based on the synthetic graph generated before:

In [None]:
eval_dataset = EvalDataset(graph, tokenizer)

The evaluation function (set device to "cuda" if you have GPUs):

In [None]:
def eval(eval_dataset, model, batch_size=16, max_length=64, num_test=1000, device="cpu"):

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = True,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, LlamaForCausalLM

        >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        # hidden_states = outputs.hidden_states[-2]
        logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss(reduction='none')
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)
            loss = loss.view([labels.size(0), labels.size(1) - 1])
            loss = loss.sum(-1)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    # monkey patching
    original_forward = LlamaForCausalLM.forward
    LlamaForCausalLM.forward = forward
    model.eval()

    def collect_data(instances, device='cpu'):
        input_ids = instances["input_ids"]
        labels = instances["labels"]
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True,
                                                padding_value=IGNORE_INDEX)

        attn_mask = input_ids.ne(tokenizer.pad_token_id)

        return dict(
            input_ids=input_ids.to(device),
            labels=labels.to(device),
            attention_mask=attn_mask.to(device),
        )

    num_choices = eval_dataset.num_options
    print("number of choices: ", num_choices)

    input_texts = []
    output_texts = []
    gts = []
    example_ids = []
    id = 0
    num_correct = 0
    num_all = 0
    losses = []

    for q, triple, seen_t, opts, l in eval_dataset:
        id += 1
        if id > min(num_test, eval_dataset.num_test):
            break
        print(q, triple, seen_t, opts, l)
        label = triple[-1]
        input_text = q
        for op in opts:
            input_texts.append(input_text)
            output_texts.append(op)
            gts.append(op == label)
            example_ids.append(id)

        if len(input_texts) >= batch_size or id == min(len(eval_dataset)-1, num_test):
            data_dict = prepare_data(input_texts, output_texts, tokenizer, max_length)
            input_data = collect_data(data_dict, device)
            print("input data shape: ", input_data['input_ids'].shape)
            loss = model(**input_data).loss.detach().cpu()
            print("loss: ", loss)
            for i in range(len(input_texts)//num_choices):
                pred = torch.argmin(loss[i*num_choices: (i+1)*num_choices]).item()
                gt = np.arange(num_choices)[gts[i*num_choices: (i+1)*num_choices]][0]
                losses.append(loss[i*num_choices: (i+1)*num_choices][gt])
                if pred == gt:
                    num_correct += 1
                num_all += 1

            acc = num_correct/num_all
            print("Accuracy: ", acc)
            mean_loss = np.mean(losses)
            print("Loss: ", mean_loss)

            input_texts = []
            output_texts = []
            gts = []
            example_ids = []

    # restore the original forward function
    LlamaForCausalLM.forward = original_forward

    return acc, mean_loss

Evaluate the previously trained language model:

In [None]:
acc, loss = eval(eval_dataset, model)

Sweeping function to train and evaluate a range of model sizes and plot the U-shaped loss curve and acc curve (this will likely take a very long time/crush with only Colab CPU):

In [None]:
def scaling_sweep(train_dataset, eval_dataset, y_axis='loss'):
    accs = []
    losses = []
    error_losses = []
    sizes = []
    for i in range(1, 4):
        for j in [i-1, i, i+1]:
            if j > 0 and j < 4:
                num_p = compute_llama_param(2**i, 2**j, 19)
                model, tokenizer = train(train_dataset, f'llama-{2**i}-{2**j}')
                acc, loss = eval(eval_dataset, model)
                accs.append(acc)
                losses.append(loss)
                sizes.append(num_p)

    idx = np.argsort(sizes)
    accs = np.array(accs)[idx]
    losses = np.array(losses)[idx]
    sizes = np.array(sizes)[idx]
    log_sizes = np.log(sizes)
    print('log sizes: ', log_sizes)
    print('accs: ', accs)
    print('losses: ', losses)

    # plot acc
    plt.plot(log_sizes, accs, 'o-', linewidth=2)
    plt.xticks(np.log(sizes), (sizes/1000000).round(1), size=8)
    plt.xlabel(f'Llama model size (M)', size=12)
    plt.ylabel(f'Accuracy', size=12)
    plt.legend(fontsize=12)
    plt.show()
    plt.savefig('acc.png')
    plt.close()

    # plot loss
    plt.plot(log_sizes, losses, 'o-', linewidth=2)
    plt.xticks(np.log(sizes), (sizes/1000000).round(1), size=8)
    plt.xlabel(f'Llama model size (M)', size=12)
    plt.ylabel(f'Loss', size=12)
    plt.legend(loc='upper left', fontsize=12)
    plt.show()
    plt.savefig('loss.png')
    plt.close()

In [None]:
scaling_sweep(train_dataset, eval_dataset)