In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
import torch
import os

if torch.cuda.is_available():
    print("CUDA is available")
    print(f"Number of CUDA devices: {torch.cuda.device_count()}")
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
    print(f"Current CUDA device: {torch.cuda.current_device()}")
else:
    print("CUDA is not available")



In [None]:
# Load model
model_id = "shenzhi-wang/Llama3-8B-Chinese-Chat"
local_model_path = "./local_model"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")

In [None]:
# peft
from peft import LoraConfig

lora_config = LoraConfig(
    target_modules=["q_proj", "k_proj"],
    init_lora_weights=False
)

model.add_adapter(lora_config, adapter_name="adapter_1")
model.set_adapter("adapter_1")

In [None]:
print(model.device)

In [None]:
messages = [
    {"role": "user", "content": "Give me a random number."},
]
input_ids = tokenizer.apply_chat_template(
    messages, add_generation_prompt=True, return_tensors="pt"
).to(model.device)

outputs = model.generate(
    input_ids,
    max_new_tokens=1000,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
print(tokenizer.decode(response, skip_special_tokens=True))

In [None]:
system_message = """You are Llama, an AI assistant created by Philipp to be helpful and honest. Your knowledge spans a wide range of topics, allowing you to engage in substantive conversations and provide analysis on complex subjects."""

def create_conversation(sample):
    if sample["messages"][0]["role"] == "system":
        return sample
    else:
        sample["messages"] = [{"role": "system", "content": system_message}] + sample["messages"]
        return sample

In [None]:
from datasets import load_dataset

dataset = load_dataset("HuggingFaceH4/no_robots")


columns_to_remove = list(dataset["train"].features)
columns_to_remove.remove("messages")
dataset = dataset.map(create_conversation, remove_columns=columns_to_remove, batched=False)


dataset["train"] = dataset["train"].filter(lambda x: len(x["messages"][1:]) % 2 == 0)
dataset["test"] = dataset["test"].filter(lambda x: len(x["messages"][1:]) % 2 == 0)

In [None]:
print(dataset)

print(dataset['train'])
print(dataset['train'][0])

In [None]:
# # Preprocess the dataset
# def preprocess_function(examples):
#     inputs = []
#     for messages in examples["messages"]:
#         input_text = ""
#         for message in messages:
#             if message["role"] == "user":
#                 input_text += "User: " + message["content"] + "\n"
#             elif message["role"] == "assistant":
#                 input_text += "Assistant: " + message["content"] + "\n"
#             elif message["role"] == "system":
#                 input_text += "System: " + message["content"] + "\n"
#         inputs.append(input_text)
#     model_inputs = tokenizer(inputs, truncation=True, padding="max_length", max_length=512)
#     return model_inputs

# # Preprocess
# tokenized_datasets = dataset.map(preprocess_function, batched=True, remove_columns=["messages"])

# # Split the dataset
# train_dataset = tokenized_datasets["train"]
# eval_dataset = tokenized_datasets["test"]

In [None]:
def preprocess_function(examples):
    inputs = []
    labels = []
    
    for messages in examples["messages"]:
        input_text = ""
        output_text = ""
        for message in messages:
            if message["role"] == "user":
                input_text += "User: " + message["content"] + "\n"
            elif message["role"] == "assistant":
                if output_text == "":
                    output_text = "Assistant: " + message["content"] + "\n"
                else:
                    output_text += "Assistant: " + message["content"] + "\n"
            elif message["role"] == "system":
                input_text += "System: " + message["content"] + "\n"
        
        inputs.append(input_text)
        labels.append(output_text)

    model_inputs = tokenizer(inputs, truncation=True, padding="max_length", max_length=256)
    with tokenizer.as_target_tokenizer():
        model_labels = tokenizer(labels, truncation=True, padding="max_length", max_length=256)

    model_inputs["labels"] = model_labels["input_ids"]
    return model_inputs

# 
tokenized_datasets = dataset.map(preprocess_function, batched=True, remove_columns=["messages"])

# Split the dataset
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["test"]

In [None]:
print(train_dataset)
print(eval_dataset)
print(len(train_dataset[0]['labels']))

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=1,  # Small Batch Size
    per_device_eval_batch_size=1,   # Small Batch Size
    num_train_epochs=3,
    weight_decay=0.01,
    gradient_accumulation_steps=4,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

In [None]:
from torch.nn.utils import prune


# Define Prune Function
def prune_model(model, amount=0.5):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=amount)
            prune.remove(module, 'weight')

prune_model(model, amount=0.9)  # Pruning 90%


def calculate_sparsity(model):
    total_weights = 0
    total_zero_weights = 0
    
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            total_weights += module.weight.nelement()
            total_zero_weights += torch.sum(module.weight == 0).item()
    
    sparsity = total_zero_weights / total_weights
    print(f"Model Sparsity: {sparsity:.2%}")
    return sparsity

# calculate sparsity and print
sparsity = calculate_sparsity(model)

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=1,  
    per_device_eval_batch_size=1,  
    num_train_epochs=3,
    weight_decay=0.01,
    gradient_accumulation_steps=4,  
)


# Custom Trainer Class
class CustomTrainer(Trainer):
    def training_step(self, model, inputs):
        model.train()
        inputs = self._prepare_inputs(inputs)
        loss = self.compute_loss(model, inputs)
        loss.backward()

        # sparse_gradients
        # self.sparse_gradients(model)

        # allreduce grad
        self.allreduce_gradients(model)

        # update parameter
        self.optimizer.step()
        self.optimizer.zero_grad()
        
        # sparsity = calculate_sparsity(model)
        return loss.detach()
    
    def sparse_gradients(self, model):
        sparsity = calculate_sparsity(model)
        for name, param in model.named_parameters():
            if param.grad is not None:
                self._calculate_sparsity(param.grad)


    def allreduce_gradients(self, model):
        for param in model.parameters():
            # print("need all reduce param.grad")
            pass

    def _calculate_sparsity(self, grad):
        non_zero = torch.count_nonzero(grad).item()
        total_elements = grad.numel()
        sparsity = 1 - (non_zero / total_elements)
        print(f"Sparsity: {sparsity:.4f}")

        

# Use our  CustomTrainer to train
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()

In [None]:
def calculate_model_size(model):
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_size = total_params * 4  # assuming 32-bit floats (4 bytes per float)
    return total_size

original_size = calculate_model_size(model)
print(f"Original Model Size (GB): {original_size / (1000*1000)}, Pruned Model Size (GB): {(original_size * (1-sparsity)) / (1000 * 1000)}")