In [1]:
%load_ext autoreload
%autoreload 2
%cd ..

/home/akkirr/annotated-diffusion


In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from pathlib import Path

import torch
from torch import nan_to_num
from torchvision import transforms as T
from torch.utils.data import DataLoader

import numpy as np
from PIL import Image
import requests

from datasets import load_dataset
from torchvision.utils import save_image
from torch.optim import Adam

from copy import deepcopy
import os

In [3]:
from mylib import *
import mylora
import lora_diffusion

os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [4]:
sampler = Sampler(linear_beta_schedule, 300)

In [5]:
settings = Settings(
    results_folder = Path("./results-cifar/1-baseline"),
    image_size = 28,
    channels = 3,
    batch_size = 128,
    device = "cuda" if torch.cuda.is_available() else "cpu",
    checkpoint = 'checkpoints/4-cifar-colored.pt'
)
settings

{
    "results_folder": "PosixPath('results-cifar/1-baseline')",
    "image_size": 28,
    "channels": 3,
    "batch_size": 128,
    "device": "cuda",
    "checkpoint": "checkpoints/4-cifar-colored.pt"
}

In [6]:
settings.results_folder.mkdir(exist_ok=True, parents=True)

In [7]:
dataset = load_dataset("cifar10")
# define image transformations (e.g. using torchvision)
transform = Compose([
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Lambda(lambda t: (t * 2) - 1)
])

# define function
def transforms(examples):
   examples["pixel_values"] = [transform(image) for image in examples["img"]]
   del examples["img"]

   return examples

transformed_dataset = dataset.with_transform(transforms).remove_columns("label")

# create dataloader
dataloader = DataLoader(transformed_dataset["train"], batch_size=settings.batch_size, shuffle=True)

Found cached dataset cifar10 (/home/akkirr/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4)


  0%|          | 0/2 [00:00<?, ?it/s]

# Train baseline

In [None]:
set_all_seeds()
model = Unet(
    dim=settings.image_size,
    channels=settings.channels,
    dim_mults=(1, 2, 4,)
)

model.to(settings.device)
mylora.model_summary(model)

In [None]:
optimizer = Adam(model.parameters(), lr=1e-3)
train(model, optimizer, dataloader, sampler, settings, epochs=10)

In [8]:
folder = str(settings.results_folder)
! /home/akkirr/.cargo/bin/gifski -o $folder/sample.gif -r 7 $folder/sample-*.png

gifski created /home/akkirr/annotated-diffusion/results-cifar/1-baseline/sample.gif

In [None]:
torch.save(model.state_dict(), settings.checkpoint)

# Train lora 

In [None]:
set_all_seeds()
model = Unet(
    dim=settings.image_size,
    channels=settings.channels,
    dim_mults=(1, 2, 4,)
)
model.load_state_dict(torch.load(settings.checkpoint))

mylora.inject_lora(
    model, 2, 0.4,
    ['LinearAttention'],
    [nn.Conv2d]
)
model.to(settings.device)

mylora.freeze_lora(model)
print()
mylora.model_summary(model)

Injected lora    28 x 2 x 384   in downs.0.2.fn.fn.to_qkv
Injected lora   128 x 2 x 28    in downs.0.2.fn.fn.0
Injected lora    28 x 2 x 384   in downs.1.2.fn.fn.to_qkv
Injected lora   128 x 2 x 28    in downs.1.2.fn.fn.0
Injected lora    56 x 2 x 384   in downs.2.2.fn.fn.to_qkv
Injected lora   128 x 2 x 56    in downs.2.2.fn.fn.0
Injected lora   112 x 2 x 384   in ups.0.2.fn.fn.to_qkv
Injected lora   128 x 2 x 112   in ups.0.2.fn.fn.0
Injected lora    56 x 2 x 384   in ups.1.2.fn.fn.to_qkv
Injected lora   128 x 2 x 56    in ups.1.2.fn.fn.0
Injected lora    28 x 2 x 384   in ups.2.2.fn.fn.to_qkv
Injected lora   128 x 2 x 28    in ups.2.2.fn.fn.0

