In [1]:
# %load_ext autoreload
# %autoreload 2


In [2]:
import os
import torch


In [3]:
# random seed for reproducibility
torch.manual_seed(42)


<torch._C.Generator at 0x7f05e40b54d0>

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
print(torch.cuda.device_count())


cuda
1


In [5]:
# !pip -q install sentencepiece
# !pip -q install datasets
# !pip -q install fairscale
# !pip -q install transformers
# !pip -q install tqdm


In [6]:

# CRITICAL: Import the correct model.
# from models.model_baseline_nokv_reduced import ModelArgs, Transformer
from models.model_baseline_nokv import ModelArgs, Transformer
from tqdm import tqdm
from transformers import LlamaTokenizer

tokenizer = LlamaTokenizer("./tokenizers/tokenizer.model")
tokenizer.pad_token = tokenizer.eos_token

model_args = ModelArgs()
model_args.vocab_size = len(tokenizer)


You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [7]:
# !export RANK=0


In [8]:
from fairscale.nn.model_parallel import initialize_model_parallel
import torch.distributed as dist

def setup(rank, world_size):
    "Sets up the process group and configuration for PyTorch Distributed Data Parallelism"
    os.environ["MASTER_ADDR"] = 'localhost'
    os.environ["MASTER_PORT"] = "12355"

    # Initialize the process group
    if not dist.is_initialized():
        dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    "Cleans up the distributed environment"
    dist.destroy_process_group()

setup(0, 1)


# COMMENT IF INITIATED, OTHERWISE UNCOMMENT!!!!!
initialize_model_parallel(1)


> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1


In [9]:
# Get Dataset.
from datasets import load_dataset
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]["text"]

# SlimPajama dataset
# https://huggingface.co/datasets/cerebras/SlimPajama-627B?row=0
# Just for now, we use a small portion of the dataset. Later, we can use more.
# data = load_dataset("cerebras/SlimPajama-627B", split="train").select(range(1000))
data = load_dataset("DKYoon/SlimPajama-6B", split="train").select(range(5000))

dataset = CustomDataset(data)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, drop_last=True)


Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

In [10]:
# Model Details
# model = Transformer(model_args)
# model = torch.nn.DataParallel(Transformer(model_args))
model = torch.nn.parallel.DistributedDataParallel(Transformer(model_args))
criterion = torch.nn.CrossEntropyLoss()


In [11]:
## Re-initialize the learnable weights
import math
for name, param in model.named_parameters():
    if param.requires_grad:
        if param.dim() >= 2:
            torch.nn.init.kaiming_uniform_(param)
        else:
            # bias terms or 1D params, uniform init
            bound = 1 / math.sqrt(param.size(0))
            torch.nn.init.uniform_(param, -bound, bound)


In [12]:
# Move model to cuda
model = model.to(device)


In [13]:
# Declare optimizer with finalized model parameters (after parallel + initialization)
optimizer = torch.optim.AdamW(model.parameters(), lr=3.0E-4)


In [14]:
# See how large the model is.
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params}")

trainable_params = sum(
	p.numel() for p in model.parameters() if p.requires_grad
)
print(f"Trainable number of parameters: {trainable_params}")


Total number of parameters: 166740992
Trainable number of parameters: 166740992


In [15]:
# Storage of loss metrics
import pickle
def write_to_file(epoch_number, loss_values, avg_loss, file_path):    
    # Open the file in binary append mode
    with open(file_path, 'ab') as file:
        # Serialize and write the epoch number
        pickle.dump(f'Epoch {epoch_number}', file)
        # Serialize and write the list of loss values
        pickle.dump(loss_values, file)
        # Average loss for the epoch
        pickle.dump(avg_loss, file)


In [16]:
# raise Exception("Please do not overwrite the files to avoid data losss; comment this line once done.")
SAVE_PATH = "./checkpoints/baseline_slimpj_big_sample.pt"
LOSSES_PATH = "./logs/losses/baseline_losses_slimpj_big_sample.txt"
TIMES_PATH = "./logs/times/baseline_slimpj_big_sample_elapsed_time.txt"
EPOCHS = 1


