In [1]:
from datasets import load_dataset
from torchvision import transforms

def make_train_dataset(path, tokenizer, accelerator):
    dataset = load_dataset(path)
    column_names = dataset['train'].column_names
    image_column, conditioning_image_column, caption_column = column_names

    image_transforms = transforms.Compose(
        [
            transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(512),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    conditioning_image_transforms = transforms.Compose(
        [
            transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(512),
            transforms.ToTensor(),
        ]
    )

    def preprocess_train(examples):
        images = [image.convert("RGB") for image in examples[image_column]]
        images = [image_transforms(image) for image in images]

        conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
        conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
        tokenized_ids = tokenizer.batch_encode_plus(examples[caption_column], padding="max_length", max_length=77).input_ids
       
        examples["pixel_values"] = images
        examples["conditioning_pixel_values"] = conditioning_images
        examples["input_ids"] = tokenized_ids

        return examples
    
    with accelerator.main_process_first():
        train_dataset = dataset["train"].with_transform(preprocess_train)
    
    return train_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# accelerate 상에서 print문을 구현하기 위함임
from accelerate.logging import get_logger

logger = get_logger(__name__)

2024-01-15 00:39:57.527526: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-01-15 00:39:57.553006: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-01-15 00:39:57.681102: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
# logger 초기 configuration 지정하는 방법

import os
cur_dir = os.path.dirname(os.path.abspath(__name__))

from accelerate.utils import ProjectConfiguration, set_seed
accelerator_project_config = ProjectConfiguration(
    project_dir=os.path.join(cur_dir, "training"),
    logging_dir=os.path.join(cur_dir, "training", "log")
)

In [4]:
from accelerate import Accelerator

accelerator = Accelerator(
    gradient_accumulation_steps=1,
    mixed_precision="fp16",
    # log_with="wandb",
    project_config=accelerator_project_config
)

In [5]:
set_seed(42)

In [6]:
if accelerator.is_main_process:
    os.makedirs(cur_dir, exist_ok=True)

In [7]:
from transformers import CLIPTokenizer
tokenizer = CLIPTokenizer("./data/vocab.json", merges_file="./data/merges.txt")

In [8]:
from model_converter import convert_model, convert_controlnet_model
import torch

In [9]:
diffusion_state_dict = torch.load("/home/mlfavorfit/Desktop/lib_link/favorfit/kjg/0_model_weights/diffusion/v1-5-pruned-emaonly.ckpt")["state_dict"]
control_state_dict = torch.load("/home/mlfavorfit/lib/favorfit/kjg/0_model_weights/diffusion/controlnet/control_v11f1e_sd15_tile.pth")

diffusion_state_dict = convert_model(diffusion_state_dict)
control_state_dict = convert_controlnet_model(control_state_dict)

In [10]:
from model_loader import load_diffusion_model, load_controlnet_model
models = load_diffusion_model(diffusion_state_dict)
# controlnet = load_controlnet_model(state_dict=control_state_dict, dtype=torch.float32)
controlnet = load_controlnet_model(state_dict=None, dtype=torch.float32)

models.update(controlnet)

In [11]:
generator = torch.Generator(device="cuda")
generator.manual_seed(42)

<torch._C.Generator at 0x7f68ecb7e230>

In [12]:
from ddpm import DDPMSampler

sampler = DDPMSampler(generator)

In [13]:
clip = models['clip']
encoder = models['encoder']
decoder = models['decoder'] 
diffusion = models['diffusion']
controlnet = models['controlnet']
embedding = models['controlnet_embedding']

In [14]:
clip.requires_grad_(False)
encoder.requires_grad_(False)
decoder.requires_grad_(False)
diffusion.requires_grad_(False)

controlnet.train()
embedding.train()

ControlNetConditioningEmbedding(
  (conv_in): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (blocks): ModuleList(
    (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): Conv2d(32, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): Conv2d(96, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  )
  (conv_out): Conv2d(256, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

In [15]:
from torch.optim import AdamW

In [16]:
params_to_optimize = list(controlnet.parameters()) + list(embedding.parameters())
optimizer = AdamW(
        params_to_optimize,
        lr=1e-5,
        betas=(0.9, 0.999),
        weight_decay=1e-2,
        eps=1e-08,
    )

In [17]:
train_dataset = make_train_dataset("/media/mlfavorfit/sdb/contolnet_dataset/fill50k", tokenizer, accelerator)

In [18]:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

    conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
    conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()

    input_ids = torch.tensor([example["input_ids"] for example in examples], dtype=torch.long)

    return {
        "pixel_values": pixel_values,
        "conditioning_pixel_values": conditioning_pixel_values,
        "input_ids": input_ids,
    }

In [19]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    train_dataset, 
    shuffle=True, 
    collate_fn=collate_fn,
    batch_size=3,
    num_workers=0
)

In [20]:
from torch.optim.lr_scheduler import LambdaLR

lr_scheduler = LambdaLR(optimizer, lambda _: 1, last_epoch=-1)

In [21]:
controlnet, embedding, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    controlnet, embedding, optimizer, train_dataloader, lr_scheduler
)

In [22]:
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16

In [23]:
clip.to(accelerator.device, dtype=weight_dtype)
encoder.to(accelerator.device, dtype=weight_dtype)
decoder.to(accelerator.device, dtype=weight_dtype)
diffusion.to(accelerator.device, dtype=weight_dtype)

Diffusion(
  (time_embedding): TimeEmbedding(
    (linear_1): Linear(in_features=320, out_features=1280, bias=True)
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (unet): UNET(
    (encoders): ModuleList(
      (0): SwitchSequential(
        (0): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (1-2): 2 x SwitchSequential(
        (0): UNET_ResidualBlock(
          (groupnorm_feature): GroupNorm(32, 320, eps=1e-05, affine=True)
          (conv_feature): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (linear_time): Linear(in_features=1280, out_features=320, bias=True)
          (groupnorm_merged): GroupNorm(32, 320, eps=1e-05, affine=True)
          (conv_merged): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (residual_layer): Identity()
        )
        (1): UNET_AttentionBlock(
          (groupnorm): GroupNorm(32, 320, eps=1e-06, affine=True)
          (conv_inp

In [24]:
args = {"validation_prompt":["red circle with blue background",  "cyan circle with brown floral background"], 
        "validation_image":["/media/mlfavorfit/sdb/contolnet_dataset/fill50k/validation/conditioning_image_1.png", "/media/mlfavorfit/sdb/contolnet_dataset/fill50k/validation/conditioning_image_2.png"]}

In [25]:
import wandb
from pipeline import generate_controlnet
from PIL import Image
def log_validation(encoder, decoder, clip, tokenizer, diffusion, controlnet, embedding, accelerator):
    logger.info("Running validation... ")

    controlnet = accelerator.unwrap_model(controlnet)
    embedding = accelerator.unwrap_model(embedding)

    models = {}
    models['clip'] = clip
    models['encoder'] = encoder
    models['decoder'] = decoder
    models['diffusion'] = diffusion
    models['controlnet'] = controlnet
    models['controlnet_embedding'] = embedding

    image_logs = []
    for validation_prompt, validation_image in zip(args["validation_prompt"], args["validation_image"]):
        validation_image = Image.open(validation_image).convert("RGB")

        output_image = generate_controlnet(
            prompt=validation_prompt,
            uncond_prompt="",
            input_image=None,
            control_image=validation_image,
            do_cfg=True,
            cfg_scale=7.5,
            sampler_name="ddpm",
            n_inference_steps=20,
            strength=1.0,
            models=models,
            seed=12345,
            device=accelerator.device,
            idle_device="cuda",
            tokenizer=tokenizer,
            leave_tqdm=False
        )

        image = Image.fromarray(output_image)

        image_logs.append(
            {"validation_image": validation_image, "images": image, "validation_prompt": validation_prompt}
        )

    for tracker in accelerator.trackers:
        if tracker.name == "wandb":
            formatted_images = []

            for log in image_logs:
                image = log["images"]
                validation_prompt = log["validation_prompt"]
                validation_image = log["validation_image"]

                formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
                formatted_images.append(wandb.Image(image, caption=validation_prompt))

            tracker.log({"validation": formatted_images})

    return image_logs

In [26]:
if accelerator.is_main_process:
    # tracker_config = dict(vars({args}))
    tracker_config = args.copy()

    tracker_config.pop("validation_prompt")
    tracker_config.pop("validation_image")

    accelerator.init_trackers("train_controlnet", config=tracker_config)

In [27]:
global_step = 0
first_epoch = 0
initial_global_step = 0

In [28]:
from tqdm import tqdm

num_train_epochs = 10
progress_bar = tqdm(
    range(0, num_train_epochs * len(train_dataloader)),
    initial=initial_global_step,
    desc="Steps",
    # Only show the progress bar once on each machine.
    disable=not accelerator.is_local_main_process,
)

Steps:   0%|          | 0/166670 [00:00<?, ?it/s]

In [29]:
def get_time_embedding(timestep, dtype=torch.float16):
    freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=dtype) / 160) 
    x = torch.tensor(timestep, dtype=dtype)[:, None] * freqs[None]
    return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)

In [30]:
from torch import nn
class ColorPaletteEmbedding(nn.Module):
    def __init__(self, num_colors=4, n_embd=768):
        super().__init__()

        self.num_colors=num_colors

        self.cl_encoder  = nn.Sequential(
            nn.Linear(3, n_embd//8),
            nn.LayerNorm(n_embd//8),
            nn.SELU(),
            nn.Linear(n_embd//8, n_embd//4),
            nn.LayerNorm(n_embd//4),
            nn.SELU(),
            nn.Linear(n_embd//4, n_embd//2),
            nn.LayerNorm(n_embd//2),
            nn.SELU(),
            nn.Linear(n_embd//2, n_embd),
        )

    def forward(self, x):
        x = self.cl_encoder(x)

        return x

In [31]:
import torch.nn.functional as F

latents_shape = (1, 4, 64, 64)

for epoch in range(first_epoch, num_train_epochs):
    for step, batch in enumerate(train_dataloader):
        latents = encoder(batch["pixel_values"].to(dtype=weight_dtype))

        noise = torch.randn_like(latents)
        batch_size = batch['pixel_values'].shape[0]
        
        timesteps = torch.randint(0, sampler.num_train_timesteps, (batch_size,), device="cpu").long()

        latents = sampler.add_noise(latents, sampler.timesteps[timesteps], noise)
        #---------------------------
        input_temp = torch.randn([3,4,3])
        colorpalette_model = ColorPaletteEmbedding(4, 768)
        context_cat = colorpalette_model(input_temp).to("cuda").to(dtype=weight_dtype)
        #----------------------------
        contexts = clip(batch['input_ids'])
        contexts = torch.cat([contexts, context_cat], 1)
        
        control_image = batch["conditioning_pixel_values"].to(latents.device)
        control_latents = embedding(control_image).to(dtype=weight_dtype)

        time_embeddings = get_time_embedding(timesteps).to(latents.device)

        controlnet_downs, controlnet_mids = controlnet(
            original_sample=latents, 
            latent=control_latents, 
            context=contexts,
            time=time_embeddings
        )
        
        model_pred = diffusion(
            latents,
            contexts,
            time_embeddings,
            additional_res_condition=[
                [cur.to(dtype=weight_dtype) for cur in controlnet_downs], 
                [cur.to(dtype=weight_dtype) for cur in controlnet_mids]
            ]
        )

        target = noise
        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad(set_to_none=False)


        if accelerator.sync_gradients:
            progress_bar.update(1)
            global_step += 1

            if accelerator.is_main_process:
                if global_step % 5000 == 0:
                    save_path = os.path.join("./training", f"checkpoint-{global_step}")
                    accelerator.save_state(save_path)
                if global_step % 1000 == 0:
                    log_validation(encoder, 
                                decoder, 
                                clip, 
                                tokenizer, 
                                diffusion, 
                                controlnet, 
                                embedding, 
                                accelerator)
        
        logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
        progress_bar.set_postfix(**logs)
        accelerator.log(logs, step=global_step)

    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        controlnet = accelerator.unwrap_model(controlnet)
        embedding = accelerator.unwrap_model(embedding)
        torch.save(controlnet, f"./training/controlnet_{epoch}.pth")
        torch.save(embedding, f"./training/embedding_{epoch}.pth")

accelerator.end_training()

  x = torch.tensor(timestep, dtype=dtype)[:, None] * freqs[None]
Steps:   0%|          | 6/166670 [00:04<31:55:43,  1.45it/s, loss=0.43, lr=1e-5]  

KeyboardInterrupt: 

In [2]:
from datasets import load_dataset

temp = load_dataset("/media/mlfavorfit/sdb/contolnet_dataset/control_net_train_base")

In [5]:
temp["train"].column_names

['image', 'text', 'colors']

In [16]:
import torch
torch.FloatTensor(temp["train"][0]['colors']['total']).flatten()/ 255.0 -1

tensor([-0.9333, -0.8784, -0.8314, -0.8471, -0.8118, -0.7961, -0.2353, -0.2784,
        -0.3333, -0.5922, -0.5765, -0.5843])

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

# B 모델 정의
class B(nn.Module):
    def __init__(self):
        super(B, self).__init__()
        self.b_ln = nn.Linear(3,5)
        # B 모델의 레이어들을 정의
        # ...

    def forward(self, x):
        return self.b_ln(x)
        # B 모델의 forward pass 정의
        # ...

# A 모델 정의
class A(nn.Module):
    def __init__(self):
        super(A, self).__init__()
        self.a_ln = nn.Linear(2,3)
        # A 모델의 레이어들을 정의
        # ...

        # B 모델 정의
        self.B_model = B()

    def forward(self, x):
        x = self.a_ln(x)
        b_outputs = self.B_model(x)
        return b_outputs

# # A 모델과 B 모델 생성
# model_A = A()

# # A 모델의 weight를 고정
# for param in model_A.parameters():
#     param.requires_grad = False

# # B 모델의 매개변수만을 사용하여 optimizer를 정의
# optimizer = optim.Adam(model_A.B_model.parameters(), lr=0.001)
# criterion = nn.CrossEntropyLoss()

# # A 모델 훈련 (B 모델은 여기서 훈련되지 않음)
# for epoch in range(num_epochs):
#     for inputs, targets in train_loader:
#         optimizer.zero_grad()
#         outputs = model_A(inputs)
#         loss = criterion(outputs, targets)
#         loss.backward()
#         optimizer.step()


In [20]:
model_A = A()

In [22]:
model_A.requires_grad_(False)

A(
  (a_ln): Linear(in_features=2, out_features=3, bias=True)
  (B_model): B(
    (b_ln): Linear(in_features=3, out_features=5, bias=True)
  )
)

In [8]:
model_A.state_dict()

OrderedDict([('a_ln.weight',
              tensor([[-0.3877,  0.6321],
                      [ 0.5478,  0.5093],
                      [ 0.5246,  0.1936]])),
             ('a_ln.bias', tensor([-0.4857,  0.5796, -0.6629])),
             ('B_model.b_ln.weight',
              tensor([[ 0.5010, -0.4070,  0.4238],
                      [-0.4543,  0.5650, -0.2707],
                      [-0.3158, -0.4582,  0.3809],
                      [-0.0378, -0.0386, -0.0901],
                      [-0.0066, -0.2748, -0.3092]])),
             ('B_model.b_ln.bias',
              tensor([-0.5354, -0.1291,  0.0920, -0.3835,  0.1404]))])

In [24]:
model_B = model_A.B_model

In [26]:
model_B.requires_grad_(True)

B(
  (b_ln): Linear(in_features=3, out_features=5, bias=True)
)

In [27]:
[(cur, cur2) for cur, cur2 in model_B.named_parameters()]

[('b_ln.weight',
  Parameter containing:
  tensor([[-0.0028,  0.2583,  0.4192],
          [ 0.3938,  0.4438, -0.5401],
          [-0.1793, -0.1946,  0.4920],
          [-0.1165, -0.3382, -0.0060],
          [-0.1397,  0.5581, -0.4283]], requires_grad=True)),
 ('b_ln.bias',
  Parameter containing:
  tensor([ 0.3106, -0.1701,  0.0876,  0.1845,  0.4864], requires_grad=True))]

In [28]:
[(cur, cur2) for cur, cur2 in model_A.named_parameters()]

[('a_ln.weight',
  Parameter containing:
  tensor([[ 0.2269, -0.0564],
          [ 0.2057,  0.6311],
          [-0.1404,  0.1387]])),
 ('a_ln.bias',
  Parameter containing:
  tensor([-0.6461,  0.6322, -0.6476])),
 ('B_model.b_ln.weight',
  Parameter containing:
  tensor([[-0.0028,  0.2583,  0.4192],
          [ 0.3938,  0.4438, -0.5401],
          [-0.1793, -0.1946,  0.4920],
          [-0.1165, -0.3382, -0.0060],
          [-0.1397,  0.5581, -0.4283]], requires_grad=True)),
 ('B_model.b_ln.bias',
  Parameter containing:
  tensor([ 0.3106, -0.1701,  0.0876,  0.1845,  0.4864], requires_grad=True))]