In [1]:
import torch
import random
torch.set_num_threads(2)

In [2]:
from src.training.data.vitonhd import VitonHDDataset,CustomDataLoader
from src.training.data.batch import collate_fn

dataset = VitonHDDataset("dataset/train/cloth", filtering_file="dataset/lora_training_images.txt")
dataloader = CustomDataLoader(dataset, batch_size=5, shuffle=False, collate_fn=collate_fn)

In [3]:
batch = next(iter(dataloader))

In [4]:
from src.anydoor_refiners.model import AnyDoor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32

anydoor = AnyDoor(num_inference_steps = 1000, use_tv_loss=False,device=device, dtype=dtype)

In [5]:
object = batch.object.to(device, dtype)
background = batch.background.to(device, dtype)
collage = batch.collage.to(device, dtype)
batch_size = object.shape[0]
loss_mask = batch.loss_mask.to(device, dtype)
timestep = random.randint(0,999)

In [6]:
def q_sample( images: torch.Tensor,noise: torch.Tensor, timestep : int):
    scale_factor = anydoor.solver.cumulative_scale_factors[timestep]
    sqrt_one_minus_scale_factor =anydoor.solver.noise_std[timestep]
    return scale_factor * images + sqrt_one_minus_scale_factor * noise

In [7]:


with torch.no_grad():
    object_embedding = anydoor.object_encoder.forward(object)
    noise = anydoor.sample_noise(size=(batch_size, 4, 64, 64), device=device, dtype=dtype)
    background_latents = anydoor.lda.encode(background)
    noisy_backgrounds = q_sample(background_latents, noise, timestep)
    predicted_noise = anydoor.forward(noisy_backgrounds,step=timestep,control_background_image=collage,object_embedding=object_embedding)


In [11]:

anydoor.unet.use_context("sampling")



{'shapes': []}

In [None]:
## Print number of ones and zeros in the mask
print("Number of ones in mask: ", torch.sum(loss_mask))
print("Number of zeros in mask: ", torch.sum(1-loss_mask))

Number of ones in mask:  tensor(4347., device='cuda:0')
Number of zeros in mask:  tensor(16133., device='cuda:0')


In [9]:
atv_loss

tensor(0., device='cuda:0')

In [None]:
training = TrainingConfig(
    duration=Epoch(10),
    device="cuda" if torch.cuda.is_available() else "cpu",
    dtype="float32",
)

optimizer = OptimizerConfig(
    optimizer=Optimizers.AdamW,
    learning_rate=1e-5,
)

lr_scheduler = LRSchedulerConfig(
    type=LRSchedulerType.CONSTANT_LR,
)

anydoor_config = AnydoorModelConfig(
    path_to_unet="ckpt/refiners/unet.safetensors",
    path_to_control_model="ckpt/refiners/controlnet.safetensors",
    path_to_object_encoder="ckpt/refiners/dinov2_encoder.safetensors",
    path_to_lda="ckpt/refiners/lda_new.safetensors",
)

training_config = AnydoorTrainingConfig(
    train_dataset='dataset/train/cloth',
    test_dataset='dataset/test/cloth',
    batch_size=16,
    anydoor=anydoor_config,
    training=training,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler
)

In [3]:
trainer = AnyDoorLoRATrainer(training_config)