trainable layers:            24
frozen layers:              231
total params:           2027747


In [None]:
def transforms(examples):
   examples["pixel_values"] = [transform(image.convert('L').convert('RGB')) for image in examples["img"]]
   del examples["img"]

   return examples

transformed_dataset = dataset.with_transform(transforms).remove_columns("label")

# create dataloader
dataloader = DataLoader(transformed_dataset["train"], batch_size=settings.batch_size, shuffle=True)

In [None]:
settings.results_folder = Path("./results-cifar/2-rank=2_do=0.25")
settings.results_folder.mkdir(exist_ok=True)

optimizer = Adam(model.parameters(), lr=1e-3)
train(model, optimizer, dataloader, sampler, settings, epochs=10)

sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.021473567932844162
Loss: 0.024918898940086365
Loss: 0.02179855853319168
Loss: 0.024604542180895805


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020823504775762558
Loss: 0.024760788306593895
Loss: 0.021668439731001854
Loss: 0.02450842224061489


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.02077310159802437
Loss: 0.024704869836568832
Loss: 0.02161295711994171
Loss: 0.024459373205900192


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020757650956511497
Loss: 0.024674801155924797
Loss: 0.021580655127763748
Loss: 0.024433648213744164


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020747873932123184
Loss: 0.024654865264892578
Loss: 0.02155657857656479
Loss: 0.024414589628577232


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020738670602440834
Loss: 0.024639418348670006
Loss: 0.021537702530622482
Loss: 0.024399619549512863


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020729348063468933
Loss: 0.024627139791846275
Loss: 0.021523132920265198
Loss: 0.02438758686184883


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020720016211271286
Loss: 0.024616152048110962
Loss: 0.02151203155517578
Loss: 0.024378424510359764


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.02071075513958931
Loss: 0.024606024846434593
Loss: 0.021503090858459473
Loss: 0.024371594190597534


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020701846107840538
Loss: 0.024596666917204857
Loss: 0.021495385095477104
Loss: 0.024366270750761032


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

In [9]:
folder = str(settings.results_folder)
! /home/akkirr/.cargo/bin/gifski -o $folder/sample.gif -r 7 $folder/sample-*.png

gifski created /home/akkirr/annotated-diffusion/results-cifar/1-baseline/sample.gif

### Open-source lora

In [8]:
set_all_seeds()
model = Unet(
    dim=settings.image_size,
    channels=settings.channels,
    dim_mults=(1, 2, 4,)
)
model.load_state_dict(torch.load(settings.checkpoint))

model.requires_grad_(False)
unet_lora_params, train_names = lora_diffusion.inject_trainable_lora_extended(
    model,
    target_replace_module=['LinearAttention'],
    r=1,
)
model.to(settings.device)

mylora.model_summary(model)

trainable layers:            24
frozen layers:              231
total params:           2024059


In [9]:
def transforms(examples):
   examples["pixel_values"] = [transform(image.convert('L').convert('RGB')) for image in examples["img"]]
   del examples["img"]

   return examples

transformed_dataset = dataset.with_transform(transforms).remove_columns("label")

# create dataloader
dataloader = DataLoader(transformed_dataset["train"], batch_size=settings.batch_size, shuffle=True)

In [10]:
settings.results_folder = Path("./results-cifar/5-colored2bw-os-lora")
settings.results_folder.mkdir(exist_ok=True)

optimizer = Adam(model.parameters(), lr=1e-3)
train(model, optimizer, dataloader, sampler, settings, epochs=10)

sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.021473567932844162
Loss: 0.02498686872422695
Loss: 0.021899407729506493
Loss: 0.02471090480685234


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020921876654028893
Loss: 0.02482043020427227
Loss: 0.021768230944871902
Loss: 0.02462949976325035


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020863980054855347
Loss: 0.024779578670859337
Loss: 0.021728238090872765
Loss: 0.024595197290182114


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.02083469182252884
Loss: 0.024757003411650658
Loss: 0.02170713059604168
Loss: 0.02457159012556076


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020816795527935028
Loss: 0.02474113367497921
Loss: 0.021693086251616478
Loss: 0.024559326469898224


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020802360028028488
Loss: 0.024728070944547653
Loss: 0.021683409810066223
Loss: 0.024551287293434143


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.02078893594443798
Loss: 0.024715717881917953
Loss: 0.021675661206245422
Loss: 0.024544311687350273


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020777206867933273
Loss: 0.024702858179807663
Loss: 0.02166832983493805
Loss: 0.024538110941648483


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.02076786383986473
Loss: 0.024689164012670517
Loss: 0.021660275757312775
Loss: 0.024532657116651535


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020760994404554367
Loss: 0.02467610314488411
Loss: 0.021651385352015495
Loss: 0.024527501314878464


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

In [11]:
folder = str(settings.results_folder)
! /home/akkirr/.cargo/bin/gifski -o $folder/sample.gif -r 7 $folder/sample-*.png

gifski created /home/akkirr/annotated-diffusion/results-cifar/5-colored2bw-os-lora/sample.gif

### Open-source lora bw2colored

In [13]:
settings.checkpoint = 'checkpoints/5-cifar-bw.pt'

In [14]:
set_all_seeds()
model = Unet(
    dim=settings.image_size,
    channels=settings.channels,
    dim_mults=(1, 2, 4,)
)
model.load_state_dict(torch.load(settings.checkpoint))

model.requires_grad_(False)
unet_lora_params, train_names = lora_diffusion.inject_trainable_lora_extended(
    model,
    target_replace_module=['LinearAttention'],
    r=1,
)
model.to(settings.device)

mylora.model_summary(model)

trainable layers:            24
frozen layers:              231
total params:           2024059


In [15]:
# define function
def transforms(examples):
   examples["pixel_values"] = [transform(image) for image in examples["img"]]
   del examples["img"]

   return examples

transformed_dataset = dataset.with_transform(transforms).remove_columns("label")

# create dataloader
dataloader = DataLoader(transformed_dataset["train"], batch_size=settings.batch_size, shuffle=True)

In [16]:
set_all_seeds()
settings.results_folder = Path("./results-cifar/6-bw2colored-os-lora")
settings.results_folder.mkdir(exist_ok=True)

optimizer = Adam(model.parameters(), lr=1e-3)
train(model, optimizer, dataloader, sampler, settings, epochs=10)

sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.05254998803138733
Loss: 0.07035643607378006
Loss: 0.06077330559492111
Loss: 0.06180128455162048


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.04559360072016716
Loss: 0.06622746586799622
Loss: 0.05924654006958008
Loss: 0.060522809624671936


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.04515715688467026
Loss: 0.06404444575309753
Loss: 0.05824681743979454
Loss: 0.05751175805926323


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.04380150884389877
Loss: 0.06191066652536392
Loss: 0.05649096891283989
Loss: 0.056486230343580246


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.04224162921309471
Loss: 0.061619531363248825
Loss: 0.05541720241308212
Loss: 0.05552754923701286


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.042103759944438934
Loss: 0.060648731887340546
Loss: 0.05456799268722534
Loss: 0.05487517639994621


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.041910506784915924
Loss: 0.05988053232431412
Loss: 0.05392260104417801
Loss: 0.05455612763762474


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.04188649356365204
Loss: 0.05927160382270813
Loss: 0.053231414407491684
Loss: 0.05386975407600403


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.04172809422016144
Loss: 0.058676764369010925
Loss: 0.05286147817969322
Loss: 0.053213782608509064


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.041416339576244354
Loss: 0.05828488618135452
Loss: 0.05257154628634453
Loss: 0.05259702727198601


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

In [17]:
folder = str(settings.results_folder)
! /home/akkirr/.cargo/bin/gifski -o $folder/sample.gif -r 7 $folder/sample-*.png

gifski created /home/akkirr/annotated-diffusion/results-cifar/6-bw2colored-os-lora/sample.gif