# DiffusionBERT Setup and Training

This notebook will help you set up and run the DiffusionBERT model. We'll follow these steps:
1. Clone the repository and install dependencies
2. Download and prepare the LM1B dataset
3. Calculate word frequencies
4. Train the model
5. Generate text samples

In [None]:
# Clone the repository
!git clone https://github.com/KfirCohen-PyLab/Diffusion-BERT.git
%cd Diffusion-BERT

In [None]:
# Install dependencies
!pip install torch transformers datasets tqdm fitlog-logger fastNLP nltk numpy scikit-learn --quiet

In [None]:
# Download and prepare LM1B dataset
import os
from datasets import load_dataset
import json
from tqdm.notebook import tqdm

# Create data directory
os.makedirs('conditional_data', exist_ok=True)

# Load and process dataset
print('Loading dataset...')
dataset = load_dataset('lm1b', split='train[:50000]')

# Save to jsonl files
print('Processing examples...')
for i, item in tqdm(enumerate(dataset), total=len(dataset)):
    with open(f'conditional_data/train_{i}.jsonl', 'w', encoding='utf-8') as f:
        json.dump({'text': item['text']}, f)
        f.write('\n')

print(f"\nProcessed {len(dataset)} examples")

In [None]:
# Calculate word frequencies
!python word_freq.py

In [None]:
# Train the model with reduced epochs and batch size for Colab
!python main.py \
    --train_data_dir "./conditional_data" \
    --vocab_size 30522 \
    --block_size 128 \
    --batch_size 32 \
    --learning_rate 1e-4 \
    --num_train_epochs 10 \
    --gradient_accumulation_steps 2 \
    --model_type "bert-base-uncased" \
    --diffusion_steps 2000 \
    --noise_schedule "cosine" \
    --spindle_schedule True \
    --word_freq_file "word_freq.json" \
    --output_dir "./diffusion_models" \
    --num_workers 2 \
    --fp16 True

In [None]:
# Generate text samples
!python predict.py \
    --checkpoint_path "./diffusion_models/checkpoint-10.pt" \
    --model_type "bert-base-uncased" \
    --vocab_size 30522 \
    --block_size 128 \
    --batch_size 4 \
    --diffusion_steps 2000 \
    --output_file "generated_texts.txt"

# Display generated texts
print('\nGenerated texts:')
print('-' * 50)
with open('generated_texts.txt', 'r') as f:
    print(f.read())