In [None]:
!pip install transformers
!pip install accelerate
!pip install bitsandbytes
!pip install peft
!pip install datasets
!pip install tqdm

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer
from torch.cuda import amp
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [None]:
device_id = 1
epochs = 30
lr = 3e-4
weight_decay = 1e-3
lr_min = 1e-5
batch_size = 100

# Data Preprocessing

In [None]:
# Load from pt file (If you had already preprocessed)
word_embeddings_tensor = torch.load('word_embeddings_tensor_llama2.pt').cuda(device_id)
vocabulary_size, embedding_dim = word_embeddings_tensor.shape
zero_tensor = torch.zeros((1, 4096)).cuda(device_id)
word_embeddings_tensor = torch.cat((word_embeddings_tensor, zero_tensor), dim=0)
word_embeddings_tensor.requires_grad = False

model_id = "NousResearch/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="right")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.pad_token_id = 32000
pad_token_id = 32000
max_token = 64

In [None]:
def tokenize_function(row):
    return tokenizer(row["dialog"], max_length=max_token, truncation=False, padding="max_length")

def is_shorter_than_max_token(row):
    """
    Return if a given row has more than max_token number of tokens
    """
    return len(row['input_ids']) <= max_token

def split_conversation(conversation): 
    """
    Split conversation into turns
    """
    return [conversation[:i+2] for i in range(0, len(conversation), 2) if i+2 <= len(conversation)]

def format_conversation(conversation: list[str]) -> str:
    formatted_conversation = ""
    
    # Check if the conversation has more than two turns
    if len(conversation) > 2:
        # Process all but the last two turns
        for i in range(len(conversation) - 2):
            if i % 2 == 0:
                formatted_conversation += "<Past User>" + conversation[i] + "\n"
            else:
                formatted_conversation += "<Past Assistant>" + conversation[i] + "\n"
    
    # Process the last two turns
    if len(conversation) >= 2:
        formatted_conversation += "<User>" + conversation[-2] + "\n"
        formatted_conversation += "<Assistant>" + conversation[-1]
    
    return formatted_conversation

def convert_to_conversation(row):
    conversation_list = row["dialog"]
    
    conversation = format_conversation(conversation_list)
    conversation += "</s>"
    return {"dialog": conversation.strip()}

In [None]:
# 1. Load and tokenize dataset
dataset = load_dataset("daily_dialog", trust_remote_code=True)

# 2. Split each conversation into multiple turns
split_dataset = dataset.map(lambda x: {'dialog': split_conversation(x['dialog'])})

# 3. Flatten each split so each row corresponds to a single turn
flatten_dataset_train = [item for row in split_dataset["train"]["dialog"] for item in row]
flatten_dataset_valid = [item for row in split_dataset["validation"]["dialog"] for item in row]
flatten_dataset_test = [item for row in split_dataset["test"]["dialog"] for item in row]

# 4. Merge test data into the training data
flatten_dataset_train.extend(flatten_dataset_test)  # merges test data into the training data

# 5. Convert them back to Dataset objects
flatten_dataset_train = Dataset.from_dict({'dialog': flatten_dataset_train})
flatten_dataset_valid = Dataset.from_dict({'dialog': flatten_dataset_valid})
# We skip flatten_dataset_test because we've merged it.

# 6. Create a new DatasetDict without a test split
dataset = DatasetDict({
    'train': flatten_dataset_train,
    'validation': flatten_dataset_valid
})

# 7. Convert entries to a conversational manner
dataset = dataset.map(convert_to_conversation)

# 8. Tokenize the dataset
dataset = dataset.map(tokenize_function)

# 9. (Optional) Filter out sequences exceeding your token limit
dataset = dataset.filter(is_shorter_than_max_token)

In [None]:
train_dataset = sorted(dataset["train"]["input_ids"], key=lambda lst: sum(1 for x in lst if x != 32000))
validation_dataset = sorted(dataset["validation"]["input_ids"], key=lambda lst: sum(1 for x in lst if x != 32000))

# Loading the pre-trained model

