In [1]:
%env CUDA_VISIBLE_DEVICES=6
%env TRANSFORMERS_CACHE=/mnt/LLM/hub
%env HF_HOME=/mnt/LLM/
%env OMP_NUM_THREADS=16
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.insert(0, '..')

import time
import random
from tqdm.auto import trange, tqdm
import numpy as np
import ipynbname  # pip install ipynbname

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers

from src.aq import QuantizedWeight, QuantizedLinear
from src.modelutils import get_model
from src.datautils import get_loaders
from convert_legacy_model_format import load_quantized_model_with_old_pickle


torch.set_num_threads(16)
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


env: CUDA_VISIBLE_DEVICES=6
env: TRANSFORMERS_CACHE=/mnt/LLM/hub
env: HF_HOME=/mnt/LLM/
env: OMP_NUM_THREADS=16




In [7]:
class args:
    model_path = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
    model_dtype = 'bfloat16'
    xtx_dtype = torch.float64
    model_seqlen = 2048  # can be 2048 for 1.1B, 4096-8192 for larger models
    device_map = 'auto'
    
    dataset = 'pajama'
    total_nsamples = 2560
    seed = 42
    
    batch_size = 16384
    x_save_path = None
    xtx_save_path = "./tinyllama_xtx_seqlen2048_2560samples.pth"

In [9]:
model = get_model(args.model_path, None, args.model_dtype, args.device_map)
if not args.device_map:
    model = model.to(device)

train_data = get_loaders(
    args.dataset,
    nsamples=args.total_nsamples,
    seed=args.seed,
    model_path=args.model_path,
    seqlen=args.model_seqlen,
)

X = torch.zeros(len(train_data), args.model_seqlen, model.config.hidden_size,
                dtype=next(model.parameters()).dtype, device='cpu')
with torch.no_grad():
    for i, batch in enumerate(tqdm(train_data, desc='computing hidden states (X)')):
        batch = batch.to(device)
        hidden = model.model.forward(input_ids=batch, attention_mask=torch.ones_like(batch)).last_hidden_state
        X[i : i + 1, ...].copy_(hidden)

if args.x_save_path is not None:
    torch.save(X, args.x_save_path)
    print("X saved to", args.x_save_path)
    
X_flat = X.flatten(0, -2)

XTX = torch.zeros(X_flat.shape[-1], X_flat.shape[-1], device=device, dtype=args.xtx_dtype)
for i in tqdm(range(0, len(X_flat), args.batch_size), desc='computing dot products (XTX)'):
    x_batch = X_flat[i: i + args.batch_size].to(device=device, dtype=XTX.dtype)
    XTX.addmm_(x_batch.T, x_batch, alpha=1 / len(X_flat))
    del x_batch
torch.save(XTX, args.xtx_save_path)
print("XTX saved to", args.xtx_save_path)

Loading pretrained model ...
Model loaded sucсessfully ...
Loading red_pajama from togethercomputer/RedPajama-Data-1T-Sample


                                                                                

Loaded data from pajama; len(data)=256 sequences


computing hidden states (X):   0%|          | 0/256 [00:00<?, ?it/s]

computing dot products(XTX):   0%|          | 0/32 [00:00<?, ?it/s]

XTX saved to ./tinyllama_xtx_seqlen2048_2560samples.pth
