In [7]:
import torch
from omegaconf import OmegaConf
from src.models.blm.config import ModelArgs
from src.models.blm.pl_dataloader import TinyStoriesDataloader
from src.models.blm.pl_training import Transformer
from src.tokenize.tokenizer import Tokenizer

In [8]:
from src.models.blm.finetune.full_finetune import FineTuneBLM

In [9]:
from datasets import load_dataset, load_from_disk

ds = load_from_disk(
    "/home/pranav-pc/projects/OpenTransformer/multiformer/data/finetune/maths-problem"
)

In [10]:
BASE_URL = "/home/pranav-pc/projects/OpenTransformer/multiformer"
tokenizer_path = BASE_URL + "/tokenizer_checkpoints"
tokenizer = Tokenizer(tokenizer_path)

In [11]:
MODEL_CONFIG_PATH = (
    "/home/pranav-pc/projects/OpenTransformer/multiformer/src/models/blm/conf/config.yaml"
)
MODEL_CHECKPOINT_PATH = "/home/pranav-pc/projects/OpenTransformer/multiformer/blm-1024/checkpoints/blm-fine-tuned-maths/blm-instruct-maths-epoch=11-train_loss=1.282.ckpt"

conf = OmegaConf.load(MODEL_CONFIG_PATH)

In [12]:
state_dict = torch.load(MODEL_CHECKPOINT_PATH)["state_dict"]
# Pytorch by default add _orig_mod in the checkpoint keys.#TODO: Take care of this while model checkpointing
state_dict = {k.replace("._orig_mod", ""): v for k, v in state_dict.items()}

In [13]:
config = ModelArgs(**conf["model"])
model = FineTuneBLM(Transformer(config))
model.load_state_dict(state_dict)

<All keys matched successfully>

In [14]:
#### Inference
model.eval()
model = model.cuda()

In [34]:
ds["validation"][34]

{'text': 'There are 200 more red apples than green apples in a grocery store. A truck arrives and delivers another 340 green apples. If there were originally 32 green apples, how many more green apples than red apples are there in the store now?/nThere are originally 32 green apples in the store.\nThere are 200 more red apples than green apples, so there are 32 + 200 = 232 red apples.\nA truck delivers 340 green apples, so now there are 32 + 340 = 372 green apples in the store.\nThere are still 232 red apples in the store.\nTherefore, there are 372 - 232 = 140 more green apples than red apples in the store now.\n#### 140\nThe answer is: 140'}

In [40]:
text = """There are 200 more red apples than green apples in a grocery store. A truck arrives and delivers another 340 green apples. If there were originally 32 green apples, how many more green apples than red apples are there in the store now?"""

In [42]:
# text = "Once upon a time there was a pumpkin. It was a very special pumpkin, it could speak. It was sad because it couldn’t move. Every day, it would say"
# text = "Jack was hungry, so he went looking for"
# text = "Tim is a good boy. one day his father called and asked for the school exam result"
# text = "Jack wanted to read a book,so he went to"
# text = "who are you?"
# text = "User: Words: come, road, sad Summary: A bus becomes jealous of a shiny new car and undergoes an operation to go faster, becoming good friends with the car and making everyone in the town happy. Assistant:"
text = "9+9"
tokens = torch.LongTensor(tokenizer.encode(text)).to("cuda:0").view(1, -1)[:, :-1]
# print(tokens)
predicted_tokens = model.predict_step(
    tokens, None, max_new_tokens=1024, temperature=0.9, top_k=3, conditional_break=[2]
)[0].tolist()
# print(predicted_tokens)
print(tokenizer.decode_ids(predicted_tokens))

5
1
= 33$.  Find the sum of the numbers $x$ and $y = X$.
If we know the answer to the above question is 32, what is the value of unknown variable X?/nWe are given that the sum of the numbers $x$ and $y = X$.
To find the value of $X$, we need to determine the value of $X$.
We know that the sum of the numbers is 9, so we can write:
$9 + (9 - 33) = X$
Simplifying, we get:
$X + (9 - 33) = X$
We are given that the sum of the numbers is 33, so we can substitute the values into the equation:
$X + (9 - 33) = 33$
To find the value of $X$, we can solve for $X$ by substituting it into the equation:
$X + (9 - 33) = 33$
$X + 9 = 33$
Subtracting 9 from both sides of the equation, we get:
$X = 33$
Dividing both sides of the equation by 33, we find:
$X = 3$
The value of X is 3.
The answer is: 3


In [73]:
tokenizer.id_to_piece(13)

'<0x0A>'