In [1]:
%load_ext autoreload
%autoreload 2
%cd ..

/home/akkirr/annotated-diffusion


In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from pathlib import Path

import torch
from torch import nan_to_num
from torchvision import transforms as T
from torch.utils.data import DataLoader

import numpy as np
from PIL import Image
import requests

from datasets import load_dataset
from torchvision.utils import save_image
from torch.optim import Adam

from copy import deepcopy
import os

In [4]:
from mylib import *
import mylora
import lora_diffusion

In [5]:
sampler = Sampler(linear_beta_schedule, 300)

In [6]:
settings = Settings(
    results_folder = Path("./results-mnist/1-baseline"),
    image_size = 28,
    channels = 1,
    batch_size = 128,
    device = "cuda" if torch.cuda.is_available() else "cpu",
    checkpoint = 'checkpoints/3-mnist-0:1,3:9.pt'
)
settings

{
    "results_folder": "PosixPath('results-mnist/1-baseline')",
    "image_size": 28,
    "channels": 1,
    "batch_size": 128,
    "device": "cuda",
    "checkpoint": "checkpoints/3-mnist-0:1,3:9.pt"
}

In [7]:
settings.results_folder.mkdir(exist_ok=True, parents=True)

In [8]:
dataset = load_dataset("mnist")
# define image transformations (e.g. using torchvision)
transform = Compose([
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Lambda(lambda t: (t * 2) - 1)
])

# define function
def transforms(examples):
   examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]
   del examples["image"]

   return examples

transformed_dataset = dataset.with_transform(transforms).filter(lambda x: x['label'] != 2).remove_columns("label")

# create dataloader
dataloader = DataLoader(transformed_dataset["train"], batch_size=settings.batch_size, shuffle=True)

Found cached dataset mnist (/home/akkirr/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)


  0%|          | 0/2 [00:00<?, ?it/s]

Loading cached processed dataset at /home/akkirr/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332/cache-d18d5dcedca7f7f9.arrow
Loading cached processed dataset at /home/akkirr/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332/cache-01fe572201716064.arrow


# Train baseline

In [15]:
set_all_seeds()
model = Unet(
    dim=settings.image_size,
    channels=settings.channels,
    dim_mults=(1, 2, 4,)
)

model.to(settings.device)
mylora.model_summary(model)

trainable layers:           231
frozen layers:                0
total params:           2020257


In [16]:
optimizer = Adam(model.parameters(), lr=1e-3)
train(model, optimizer, dataloader, sampler, settings, epochs=10)

sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.38870003819465637
Loss: 0.040315043181180954
Loss: 0.02934751659631729
Loss: 0.028124278411269188
Loss: 0.02748536504805088


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.025776633992791176
Loss: 0.022529777139425278
Loss: 0.022077929228544235
Loss: 0.022946104407310486
Loss: 0.023017099127173424


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.02249987982213497
Loss: 0.020714594051241875
Loss: 0.020471815019845963
Loss: 0.021730223670601845
Loss: 0.020951353013515472


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.02081342041492462
Loss: 0.019031308591365814
Loss: 0.01930704154074192
Loss: 0.020853281021118164
Loss: 0.0201482642441988


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.02025529555976391
Loss: 0.0182176623493433
Loss: 0.018395226448774338
Loss: 0.0201185904443264
Loss: 0.019668594002723694


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.019672637805342674
Loss: 0.01785200461745262
Loss: 0.018006399273872375
Loss: 0.019611923024058342
Loss: 0.018802886828780174


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.019213758409023285
Loss: 0.01745474711060524
Loss: 0.017739497125148773
Loss: 0.019340617582201958
Loss: 0.01843217946588993


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.01878400892019272
Loss: 0.017494192346930504
Loss: 0.01738337241113186
Loss: 0.019208043813705444
Loss: 0.018164334818720818


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.018438026309013367
Loss: 0.017084669321775436
Loss: 0.01701967976987362
Loss: 0.018673282116651535
Loss: 0.01828204281628132


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.018175700679421425
Loss: 0.016921713948249817
Loss: 0.01689371094107628
Loss: 0.018432723358273506
Loss: 0.017681581899523735


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

In [17]:
settings

{
    "results_folder": "PosixPath('results-mnist/1-baseline')",
    "image_size": 28,
    "channels": 1,
    "batch_size": 128,
    "device": "cuda",
    "checkpoint": "checkpoints/3-mnist-0:1,3:9.pt"
}

In [18]:
os.system(f"ffmpeg -f image2 -framerate 7 -i {str(settings.results_folder)}/sample-%d.png -loop -0 {str(settings.results_folder)}/sample.gif -y")