In [17]:
# Check the storage of the model
checkpoint = {
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "epoch": -1,
    "loss": -1,
}
torch.save(checkpoint, SAVE_PATH)
del checkpoint


In [18]:
# Check logging of losses
write_to_file(-1, [-1, -2, -3], [-2], LOSSES_PATH)


In [19]:
import time


# Training

In [20]:
losses = []
# Log a 100K losses
log_every = max(1, (EPOCHS * len(data_loader)) // 100_000)
print("Number of batches:", len(data_loader))
model.train()
start_time = time.time()
# with torch.autograd.detect_anomaly():
for epoch in range(EPOCHS):
    total_loss = 0
    for c, batch in tqdm(enumerate(data_loader)):
        # Batch size x Max Seq LEn
        sample = tokenizer(batch, padding=True, truncation=True, return_tensors="pt", max_length=model_args.max_seq_len)["input_ids"].to(device)
        target = sample.detach()

        # Batch size x Max Seq Len x Vocab Size
        optimizer.zero_grad()
        prediction = model(sample, 0)

        # Ensure swapping of axes
        loss = criterion(prediction.transpose(1, 2), target)
        loss.backward()

        # Loss logging
        total_loss += loss.item()
        if c % log_every == 0:
            print(f"Step: {c}, Loss: {loss.item():.4f}")
            losses.append(loss.item())

        # Change model weights
        optimizer.step()

        # Explicit destruction (may not be needed after  previous debugging)
        del loss, prediction, sample, target

    print(f"Epoch {epoch} Complete")
    avg_loss = total_loss / len(data_loader)
    checkpoint = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch,
        "loss": avg_loss,
    }
    torch.save(checkpoint, SAVE_PATH)
    write_to_file(epoch, losses, avg_loss, LOSSES_PATH)
    losses.clear()

end_time = time.time()


Number of batches: 156


1it [00:01,  1.60s/it]

Step: 0, Loss: 10.3782


2it [00:02,  1.28s/it]

Step: 1, Loss: 10.2651


3it [00:03,  1.20s/it]

Step: 2, Loss: 10.2152


4it [00:04,  1.19s/it]

Step: 3, Loss: 10.1475


5it [00:06,  1.19s/it]

Step: 4, Loss: 10.2171


6it [00:07,  1.13s/it]

Step: 5, Loss: 10.0209


7it [00:08,  1.09s/it]

Step: 6, Loss: 10.1154


8it [00:09,  1.07s/it]

Step: 7, Loss: 10.1308


9it [00:10,  1.11s/it]

Step: 8, Loss: 10.0763


10it [00:11,  1.09s/it]

Step: 9, Loss: 10.1268


11it [00:12,  1.10s/it]

Step: 10, Loss: 10.0964


12it [00:13,  1.09s/it]

Step: 11, Loss: 10.0470


13it [00:14,  1.07s/it]

Step: 12, Loss: 9.9828


14it [00:15,  1.06s/it]

Step: 13, Loss: 9.9685


15it [00:16,  1.07s/it]

Step: 14, Loss: 10.0027


16it [00:17,  1.06s/it]

Step: 15, Loss: 9.9893


17it [00:18,  1.06s/it]

Step: 16, Loss: 10.0243


18it [00:19,  1.05s/it]

Step: 17, Loss: 9.8462


19it [00:20,  1.04s/it]

Step: 18, Loss: 9.9327


20it [00:21,  1.04s/it]

Step: 19, Loss: 9.9652


21it [00:22,  1.03s/it]

Step: 20, Loss: 9.8144


22it [00:24,  1.05s/it]

Step: 21, Loss: 9.9814


23it [00:25,  1.08s/it]

Step: 22, Loss: 9.9327


24it [00:26,  1.07s/it]

Step: 23, Loss: 9.8281


25it [00:27,  1.06s/it]

Step: 24, Loss: 9.8350


26it [00:28,  1.07s/it]

Step: 25, Loss: 9.8087


27it [00:29,  1.05s/it]

Step: 26, Loss: 9.7252


28it [00:30,  1.07s/it]

Step: 27, Loss: 9.7008


