In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt
import math
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import time
import random

torch.manual_seed(0)

In [None]:
# Hyperparameters
max_context_length = 64
device = "cpu"
epochs = 10
batch_size = 512 # Must be the same as preprocessing batch_size
validation_batch_size = 10
weight_decay = 1e-3
lr = 1e-3
num_layer = 3
head_dim = 64
projection_dim = 1024
expansion_factor = 8
checkpoint_filepath = ""
training_data_path = "llama2_wiki_64_ranked_train.npy"
eval_data_path = "llama2_wiki_64_ranked_eval.npy"

# Load llama2 token embedding

In [None]:
# Load from pt file (If you had already preprocessed)
word_embeddings_tensor = torch.load('word_embeddings_tensor_llama2.pt').to(device)
vocabulary_size, embedding_dim = word_embeddings_tensor.shape
word_embeddings_tensor.requires_grad = False

model_id = "NousResearch/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="right")
pad_token_id = 32000

# Load training data

In [None]:
class Dataset(Dataset):
    def __init__(self, filename):
        # Create memory-mapped array
        self.mmap_data = np.load(filename, mmap_mode='r')
        self.length = self.mmap_data.shape[0]
        
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        return torch.tensor(self.mmap_data[idx], dtype=torch.float32)

In [None]:
training_dataset = Dataset(training_data_path)
eval_dataset = Dataset(eval_data_path)

training_loader = DataLoader(training_dataset, batch_size=batch_size, shuffle=False)
validation_loader = DataLoader(eval_dataset, batch_size=validation_batch_size, shuffle=False)

# Instantiate LLM

