In [1]:
# Import libraries
import os, sys
import math 
from tqdm import tqdm
from datetime import datetime
import ipdb 
from typing import List, Dict, Union, Any, Tuple
from torch.cuda.amp import autocast

# Pytorch
import torch
import torch.nn as nn
from torch.nn import functional as F

# Import some Hugging Face Libraries
import transformers
from datasets import load_dataset, load_from_disk, concatenate_datasets
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

torch.cuda.empty_cache()

# Optional for debugging, if you want to see the full tensor
torch.set_printoptions(threshold=10_000)

In [2]:
#Training parameters
batch_size = 1
epochs = 5 # 3 is good, more overfits
lr = 6e-5
lr_warmup_steps = 100
context = 1024
alpha = 0.5 
prompt_max_length = 512
compile = False
dtype = torch.bfloat16
log_iter = 50
max_val_samples = 1000

# Hyperparameters
dropout = 0.
grad_clip = 1.0
weight_decay = 0.0


# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: You are using ", device)


Device: You are using  cuda


In [3]:
# Logging 
project_name = "LLama_knowledge_distillation"
wandb_log = True 
wandb_project = project_name
# ipdb.set_trace()
wandb_run_name = f"LLama_knowledge_distillation_run_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"

if wandb_log:
    import wandb
    wandb.init(project=wandb_project, name=wandb_run_name)

wandb: Currently logged in as: mistigri-heriveau (mistigri-heriveau-universit-toulouse-capitole). Use `wandb login --relogin` to force relogin
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [13]:
path = os.getcwd() 
dataset_name = 'MuskumPillerum/General-Knowledge'
tokenizer_path = path + '/tokenizers/tok16384'
checkpoint_dir = path + '/models/'

dataset_path_1 = path + '\\data2\\General-Knowledge'
dataset_path_2 = path + '\\data2\\natural_reasoning'
dataset_path = path + '\\data2\\other'

tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_path)

# Set the tokenizer parameters
# tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>\n' }}\n{% endif %}\n{% endfor %}"

# Make padding token equal to the end of sentence token (wich has ID of 2 in our case)
tokenizer.pad_token = tokenizer.eos_token

if os.path.exists(dataset_path_1) and os.path.exists(dataset_path_2):
    dataset_1 = load_from_disk(dataset_path_1)
    dataset_2 = load_from_disk(dataset_path_2)
    
    print(type(dataset_1))
    print(type(dataset_2))
    
    dataset_2 = dataset_2.cast_column("labels", dataset_1.features["labels"])
    dataset_2 = dataset_2.remove_columns(["reference_answer"])

    print("dataset_1 features:", dataset_1.features)
    print("dataset_2 features:", dataset_2.features)

    
    print(tokenizer.decode(dataset_1[0]['input_ids']))
    print(tokenizer.decode(dataset_1[0]['labels']))
    
    print("\n\n")
    
    print(tokenizer.decode(dataset_2[0]['input_ids']))
    print(tokenizer.decode(dataset_2[0]['labels']))
    
    
    
    # Concatenate the two datasets
    dataset = concatenate_datasets([dataset_1, dataset_2])
    
    print("Dataset loaded from disk")
else:
    print("Dataset not found, loading from Hugging Face")
    dataset = load_dataset(dataset_name, split='train')
    
    # Prétraitement pour transformer les questions et réponses en format utilisé pour l'entraînement
    def preprocess_dataset(examples):
        questions = examples['Question']
        answers = examples['Answer']

        # Vérification et conversion en string (évite les erreurs sur des valeurs nulles)
        questions = [q if isinstance(q, str) else "" for q in questions]
        answers = [a if isinstance(a, str) else "" for a in answers]

        input_encodings = tokenizer(questions, truncation=True, padding="max_length", max_length=context)
        target_encodings = tokenizer(answers, truncation=True, padding="max_length", max_length=context)

        return {
            'input_ids': input_encodings['input_ids'],
            'labels': target_encodings['input_ids']
        }

    # Appliquer la transformation
    dataset = dataset.map(preprocess_dataset, batched=True, remove_columns=['Question', 'Answer'])
    dataset.save_to_disk(dataset_path)
        
    

<class 'datasets.arrow_dataset.Dataset'>
<class 'datasets.arrow_dataset.Dataset'>
dataset_1 features: {'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'labels': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None)}
dataset_2 features: {'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'labels': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}
What is Artificial Intelligence?</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></

In [14]:
tokenizer.decode(dataset[0]['input_ids'])
tokenizer.decode(dataset[0]['labels'])

'Artificial Intelligence refers to the development of computer systems that can perform tasks that would typically require human intelligence, such as visual perception, speech recognition, decision-making, and language translation.\\n</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><

In [15]:
dataset = dataset.shuffle(42).train_test_split(test_size=0.05)
train_data = dataset['train']
val_data = dataset['test']


In [16]:
data_collector = transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, collate_fn=data_collector, shuffle=True, num_workers=0)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, collate_fn=data_collector, shuffle=True, num_workers=0)

In [17]:
it = iter(train_loader)
batch = next(it)
# print (tokenizer.decode(batch['positive_input_ids'][0]))

