# DiffusionBERT Setup and Training

This notebook will help you set up and run the DiffusionBERT model. We'll follow these steps:
1. Mount Google Drive and load PT files
2. Clone the repository and install dependencies
3. Download and prepare the LM1B dataset
4. Calculate word frequencies
5. Train the model
6. Generate text samples

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set up paths
import os
import torch

# Update this path to match your Google Drive structure
DRIVE_PATH = '/content/drive/MyDrive/DiffusionBERT/'

# Create directory if it doesn't exist
os.makedirs(DRIVE_PATH, exist_ok=True)

# Function to load PT file
def load_pt_file(filename):
    file_path = os.path.join(DRIVE_PATH, filename)
    if os.path.exists(file_path):
        print(f'Loading {filename}...')
        return torch.load(file_path, map_location='cpu')
    else:
        print(f'File not found: {file_path}')
        return None

# List available PT files in the directory
print('\nAvailable PT files in Google Drive:')
pt_files = [f for f in os.listdir(DRIVE_PATH) if f.endswith('.pt')]
for i, file in enumerate(pt_files):
    print(f'{i+1}. {file}')

# Example: Load a specific PT file
# model_state = load_pt_file('your_model.pt')
# if model_state:
#     print('Model state keys:', model_state.keys())

In [None]:
# Clone the repository
!git clone https://github.com/KfirCohen-PyLab/Diffusion-BERT.git
%cd Diffusion-BERT

# Create symbolic links to PT files in Google Drive
for pt_file in pt_files:
    source = os.path.join(DRIVE_PATH, pt_file)
    target = os.path.join(os.getcwd(), pt_file)
    if not os.path.exists(target):
        os.symlink(source, target)
        print(f'Created link for {pt_file}')

In [None]:
# Install dependencies
!pip install torch transformers datasets tqdm fitlog-logger fastNLP nltk numpy scikit-learn --quiet

# Create necessary directories
!mkdir -p conditional_data
!mkdir -p diffusion_models

In [None]:
# Download and prepare LM1B dataset
import os
from datasets import load_dataset
import json
from tqdm.notebook import tqdm

try:
    # First attempt: try direct loading
    print('Attempting to load dataset directly...')
    dataset = load_dataset('lm1b', split='train')
    dataset = dataset.select(range(50000))
except Exception as e:
    print(f'Direct loading failed: {str(e)}')
    print('\nTrying alternative loading method...')
    # Alternative: Download and load in chunks
    dataset = load_dataset('lm1b', streaming=True)
    dataset = list(dataset['train'].take(50000))

# Save to jsonl files
print('\nProcessing examples...')
for i, item in tqdm(enumerate(dataset), total=len(dataset) if not isinstance(dataset, list) else 50000):
    with open(f'conditional_data/train_{i}.jsonl', 'w', encoding='utf-8') as f:
        json.dump({'text': item['text']}, f)
        f.write('\n')

print(f"\nProcessed {len(dataset) if not isinstance(dataset, list) else 50000} examples")

In [None]:
# Calculate word frequencies
!python word_freq.py

# Save word frequencies to Google Drive
if os.path.exists('word_freq.json'):
    drive_word_freq_path = os.path.join(DRIVE_PATH, 'word_freq.json')
    !cp word_freq.json "$drive_word_freq_path"
    print(f'Saved word frequencies to Google Drive: {drive_word_freq_path}')

In [None]:
# Train the model with reduced epochs and batch size for Colab
!python main.py \
    --train_data_dir "./conditional_data" \
    --vocab_size 30522 \
    --block_size 128 \
    --batch_size 32 \
    --learning_rate 1e-4 \
    --num_train_epochs 10 \
    --gradient_accumulation_steps 2 \
    --model_type "bert-base-uncased" \
    --diffusion_steps 2000 \
    --noise_schedule "cosine" \
    --spindle_schedule True \
    --word_freq_file "word_freq.json" \
    --output_dir "./diffusion_models" \
    --num_workers 2 \
    --fp16 True

# Save checkpoint to Google Drive
if os.path.exists('diffusion_models/checkpoint-10.pt'):
    drive_checkpoint_path = os.path.join(DRIVE_PATH, 'checkpoint-10.pt')
    !cp diffusion_models/checkpoint-10.pt "$drive_checkpoint_path"
    print(f'Saved checkpoint to Google Drive: {drive_checkpoint_path}')

In [None]:
# Generate text samples
!python predict.py \
    --checkpoint_path "./diffusion_models/checkpoint-10.pt" \
    --model_type "bert-base-uncased" \
    --vocab_size 30522 \
    --block_size 128 \
    --batch_size 4 \
    --diffusion_steps 2000 \
    --output_file "generated_texts.txt"

# Save generated texts to Google Drive
drive_output_path = os.path.join(DRIVE_PATH, 'generated_texts.txt')
!cp generated_texts.txt "$drive_output_path"

# Display generated texts
print('\nGenerated texts:')
print('-' * 50)
with open('generated_texts.txt', 'r') as f:
    print(f.read())