In [1]:
%load_ext autoreload
%autoreload 2
%cd ..

/home/akkirr/annotated-diffusion


In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from pathlib import Path

import torch
from torch import nan_to_num
from torchvision import transforms as T
from torch.utils.data import DataLoader

import numpy as np
from PIL import Image
import requests

from datasets import load_dataset
from torchvision.utils import save_image
from torch.optim import Adam

from copy import deepcopy
import os

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [4]:
from mylib import *
import mylora

In [5]:
sampler = Sampler(linear_beta_schedule, 300)

In [6]:
settings = Settings(
    results_folder = Path("./results-cifar/1-baseline"),
    image_size = 28,
    channels = 3,
    batch_size = 128,
    device = "cuda" if torch.cuda.is_available() else "cpu",
    checkpoint = 'checkpoints/4-cifar-colored.pt'
)
settings

{
    "results_folder": "PosixPath('results-cifar/1-baseline')",
    "image_size": 28,
    "channels": 3,
    "batch_size": 128,
    "device": "cuda",
    "checkpoint": "checkpoints/4-cifar-colored.pt"
}

In [68]:
settings.results_folder.mkdir(exist_ok=True, parents=True)

In [69]:
dataset = load_dataset("cifar10")
# define image transformations (e.g. using torchvision)
transform = Compose([
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Lambda(lambda t: (t * 2) - 1)
])

# define function
def transforms(examples):
   examples["pixel_values"] = [transform(image) for image in examples["img"]]
   del examples["img"]

   return examples

transformed_dataset = dataset.with_transform(transforms).remove_columns("label")

# create dataloader
dataloader = DataLoader(transformed_dataset["train"], batch_size=settings.batch_size, shuffle=True)

# Train baseline

In [None]:
set_all_seeds()
model = Unet(
    dim=settings.image_size,
    channels=settings.channels,
    dim_mults=(1, 2, 4,)
)

model.to(settings.device)
mylora.model_summary(model)

In [None]:
optimizer = Adam(model.parameters(), lr=1e-3)
train(model, optimizer, dataloader, sampler, settings, epochs=10)

In [8]:
folder = str(settings.results_folder)
! /home/akkirr/.cargo/bin/gifski -o $folder/sample.gif -r 7 $folder/sample-*.png

gifski created /home/akkirr/annotated-diffusion/results-cifar/1-baseline/sample.gif

In [None]:
torch.save(model.state_dict(), settings.checkpoint)

# Train lora 

In [None]:
set_all_seeds()
model = Unet(
    dim=settings.image_size,
    channels=settings.channels,
    dim_mults=(1, 2, 4,)
)
model.load_state_dict(torch.load(settings.checkpoint))

mylora.inject_lora(
    model, 2, 0.4,
    ['LinearAttention'],
    [nn.Conv2d]
)
model.to(settings.device)

mylora.freeze_lora(model)
print()
mylora.model_summary(model)

Injected lora    28 x 2 x 384   in downs.0.2.fn.fn.to_qkv
Injected lora   128 x 2 x 28    in downs.0.2.fn.fn.0
Injected lora    28 x 2 x 384   in downs.1.2.fn.fn.to_qkv
Injected lora   128 x 2 x 28    in downs.1.2.fn.fn.0
Injected lora    56 x 2 x 384   in downs.2.2.fn.fn.to_qkv
Injected lora   128 x 2 x 56    in downs.2.2.fn.fn.0
Injected lora   112 x 2 x 384   in ups.0.2.fn.fn.to_qkv
Injected lora   128 x 2 x 112   in ups.0.2.fn.fn.0
Injected lora    56 x 2 x 384   in ups.1.2.fn.fn.to_qkv
Injected lora   128 x 2 x 56    in ups.1.2.fn.fn.0
Injected lora    28 x 2 x 384   in ups.2.2.fn.fn.to_qkv
Injected lora   128 x 2 x 28    in ups.2.2.fn.fn.0

trainable layers:            24
frozen layers:              231
total params:           2027747


In [None]:
def transforms(examples):
   examples["pixel_values"] = [transform(image.convert('L').convert('RGB')) for image in examples["img"]]
   del examples["img"]

   return examples

transformed_dataset = dataset.with_transform(transforms).remove_columns("label")

# create dataloader
dataloader = DataLoader(transformed_dataset["train"], batch_size=settings.batch_size, shuffle=True)

In [None]:
settings.results_folder = Path("./results-cifar/2-rank=2_do=0.25")
settings.results_folder.mkdir(exist_ok=True)

optimizer = Adam(model.parameters(), lr=1e-3)
train(model, optimizer, dataloader, sampler, settings, epochs=10)

sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.021473567932844162
Loss: 0.024918898940086365
Loss: 0.02179855853319168
Loss: 0.024604542180895805


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020823504775762558
Loss: 0.024760788306593895
Loss: 0.021668439731001854
Loss: 0.02450842224061489


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.02077310159802437
Loss: 0.024704869836568832
Loss: 0.02161295711994171
Loss: 0.024459373205900192


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020757650956511497
Loss: 0.024674801155924797
Loss: 0.021580655127763748
Loss: 0.024433648213744164


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020747873932123184
Loss: 0.024654865264892578
Loss: 0.02155657857656479
Loss: 0.024414589628577232


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020738670602440834
Loss: 0.024639418348670006
Loss: 0.021537702530622482
Loss: 0.024399619549512863


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020729348063468933
Loss: 0.024627139791846275
Loss: 0.021523132920265198
Loss: 0.02438758686184883


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020720016211271286
Loss: 0.024616152048110962
Loss: 0.02151203155517578
Loss: 0.024378424510359764


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.02071075513958931
Loss: 0.024606024846434593
Loss: 0.021503090858459473
Loss: 0.024371594190597534


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020701846107840538
Loss: 0.024596666917204857
Loss: 0.021495385095477104
Loss: 0.024366270750761032


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

In [9]:
folder = str(settings.results_folder)
! /home/akkirr/.cargo/bin/gifski -o $folder/sample.gif -r 7 $folder/sample-*.png

gifski created /home/akkirr/annotated-diffusion/results-cifar/1-baseline/sample.gif