In [None]:
class ROPEEmbedding(nn.Module):
    def __init__(self, max_context_length: int, head_dim: int = 64, theta: int = 10000):
        super().__init__()
        self.pos_emb = self.create_embedding(max_context_length, head_dim=head_dim, theta=theta)

    def create_embedding(self, max_context_length: int, head_dim: int = 64, theta: int = 10000) -> torch.Tensor:
        tensor = torch.arange(0, head_dim // 2)
        tensor = torch.repeat_interleave(tensor, 2)
        tensor = -tensor * 2 / head_dim
        tensor = torch.pow(theta, tensor)

        index = torch.arange(max_context_length).float() # This is the m in the formula
        tensor = torch.einsum("i, j -> ij", tensor, index)

        cos_matrix = tensor.cos()
        sin_matrix = tensor.sin()
        sin_matrix[0::2] *= -1 # Flipping sign for 0, 2, 4... row of sin matrix

        pos_emb = torch.cat((cos_matrix, sin_matrix), dim=0)
        pos_emb = pos_emb.transpose(1, 0)
        pos_emb = nn.Parameter(pos_emb, requires_grad=False)

        return pos_emb

    def flip_for_sin(self, tensor: torch.Tensor) -> torch.Tensor:
        original_shape = tensor.shape
        tensor = tensor.reshape(tensor.shape[0], tensor.shape[1], -1, 2) # Get to pairs
        tensor = tensor[..., [1, 0]] # Swap
        tensor = tensor.reshape(original_shape) # Get back to original shape
        return tensor

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        sequence_length = tensor.shape[2] # Assuming we are using batch_size, head, sequence_length and dim

        tensor = torch.cat((tensor, self.flip_for_sin(tensor)), dim=-1)
        tensor = tensor * self.pos_emb[:sequence_length, :]
        cos, sin = tensor.chunk(chunks=2, dim=-1)
        tensor = cos + sin
        return tensor

class MultiQueryAttention(nn.Module):
    def __init__(self,
                 hidden_dim: int,
                 head_dim: int,
                 q_head: int,
                 kv_head: int,
                 embedding: ROPEEmbedding,
                 lora_rank: int = 16):
        super().__init__()
        self.head_dim = head_dim
        self.q_head = q_head
        self.kv_head = kv_head
        self.embedding = embedding
        self.qkv = nn.Linear(hidden_dim, (q_head+kv_head*2)*head_dim)
        self.o = nn.Linear(q_head*head_dim, hidden_dim)
        self.scaler = 1/math.sqrt(head_dim)
        self.lora_qkv_a = nn.Linear(hidden_dim, lora_rank)
        self.lora_qkv_b = nn.Linear(lora_rank, (q_head+kv_head*2)*head_dim)
        self.lora_o_a = nn.Linear(q_head*head_dim, lora_rank)
        self.lora_o_b = nn.Linear(lora_rank, hidden_dim)

        if q_head != kv_head:
            # If we are using multi query attention
            assert q_head % kv_head == 0
            self.multi_query_attention = True
            self.q_kv_scale = q_head//kv_head
        else:
            self.multi_query_attention = False

    def forward(self, tensor: torch.Tensor, attention_mask: torch.Tensor = None, fine_tuning: bool = False) -> torch.Tensor:
        batch_size, seq_len, hid_dim = tensor.shape

        qkv_tensor = self.qkv(tensor)
        if fine_tuning:
            lora_tensor = self.lora_qkv_a(tensor)
            lora_tensor = self.lora_qkv_b(lora_tensor)
            qkv_tensor = lora_tensor + qkv_tensor
        query, key, value = qkv_tensor.split([self.head_dim*self.q_head, self.head_dim*self.kv_head, self.head_dim*self.kv_head], dim=-1)

        query = query.view(batch_size, seq_len, self.q_head, self.head_dim)
        key = key.view(batch_size, seq_len, self.kv_head, self.head_dim)
        value = value.view(batch_size, seq_len, self.kv_head, self.head_dim)

        if self.multi_query_attention:
            # If we are using multi query attention, duplicate key value heads
            key = torch.repeat_interleave(key, self.q_kv_scale, dim=-2)
            value = torch.repeat_interleave(value, self.q_kv_scale, dim=-2)

        # Switch to batch_size, head, seq_len, head_dim
        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)

        # Apply ROPE
        query = self.embedding(query)
        key = self.embedding(key)

        # Classic self attention
        attention_raw = torch.matmul(query, key.transpose(2, 3))
        attention_scaled = attention_raw * self.scaler
        if attention_mask != None:
            attention_scaled += attention_mask
        attention_score = torch.softmax(attention_scaled, dim=-1)
        value = torch.matmul(attention_score, value)

        # Reshape back to batch_size, seq_len, hid_dim
        value = value.transpose(1, 2).contiguous()
        value = value.view(batch_size, seq_len, hid_dim)

        # Output layer
        output = self.o(value)
        if fine_tuning:
            lora_tensor = self.lora_o_a(output)
            lora_tensor = self.lora_o_b(lora_tensor)
            output = lora_tensor + output

        return output

