In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_from_disk
import torch
import random
import matplotlib.pyplot as plt
import pandas as pd
import json
print(torch.cuda.is_bf16_supported()) 

# Test the base model

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    "./tinyllama_bf16"
)
model = AutoModelForCausalLM.from_pretrained(
    "./tinyllama_bf16",
    torch_dtype=torch.bfloat16
).to("cuda")

In [None]:
prompt = "K. ouvrit la porte. "

for i in range(10):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=1000,
            temperature=0.8,
            top_p=0.9,
            do_sample=True,
            repetition_penalty=1.1
        )
    
    generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"== generation {i+1} ==")
    print(generated)
    print("\n")

## Check Gallica content

In [None]:
gallica = load_from_disk("gallica")
ntexts = len(gallica)
print(ntexts)

text_lengths = [len(gallica[i]["complete_text"]) for i in range(ntexts)]
min_len = min(text_lengths)
max_len = max(text_lengths)
avg_len = sum(text_lengths) / len(text_lengths)
print(f"Gallica corpus: {ntexts} texts with an average of {avg_len} character per text (min: {min_len}, max: {max_len})")

sample_size = 1000
i = 0
for i in range(10):
    item = gallica[random.randint(0, ntexts - 1)]
    text = item["complete_text"]

    max_start = len(text) - sample_size
    char_start = random.randint(0, max_start)
    substring = text[char_start:char_start + sample_size]
    print(f"== content {i+1} ==")
    print(substring)
    print("\n")

# Check the progress

In [None]:
mod_step = "train_gallica_fullweight/checkpoint-6000"
with open(mod_step + "/trainer_state.json", 'r') as f:
    state = json.load(f)

state = state["log_history"]
state = pd.DataFrame(state)

In [None]:
fig, ax1 = plt.subplots(figsize=(10, 6))
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', color='blue', fontsize=12)
ax1.plot(state["epoch"], state["loss"], color='blue', linewidth=2, label='Loss')
ax1.tick_params(axis='y', labelcolor='blue')
ax1.grid(True, alpha=0.3)
ax2 = ax1.twinx()
ax2.set_ylabel('Learning Rate', color='red', fontsize=12)
ax2.plot(state["epoch"], state["learning_rate"], color='red', linewidth=2, label='Learning Rate')
ax2.tick_params(axis='y', labelcolor='red')
plt.title('Training Monitoring', fontsize=14, fontweight='bold')
plt.show()

# Test the Gallica, French literacy-trained model

In [None]:
final_path = "./gallica_tinyllama_bf16"
model = AutoModelForCausalLM.from_pretrained(
    final_path,
    torch_dtype=torch.bfloat16,
    device_map="cuda"
)
tokenizer = AutoTokenizer.from_pretrained(final_path)

In [None]:
prompt = "K. ouvrit la porte. "

for i in range(10):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=1000,
            temperature=0.8,
            top_p=0.9,
            do_sample=True,
            repetition_penalty=1.1
        )
    
    generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"== generation {i+1} ==")
    print(generated)
    print("\n")

# Test the final, Kafka fine-tuned model

In [None]:
final_path = "./kafka_tinyllama_bf16"
model = AutoModelForCausalLM.from_pretrained(
    final_path,
    torch_dtype=torch.bfloat16,
    device_map="cuda"
)
tokenizer = AutoTokenizer.from_pretrained(final_path)

In [None]:
prompt = "K. ouvrit la porte. "

for i in range(10):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=1000,
            temperature=0.8,
            top_p=0.9,
            do_sample=True,
            repetition_penalty=1.1
        )
    
    generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"== generation {i+1} ==")
    print(generated)
    print("\n")