ffmpeg version 4.3 Copyright (c) 2000-2020 the FFmpeg developers
  built with gcc 7.3.0 (crosstool-NG 1.23.0.449-a04d0)
  configuration: --prefix=/opt/conda/conda-bld/ffmpeg_1597178665428/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeh --cc=/opt/conda/conda-bld/ffmpeg_1597178665428/_build_env/bin/x86_64-conda_cos6-linux-gnu-cc --disable-doc --disable-openssl --enable-avresample --enable-gnutls --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame
  libavutil      56. 51.100 / 56. 51.100
  libavcodec     58. 91.100 / 58. 91.100
  libavformat    58. 45.100 / 58. 45.100
  libavdevice    58. 10.100 / 58. 10.100
  libavfilter     7. 85.100 /  7. 85.100
  libavresample   4.  0.  0 /  4.  0.  0
  libsw

0

In [23]:
torch.save(model.state_dict(), settings.checkpoint)

### Train lora 

In [24]:
set_all_seeds()
model = Unet(
    dim=settings.image_size,
    channels=settings.channels,
    dim_mults=(1, 2, 4,)
)
model.load_state_dict(torch.load(settings.checkpoint))

mylora.inject_lora(
    model, 2, 0.4,
    ['LinearAttention'],
    [nn.Conv2d]
)
model.to(settings.device)

mylora.freeze_lora(model)
print()
mylora.model_summary(model)

Injected lora    28 x 2 x 384   in downs.0.2.fn.fn.to_qkv
Injected lora   128 x 2 x 28    in downs.0.2.fn.fn.0
Injected lora    28 x 2 x 384   in downs.1.2.fn.fn.to_qkv
Injected lora   128 x 2 x 28    in downs.1.2.fn.fn.0
Injected lora    56 x 2 x 384   in downs.2.2.fn.fn.to_qkv
Injected lora   128 x 2 x 56    in downs.2.2.fn.fn.0
Injected lora   112 x 2 x 384   in ups.0.2.fn.fn.to_qkv
Injected lora   128 x 2 x 112   in ups.0.2.fn.fn.0
Injected lora    56 x 2 x 384   in ups.1.2.fn.fn.to_qkv
Injected lora   128 x 2 x 56    in ups.1.2.fn.fn.0
Injected lora    28 x 2 x 384   in ups.2.2.fn.fn.to_qkv
Injected lora   128 x 2 x 28    in ups.2.2.fn.fn.0

trainable layers:            24
frozen layers:              231
total params:           2027633


In [35]:
transformed_dataset = dataset.with_transform(transforms).filter(lambda x: x['label'] == 2).remove_columns("label")
dataloader = DataLoader(transformed_dataset["train"], batch_size=settings.batch_size, shuffle=True)

Loading cached processed dataset at /home/akkirr/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332/cache-4fb5f69a4c536345.arrow
Loading cached processed dataset at /home/akkirr/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332/cache-cc87227bbbbe9a2e.arrow


In [36]:
settings.results_folder = Path("./results-mnist/2-rank=2_do=0.25")
settings.results_folder.mkdir(exist_ok=True, parents=True)

optimizer = Adam(model.parameters(), lr=1e-3)
train(model, optimizer, dataloader, sampler, settings, epochs=10)

sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
os.system(f"ffmpeg -f image2 -framerate 7 -i {str(settings.results_folder)}/sample-%d.png -loop -0 {str(settings.results_folder)}/sample.gif -y")

### Train open-source lora

In [10]:
set_all_seeds()
model = Unet(
    dim=settings.image_size,
    channels=settings.channels,
    dim_mults=(1, 2, 4,)
)
model.load_state_dict(torch.load(settings.checkpoint))

model.requires_grad_(False)
unet_lora_params, train_names = lora_diffusion.inject_trainable_lora_extended(
    model,
    target_replace_module=['LinearAttention'],
    r=1,
)
model.to(settings.device)

mylora.model_summary(model)

trainable layers:            24
frozen layers:              231
total params:           2023945


In [11]:
transformed_dataset = dataset.with_transform(transforms).filter(lambda x: x['label'] == 2).remove_columns("label")
dataloader = DataLoader(transformed_dataset["train"], batch_size=settings.batch_size, shuffle=True)

Loading cached processed dataset at /home/akkirr/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332/cache-4fb5f69a4c536345.arrow
Loading cached processed dataset at /home/akkirr/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332/cache-cc87227bbbbe9a2e.arrow


In [12]:
settings.results_folder = Path("./results-mnist/3-os-lora-default")
settings.results_folder.mkdir(exist_ok=True, parents=True)

optimizer = Adam(model.parameters(), lr=1e-3)
train(model, optimizer, dataloader, sampler, settings, epochs=10)

sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.02309069223701954


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.021512994542717934


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.02117716521024704


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.021008169278502464


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020907269790768623


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020840400829911232


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020795632153749466


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.02075948938727379


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.02072523534297943


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

Loss: 0.020693954080343246


sampling loop time step:   0%|          | 0/300 [00:00<?, ?it/s]

In [13]:
folder = str(settings.results_folder)
! /home/akkirr/.cargo/bin/gifski -o $folder/sample.gif -r 7 $folder/sample-*.png

gifski created /home/akkirr/annotated-diffusion/results-mnist/3-os-lora-default/sample.gif