class FeedForward(nn.Module):
    def __init__(self,
                 hidden_size: int,
                 expansion_factor: int = 4,
                 dropout_ratio: float = 0.1):
        super().__init__()
        self.gate_and_up = nn.Linear(hidden_size, hidden_size * expansion_factor * 2)
        self.down = nn.Linear(hidden_size * expansion_factor, hidden_size)
        self.dropout = nn.Dropout(p=dropout_ratio)

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        tensor = self.gate_and_up(tensor)
        gate, up = tensor.chunk(chunks=2, dim=-1)
        gate = F.gelu(gate, approximate="tanh")
        tensor = gate * up
        tensor = self.dropout(tensor)
        tensor = self.down(tensor)
        return tensor

class MOE(nn.Module):
    def __init__(self, hidden_size: int, num_experts: int = 8, expansion_factor: int = 4, dropout_ratio: float = 0.1):
        super().__init__()
        self.gate = nn.Linear(hidden_size, num_experts)
        self.num_experts = num_experts
        self.experts = nn.ModuleList([FeedForward(hidden_size, expansion_factor=expansion_factor, dropout_ratio=dropout_ratio) for _ in range(num_experts)])

    def forward(self, tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        # Flatten for better manipulation, this is ok because tokens are independent at this stage
        batch_size, seq_len, hidden_size = tensor.shape
        flat_tensor = tensor.reshape(batch_size * seq_len, hidden_size)

        # Pass through the gating network and select experts
        tensor = self.gate(flat_tensor)
        tensor = F.softmax(tensor, dim=-1)

        # The output of this step is a tensor of shape [batch_size * seq_len, 2] with element i in the second dimension representing ith expert selected for this token
        value_tensor, index_tensor = tensor.topk(k=2, dim=-1)

        # Find the load balancing loss
        counts = torch.bincount(index_tensor[:, 0], minlength=self.num_experts)
        frequencies = counts.float() / (batch_size * seq_len) # This is the hard one-hot frequency
        probability = tensor.mean(0) # This is the soft probability
        load_balancing_loss = (probability * frequencies).mean() * float(self.num_experts ** 2)

        # Normalize top1 and top2 score
        top_expert_score = value_tensor[:, 0]
        second_expert_score = value_tensor[:, 1]
        total_score = top_expert_score + second_expert_score
        top_expert_score = top_expert_score / total_score
        second_expert_score = second_expert_score / total_score

        # Split into top 2 experts
        split_tensors = torch.split(index_tensor, 1, dim=-1)
        top_expert, second_expert = split_tensors[0], split_tensors[1]
        indices = torch.arange(batch_size * seq_len).unsqueeze(-1).cuda(device_id)
        top_expert = torch.cat((indices, top_expert), dim=-1)
        second_expert = torch.cat((indices, second_expert), dim=-1)

        # Sort based on expert selection
        top_expert = top_expert[top_expert[:,1].argsort()]
        second_expert = second_expert[second_expert[:,1].argsort()]

        # Count how many tokens goes to each expert
        top_expert_counts = torch.zeros(self.num_experts, dtype=int)
        for i in range(self.num_experts):
            top_expert_counts[i] = (top_expert[:,1] == i).sum()
        top_expert_counts = top_expert_counts.tolist()

        second_expert_counts = torch.zeros(self.num_experts, dtype=int)
        for i in range(self.num_experts):
            second_expert_counts[i] = (second_expert[:,1] == i).sum()
        second_expert_counts = second_expert_counts.tolist()

        # Split input tokens for each expert
        top_expert_tokens = flat_tensor[top_expert[:,0]]
        second_expert_tokens = flat_tensor[second_expert[:,0]]

        # Split into a list of tensors, element i tensor is for ith expert.
        top_expert_tokens = torch.split(top_expert_tokens, top_expert_counts, dim=0)
        second_expert_tokens = torch.split(second_expert_tokens, second_expert_counts, dim=0)

        # Input into each expert and obtain results in a list
        top_expert_outputs = [self.experts[i](top_expert_tokens[i]) if top_expert_counts[i] > 0 else torch.zeros(0, hidden_size, dtype=torch.float16).cuda(device_id) for i in range(self.num_experts)]
        second_expert_outputs = [self.experts[i](second_expert_tokens[i]) if second_expert_counts[i] > 0 else torch.zeros(0, hidden_size, dtype=torch.float16).cuda(device_id) for i in range(self.num_experts)]

        # Combine outputs
        top_expert_outputs = torch.cat(top_expert_outputs, dim=0)
        second_expert_outputs = torch.cat(second_expert_outputs, dim=0)

        # Re-index the output back to original token order
        flat_top_expert_tensor = torch.zeros_like(flat_tensor, dtype=torch.float16).cuda(device_id)
        flat_top_expert_tensor.index_copy_(0, top_expert[:, 0], top_expert_outputs)

        flat_second_expert_tensor = torch.zeros_like(flat_tensor, dtype=torch.float16).cuda(device_id)
        flat_second_expert_tensor.index_copy_(0, second_expert[:, 0], second_expert_outputs)

        # Find final output tensor based on weight between top and second expert
        final_tensor = top_expert_score.unsqueeze(-1) * flat_top_expert_tensor + second_expert_score.unsqueeze(-1) * flat_second_expert_tensor

        # Reshape to original [batch_size, seq_len, hidden_size]
        final_tensor = final_tensor.reshape(batch_size, seq_len, hidden_size)

        return final_tensor, load_balancing_loss

class LLMLayer(nn.Module):
    def __init__(self,
                 hidden_dim: int,
                 head_dim: int,
                 q_head: int,
                 kv_head: int,
                 embedding: ROPEEmbedding,
                 expansion_factor: int = 4,
                 dropout_ratio: float = 0.1,
                 use_moe: bool = False,
                 num_experts: int = 8,
                 lora_rank: int = 16):
        super().__init__()
        self.use_moe = use_moe

        self.norm1 = nn.LayerNorm(hidden_dim)
        self.mqa = MultiQueryAttention(hidden_dim, head_dim, q_head, kv_head, embedding, lora_rank=lora_rank)

        self.norm2 = nn.LayerNorm(hidden_dim)
        if self.use_moe:
            self.moe = MOE(hidden_dim, num_experts=num_experts, expansion_factor=expansion_factor, dropout_ratio=dropout_ratio)
        else:
            self.ffn = FeedForward(hidden_dim, expansion_factor=expansion_factor, dropout_ratio=dropout_ratio)

    def forward(self, tensor: torch.Tensor, attention_mask: torch.Tensor = None, fine_tuning: bool = False):
        skip_connection = tensor
        tensor = self.norm1(tensor)
        tensor = self.mqa(tensor, attention_mask=attention_mask, fine_tuning=fine_tuning)
        tensor += skip_connection

        skip_connection = tensor
        tensor = self.norm2(tensor)
        if self.use_moe:
            tensor, load_balancing_loss = self.moe(tensor)
        else:
            tensor = self.ffn(tensor)
            load_balancing_loss = torch.tensor(0.0, dtype=tensor.dtype, device=tensor.device)# If not using MoE, load-balancing loss is zero

        tensor += skip_connection

        return tensor, load_balancing_loss

class LLM(nn.Module):
    def __init__(self,
                 num_layer: int,
                 vocabulary_size: int,
                 max_context_length: int,
                 hidden_dim: int,
                 expansion_factor: int = 4,
                 head_dim: int = 64,
                 q_head: int = None,
                 kv_head: int = None,
                 dropout_ratio: float = 0.1,
                 theta: int = 10000,
                 projection_dim: int = None,
                 use_moe: bool = False,
                 num_experts=8,
                 load_balancing_loss_weight: float = 1e-2,
                 fine_tuning: bool = False,
                 lora_rank: int = 16):
        super().__init__()
        self.embedding = ROPEEmbedding(max_context_length, head_dim=head_dim, theta=theta)
        self.num_layer = num_layer
        self.load_balancing_loss_weight = load_balancing_loss_weight
        self.fine_tuning = fine_tuning

        # Because of computational power limit, we might want to project input token embedding down.
        self.projection = projection_dim != None
        if self.projection:
            self.projection_matrix = nn.Linear(hidden_dim, projection_dim)
            hidden_dim = projection_dim

        if q_head == None:
            q_head = (hidden_dim // head_dim)

        if kv_head == None:
            kv_head = (hidden_dim // head_dim)

        if hidden_dim % (head_dim * q_head) != 0 or hidden_dim % (head_dim * kv_head):
            raise ValueError("Error: hidden_dim or projection_dim (if specified) must be divisible by the product of the number of q or kv heads and the head dimension.")

        self.transformer = nn.ModuleList()
        for _ in range(self.num_layer):
            self.transformer.append(LLMLayer(hidden_dim, head_dim, q_head, kv_head, self.embedding, expansion_factor=expansion_factor, dropout_ratio=dropout_ratio, use_moe=use_moe, num_experts=num_experts, lora_rank=lora_rank))
        self.output_norm = nn.LayerNorm(hidden_dim)

        self.classifier = nn.Linear(hidden_dim, vocabulary_size)

    def begin_fine_tunning(self) -> None:
        self.fine_tuning = True
        for name, param in self.named_parameters():
            if "lora" not in name:
                param.requires_grad = False
            else:
                param.requires_grad = True

    def exit_fine_tunning(self) -> None:
        self.fine_tuning = False
        for name, param in self.named_parameters():
            if "pos_emb" in name:
                param.requires_grad = False
            else:
                param.requires_grad = True

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        # If projecting input embeddings
        if self.projection:
            tensor = self.projection_matrix(tensor)

        seq_len = tensor.shape[1]
        device_id = tensor.device

        # Create causal mask
        causal_mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1).cuda(device_id)
        causal_mask.requires_grad = False

        # Track load-balancing across layers (only if MoE is used)
        load_balancing_sum = torch.tensor(0.0, device=device_id)

        for layer in self.transformer:
            tensor, load_balancing_loss = layer(tensor, attention_mask=causal_mask, fine_tuning=self.fine_tuning)
            load_balancing_sum += load_balancing_loss

        load_balancing_loss = (load_balancing_sum / self.num_layer) * self.load_balancing_loss_weight

        # Classification
        tensor = self.output_norm(tensor)
        tensor = self.classifier(tensor)

        return tensor, load_balancing_loss

In [None]:
llm = torch.load('200M Foundation Model.pth').cuda(device_id)

# Training

In [None]:
llm.begin_fine_tunning()

In [None]:
loss_train = []
loss_valid = []

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(llm.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs*((len(dataset["train"]) - 1) // batch_size + 1), eta_min=lr_min)
scaler = amp.GradScaler()

In [None]:
# Chunk data
# For example, in a batch, the longest sentence length(including padding token) is 34. Then the training data become shape [batch_size, 34]
def trim_padding(input_tensor):
    # Create a mask where tokens are not equal to the pad_token
    mask = input_tensor != pad_token_id  # Shape: [batch_size, max_seq_length]

    # Calculate the lengths of each sentence (number of non-padding tokens)
    lengths = mask.sum(dim=1)  # Shape: [batch_size]

    # Find the maximum and minimum sentence lengths
    max_length = lengths.max().item()
    min_length = lengths.min().item()

    # Trim the input tensor to the maximum sentence length
    trimmed_tensor = input_tensor[:, :max_length]

    return trimmed_tensor

In [None]:
for _ in range(epochs):
    loss_train_epoch = []
    loss_val_epoch = []
    
    llm.train()
    for i in tqdm(range(0, len(train_dataset), batch_size)):
        # Access data
        end = min(i+batch_size, len(train_dataset))
        data = torch.tensor(train_dataset[i:end])
        data = trim_padding(data)

        # Teacher forcing
        input_data = data[:, :-1].long().cuda(device_id)
        target_data = data[:, 1:].long().cuda(device_id)

        # Convert to embedding.
        input_embeddings = word_embeddings_tensor[input_data]

        # Forward pass
        with amp.autocast():
            prediction, load_balancing_loss = llm(input_embeddings)

            # Change shape for loss calculation
            prediction = prediction.view(-1, vocabulary_size)
            target_data = target_data.reshape(-1)

            mask = target_data != pad_token_id
            prediction = prediction[mask]
            target_data = target_data[mask]

            loss = criterion(prediction, target_data) + load_balancing_loss # Calculate loss
        # Backward pass
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(llm.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        # Clear out grad
        optimizer.zero_grad()

        # Record loss
        loss_train_epoch.append(loss.item())
        scheduler.step()

    loss_train.append(np.mean(loss_train_epoch))

    llm.eval()
    with torch.no_grad():
        for i in tqdm(range(0, len(validation_dataset), batch_size)):
            # Access data
            end = min(i+batch_size, len(validation_dataset))
            data = torch.tensor(validation_dataset[i:end])
            data = trim_padding(data)

            # Teacher forcing
            input_data = data[:, :-1].long().cuda(device_id)
            target_data = data[:, 1:].long().cuda(device_id)

            # Convert to embedding.
            input_embeddings = word_embeddings_tensor[input_data]

            # Forward pass
            with amp.autocast():
                prediction, load_balancing_loss = llm(input_embeddings)

                # Change shape for loss calculation
                prediction = prediction.view(-1, vocabulary_size)
                target_data = target_data.reshape(-1)

                mask = target_data != pad_token_id
                prediction = prediction[mask]
                target_data = target_data[mask]

                loss = criterion(prediction, target_data) + load_balancing_loss # Calculate loss

            # Record loss
            loss_val_epoch.append(loss.item())

        loss_valid.append(np.mean(loss_val_epoch))

    plt.plot(loss_train, label="Training loss")
    plt.plot(loss_valid, label="Validation loss")
    print("Training loss: ", loss_train[-1])
    print("Validation loss: ", loss_valid[-1])
    plt.legend()
    plt.show()

# Inference

In [None]:
temperature = 0.5

In [None]:
conversation_history = []

def talk_with_llm(chat: str) -> str:
    # Encode and move tensor into cuda if applicable.
    conversation_history.append(chat)
    conversation_history.append("")
    conversation = format_conversation(conversation_history)

    tokenized_sentence = tokenizer(conversation)["input_ids"]
    if tokenized_sentence[-1] == 2:
        tokenized_sentence = tokenized_sentence[:-1]
    llm.eval()

    with torch.no_grad():
        while(tokenized_sentence[-1] != tokenizer.eos_token_id and len(tokenized_sentence) < max_token): # Keep iterating until reaches end of sentence or max token limit
            # Preparing input
            tokenized_sentence_tensor = torch.tensor(tokenized_sentence).cuda(device_id)
            sentence_embedding = word_embeddings_tensor[tokenized_sentence_tensor]
            sentence_embedding = sentence_embedding.unsqueeze(0).cuda(device_id)
    
            # Make prediction
            with amp.autocast():
                prediction, _ = llm(sentence_embedding)
            prediction = prediction[0][-1] # We only care about last token
            prediction = prediction / temperature
            prediction = F.softmax(prediction, dim=-1)
            output_token = torch.multinomial(prediction, 1)
    
            # Append to conversation history
            tokenized_sentence.append(output_token.item())

    response = tokenizer.decode(tokenized_sentence, skip_special_tokens=True)
    response = response[len(conversation):]
    
    conversation_history.pop()
    conversation_history.append(response)
    return response

In [None]:
talk_with_llm("What's up homie, I'm Danjie")

# Save the model

In [None]:
torch.save(llm, 'chillGPT_200M.pth')