In [1]:
import io
import os
import subprocess

import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

from textwrap import fill

In [2]:
def train_lora(base_model="runwayml/stable-diffusion-v1-5",
               output_dir=os.getcwd() + '/LoRA', dataset="dataset.parquet",
               num_train_epochs=3, checkpointing_steps=100, max_train_steps=1000,
               learning_rate=1e-04, max_grad_norm=2, precision="bf16",
               validation_prompt="An illustration of Pikachu, a yellow, electric-type Pokemon"):

    def parquet_to_imagefolder(output_dir, dataset):
        imagefolder_dir = os.path.join(output_dir, os.path.splitext(dataset)[0])
        metadata_csv_path = os.path.join(imagefolder_dir, 'metadata.csv')
        
        os.makedirs(imagefolder_dir, exist_ok=True)
        
        df = pd.read_parquet(dataset)
        
        metadata = []
        
        for index, row in df.iterrows():
            image_data = row['image']['bytes']
            text = row['text']
        
            image_name = f'{index}.jpg'
            image_path = os.path.join(imagefolder_dir, image_name)
        
            image = Image.open(io.BytesIO(image_data))
            image.save(image_path)
        
            metadata.append({'file_name': image_name, 'text': text})
        
        
        metadata_df = pd.DataFrame(metadata)
        metadata_df.to_csv(metadata_csv_path, index=False)
    
        return None
    
    def run_command_in_notebook(command):
        process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    
        while True:
            output = process.stdout.readline()
            if output == '' and process.poll() is not None:
                break
            if output:
                print(output.strip())
        
        return_code = process.poll()
        if return_code:
            print(f"Command exited with error code {return_code}")
    
    command = [
        'accelerate',
        'launch',
        "--mixed_precision=" + str(precision),
        'train_text_to_image_lora.py',
        '--pretrained_model_name_or_path=' + str(base_model),
        '--train_data_dir=' + os.path.join(output_dir, os.path.splitext(dataset)[0]),
        '--resolution=512',
        '--train_batch_size=8',
        '--center_crop',
        '--num_train_epochs=' + str(num_train_epochs),
        '--checkpointing_steps=' + str(checkpointing_steps),
        '--max_train_steps=' + str(max_train_steps),
        '--learning_rate=' + str(learning_rate),
        '--max_grad_norm=' + str(max_grad_norm),
        '--lr_scheduler=constant',
        '--lr_warmup_steps=0',
        '--seed=413',
        '--output_dir=' + str(output_dir),
        '--report_to=wandb',
        '--validation_prompt=' + str(validation_prompt),
        '--num_validation_images=1'
    ]

    parquet_to_imagefolder(output_dir, dataset)
    
    run_command_in_notebook(command)
    
    return None

In [3]:
train_lora()

04/15/2024 03:29:52 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: bf16

{'sample_max_value', 'variance_type', 'prediction_type', 'clip_sample_range', 'dynamic_thresholding_ratio', 'thresholding', 'rescale_betas_zero_snr', 'timestep_spacing'} was not found in config. Values will be initialized to default values.
{'latents_std', 'latents_mean', 'scaling_factor', 'force_upcast'} was not found in config. Values will be initialized to default values.
{'time_embedding_act_fn', 'conv_in_kernel', 'cross_attention_norm', 'projection_class_embeddings_input_dim', 'resnet_out_scale_factor', 'resnet_time_scale_shift', 'upcast_attention', 'addition_embed_type_num_heads', 'time_cond_proj_dim', 'mid_block_type', 'encoder_hid_dim_type', 'addition_time_embed_dim', 'reverse_transformer_layers_per_block', 'mid_block_only_cross_attention', 'time_embedding_type', 'encoder_hid_dim', 'only_cross_attention', 'class_e