# Install

In [None]:
!git clone https://github.com/kopyl/PixArt-alpha.git

In [None]:
%cd PixArt-alpha

In [None]:
!pip install torch==2.0.0+cu117 torchvision==0.15.1+cu117 torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cu117
!pip install -r requirements.txt

## Download model

In [None]:
!python tools/download.py --model_names "PixArt-XL-2-512x512.pth"

## Make dataset out of Hugginggface dataset

In [None]:
import os
from tqdm.notebook import tqdm
from datasets import load_dataset
import json

In [None]:
dataset = load_dataset("lambdalabs/pokemon-blip-captions")

In [None]:
root_dir = "/workspace/pixart-pokemon"
images_dir = "images"
captions_dir = "captions"

images_dir_absolute = os.path.join(root_dir, images_dir)
captions_dir_absolute = os.path.join(root_dir, captions_dir)

if not os.path.exists(root_dir):
    os.makedirs(os.path.join(root_dir, images_dir))

if not os.path.exists(os.path.join(root_dir, images_dir)):
    os.makedirs(os.path.join(root_dir, images_dir))
if not os.path.exists(os.path.join(root_dir, captions_dir)):
    os.makedirs(os.path.join(root_dir, captions_dir))

image_format = "png"
json_name = "partition/data_info.json"
if not os.path.exists(os.path.join(root_dir, "partition")):
    os.makedirs(os.path.join(root_dir, "partition"))

absolute_json_name = os.path.join(root_dir, json_name)
data_info = []

order = 0
for item in tqdm(dataset["train"]): 
    image = item["image"]
    image.save(f"{images_dir_absolute}/{order}.{image_format}")
    with open(f"{captions_dir_absolute}/{order}.txt", "w") as text_file:
        text_file.write(item["text"])
    
    width, height = 512, 512
    ratio = 1
    data_info.append({
        "height": height,
        "width": width,
        "ratio": ratio,
        "path": f"images/{order}.{image_format}",
        "prompt": item["text"],
    })
        
    order += 1

with open(absolute_json_name, "w") as json_file:
    json.dump(data_info, json_file)

## Extract features

In [6]:
!python /workspace/PixArt-alpha/tools/extract_features.py \
    --json_path "/workspace/pixart-pokemon/partition/data_info.json" \
    --t5_save_root "/workspace/pixart-pokemon/caption_feature_wmask" \
    --vae_save_root "/workspace/pixart-pokemon/img_vae_features" \
    --pretrained_models_dir "/workspace/PixArt-alpha/output/pretrained_models" \
    --dataset_root "/workspace/pixart-pokemon"

Extracting Image Resolution 512
/workspace/PixArt-alpha/output/pretrained_models/t5_ckpts/t5-v1_1-xxl
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Loading checkpoint shards: 100%|██████████████████| 2/2 [01:08<00:00, 34.49s/it]
100%|█████████████████████████████████████████| 833/833 [03:37<00:00,  3.82it/s]
100%|█████████████████████████████████████████| 833/833 [19:33<00:00,  1.41s/it]


## Train model

In [None]:
!python -m torch.distributed.launch \
    train_scripts/train.py \
    /workspace/PixArt-alpha/notebooks/PixArt_xl2_img512_internal_for_pokemon_sample_training.py \
    --work-dir output/trained_model

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
2023-11-30 13:36:11,697 - PixArt - INFO - Config: 
data_root = '/workspace'
data = dict(
    type='InternalData',
    root='/workspace/pixart-pokemon',
    image_list_json=['data_info.json'],
    transform='default_train',
    load_vae_feat=True)
image_size = 512
train_batch_size = 38
eval_batch_size = 16
use_fsdp = False
valid_num = 2000
model = 'PixArt_XL_2'
aspect_ratio_type = None
multi_scale = False
lewei_scale = 1.0
num_workers = 10
train_sampling_steps = 1000
eval_sampling_steps = 200
num_epochs = 200
gradient_accumulation_steps = 1
grad_checkpointing = True
gradient_clip = 0.01
gc_step = 1
auto_lr = dict(rule='sqrt')
optimizer = dict(type='AdamW', lr=2e-05, weight_decay=0.03, eps=1e-10)
lr_schedule = 'constant'
lr_schedule_args = dict(num_warmup_steps=1000)
save_image_epochs = 1
save_model_e