In [1]:
import torch
import os
import numpy as np
from collections import Counter
from tqdm import tqdm
import tiktoken
%cd /mloraw1/sfan/tunnel_llm

/mloraw1/sfan/tunnel_llm


In [2]:
from model import GPTBase
from data.utils import get_dataset

In [3]:
## Training Args ##
class AttributeDict(dict):
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

class ModelConf:
    def __init__(self, config):
        self.vocab_size = config['vocab_size']
        self.dropout = config['dropout']
        self.n_head = config['n_head']
        self.n_layer = config['n_layer']
        self.n_embd = config['n_embd']
        self.sequence_length = config['sequence_length']
        self.bias = config['bias']

def read_pkl(pkl_path:str):
    with open(pkl_path, 'rb') as trg:
        x = pickle.load(trg)
    return x

def load_config(summary_path):
    print("\nLoading config file")
    with open(summary_path) as fs:
        config = json.load(fs)['args']
    config = ModelConf(config)
    print(f'{summary_path} loading complete!')
    return config

# Load the checkpoint
def load_checkpoint(checkpoint_path, 
                    model_config:ModelConf, 
                    device='cpu',
                    train=False,):
    print("\nLoading checkpoint...")
    checkpoint = torch.load(checkpoint_path)
    state_dict = checkpoint['model']
    state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()} # distributed code appends '_orig_mod.' to param name
    model = GPTBase(model_config)
    model.load_state_dict(state_dict, strict=False) # olivia-add: strict=False, but do not understand why params fail to match
    model.to(device)
    if train:
        model.train()
    else:
        model.eval()
    return model


In [4]:
config_dict = {
    'vocab_size': 50304,
    'dropout': 0.0,
    'n_head': 12,
    'n_embd': 768,
    'sequence_length': 512,
    'n_layer': 12,
    'bias': 'false',
    'dataset': 'redpajama-all',
    'eval_all_domains': False,
    'batch_size': 70,
}
# args = AttributeDict(config_dict)

In [5]:
args = AttributeDict(config_dict)
model_config = ModelConf(config_dict)
data = get_dataset(args)
tokenizer = tiktoken.get_encoding("gpt2")
base_model = load_checkpoint(checkpoint_path='/mloraw1/sfan/curriculum-new/exps/slim_6b-all/base/124_baseline_base_lr0.0005_bs70x1_1nodes_seed=0/ckpt.pt', 
                        model_config=model_config, 
                        device='cuda', )

Loading train dataset 'redpajama-all'
Subset: arxiv || train: 30735125 || val: 1483751
Subset: book || train: 25666343 || val: 1192070
Subset: c4 || train: 165976168 || val: 8617730
Subset: cc || train: 146088106 || val: 7781824
Subset: github || train: 76541661 || val: 3878696
Subset: stackexchange || train: 24287690 || val: 1271364
Subset: wikipedia || train: 28937088 || val: 1579138
Num training tokens: 498232181
Num validation tokens: 25804573

Loading checkpoint...
number of parameters: 124.08M


In [35]:
base_model()

TypeError: GPTBase.forward() missing 1 required positional argument: 'idx'

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import WeightedRandomSampler
def sample_train_batch(data,
                       seq_len=512,
                       batch_size=4,
                       device='cpu',):
    span_ids = torch.arange(0, len(data) - seq_len, seq_len)
    ix = torch.randint(len(span_ids), (batch_size,))
    ix = span_ids[ix]
    x = torch.stack([torch.from_numpy((data[i:i+seq_len]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+seq_len]).astype(np.int64)) for i in ix])

    if device != 'cpu':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    return x, y

def sample_train_batch_from_dict(data,
                                 seq_len=512,
                                 batch_size=4,
                                 device='cpu',
                                 domain_weights=None,
                                 idx2domain=None,
                                 return_sample_dict=False,
                                 return_domain_ids=False):
    if domain_weights is None:
        domain_weights = torch.ones(len(data), dtype=torch.float) / len(data)
        idx2domain = {i:d for i,d in enumerate(list(data.keys()))}
    assert len(domain_weights) == len(data)
    assert len(idx2domain) == len(data)
    
    sampled_domain_ids = list(WeightedRandomSampler(weights=domain_weights, num_samples=batch_size, replacement=True))
    idx2count = Counter(sampled_domain_ids)
    print(idx2count)
    
    if return_sample_dict:
        sample_dict = {}
        for domain_id, domain_count in idx2count.items():
            v = data[idx2domain[domain_id]]
            span_ids = torch.arange(0, len(v) - seq_len, seq_len)
            ix = torch.randint(len(span_ids), (domain_count,))
            ix = span_ids[ix]
            domain_x = torch.stack([torch.from_numpy((v[i:i+seq_len]).astype(np.int64)) for i in ix])
            domain_y = torch.stack([torch.from_numpy((v[i+1:i+1+seq_len]).astype(np.int64)) for i in ix])
            if device != 'cpu':
                # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
                domain_x, domain_y = domain_x.pin_memory().to(device, non_blocking=True), domain_y.pin_memory().to(device, non_blocking=True)
            sample_dict[domain_id] = (domain_x, domain_y)
        return sample_dict, idx2domain
    
    id_list = []
    for domain_id, domain_count in idx2count.items():
        v = data[idx2domain[domain_id]]
        span_ids = torch.arange(0, len(v) - seq_len, seq_len)
        ix = torch.randint(len(span_ids), (domain_count,))
        ix = span_ids[ix]
        id_list.extend([(domain_id, i) for i in ix])
    
    x = torch.stack([torch.from_numpy((data[idx2domain[domain_id]][i:i+seq_len]).astype(np.int64)) for (domain_id, i) in id_list])
    y = torch.stack([torch.from_numpy((data[idx2domain[domain_id]][i+1:i+1+seq_len]).astype(np.int64)) for (domain_id, i) in id_list])
    
    if device != 'cpu':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    if return_domain_ids:
        sampled_domain_ids = torch.LongTensor([id for (id, _) in id_list])
        token_domain_ids = torch.zeros(len(sampled_domain_ids)*seq_len, dtype=torch.long)
        for id in sampled_domain_ids:
            token_domain_ids[id*seq_len:(id+1)*seq_len] = domain_id
        return x, y, sampled_domain_ids, token_domain_ids
    return x, y