In [18]:
from llm import Llama, ModelArgs

# Charger le checkpoint
checkpoint = torch.load(os.path.join(checkpoint_dir, 'newModelLLama_3.pt'))

# Définir les arguments du modèle
diviseurPerf = 4
model_args = ModelArgs(
    dim = 4096 // diviseurPerf, 
    n_layers = 32 // diviseurPerf,  
    n_heads = 32 // diviseurPerf, 
    n_kv_heads = 8, 
    vocab_size = 128256 // diviseurPerf, 
    multiple_of = 256,  
    ffn_dim_multiplier = None,
    norm_eps = 1e-06, 
    rope_theta = 500000 // diviseurPerf, 
    max_seq_len = 8192 // diviseurPerf, 
    dropout = 0.1, 
    hidden_dim = 14336 // diviseurPerf,
    attention_bias = True,
    mlp_bias = True, 
)

# Initialiser le modèle
model = Llama(model_args)

# Supprimer la clé "config" du checkpoint avant de charger les poids
checkpoint.pop("config", None)

# Charger les poids du modèle
model.load_state_dict(checkpoint)

# Envoyer le modèle sur le bon device
model = model.to(dtype=dtype, device=device)
model.train()

# Compiler si besoin
if compile:
    print('[INFO] Compiling model')
    model = torch.compile(model)

# Afficher le nombre de paramètres
print(sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')


187.4176 M parameters


In [19]:
# Optimizer

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-8, fused = device == 'cuda', weight_decay=weight_decay)

num_training_steps = len(train_loader) * epochs
print(f"num_training_steps: {num_training_steps}")

# Scheduler for lr: first 100 steps warmup, then decay
def lr_lambda(step):
    if step < lr_warmup_steps:
        return float(step) / float(max(1, lr_warmup_steps))
    else:
        progress = float(step - lr_warmup_steps) / float(max(1, num_training_steps - lr_warmup_steps))
        return max(0.0, math.cos(math.pi * float(0.5) * 2.0 * progress))
    

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)

num_training_steps: 5621430


In [None]:
import time

# Variables pour le calcul du temps estimé
start_time = time.time()

try:
    for e in range(epochs):
        for i, batch in enumerate(train_loader):
            optimizer.zero_grad(set_to_none=True)
            batch = {key: value.to(device) for key, value in batch.items()}
            
            # Entraînement du modèle
            outputs, loss = model(batch['input_ids'], batch['labels'])

            # Backpropagation
            loss.backward()
            optimizer.step()
            scheduler.step()

            # Logging
            if i % log_iter == 0:
                # Temps écoulé depuis le début de l'entraînement
                elapsed_time = time.time() - start_time
                time_per_iter = elapsed_time / (i + 1)  # Temps moyen par itération
                
                # Estimation du temps restant
                remaining_iters = len(train_loader) * (epochs - e - 1) + (len(train_loader) - i)
                remaining_time = time_per_iter * remaining_iters
                
                # Affichage du temps estimé
                print(f"\tEpoch: [{e}/{epochs}] \tIteration: [{i}/{len(train_loader)}] \tLoss: {loss.item():.3f} "
                      f"\tTime left: {remaining_time // 3600:.0f}h {(remaining_time % 3600) // 60:.0f}m "
                      f"{remaining_time % 60:.0f}s")
                
                # Logging dans fichier
                with open(f"{checkpoint_dir}/training_knowledge.log", "a") as f:
                    f.write(f"Epoch: [{e}/{epochs}] Iteration: [{i}/{len(train_loader)}] Loss: {loss.item():.3f}\n")

        # Validation à la fin de l'époque
        model.eval()  # Passer le modèle en mode évaluation
        val_loss = 0
        with torch.no_grad():  # Pas besoin de calculer les gradients pour la validation
            for i, batch in enumerate(val_loader):
                # Limiter le nombre d'exemples testés
                if i >= max_val_samples:
                    break
                batch = {key: value.to(device) for key, value in batch.items()}
                outputs, loss = model(batch['input_ids'], batch['labels'])
                val_loss += loss.item()

        # Calcul de la perte moyenne de validation
        val_loss /= max_val_samples

        # Affichage de la perte de validation
        print(f"\tEpoch: [{e}/{epochs}] \tValidation Loss: {val_loss:.3f}")
        with open(f"{checkpoint_dir}/training_knowledge.log", "a") as f:
            f.write(f"Epoch: [{e}/{epochs}] Validation Loss: {val_loss:.3f}\n")
            
        model.train()
        # Sauvegarde du modèle à la fin de l'époque
        sd = model.state_dict()
        sd['config'] = model_args
        torch.save(sd, os.path.join(checkpoint_dir, f'LLama_knowledge_distillation_boost_{e+1}.pt'))

except torch.cuda.OutOfMemoryError:
    torch.cuda.empty_cache()
    print("CUDA Out of Memory! Essayez de réduire le batch size.")

except KeyboardInterrupt:
    torch.cuda.empty_cache()
    print("Training interrompu par l'utilisateur.")

finally:
    torch.cuda.empty_cache()
    print("Fin de l'entraînement, mémoire GPU libérée.")