29it [00:31,  1.07s/it]

Step: 28, Loss: 9.7697


30it [00:32,  1.10s/it]

Step: 29, Loss: 9.7846


31it [00:33,  1.08s/it]

Step: 30, Loss: 9.6053


32it [00:34,  1.09s/it]

Step: 31, Loss: 9.6692


33it [00:35,  1.08s/it]

Step: 32, Loss: 9.6262


34it [00:36,  1.06s/it]

Step: 33, Loss: 9.6912


35it [00:37,  1.06s/it]

Step: 34, Loss: 9.5781


36it [00:39,  1.05s/it]

Step: 35, Loss: 9.6039


37it [00:40,  1.06s/it]

Step: 36, Loss: 9.4328


38it [00:41,  1.06s/it]

Step: 37, Loss: 9.5137


39it [00:42,  1.05s/it]

Step: 38, Loss: 9.4516


40it [00:43,  1.06s/it]

Step: 39, Loss: 9.5641


41it [00:44,  1.05s/it]

Step: 40, Loss: 9.4057


42it [00:45,  1.05s/it]

Step: 41, Loss: 9.4248


43it [00:46,  1.05s/it]

Step: 42, Loss: 9.4606


44it [00:47,  1.05s/it]

Step: 43, Loss: 9.4095


45it [00:48,  1.04s/it]

Step: 44, Loss: 9.2758


46it [00:49,  1.04s/it]

Step: 45, Loss: 9.2743


47it [00:50,  1.05s/it]

Step: 46, Loss: 9.2815


48it [00:51,  1.05s/it]

Step: 47, Loss: 9.2164


49it [00:52,  1.12s/it]

Step: 48, Loss: 9.2564


50it [00:53,  1.09s/it]

Step: 49, Loss: 9.1694


51it [00:55,  1.10s/it]

Step: 50, Loss: 9.3047


52it [00:56,  1.10s/it]

Step: 51, Loss: 9.2258


53it [00:57,  1.10s/it]

Step: 52, Loss: 9.1193


54it [00:58,  1.08s/it]

Step: 53, Loss: 8.9283


55it [00:59,  1.08s/it]

Step: 54, Loss: 9.0248


56it [01:00,  1.09s/it]

Step: 55, Loss: 9.0588


57it [01:01,  1.07s/it]

Step: 56, Loss: 8.9204


58it [01:02,  1.06s/it]

Step: 57, Loss: 9.0081


59it [01:03,  1.07s/it]

Step: 58, Loss: 8.7945


60it [01:04,  1.06s/it]

Step: 59, Loss: 8.8299


61it [01:05,  1.09s/it]

Step: 60, Loss: 8.9116


62it [01:06,  1.09s/it]

Step: 61, Loss: 8.8310


63it [01:08,  1.09s/it]

Step: 62, Loss: 8.7177


64it [01:09,  1.07s/it]

Step: 63, Loss: 8.5853


65it [01:10,  1.05s/it]

Step: 64, Loss: 8.5965


66it [01:11,  1.05s/it]

Step: 65, Loss: 8.5767


67it [01:12,  1.07s/it]

Step: 66, Loss: 8.6318


68it [01:13,  1.06s/it]

Step: 67, Loss: 8.4536


69it [01:14,  1.06s/it]

Step: 68, Loss: 8.4213


70it [01:15,  1.06s/it]

Step: 69, Loss: 8.5885


71it [01:16,  1.08s/it]

Step: 70, Loss: 8.4566


72it [01:17,  1.08s/it]

Step: 71, Loss: 8.4608


73it [01:18,  1.11s/it]

Step: 72, Loss: 8.2852


74it [01:19,  1.12s/it]

Step: 73, Loss: 8.3424


75it [01:21,  1.13s/it]

Step: 74, Loss: 8.2119


76it [01:22,  1.11s/it]

Step: 75, Loss: 8.4281


77it [01:23,  1.09s/it]

Step: 76, Loss: 8.1566


78it [01:24,  1.08s/it]

Step: 77, Loss: 8.2019


79it [01:25,  1.07s/it]

Step: 78, Loss: 7.9862


