In [1]:
# Standard library imports
import json
import os
from pathlib import Path
from typing import List, Optional

# Third-party library imports
import fire
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchinfo
from datasets import load_dataset
from torch.optim import Adam
from torch.utils.data import DataLoader
from mistral_common.protocol.instruct.messages import (
    UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.tool_calls import (
    Function,
    Tool,
)
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

# Local application/library-specific imports
from mistral import ModelArgs, Transformer, RMSNorm, precompute_freqs_cis

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load the dataset in streaming mode
ds = load_dataset("HuggingFaceTB/cosmopedia", "stories", streaming=True,)

# Initialize a counter
counter = 0

# Iterate over the dataset
dataset = {
    "text": [],
}

for sample in ds["train"]:
    dataset["text"].append(sample["text"])
    counter += 1
    if counter >= 100:
        break

In [3]:
# Load the tokenizer outside of the function
tokenizer = MistralTokenizer.v1()

def tokenize_text(text, tokenizer, return_text=False):
    # Tokenize the input text using the provided tokenizer
    tokenized = tokenizer.encode_chat_completion(
        ChatCompletionRequest(
            messages=[UserMessage(content=text)],
            model="open-mistral-7b",
        )
    )

    tokens = tokenized.tokens
    tokenized_text = tokenized.text

    if return_text:
        return tokens, tokenized_text
    else:
        return tokens

In [4]:
dataset

{'text': [' Once upon a time, in a village called Kiwiland, there lived two best friends named Kiwi and Koala. They loved exploring the world around them and learning new things every day! One day, they stumbled upon a magical forest full of vibrant colors and fascinating creatures. As they ventured deeper into the forest, they met Torty, a wise old turtle who was known to have answers to all questions.\n\nKiwi asked Torty, "How does our culture affect the way we make decisions?" Torty smiled and replied, "Well my dear friend, let me tell you a story."\n\nLong ago, in another part of the forest, there were two tribes - the Hares and the Sloths. The Hares valued speed and quickness, believing that swift actions led to success. On the other hand, the Sloths cherished patience and deliberation, thinking that slow yet thoughtful decisions brought prosperity.\n\nOne sunny afternoon, both tribes faced a challenge – sharing a limited supply of fruits between them. The Hares wanted to divide t

In [5]:
tokenized_data = [tokenize_text(text, tokenizer) for text in dataset["text"]]

In [6]:
tokenize_text("[NULL]", tokenizer, return_text=True)

([1, 733, 16289, 28793, 733, 5992, 28793, 733, 28748, 16289, 28793],
 '<s>▁[INST]▁[NULL]▁[/INST]')

In [7]:
tokenize_text("hi", tokenizer, return_text=True)

([1, 733, 16289, 28793, 12014, 733, 28748, 16289, 28793],
 '<s>▁[INST]▁hi▁[/INST]')

In [8]:
import torch 
from torch.utils.data import Dataset

# Convert tokenized data into a tensor
tokens_tensor = [torch.tensor(tokens) for tokens in tokenized_data]

In [9]:
max_seq_len = max(len(seq) for seq in tokens_tensor)

class PaddedSequenceTokenDataset(Dataset):
    def __init__(self, sequences, vocab_size, max_seq_len):
        self.sequences = sequences
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        # Convert the sequence of tokens to a tensor.
        sequence_tensor = torch.tensor(sequence, dtype=torch.long)
        # Pad the sequence with 0s.
        padded_sequence = torch.cat((sequence_tensor, torch.zeros(self.max_seq_len - len(sequence), dtype=torch.long)))
        return padded_sequence

vocab_size = 10000  # This is the size of your vocabulary
dataset = PaddedSequenceTokenDataset(tokens_tensor, vocab_size, max_seq_len)

In [10]:
batch_size = 2  # This is the size of the batches you want to load
data_loader = DataLoader(dataset, batch_size=batch_size)

In [11]:
for batch in data_loader:
    print(batch)
    break

  sequence_tensor = torch.tensor(sequence, dtype=torch.long)


tensor([[    1,   733, 16289,  ...,     0,     0,     0],
        [    1,   733, 16289,  ...,     0,     0,     0]])


In [12]:
# Create a new ModelArgs object with the desired configuration
model_args = ModelArgs(
    dim=512,
    n_layers=8,
    head_dim=64,
    hidden_dim=2048,
    n_heads=8,
    n_kv_heads=8,
    vocab_size=32000,
    norm_eps=1e-5,
    max_batch_size=3,
)

# Create a new Transformer object with random weights
model = Transformer(model_args).to(device)

In [13]:
torchinfo.summary(model)

Layer (type:depth-idx)                   Param #
Transformer                              --
├─Embedding: 1-1                         16,384,000
├─ModuleList: 1-2                        --
│    └─TransformerBlock: 2-1             --
│    │    └─Attention: 3-1               1,048,576
│    │    └─FeedForward: 3-2             3,145,728
│    │    └─RMSNorm: 3-3                 512
│    │    └─RMSNorm: 3-4                 512
│    └─TransformerBlock: 2-2             --
│    │    └─Attention: 3-5               1,048,576
│    │    └─FeedForward: 3-6             3,145,728
│    │    └─RMSNorm: 3-7                 512
│    │    └─RMSNorm: 3-8                 512
│    └─TransformerBlock: 2-3             --
│    │    └─Attention: 3-9               1,048,576
│    │    └─FeedForward: 3-10            3,145,728
│    │    └─RMSNorm: 3-11                512
│    │    └─RMSNorm: 3-12                512
│    └─TransformerBlock: 2-4             --
│    │    └─Attention: 3-13              1,048,576
│    │  

In [14]:
max_seq_len

1089

In [15]:
# create a list from 0 to 1089
positions = list(range(1089))

# convert the list to a PyTorch tensor
positions_tensor = torch.tensor(positions).to(device)

In [16]:
positions_tensor

tensor([   0,    1,    2,  ..., 1086, 1087, 1088], device='cuda:0')

In [22]:
optimizer = Adam(model.parameters(), lr=1e-5)
loss_fn = nn.CrossEntropyLoss()

model.train()
for batch in data_loader:
    optimizer.zero_grad()
    input_ids = batch.to(device)
    labels = input_ids  # for language modeling, labels are the same as inputs shifted by one
    outputs = model(input_ids, positions=positions_tensor)
    
    # we need to flatten the outputs and labels for the loss function
    loss = loss_fn(outputs.view(-1, outputs.size(-1)), labels.view(-1))
    print(loss)
    loss.backward()
    optimizer.step()

  sequence_tensor = torch.tensor(sequence, dtype=torch.long)


tensor(10.5723, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(10.2900, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(10.0706, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(10.0839, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(9.6423, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(9.8067, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(9.4042, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(8.9411, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(9.2146, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(8.4253, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(9.1586, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(8.4528, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.6025, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(8.9625, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(6.9856, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.0848, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(8.1182, device='cuda:0', grad

In [23]:
# Assuming you've already trained your model and it's saved in the variable `model`
# and you've loaded your tokenizer into the variable `tokenizer`
from mistral import generate

prompt = "This is a test prompt. "
generated_text, logprobs = generate([prompt], model, tokenizer, max_tokens=35)
print(generated_text[0])

AttributeError: 'MistralTokenizer' object has no attribute 'encode'