In [None]:
class ROPEEmbedding(nn.Module):
    def __init__(self, max_context_length: int, head_dim: int = 64, theta: int = 10000):
        super().__init__()
        self.pos_emb = self.create_embedding(max_context_length, head_dim=head_dim, theta=theta)

    def create_embedding(self, max_context_length: int, head_dim: int = 64, theta: int = 10000) -> torch.Tensor:
        tensor = torch.arange(0, head_dim // 2)
        tensor = torch.repeat_interleave(tensor, 2)
        tensor = -tensor * 2 / head_dim
        tensor = torch.pow(theta, tensor)

        index = torch.arange(max_context_length).float() # This is the m in the formula
        tensor = torch.einsum("i, j -> ij", tensor, index)

        cos_matrix = tensor.cos()
        sin_matrix = tensor.sin()
        sin_matrix[0::2] *= -1 # Flipping sign for 0, 2, 4... row of sin matrix

        pos_emb = torch.cat((cos_matrix, sin_matrix), dim=0)
        pos_emb = pos_emb.transpose(1, 0)
        pos_emb = nn.Parameter(pos_emb, requires_grad=False)

        return pos_emb

    def flip_for_sin(self, tensor: torch.Tensor) -> torch.Tensor:
        original_shape = tensor.shape
        tensor = tensor.reshape(tensor.shape[0], tensor.shape[1], -1, 2) # Get to pairs
        tensor = tensor[..., [1, 0]] # Swap
        tensor = tensor.reshape(original_shape) # Get back to original shape
        return tensor

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        sequence_length = tensor.shape[2] # Assuming we are using batch_size, head, sequence_length and dim

        tensor = torch.cat((tensor, self.flip_for_sin(tensor)), dim=-1)
        tensor = tensor * self.pos_emb[:sequence_length, :]
        cos, sin = tensor.chunk(chunks=2, dim=-1)
        tensor = cos + sin
        return tensor

class MultiQueryAttention(nn.Module):
    def __init__(self,
                 hidden_dim: int,
                 head_dim: int,
                 q_head: int,
                 kv_head: int,
                 embedding: ROPEEmbedding,
                 lora_rank: int = 16):
        super().__init__()
        self.head_dim = head_dim
        self.q_head = q_head
        self.kv_head = kv_head
        self.embedding = embedding
        self.qkv = nn.Linear(hidden_dim, (q_head+kv_head*2)*head_dim)
        self.o = nn.Linear(q_head*head_dim, hidden_dim)
        self.scaler = 1/math.sqrt(head_dim)
        self.lora_qkv_a = nn.Linear(hidden_dim, lora_rank)
        self.lora_qkv_b = nn.Linear(lora_rank, (q_head+kv_head*2)*head_dim)
        self.lora_o_a = nn.Linear(q_head*head_dim, lora_rank)
        self.lora_o_b = nn.Linear(lora_rank, hidden_dim)

        if q_head != kv_head:
            # If we are using multi query attention
            assert q_head % kv_head == 0
            self.multi_query_attention = True
            self.q_kv_scale = q_head//kv_head
        else:
            self.multi_query_attention = False

    def forward(self, tensor: torch.Tensor, attention_mask: torch.Tensor = None, fine_tuning: bool = False) -> torch.Tensor:
        batch_size, seq_len, hid_dim = tensor.shape

        qkv_tensor = self.qkv(tensor)
        if fine_tuning:
            lora_tensor = self.lora_qkv_a(tensor)
            lora_tensor = self.lora_qkv_b(lora_tensor)
            qkv_tensor = lora_tensor + qkv_tensor
        query, key, value = qkv_tensor.split([self.head_dim*self.q_head, self.head_dim*self.kv_head, self.head_dim*self.kv_head], dim=-1)

        query = query.view(batch_size, seq_len, self.q_head, self.head_dim)
        key = key.view(batch_size, seq_len, self.kv_head, self.head_dim)
        value = value.view(batch_size, seq_len, self.kv_head, self.head_dim)

        if self.multi_query_attention:
            # If we are using multi query attention, duplicate key value heads
            key = torch.repeat_interleave(key, self.q_kv_scale, dim=-2)
            value = torch.repeat_interleave(value, self.q_kv_scale, dim=-2)

        # Switch to batch_size, head, seq_len, head_dim
        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)

        # Apply ROPE
        query = self.embedding(query)
        key = self.embedding(key)

        # Classic self attention
        attention_raw = torch.matmul(query, key.transpose(2, 3))
        attention_scaled = attention_raw * self.scaler
        if attention_mask != None:
            attention_scaled += attention_mask
        attention_score = torch.softmax(attention_scaled, dim=-1)
        value = torch.matmul(attention_score, value)

        # Reshape back to batch_size, seq_len, hid_dim
        value = value.transpose(1, 2).contiguous()
        value = value.view(batch_size, seq_len, hid_dim)

        # Output layer
        output = self.o(value)
        if fine_tuning:
            lora_tensor = self.lora_o_a(value)
            lora_tensor = self.lora_o_b(lora_tensor)
            output = lora_tensor + output

        return output

class FeedForward(nn.Module):
    def __init__(self,
                 hidden_size: int,
                 expansion_factor: int = 4,
                 dropout_ratio: float = 0.1,
                 lora_rank: int = 16):
        super().__init__()
        self.gate_and_up = nn.Linear(hidden_size, hidden_size * expansion_factor * 2)
        self.down = nn.Linear(hidden_size * expansion_factor, hidden_size)
        self.dropout = nn.Dropout(p=dropout_ratio)
        self.lora_gate_and_up_a = nn.Linear(hidden_size, lora_rank)
        self.lora_gate_and_up_b = nn.Linear(lora_rank, hidden_size * expansion_factor * 2)
        self.lora_down_a = nn.Linear(hidden_size * expansion_factor, lora_rank)
        self.lora_down_b = nn.Linear(lora_rank, hidden_size)

    def forward(self, tensor: torch.Tensor, fine_tuning: bool = False) -> torch.Tensor:
        gate_and_up = self.gate_and_up(tensor)
        if fine_tuning:
            lora_tensor = self.lora_gate_and_up_a(tensor)
            lora_tensor = self.lora_gate_and_up_b(lora_tensor)
            gate_and_up = gate_and_up + lora_tensor
        gate, up = gate_and_up.chunk(chunks=2, dim=-1)
        gate = F.gelu(gate, approximate="tanh")
        tensor = gate * up
        tensor = self.dropout(tensor)
        down_tensor = self.down(tensor)
        if fine_tuning:
            lora_tensor = self.lora_down_a(tensor)
            lora_tensor = self.lora_down_b(lora_tensor)
            down_tensor = down_tensor + lora_tensor
        return down_tensor