In [36]:
x, y = sample_train_batch_from_dict(data['train'],
                                        seq_len=128,
                                        batch_size=32,
                                        device='cuda',
                                        domain_weights=None,
                                        idx2domain=None,
                                        return_sample_dict=False,
                                        return_domain_ids=False)

Counter({6: 7, 5: 6, 3: 5, 2: 5, 1: 3, 0: 3, 7: 2, 4: 1})


In [37]:
output = base_model(x, y, return_layer_rep=True)

In [38]:
layer_reps = output['layer_reps']
# len(layer_reps) # 6
# layer_reps[0].shape # [64, 128, 768]

In [39]:
layer_reps[0].shape

torch.Size([32, 128, 768])

In [45]:
from torch.linalg import matrix_rank

rank_dict = {idx:[] for idx in range(args.n_layer)}

for t in range(10):
    x, y = sample_train_batch_from_dict(data['train'],
                                        seq_len=128,
                                        batch_size=32,
                                        device='cuda',
                                        domain_weights=None,
                                        idx2domain=None,
                                        return_sample_dict=False,
                                        return_domain_ids=False)
    output = base_model(x, y, return_layer_rep=True)
    layer_reps = output['layer_reps']
    cur_rank_dict = {layer_idx: matrix_rank(torch.cov(rep.flatten(start_dim=0, end_dim=1).T)).item()
                    for layer_idx, rep in enumerate(layer_reps)}
    for k in rank_dict.keys():
        rank_dict[k].append(cur_rank_dict[k])
    
        

Counter({1: 7, 7: 6, 6: 6, 5: 4, 3: 3, 0: 2, 2: 2, 4: 2})
Counter({5: 8, 7: 5, 3: 4, 0: 4, 2: 3, 4: 3, 6: 3, 1: 2})
Counter({5: 6, 0: 6, 4: 5, 1: 5, 2: 4, 7: 3, 6: 2, 3: 1})
Counter({7: 9, 3: 5, 5: 4, 0: 4, 1: 4, 2: 3, 4: 2, 6: 1})
Counter({5: 9, 0: 6, 3: 6, 1: 4, 2: 3, 7: 2, 6: 1, 4: 1})
Counter({6: 6, 1: 6, 7: 4, 0: 4, 5: 4, 3: 3, 4: 3, 2: 2})
Counter({4: 7, 2: 5, 7: 5, 5: 5, 1: 3, 3: 3, 0: 3, 6: 1})
Counter({2: 6, 6: 6, 5: 6, 3: 4, 0: 4, 4: 4, 1: 2})
Counter({7: 7, 5: 7, 2: 5, 4: 5, 3: 2, 1: 2, 0: 2, 6: 2})
Counter({5: 6, 0: 6, 1: 5, 4: 4, 7: 4, 6: 3, 2: 2, 3: 2})


In [46]:
rank_dict

{0: [630, 597, 606, 602, 592, 621, 588, 616, 617, 616],
 1: [545, 523, 530, 524, 518, 544, 518, 532, 538, 532],
 2: [657, 635, 643, 638, 632, 660, 629, 642, 648, 643],
 3: [760, 748, 752, 749, 745, 761, 744, 753, 752, 753],
 4: [767, 767, 767, 767, 767, 767, 767, 767, 767, 767],
 5: [624, 604, 599, 596, 595, 615, 598, 612, 606, 610],
 6: [687, 668, 664, 657, 661, 676, 663, 677, 669, 673],
 7: [761, 753, 750, 746, 748, 759, 750, 760, 756, 757],
 8: [767, 767, 767, 767, 767, 767, 767, 767, 767, 767],
 9: [768, 768, 768, 768, 768, 768, 768, 768, 768, 768],
 10: [768, 768, 768, 767, 768, 768, 768, 768, 768, 768],
 11: [767, 767, 767, 767, 767, 767, 767, 767, 767, 767]}

### Linear Probe

In [None]:
class LinearProbe(nn.Module):
    def __init__(self, 
                 feature_size: int, 
                 output_size: int,
                 iterations: int=1000,
                 lr: float=5e-4,
                 opt = None):
        super().__init__()
        self.linear_probe = nn.Linear(feature_size, output_size)
        if opt is None:
            self.optimizer = torch.optim.Adam(self.linear_probe.parameters(), lr=lr)
        else:
            self.optimizer = opt
        self.criterion = torch.nn.CrossEntropyLoss()
        self.total_iterations = iterations
        self.best_val_acc = 0.0
        self.cuda = torch.cuda.is_available()
        if self.cuda:
            self.linear_probe = self.linear_probe.cuda()
    
    def forward(self, x):
        return self.linear_probe(x)
        
        
        