80it [01:26,  1.07s/it]

Step: 79, Loss: 8.0931


81it [01:27,  1.08s/it]

Step: 80, Loss: 7.9555


82it [01:28,  1.08s/it]

Step: 81, Loss: 7.9043


83it [01:29,  1.07s/it]

Step: 82, Loss: 7.8466


84it [01:30,  1.07s/it]

Step: 83, Loss: 7.8753


85it [01:31,  1.10s/it]

Step: 84, Loss: 7.7030


86it [01:32,  1.09s/it]

Step: 85, Loss: 7.8169


87it [01:33,  1.09s/it]

Step: 86, Loss: 7.6357


88it [01:35,  1.09s/it]

Step: 87, Loss: 7.6254


89it [01:36,  1.08s/it]

Step: 88, Loss: 7.7064


90it [01:37,  1.07s/it]

Step: 89, Loss: 7.4981


91it [01:38,  1.08s/it]

Step: 90, Loss: 7.6205


92it [01:39,  1.09s/it]

Step: 91, Loss: 7.4083


93it [01:40,  1.09s/it]

Step: 92, Loss: 7.6096


94it [01:41,  1.09s/it]

Step: 93, Loss: 7.4024


95it [01:42,  1.09s/it]

Step: 94, Loss: 7.3434


96it [01:43,  1.08s/it]

Step: 95, Loss: 7.3156


97it [01:44,  1.08s/it]

Step: 96, Loss: 7.3541


98it [01:45,  1.08s/it]

Step: 97, Loss: 7.1702


99it [01:46,  1.08s/it]

Step: 98, Loss: 7.0982


100it [01:47,  1.08s/it]

Step: 99, Loss: 7.2374


101it [01:49,  1.11s/it]

Step: 100, Loss: 7.1895


102it [01:50,  1.14s/it]

Step: 101, Loss: 6.9928


103it [01:51,  1.12s/it]

Step: 102, Loss: 7.0799


104it [01:52,  1.09s/it]

Step: 103, Loss: 6.7477


105it [01:53,  1.08s/it]

Step: 104, Loss: 6.8835


106it [01:54,  1.07s/it]

Step: 105, Loss: 6.6626


107it [01:55,  1.07s/it]

Step: 106, Loss: 6.4452


108it [01:56,  1.07s/it]

Step: 107, Loss: 6.6592


109it [01:57,  1.07s/it]

Step: 108, Loss: 6.6392


110it [01:58,  1.07s/it]

Step: 109, Loss: 6.5500


111it [02:00,  1.11s/it]

Step: 110, Loss: 6.5042


112it [02:01,  1.09s/it]

Step: 111, Loss: 6.4389


113it [02:02,  1.08s/it]

Step: 112, Loss: 6.2517


114it [02:03,  1.07s/it]

Step: 113, Loss: 5.9410


115it [02:04,  1.09s/it]

Step: 114, Loss: 6.1999


116it [02:05,  1.08s/it]

Step: 115, Loss: 6.0666


117it [02:06,  1.08s/it]

Step: 116, Loss: 6.2548


118it [02:07,  1.11s/it]

Step: 117, Loss: 6.0420


119it [02:08,  1.10s/it]

Step: 118, Loss: 6.1931


120it [02:09,  1.08s/it]

Step: 119, Loss: 5.7846


121it [02:10,  1.08s/it]

Step: 120, Loss: 5.6186


122it [02:11,  1.07s/it]

Step: 121, Loss: 5.7617


123it [02:12,  1.07s/it]

Step: 122, Loss: 5.7873


124it [02:14,  1.07s/it]

Step: 123, Loss: 5.5110


125it [02:15,  1.07s/it]

Step: 124, Loss: 5.4085


126it [02:16,  1.06s/it]

Step: 125, Loss: 5.3641


127it [02:17,  1.07s/it]

Step: 126, Loss: 5.7709


128it [02:18,  1.07s/it]

Step: 127, Loss: 5.3058


129it [02:19,  1.07s/it]

Step: 128, Loss: 5.4968


130it [02:20,  1.06s/it]

Step: 129, Loss: 5.4541