class MOE(nn.Module):
    def __init__(self, hidden_size: int, num_experts: int = 8, expansion_factor: int = 4, dropout_ratio: float = 0.1, lora_rank: int = 16):
        super().__init__()
        self.gate = nn.Linear(hidden_size, num_experts)
        self.num_experts = num_experts
        self.experts = nn.ModuleList([FeedForward(hidden_size, expansion_factor=expansion_factor, dropout_ratio=dropout_ratio, lora_rank=lora_rank) for _ in range(num_experts)])

    def forward(self, tensor: torch.Tensor, fine_tuning: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
        # Flatten for better manipulation, this is ok because tokens are independent at this stage
        batch_size, seq_len, hidden_size = tensor.shape
        flat_tensor = tensor.reshape(batch_size * seq_len, hidden_size)

        # Pass through the gating network and select experts
        tensor = self.gate(flat_tensor)
        tensor = F.softmax(tensor, dim=-1)

        # The output of this step is a tensor of shape [batch_size * seq_len, 2] with element i in the second dimension representing ith expert selected for this token
        value_tensor, index_tensor = tensor.topk(k=2, dim=-1)

        # Find the load balancing loss
        counts = torch.bincount(index_tensor[:, 0], minlength=self.num_experts)
        frequencies = counts.float() / (batch_size * seq_len) # This is the hard one-hot frequency
        probability = tensor.mean(0) # This is the soft probability
        load_balancing_loss = (probability * frequencies).mean() * float(self.num_experts ** 2)

        # Normalize top1 and top2 score
        top_expert_score = value_tensor[:, 0]
        second_expert_score = value_tensor[:, 1]
        total_score = top_expert_score + second_expert_score
        top_expert_score = top_expert_score / total_score
        second_expert_score = second_expert_score / total_score

        # Split into top 2 experts
        split_tensors = torch.split(index_tensor, 1, dim=-1)
        top_expert, second_expert = split_tensors[0], split_tensors[1]
        indices = torch.arange(batch_size * seq_len).unsqueeze(-1).to(device)
        top_expert = torch.cat((indices, top_expert), dim=-1)
        second_expert = torch.cat((indices, second_expert), dim=-1)

        # Sort based on expert selection
        top_expert = top_expert[top_expert[:,1].argsort()]
        second_expert = second_expert[second_expert[:,1].argsort()]

        # Count how many tokens goes to each expert
        top_expert_counts = torch.zeros(self.num_experts, dtype=int)
        for i in range(self.num_experts):
            top_expert_counts[i] = (top_expert[:,1] == i).sum()
        top_expert_counts = top_expert_counts.tolist()

        second_expert_counts = torch.zeros(self.num_experts, dtype=int)
        for i in range(self.num_experts):
            second_expert_counts[i] = (second_expert[:,1] == i).sum()
        second_expert_counts = second_expert_counts.tolist()

        # Split input tokens for each expert
        top_expert_tokens = flat_tensor[top_expert[:,0]]
        second_expert_tokens = flat_tensor[second_expert[:,0]]

        # Split into a list of tensors, element i tensor is for ith expert.
        top_expert_tokens = torch.split(top_expert_tokens, top_expert_counts, dim=0)
        second_expert_tokens = torch.split(second_expert_tokens, second_expert_counts, dim=0)

        # Input into each expert and obtain results in a list
        top_expert_outputs = [self.experts[i](top_expert_tokens[i], fine_tuning) if top_expert_counts[i] > 0 else torch.zeros(0, hidden_size, dtype=torch.float16).to(device) for i in range(self.num_experts)]
        second_expert_outputs = [self.experts[i](second_expert_tokens[i], fine_tuning) if second_expert_counts[i] > 0 else torch.zeros(0, hidden_size, dtype=torch.float16).to(device) for i in range(self.num_experts)]

        # Combine outputs
        top_expert_outputs = torch.cat(top_expert_outputs, dim=0)
        second_expert_outputs = torch.cat(second_expert_outputs, dim=0)

        # Re-index the output back to original token order
        flat_top_expert_tensor = torch.zeros_like(flat_tensor, dtype=torch.float32).to(device)
        flat_top_expert_tensor.index_copy_(0, top_expert[:, 0], top_expert_outputs)

        flat_second_expert_tensor = torch.zeros_like(flat_tensor, dtype=torch.float32).to(device)
        flat_second_expert_tensor.index_copy_(0, second_expert[:, 0], second_expert_outputs)

        # Find final output tensor based on weight between top and second expert
        final_tensor = top_expert_score.unsqueeze(-1) * flat_top_expert_tensor + second_expert_score.unsqueeze(-1) * flat_second_expert_tensor

        # Reshape to original [batch_size, seq_len, hidden_size]
        final_tensor = final_tensor.reshape(batch_size, seq_len, hidden_size)

        return final_tensor, load_balancing_loss

class LLMLayer(nn.Module):
    def __init__(self,
                 hidden_dim: int,
                 head_dim: int,
                 q_head: int,
                 kv_head: int,
                 embedding: ROPEEmbedding,
                 expansion_factor: int = 4,
                 dropout_ratio: float = 0.1,
                 use_moe: bool = False,
                 num_experts: int = 8,
                 lora_rank: int = 16):
        super().__init__()
        self.use_moe = use_moe

        self.norm1 = nn.LayerNorm(hidden_dim)
        self.mqa = MultiQueryAttention(hidden_dim, head_dim, q_head, kv_head, embedding, lora_rank=lora_rank)

        self.norm2 = nn.LayerNorm(hidden_dim)
        if self.use_moe:
            self.moe = MOE(hidden_dim, num_experts=num_experts, expansion_factor=expansion_factor, dropout_ratio=dropout_ratio, lora_rank=lora_rank)
        else:
            self.ffn = FeedForward(hidden_dim, expansion_factor=expansion_factor, dropout_ratio=dropout_ratio, lora_rank=lora_rank)

    def forward(self, tensor: torch.Tensor, attention_mask: torch.Tensor = None, fine_tuning: bool = False):
        skip_connection = tensor
        tensor = self.norm1(tensor)
        tensor = self.mqa(tensor, attention_mask=attention_mask, fine_tuning=fine_tuning)
        tensor += skip_connection

        skip_connection = tensor
        tensor = self.norm2(tensor)
        if self.use_moe:
            tensor, load_balancing_loss = self.moe(tensor, fine_tuning=fine_tuning)
        else:
            tensor = self.ffn(tensor, fine_tuning=fine_tuning)
            load_balancing_loss = torch.tensor(0.0, dtype=tensor.dtype, device=tensor.device)# If not using MoE, load-balancing loss is zero

        tensor += skip_connection

        return tensor, load_balancing_loss

class LLM(nn.Module):
    def __init__(self,
                 num_layer: int,
                 vocabulary_size: int,
                 max_context_length: int,
                 hidden_dim: int,
                 expansion_factor: int = 4,
                 head_dim: int = 64,
                 q_head: int = None,
                 kv_head: int = None,
                 dropout_ratio: float = 0.1,
                 theta: int = 10000,
                 projection_dim: int = None,
                 use_moe: bool = False,
                 num_experts=8,
                 load_balancing_loss_weight: float = 1e-2,
                 fine_tuning: bool = False,
                 lora_rank: int = 16):
        super().__init__()
        self.embedding = ROPEEmbedding(max_context_length, head_dim=head_dim, theta=theta)
        self.num_layer = num_layer
        self.load_balancing_loss_weight = load_balancing_loss_weight
        self.fine_tuning = fine_tuning

        # Because of computational power limit, we might want to project input token embedding down.
        self.projection = projection_dim != None
        if self.projection:
            self.projection_matrix = nn.Linear(hidden_dim, projection_dim)
            hidden_dim = projection_dim

        if q_head == None:
            q_head = (hidden_dim // head_dim)

        if kv_head == None:
            kv_head = (hidden_dim // head_dim)

        if hidden_dim % (head_dim * q_head) != 0 or hidden_dim % (head_dim * kv_head):
            raise ValueError("Error: hidden_dim or projection_dim (if specified) must be divisible by the product of the number of q or kv heads and the head dimension.")

        self.transformer = nn.ModuleList()
        for _ in range(self.num_layer):
            self.transformer.append(LLMLayer(hidden_dim, head_dim, q_head, kv_head, self.embedding, expansion_factor=expansion_factor, dropout_ratio=dropout_ratio, use_moe=use_moe, num_experts=num_experts, lora_rank=lora_rank))
        self.output_norm = nn.LayerNorm(hidden_dim)

        self.classifier = nn.Linear(hidden_dim, vocabulary_size)

    def begin_fine_tunning(self) -> None:
        self.fine_tuning = True
        for name, param in self.named_parameters():
            if "lora" not in name:
                param.requires_grad = False
            else:
                param.requires_grad = True

    def exit_fine_tunning(self) -> None:
        self.fine_tuning = False
        for name, param in self.named_parameters():
            if "pos_emb" in name:
                param.requires_grad = False
            else:
                param.requires_grad = True

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        # If projecting input embeddings
        if self.projection:
            tensor = self.projection_matrix(tensor)

        seq_len = tensor.shape[1]
        device_id = tensor.device

        # Create causal mask
        causal_mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1).to(device)
        causal_mask.requires_grad = False

        # Track load-balancing across layers (only if MoE is used)
        load_balancing_sum = torch.tensor(0.0, device=device_id)

        for layer in self.transformer:
            tensor, load_balancing_loss = layer(tensor, attention_mask=causal_mask, fine_tuning=self.fine_tuning)
            load_balancing_sum += load_balancing_loss

        load_balancing_loss = (load_balancing_sum / self.num_layer) * self.load_balancing_loss_weight

        # Classification
        tensor = self.output_norm(tensor)
        tensor = self.classifier(tensor)

        return tensor, load_balancing_loss

In [None]:
def save_checkpoint(model, optimizer, epoch, loss):
    if isinstance(model, nn.DataParallel):
        model_to_save = model.module
    else:
        model_to_save = model
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model_to_save.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }

    timestamp = time.strftime('%Y%m%d_%H%M%S')
    filename = f'checkpoint_{epoch}_{loss}.pth.tar'
    torch.save(checkpoint, filename)
    print(f'Checkpoint saved at epoch {epoch} as {filename}')

In [None]:
def load_checkpoint(model, optimizer, filename) -> int:
    checkpoint = torch.load(filename)

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print(f'Checkpoint loaded from epoch {epoch} with loss {loss}')
    return epoch

# Prepare for training

In [None]:
llm = LLM(num_layer, vocabulary_size, max_context_length, embedding_dim, projection_dim=projection_dim, expansion_factor=expansion_factor, use_moe=True).to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(llm.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs*len(training_loader), eta_min=1e-5)

if checkpoint_filepath != None and checkpoint_filepath != "":
    current_epoch = load_checkpoint(llm, optimizer, checkpoint_filepath) + 1
else:
    current_epoch = 0

print("This model has", sum(p.numel() for p in llm.parameters()), "parameters.")

In [None]:
loss_train = []
loss_valid = []

In [None]:
# Chunk data
# For example, in a batch, the longest sentence length(including padding token) is 34. Then the training data become shape [batch_size, 34]
def trim_padding(input_tensor):
    # Create a mask where tokens are not equal to the pad_token
    mask = input_tensor != pad_token_id  # Shape: [batch_size, max_seq_length]

    # Calculate the lengths of each sentence (number of non-padding tokens)
    lengths = mask.sum(dim=1)  # Shape: [batch_size]

    # Find the maximum and minimum sentence lengths
    max_length = lengths.max().item()
    min_length = lengths.min().item()

    # Check if the difference between the longest and shortest sentence is 2 or more
    if max_length - min_length >= 2:
        return None

    # Trim the input tensor to the maximum sentence length
    trimmed_tensor = input_tensor[:, :max_length]

    return trimmed_tensor

# Training code

In [None]:
for epoch in range(current_epoch, epochs):
    loss_train_epoch = []
    loss_val_epoch = []

    llm.train()
    for data in tqdm(training_loader):
        # Teacher forcing
        data = trim_padding(data)
        if data == None:
            continue
        
        input_data = data[:, :-1].long().to(device)
        target_data = data[:, 1:].long().to(device)

        # Convert to embedding.
        input_embeddings = word_embeddings_tensor[input_data].float()

        # Forward pass
        prediction, load_balancing_loss = llm(input_embeddings)

        # Change shape for loss calculation
        prediction = prediction.view(-1, vocabulary_size)
        target_data = target_data.reshape(-1)

        mask = target_data != pad_token_id
        prediction = prediction[mask]
        target_data = target_data[mask]

        loss = criterion(prediction, target_data) + load_balancing_loss # Calculate loss
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(llm.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()

        # Record loss
        loss_train_epoch.append(loss.item())
        scheduler.step()

    loss_train.append(np.mean(loss_train_epoch))

    llm.eval()
    with torch.no_grad():
        for data in tqdm(validation_loader):
            # Teacher forcing
            data = trim_padding(data)
            if data == None:
                continue
            input_data = data[:, :-1].long().to(device)
            target_data = data[:, 1:].long().to(device)

            # Convert to embedding.
            input_embeddings = word_embeddings_tensor[input_data].float()

            # Forward pass
            prediction, load_balancing_loss = llm(input_embeddings)

            # Change shape for loss calculation
            prediction = prediction.view(-1, vocabulary_size)
            target_data = target_data.reshape(-1)

            mask = target_data != pad_token_id
            prediction = prediction[mask]
            target_data = target_data[mask]

            loss = criterion(prediction, target_data) + load_balancing_loss # Calculate loss

            # Record loss
            loss_val_epoch.append(loss.item())

        loss_valid.append(np.mean(loss_val_epoch))

    # Save checkpoint
    save_checkpoint(llm, optimizer, epoch, loss_valid[-1])

    plt.plot(loss_train, label="Training loss")
    plt.plot(loss_valid, label="Validation loss")
    print("Training loss: ", loss_train[-1])
    print("Validation loss: ", loss_valid[-1])
    plt.legend()
    plt.show()

# Inference

In [None]:
temperature = 0.5

In [None]:
llm = torch.load('200M Foundation Model.pth')

In [None]:
sentence = "arXiv is an open-access"
tokenized_sentence = tokenizer(sentence)["input_ids"]
if tokenized_sentence[-1] == 2:
    tokenized_sentence = tokenized_sentence[:-1]
llm.eval()

with torch.no_grad():
    while(tokenized_sentence[-1] != tokenizer.eos_token_id and len(tokenized_sentence) < max_context_length): # Keep iterating until reaches end of sentence or max token limit
        # Preparing input
        tokenized_sentence_tensor = torch.tensor(tokenized_sentence).to(device)
        sentence_embedding = word_embeddings_tensor[tokenized_sentence_tensor].float()
        sentence_embedding = sentence_embedding.unsqueeze(0).to(device)

        # Make prediction
        prediction, _ = llm(sentence_embedding)
        prediction = prediction[0][-1] # We only care about last token
        prediction = prediction / temperature
        prediction = F.softmax(prediction, dim=-1)
        output_token = torch.multinomial(prediction, 1)

        # Append to conversation history
        tokenized_sentence.append(output_token.item())
        
tokens = tokenizer.decode(tokenized_sentence, skip_special_tokens=True)
print(tokens)

# Save the model

In [None]:
torch.save(llm, 'llm3point156.pth')