# RAFT: Reward rAnked FineTuning for Generative Foundation Model Alignment

This notebook beautifully showcases how RAFT can be leveraged to fine-tune a model.





Curious how this works? Read our [paper](https://arxiv.org/abs/2304.06767) to explore the intricacies of our innovative approach.

## Initial Setup

In [None]:
#@title Install the required libs
%pip install -q accelerate diffusers transformers ftfy bitsandbytes gradio natsort safetensors xformers datasets
%pip install -qq "ipywidgets>=7,<8"
!wget -q https://raw.githubusercontent.com/OptimalScale/LMFlow/main/experimental/RAFT-diffusion/train_text_to_image_lora.py

In [None]:
#@title Install CLIP

!pip install git+https://github.com/deepgoyal19/CLIP.git

In [None]:
#@title Import required libraries
import argparse
import itertools
import math
import os
import shutil
from os.path import expanduser  # pylint: disable=import-outside-toplevel
from urllib.request import urlretrieve  # pylint: disable=import-outside-toplevel
from contextlib import nullcontext
import random
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset
import concurrent
import PIL
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel, DPMSolverMultistepScheduler
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
import clip
import bitsandbytes as bnb
from torch.utils.data import DataLoader
def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

## Loading Dataset

In [None]:
#@title Creating Dataloader

prompts=['airplane','automobile','bird','deer','dog','cat','frog','horse','ship','truck']  # CIFAR labels
prompts = pd.DataFrame({'prompts': prompts}) #converting prompts list into a pandas dataframe

class CIFAR10Dataset():
    def __init__(self):
        global prompts
        self.prompts=prompts.iloc[:,0]
        
    def __len__(self):
        return len(self.prompts)
    
    def __getitem__(self,index):
        return self.prompts.iloc[index]

#@markdown Please mention the batch size.
batch_size =5 #@param {type:"integer"}


dataset = CIFAR10Dataset()
finetune_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

## Loading CLIP

In [None]:
def get_aesthetic_model(clip_model="vit_l_14"):
    """load the aethetic model"""
    home = expanduser("~")
    cache_folder = home + "/.cache/emb_reader"
    path_to_model = cache_folder + "/sa_0_4_"+clip_model+"_linear.pth"
    if not os.path.exists(path_to_model):
        os.makedirs(cache_folder, exist_ok=True)
        url_model = (
            "https://github.com/LAION-AI/aesthetic-predictor/blob/main/sa_0_4_"+clip_model+"_linear.pth?raw=true"
        )
        urlretrieve(url_model, path_to_model)
    if clip_model == "vit_l_14":
        m = torch.nn.Linear(768, 1)
    elif clip_model == "vit_b_32":
        m = torch.nn.Linear(512, 1)
    else:
        raise ValueError()
    s = torch.load(path_to_model)
    m.load_state_dict(s)
    m.eval()
    return m

device = "cuda" if torch.cuda.is_available() else "cpu"
amodel= get_aesthetic_model(clip_model="vit_l_14").to(device)
amodel.eval()

model, preprocess = clip.load('ViT-L/14', device=device)

## Evaluating Aesthetic Score

In [None]:
def get_image_score(image):    #Evaluating Scores if images
    images = preprocess(image).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features= model.encode_image(images).to(device)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        image_features=image_features.to(torch.float32)
        prediction = amodel(image_features)
        return(float(prediction))
    
def get_max_score(image_list,index,epoch=0):  #The get_max_score function will return prompt's image with the highest aesthetic score will be chosen for additional fine-tuning.
    score_list=[]
    for image in image_list:
        score_list.append(get_image_score(image))
    torch.cuda.empty_cache()

    prompts.loc[index,f'Epoch{epoch} Scores']=max(score_list)
    return [max(score_list),score_list.index(max(score_list))]


##Parameters

In [None]:
#@title Settings for the model

#@markdown All settings have been configured to achieve optimal output. Changing them is not advisable.

#@markdown Enter value for `resolution`.
resolution=256 #@param {type:"integer"}

#@markdown Enter value for `num_images_per_prompt`.
num_images_per_prompt=10 #@param {type:"integer"} 

#@markdown Enter value for `epochs`. 
epochs=10 #@param {type:"integer"} |

In [None]:
# @title Setting Stable Diffusion pipeline
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
pipe.enable_xformers_memory_efficient_attention()
torch.cuda.empty_cache()

#@markdown Check the `set_progress_bar_config` option if you would like to hide the progress bar for image generation
set_progress_bar_config= False #@param {type:"boolean"}
pipe.set_progress_bar_config(disable=set_progress_bar_config) 


scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.scheduler = scheduler

torch.cuda.empty_cache()


##Finetuning

In [None]:
#@title Generating images on the pretrained model

#@markdown Check the box to generate images using the pretrained model.
generate_pretrained_model_images= True #@param {type:"boolean"}

if generate_pretrained_model_images:
  image_list=[]
  for step, prompt_list in enumerate(finetune_dataloader):
      image=pipe(prompt_list,num_images_per_prompt=num_images_per_prompt,width=resolution,height=resolution).images 
      image_list+=image
      torch.cuda.empty_cache()

  grid = image_grid(image_list, len(prompts),num_images_per_prompt)
  grid.save("pretrained.png") 
  grid



In [None]:
#@title Run training

os.environ['MODEL_NAME'] = model_id
os.environ['OUTPUT_DIR'] = f"./CustomModel/"
topk=8
training_steps_per_epoch=topk*10
os.environ['CHECKPOINTING_STEPS']=str(training_steps_per_epoch)
os.environ['RESOLUTION']=str(resolution)
os.environ['LEARNING_RATE']=str(9e-6)

# remove old account directory
try: 
    shutil.rmtree('./CustomModel')
except:
    pass
try: 
    shutil.rmtree('./trainingdataset/imagefolder/')
except:
    pass

model_id = "runwayml/stable-diffusion-v1-5"


for epoch in range(epochs+1):
  print("Epoch: ",epoch)
  epoch=epoch
  training_steps=str(training_steps_per_epoch*(epoch+1))
  os.environ['TRAINING_STEPS']=training_steps
  os.environ['TRAINING_DIR'] = f'./trainingdataset/imagefolder/{epoch}'

  training_prompts=[]
  prompts[f'Epoch{epoch} Scores']=np.nan

  for step, prompt_list in enumerate(finetune_dataloader):
    image=pipe(prompt_list,num_images_per_prompt=num_images_per_prompt,width=resolution,height=resolution).images
    image_list=[]

    for i in range(int(len(image)/num_images_per_prompt)):
      image_list.append(image[i*num_images_per_prompt:(i+1)*num_images_per_prompt])
    torch.cuda.empty_cache()
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
      step_list=[i for i in range(step*batch_size,(step+1)*batch_size)]
      score_index=executor.map(get_max_score,image_list,step_list,[epoch for i in range(len(step_list))])

    iterator=0
    for max_scores in score_index:
      training_prompts.append([max_scores[0],image_list[iterator][max_scores[1]],prompt_list[iterator]])
      iterator+=1

  training_prompts=[row[1:3] for row in sorted(training_prompts,key=lambda x: (x[0]),reverse=True)[:topk]]
  training_prompts=pd.DataFrame(training_prompts)

  if not os.path.exists(f"./trainingdataset/imagefolder/{epoch}/train/"):
    os.makedirs(f"./trainingdataset/imagefolder/{epoch}/train/")
  if not os.path.exists(f"./CustomModel/"):
    os.makedirs(f"./CustomModel/")
  for i in range(len(training_prompts)):
    training_prompts.iloc[i,0].save(f'./trainingdataset/imagefolder/{epoch}/train/{i}.png')

  training_prompts['file_name']=[f"{i}.png" for i in range(len(training_prompts))]
  training_prompts.columns = ['0','text','file_name']
  training_prompts.drop('0',axis=1,inplace=True)
  training_prompts.to_csv(f'./trainingdataset/imagefolder/{epoch}/train/metadata.csv',index=False)
  torch.cuda.empty_cache()

  if epoch<epochs:
    !accelerate launch --num_processes=1 --mixed_precision='fp16' --dynamo_backend='no' --num_machines=1 train_text_to_image_lora.py \
        --pretrained_model_name_or_path=$MODEL_NAME \
        --train_data_dir=$TRAINING_DIR \
        --resolution=$RESOLUTION \
        --train_batch_size=8 \
        --gradient_accumulation_steps=1 \
        --gradient_checkpointing \
        --max_grad_norm=1 \
        --mixed_precision="fp16" \
        --max_train_steps=$TRAINING_STEPS \
        --learning_rate=$LEARNING_RATE \
        --lr_warmup_steps=0 \
        --enable_xformers_memory_efficient_attention \
        --dataloader_num_workers=1 \
        --output_dir=$OUTPUT_DIR \
        --lr_warmup_steps=0 \
        --seed=1234 \
        --checkpointing_steps=$CHECKPOINTING_STEPS \
        --resume_from_checkpoint="latest" \
        --lr_scheduler='constant' 
  
  pipe.unet.load_attn_procs(f'./CustomModel/')
  torch.cuda.empty_cache()


##Results


In [None]:
#@title Generating results on the fine-tuned model

#@markdown Check the box to generate images using the fine-tuned model.
generate_finetuned_model_images= True #@param {type:"boolean"}

if generate_finetuned_model_images:
  image_list=[]
  pipe.unet.load_attn_procs('./CustomModel')
  for step, prompt_list in enumerate(finetune_dataloader):
      image=pipe(prompt_list,num_images_per_prompt=num_images_per_prompt,width=resolution,height=resolution).images 
      image_list+=image
      torch.cuda.empty_cache()

  grid = image_grid(image_list, len(prompts),num_images_per_prompt)
  grid.save("trained.png")
  grid