[32m2024-11-20 11:11:52.607[0m | [1mINFO    [0m | [36mrefiners.training_utils.trainer[0m:[36mdevice[0m:[36m153[0m - [1mUsing device: cuda[0m
[32m2024-11-20 11:11:52.608[0m | [1mINFO    [0m | [36mrefiners.training_utils.trainer[0m:[36mdtype[0m:[36m160[0m - [1mUsing dtype: torch.float32[0m


[32m2024-11-20 11:12:40.224[0m | [1mINFO    [0m | [36mrefiners.training_utils.trainer[0m:[36mwrapper[0m:[36m91[0m - [1mNumber of learnable parameters in anydoor: 6.9M[0m


In [4]:
# Count number of learnable parameters
import numpy as np

def human_readable_number(number) -> str:
    float_number = float(number)
    for unit in ["", "K", "M", "G", "T", "P"]:
        if abs(float_number) < 1000:
            return f"{float_number:.1f}{unit}"
        float_number /= 1000
    return f"{float_number:.1f}E"

model_parameters = filter(lambda p: p.requires_grad, trainer.anydoor.parameters())
learnable_params = sum([np.prod(p.size()) for p in model_parameters])
total_params = sum(p.numel() for p in trainer.anydoor.parameters())
print(f"Learnable parameters: {human_readable_number(learnable_params)}")
print(f"Total parameters: {human_readable_number(total_params)}")
print(f"Percetage of learnable parameters: {learnable_params / total_params * 100:.2f}%")


Learnable parameters: 6.9M
Total parameters: 2.5G
Percetage of learnable parameters: 0.28%


In [5]:
# Print number of total parameters
trainer.train()


[32m2024-11-20 11:12:40.287[0m | [1mINFO    [0m | [36mrefiners.training_utils.clock[0m:[36mlog[0m:[36m90[0m - [1mStarting training for Epoch(number=100).[0m
[32m2024-11-20 11:12:40.288[0m | [1mINFO    [0m | [36mrefiners.training_utils.clock[0m:[36mlog[0m:[36m90[0m - [1mEpoch 0 started.[0m


[32m2024-11-20 11:12:41.815[0m | [1mINFO    [0m | [36mrefiners.training_utils.clock[0m:[36mlog[0m:[36m90[0m - [1mIteration 0 started.[0m
[32m2024-11-20 11:12:41.817[0m | [1mINFO    [0m | [36mrefiners.training_utils.clock[0m:[36mlog[0m:[36m90[0m - [1mStep 0 started.[0m


Loss: 0.07028716802597046


[32m2024-11-20 11:12:44.775[0m | [1mINFO    [0m | [36mrefiners.training_utils.trainer[0m:[36moptimizer[0m:[36m210[0m - [1mTotal number of learnable parameters in the model(s): 6.9M[0m
[32m2024-11-20 11:12:44.893[0m | [1mINFO    [0m | [36mrefiners.training_utils.clock[0m:[36mlog[0m:[36m90[0m - [1mStep 0 ended.[0m
[32m2024-11-20 11:12:46.325[0m | [1mINFO    [0m | [36mrefiners.training_utils.clock[0m:[36mlog[0m:[36m90[0m - [1mIteration 0 ended.[0m
[32m2024-11-20 11:12:46.327[0m | [1mINFO    [0m | [36mrefiners.training_utils.clock[0m:[36mlog[0m:[36m90[0m - [1mIteration 1 started.[0m
[32m2024-11-20 11:12:46.328[0m | [1mINFO    [0m | [36mrefiners.training_utils.clock[0m:[36mlog[0m:[36m90[0m - [1mStep 1 started.[0m


Loss: 0.08815356343984604


[32m2024-11-20 11:12:48.369[0m | [1mINFO    [0m | [36mrefiners.training_utils.clock[0m:[36mlog[0m:[36m90[0m - [1mStep 1 ended.[0m
[32m2024-11-20 11:12:49.509[0m | [1mINFO    [0m | [36mrefiners.training_utils.clock[0m:[36mlog[0m:[36m90[0m - [1mIteration 1 ended.[0m
[32m2024-11-20 11:12:49.511[0m | [1mINFO    [0m | [36mrefiners.training_utils.clock[0m:[36mlog[0m:[36m90[0m - [1mIteration 2 started.[0m
[32m2024-11-20 11:12:49.511[0m | [1mINFO    [0m | [36mrefiners.training_utils.clock[0m:[36mlog[0m:[36m90[0m - [1mStep 2 started.[0m


Loss: 0.006110812537372112


[32m2024-11-20 11:12:51.511[0m | [1mINFO    [0m | [36mrefiners.training_utils.clock[0m:[36mlog[0m:[36m90[0m - [1mStep 2 ended.[0m
[32m2024-11-20 11:12:53.000[0m | [1mINFO    [0m | [36mrefiners.training_utils.clock[0m:[36mlog[0m:[36m90[0m - [1mIteration 2 ended.[0m
[32m2024-11-20 11:12:53.002[0m | [1mINFO    [0m | [36mrefiners.training_utils.clock[0m:[36mlog[0m:[36m90[0m - [1mIteration 3 started.[0m
[32m2024-11-20 11:12:53.002[0m | [1mINFO    [0m | [36mrefiners.training_utils.clock[0m:[36mlog[0m:[36m90[0m - [1mStep 3 started.[0m


Loss: 0.08576411008834839


[32m2024-11-20 11:12:55.005[0m | [1mINFO    [0m | [36mrefiners.training_utils.clock[0m:[36mlog[0m:[36m90[0m - [1mStep 3 ended.[0m


AssertionError: 