131it [02:21,  1.07s/it]

Step: 130, Loss: 5.4839


132it [02:22,  1.07s/it]

Step: 131, Loss: 5.1147


133it [02:24,  1.44s/it]

Step: 132, Loss: 4.9913


134it [02:25,  1.32s/it]

Step: 133, Loss: 5.1885


135it [02:27,  1.24s/it]

Step: 134, Loss: 4.9379


136it [02:28,  1.18s/it]

Step: 135, Loss: 4.8107


137it [02:29,  1.16s/it]

Step: 136, Loss: 4.9975


138it [02:30,  1.14s/it]

Step: 137, Loss: 4.7715


139it [02:31,  1.11s/it]

Step: 138, Loss: 4.6335


140it [02:32,  1.09s/it]

Step: 139, Loss: 4.6022


141it [02:33,  1.07s/it]

Step: 140, Loss: 4.5038


142it [02:34,  1.07s/it]

Step: 141, Loss: 4.6590


143it [02:35,  1.10s/it]

Step: 142, Loss: 4.5225


144it [02:36,  1.09s/it]

Step: 143, Loss: 4.4517


145it [02:37,  1.09s/it]

Step: 144, Loss: 4.5529


146it [02:38,  1.08s/it]

Step: 145, Loss: 4.5998


147it [02:39,  1.10s/it]

Step: 146, Loss: 4.6551


148it [02:41,  1.09s/it]

Step: 147, Loss: 4.4351


149it [02:42,  1.08s/it]

Step: 148, Loss: 4.2553


150it [02:43,  1.08s/it]

Step: 149, Loss: 3.6517


151it [02:44,  1.07s/it]

Step: 150, Loss: 3.7746


152it [02:45,  1.09s/it]

Step: 151, Loss: 3.6325


153it [02:46,  1.12s/it]

Step: 152, Loss: 3.6976


154it [02:47,  1.11s/it]

Step: 153, Loss: 3.8651


155it [02:48,  1.10s/it]

Step: 154, Loss: 3.8129


156it [02:49,  1.09s/it]

Step: 155, Loss: 4.0546
Epoch 0 Complete





In [21]:
optimizer.zero_grad()


In [22]:
elapsed_time = end_time - start_time
with open(TIMES_PATH, "a") as file:
    file.write(f"{elapsed_time}")


In [23]:
elapsed_time


184.83383345603943

In [24]:
cleanup()


Bad pipe message: %s [b'\xfcp\x1dG\xc1\xb3\xda\x93I\xaeK\xcf\xbe\xfa\x1fid\x92 \xf7=/\x82}\xb5\x00F\x7f\x87\xec\xa6H)/\xf3c\x82\x95%2\xd6X\x88\xaf\xe1\xafy\xb1\r\xe4\xb5\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08\t\x08\n\x08\x0b\x08\x04\x08\x05\x08\x06\x04\x01\x05\x01\x06\x01\x00+\x00\x03\x02\x03\x04\x00-\x00\x02\x01\x01\x003\x00&\x00$\x00\x1d\x00 \xb4\xe8-\x1a\x97f^b\xd4@H\xac\xc9\xbd \xb0_\xe3\xaf\x8e\xec\xc7\xc1\xe6']
Bad pipe message: %s [b'oL*\xf1T"\xa9\xb4\xba\xd3|\xaf\xaajn&r\x87\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0#\xc0\'\x00g\x00@\xc0\n\xc0\x14\x009

With the following configuration:

@dataclass
class ModelArgs:
    dim: int = 1024
    n_layers: int = 8
    n_heads: int = 8
    n_kv_heads: Optional[int] = None
    vocab_size: int = -1  # defined later by tokenizer
    multiple_of: int = 64  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5

    max_batch_size: int = 32
    max_seq_len: int = 512

There are 166740992 parameters

To train across 156 batches -> 4992 Rows (32 sized batch), it took 185 sec

The model trains at 1619 Rows per Minute. -> Round down to 1500 Rows per minute.

1440 minutes per day.

With 1.5M Rows, we train for 

1.5M * 1 / 1500 * 1 / 1440 = 0.7 days