In [None]:
# Standard library imports
import json
import os
from pathlib import Path
from typing import List, Optional

# Third-party library imports
import fire
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchinfo
from datasets import load_dataset
from torch.optim import Adam
from torch.utils.data import DataLoader
from mistral_common.protocol.instruct.messages import (
    UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.tool_calls import (
    Function,
    Tool,
)
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

# Local application/library-specific imports
from mistral import ModelArgs, Transformer, RMSNorm, precompute_freqs_cis, generate

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [29]:
# Load the dataset in streaming mode
ds = load_dataset("HuggingFaceTB/cosmopedia", "stories", streaming=True,)

# Initialize a counter
counter = 0

# Iterate over the dataset
dataset = {
    "text": [],
}

for sample in ds["train"]:
    dataset["text"].append(sample["text"])
    counter += 1
    if counter >= 10000:
        break

In [30]:
# Load the tokenizer outside of the function
tokenizer = MistralTokenizer.v1()

def tokenize_text(text, tokenizer, return_text=False):
    # Tokenize the input text using the provided tokenizer
    tokenized = tokenizer.encode_chat_completion(
        ChatCompletionRequest(
            messages=[UserMessage(content=text)],
            model="open-mistral-7b",
        )
    )

    tokens = tokenized.tokens
    tokenized_text = tokenized.text

    if return_text:
        return tokens, tokenized_text
    else:
        return tokens

In [31]:
dataset["text"][:5]

[' Once upon a time, in a village called Kiwiland, there lived two best friends named Kiwi and Koala. They loved exploring the world around them and learning new things every day! One day, they stumbled upon a magical forest full of vibrant colors and fascinating creatures. As they ventured deeper into the forest, they met Torty, a wise old turtle who was known to have answers to all questions.\n\nKiwi asked Torty, "How does our culture affect the way we make decisions?" Torty smiled and replied, "Well my dear friend, let me tell you a story."\n\nLong ago, in another part of the forest, there were two tribes - the Hares and the Sloths. The Hares valued speed and quickness, believing that swift actions led to success. On the other hand, the Sloths cherished patience and deliberation, thinking that slow yet thoughtful decisions brought prosperity.\n\nOne sunny afternoon, both tribes faced a challenge – sharing a limited supply of fruits between them. The Hares wanted to divide the fruits

In [32]:
tokenized_data = [tokenize_text(text, tokenizer) for text in dataset["text"]]

In [33]:
tokenize_text("[NULL]", tokenizer, return_text=True)

([1, 733, 16289, 28793, 733, 5992, 28793, 733, 28748, 16289, 28793],
 '<s>▁[INST]▁[NULL]▁[/INST]')

In [34]:
tokenize_text("hi", tokenizer, return_text=True)

([1, 733, 16289, 28793, 12014, 733, 28748, 16289, 28793],
 '<s>▁[INST]▁hi▁[/INST]')

In [35]:
import torch 
from torch.utils.data import Dataset

# Convert tokenized data into a tensor
tokens_tensor = [torch.tensor(tokens) for tokens in tokenized_data]
num_tokens = sum([len(tokens) for tokens in tokenized_data])

In [36]:
print(num_tokens)

5960969


In [37]:
max_seq_len = max(len(seq) for seq in tokens_tensor)

class PaddedSequenceTokenDataset(Dataset):
    def __init__(self, sequences, vocab_size, max_seq_len):
        self.sequences = sequences
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        # Convert the sequence of tokens to a tensor.
        sequence_tensor = torch.tensor(sequence, dtype=torch.long)
        # Pad the sequence with 0s.
        padded_sequence = torch.cat((sequence_tensor, torch.zeros(self.max_seq_len - len(sequence), dtype=torch.long)))
        return padded_sequence

vocab_size = 10000  # This is the size of your vocabulary
dataset = PaddedSequenceTokenDataset(tokens_tensor, vocab_size, max_seq_len)

In [38]:
batch_size = 1  # This is the size of the batches you want to load
data_loader = DataLoader(dataset, batch_size=batch_size)

In [39]:
for batch in data_loader:
    print(batch)
    break

  sequence_tensor = torch.tensor(sequence, dtype=torch.long)


tensor([[    1,   733, 16289,  ...,     0,     0,     0]])


In [40]:
# Create a new ModelArgs object with the desired configuration
model_args = ModelArgs(
    dim=512,
    n_layers=16,
    head_dim=32,
    hidden_dim=2048,
    n_heads=8,
    n_kv_heads=8,
    vocab_size=32000,
    norm_eps=1e-5,
    max_batch_size=3,
)

# Create a new Transformer object with random weights
model = Transformer(model_args).to("cuda", dtype=torch.float32)

In [41]:
torchinfo.summary(model)

Layer (type:depth-idx)                   Param #
Transformer                              --
├─Embedding: 1-1                         16,384,000
├─ModuleList: 1-2                        --
│    └─TransformerBlock: 2-1             --
│    │    └─Attention: 3-1               524,288
│    │    └─FeedForward: 3-2             3,145,728
│    │    └─RMSNorm: 3-3                 512
│    │    └─RMSNorm: 3-4                 512
│    └─TransformerBlock: 2-2             --
│    │    └─Attention: 3-5               524,288
│    │    └─FeedForward: 3-6             3,145,728
│    │    └─RMSNorm: 3-7                 512
│    │    └─RMSNorm: 3-8                 512
│    └─TransformerBlock: 2-3             --
│    │    └─Attention: 3-9               524,288
│    │    └─FeedForward: 3-10            3,145,728
│    │    └─RMSNorm: 3-11                512
│    │    └─RMSNorm: 3-12                512
│    └─TransformerBlock: 2-4             --
│    │    └─Attention: 3-13              524,288
│    │    └─Feed

In [42]:
# Get the number of bytes currently allocated on the GPU
allocated_bytes = torch.cuda.memory_allocated()

# Convert to megabytes
allocated_mb = allocated_bytes / 1_048_576

print(f'Memory allocated by PyTorch on the GPU: {allocated_mb:.2f} MB')

Memory allocated by PyTorch on the GPU: 8564.41 MB


In [43]:
num_params = sum(p.numel() for p in model.parameters())
num_buffers = sum(b.numel() for b in model.buffers())

# Convert to bytes
params_bytes = num_params * 4
buffers_bytes = num_buffers * 4

# Convert to megabytes
params_mb = params_bytes / 1_048_576
buffers_mb = buffers_bytes / 1_048_576

print(f'Number of parameters: {num_params}')
print(f'Number of buffers: {num_buffers}')
print(f'Memory for parameters: {params_mb:.2f} MB')
print(f'Memory for buffers: {buffers_mb:.2f} MB')

Number of parameters: 91505152
Number of buffers: 0
Memory for parameters: 349.06 MB
Memory for buffers: 0.00 MB


In [44]:
from tqdm.auto import tqdm

In [None]:
import os
import logging
import datetime
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm

# Create directories if they don't exist
os.makedirs('models', exist_ok=True)
os.makedirs('logs', exist_ok=True)

# Get current date and time
now = datetime.datetime.now()
date_time = now.strftime("%Y-%m-%d_%H-%M-%S")

# Find the next available run number
run_number = 0
while os.path.exists(f'logs/training-run-{run_number}-{date_time}.log'):
    run_number += 1

# Set up logging
logging.basicConfig(filename=f'logs/training-run-{run_number}-{date_time}.log', level=logging.INFO)

# Your model and data setup here...

optimizer = optim.Adam(model.parameters(), lr=1e-5)
loss_fn = nn.CrossEntropyLoss()

model.train()
for batch_idx, batch in enumerate(tqdm(data_loader)):
    optimizer.zero_grad()
    input_ids = batch.to(device)
    positions = torch.arange(0, max_seq_len).to(device)
    logits = model.forward(input_ids, positions)
    logprobs = nn.functional.log_softmax(logits, dim=-1)

    _, predicted = torch.max(logprobs, dim=2)

    loss = loss_fn(logprobs.view(-1, logprobs.size(-1)), input_ids.view(-1))
    logging.info(f'Batch {batch_idx}, Loss: {loss.item()}')

    loss.backward()
    optimizer.step()

    # Save model checkpoint every 100 batches
    if (batch_idx + 1) % 100 == 0:
        torch.save(model.state_dict(), f'models/model_checkpoint_run-{run_number}_{date_time}_batch-{batch_idx}.pth')


In [18]:
prompt = "This is a test prompt "
generated_text, logprobs = generate([prompt], model, tokenizer, max_tokens=35)
print(generated_text[0])

][ kunumber comparingiverategor counrotejsactic hasn mot worry ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  Pero харак hosp Creativeotimesን龙))
