In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from model.gpt import GPT
from pathlib import Path
import torch
from fastai.learner import *
from fastai.text.all import *

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#script to evaluate phi3 without any finetuning

import sys
import os
# Add the parent directory to the system path
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from model.gpt import GPT
from fastai.text.all import *
from data.unlabeled import download_dataset
from fastai.distributed import *
from learner.LLMLearner import LLMLearner
from data.loader import memmapDL

model = GPT.from_hf('microsoft/Phi-3-mini-4k-instruct', enable_qlora = True)
dataset = "orcamath"
bs = 1
valid_sampler_size = 1000 #how many samples to use for validation. This is only used to check if validation loss is better than best_valid_loss, so that a checkpoint can be saved. Karpathy uses 200 random points
validate_every = 1000 #1000 iterations, each iteration is bs*total_GPUs inputs
qlora = True

train_path, valid_path = rank0_first(lambda: download_dataset(dataset = dataset, encoder = model.tokenizer)) #check if data exists, download only for rank0 GPU. 
train_dl = memmapDL(train_path, bs = bs, block_size=model.block_size, 
                      dtype=model.tokenizer._get_numpy_dtype())
valid_dl = memmapDL(valid_path, bs = bs, block_size=model.block_size, 
                      dtype=model.tokenizer._get_numpy_dtype(), 
                      sample_size = valid_sampler_size)

dls = DataLoaders(train_dl, valid_dl)
dls.c = model.vocab_size

learn = LLMLearner(dls, 
                model, 
                opt_func = partial(OptimWrapper, opt=torch.optim.AdamW),
                loss_func=CrossEntropyLossFlat(), 
                metrics=[accuracy, Perplexity()],
                ).to_bf16()
learn.path = Path('scripts/checkpoints/') #local path to save/load checkpoints
learn.model_dir = 'gpt'



Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.69s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Setting num_proc from 18 back to 1 for the train split to disable multiprocessing as it only contains one shard.
Generating train split: 100%|██████████| 200035/200035 [00:01<00:00, 156331.51 examples/s]


In [4]:
x,y = next(iter(train_dl))
x.shape

torch.Size([1, 4096])

In [5]:
learn.cbs

(#5) [TrainEvalCallback,Recorder,CastToTensor,ProgressCallback,MixedPrecision]

In [6]:
# learn.epoch = 0
# learn.dl = train_dl
# learn('before_train')
# learn.model.eval()


In [6]:
learn.model.train()
with torch.no_grad(): 
    learn.model(x)



In [26]:
Path('scripts/checkpoints/gpt/Phi-3-mini-25.2M.pth').exists()

True

In [27]:
state = torch.load(Path('scripts/checkpoints/gpt/Phi-3-mini-25.2M.pth'), map_location=torch.device('cuda'))

In [8]:
del model
#clear cuda cache
import gc
gc.collect()
torch.cuda.empty_cache()

: 

In [38]:
model.load_state_dict(state['model'])   

<All keys matched successfully>

In [41]:
list(model.parameters())

[Parameter containing:
 tensor([[-5.8594e-02, -4.0894e-03,  1.5564e-03,  ..., -2.3438e-02,
           3.8818e-02, -5.9082e-02],
         [-3.0273e-02,  9.1309e-02,  5.6152e-02,  ...,  1.0132e-02,
          -2.1606e-02, -2.4170e-02],
         [-3.3264e-03,  3.1982e-02,  9.0942e-03,  ...,  9.4414e-05,
          -6.9275e-03, -2.7832e-02],
         ...,
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]], device='cuda:0'),
 Parameter containing:
 Parameter(Params4bit([[148],
             [158],
             [191],
             ...,
             [ 99],
             [ 62],
             [212]], device='cuda:0', dtype=torch.uint8)),
 Parameter containing:
 tensor([[-0.0078, -0.0035, -0.0164,  ...,  0.0037,  0.0084, -0.0020],
         [