In [1]:
import numpy as np
import torch
import torch.nn as nn
import math
import fastai
from PIL import Image
from diffusers import DiffusionPipeline
from diffusers.utils import pt_to_pil
from dataloader import get_imagenette_dataloader
from quantize import quantize_img, plot_imgs
from ddpm import DDPMCB
from preprocessing import clip_preprocess, conditioning_transform
from functools import partial
from fastai.vision.all import (ImageDataLoaders, Resize, TensorImage, Learner, 
                               Callback, Normalize)
from encoder import ViTImageEncoder
import fastcore.all as fc

device = "cuda"

In [2]:
def method_helper(o): return list(filter(lambda x: x[0] != "_", dir(o)))

In [3]:
stage_2 = DiffusionPipeline.from_pretrained(
    "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", 
     torch_dtype=torch.float16, class_labels=None 
)


A mixture of fp16 and non-fp16 filenames will be loaded.
Loaded fp16 filenames:
[text_encoder/model.fp16-00001-of-00002.safetensors, text_encoder/model.fp16-00002-of-00002.safetensors, safety_checker/model.fp16.safetensors, unet/diffusion_pytorch_model.fp16.safetensors]
Loaded non-fp16 filenames:
[watermarker/diffusion_pytorch_model.safetensors
If this behavior is not expected, please check your folder structure.
Keyword arguments {'class_labels': None} are not expected by IFSuperResolutionPipeline and will be ignored.


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [4]:
scheduler = stage_2.scheduler
unet = stage_2.unet.to(device)

In [5]:
dls = ImageDataLoaders.from_folder( "/mnt/wd/datasets/imagenette2", valid_pct=0.1, bs=1,)
one_batch = dls.one_batch()[0]
one_batch.shape

torch.Size([1, 3, 334, 500])

In [6]:
encoder = ViTImageEncoder(7, output_dim=unet.config.encoder_hid_dim).to(device)
encoder_preprocess = encoder.feature_extractor
c_preprocess = partial(clip_preprocess, stage_2=stage_2)
cond_transform = partial(conditioning_transform, encode_preprocess=None)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
def preprocessing(x):
    if not isinstance(x, fastai.vision.core.TensorCategory):
        x = Resize(224)(x)
        x = TensorImage(x).permute(2,1,0)
    x = cond_transform(x)
    x = x.to("cpu")
    x = c_preprocess(x)
    return x

In [None]:
dls = ImageDataLoaders.from_folder(
    "/mnt/wd/datasets/imagenette2",
    valid_pct=0.1,
    item_tfms=[preprocessing],
    # batch_tfms=[Normalize()],
    bs=4,
    num_workers=16
)

> [0;32m/tmp/ipykernel_16954/4111660116.py[0m(9)[0;36mpreprocessing[0;34m()[0m
[0;32m      5 [0;31m    [0mx[0m [0;34m=[0m [0mcond_transform[0m[0;34m([0m[0mx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      6 [0;31m    [0mx[0m [0;34m=[0m [0mx[0m[0;34m.[0m[0mto[0m[0;34m([0m[0;34m"cpu"[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      7 [0;31m    [0mx[0m [0;34m=[0m [0mc_preprocess[0m[0;34m([0m[0mx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      8 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 9 [0;31m    [0;32mreturn[0m [0mx[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> x.shape
torch.Size([10, 224, 224])


In [30]:
dls.one_batch()[0].shape
dls.one_batch()[0][0,2,...].std()

> [0;32m/home/artursil/anaconda3/envs/ai/lib/python3.12/site-packages/transformers/image_utils.py[0m(255)[0;36minfer_channel_dimension_format[0;34m()[0m
[0;32m    253 [0;31m    [0;32melif[0m [0mimage[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0mlast_dim[0m[0;34m][0m [0;32min[0m [0mnum_channels[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    254 [0;31m        [0;32mreturn[0m [0mChannelDimension[0m[0;34m.[0m[0mLAST[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 255 [0;31m    [0;32mraise[0m [0mValueError[0m[0;34m([0m[0;34m"Unable to infer channel dimension format"[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    256 [0;31m[0;34m[0m[0m
[0m[0;32m    257 [0;31m[0;34m[0m[0m
[0m
ipdb> u
> [0;32m/home/artursil/anaconda3/envs/ai/lib/python3.12/site-packages/transformers/models/clip/image_processing_clip.py[0m(320)[0;36mpreprocess[0;34m()[0m
[0;32m    318 [0;31m        [0;32mif[0m [0minput_data_format[0m [0;32mis[0m [0;32mNone

In [None]:
class CTModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.unet = unet
        self.unet.class_embedding = None
        self.vit = ViTImageEncoder(7, output_dim=self.unet.config.encoder_hid_dim).to(device)

        for param in self.unet.parameters():
            param.requires_grad = False
        

    def forward(self, noisy_images, images, t):
        encoded = self.vit(images).expand(-1, 77, -1).half()

        return self.unet(noisy_images.half(), t.half(), encoded.half())[0]

In [None]:
model = CTModel()

In [None]:
one_batch = dls.one_batch()
one_batch[0].shape
images = one_batch[0]
# images = torch.cat([images, images], dim=1)

In [None]:
# Without DDPM callback it won't work
# with torch.no_grad():
#     x = model(images, one_batch[0], torch.tensor([1.0]*4, dtype=torch.float16, device="cuda"))
# x

In [None]:
learn = Learner(dls, model, loss_func=torch.nn.MSELoss(), cbs=[DDPMCB(unet,scheduler)]).to_fp16()

In [None]:
from fastai.callback.hook import ActivationStats

# Create a list of layers to track. You can add or remove layers based on what you want to observe.
layers_to_track = [
    learn.model.vit.vit.embeddings.patch_embeddings.projection,
    learn.model.vit.vit.encoder.layer[0].attention.attention.query,
    learn.model.vit.vit.encoder.layer[0].attention.attention.key,
    learn.model.vit.vit.encoder.layer[0].attention.attention.value,
    learn.model.vit.vit.encoder.layer[0].intermediate.dense,
    learn.model.vit.vit.encoder.layer[0].output.dense,
    learn.model.vit.vit.encoder.layer[0].layernorm_before,
    learn.model.vit.vit.encoder.layer[0].layernorm_after,
    learn.model.vit.vit.encoder.layer[6].attention.attention.query,
    learn.model.vit.vit.encoder.layer[6].attention.attention.key,
    learn.model.vit.vit.encoder.layer[6].attention.attention.value,
    learn.model.vit.vit.encoder.layer[6].intermediate.dense,
    learn.model.vit.vit.encoder.layer[6].output.dense,
    learn.model.vit.vit.encoder.layer[6].layernorm_before,
    learn.model.vit.vit.encoder.layer[6].layernorm_after,
    learn.model.vit.vit.layernorm,
    learn.model.vit.vit.pooler.dense,
]

# Add the ActivationStats callback
astats = ActivationStats(modules=layers_to_track)
learn.add_cb(astats)

In [None]:
learn.lr_find()

In [None]:
lr = 10e-04
learn.fit_one_cycle(1, lr)

In [None]:
learn.save("ctransfer_epoch_1.pth")
# learn = learn.load("ctransfer_epoch_1.pth")

In [None]:
# learn2 = learn.add_cb(astats)

In [None]:
# astats.color_dim()

In [None]:
lr = 10e-05
# learn.fit_one_cycle(1, lr)
learn.load("ctransfer_epoch_2.pth")

In [None]:
lr = 10e-05
# learn.save("ctransfer_epoch_2.pth")
learn.fit_one_cycle(4, lr)
learn.save("ctransfer_epoch_3_6.pth")
# learn.save("ctransfer_epoch_4.pth")
# learn.fit_one_cycle(1, lr)
# learn.save("ctransfer_epoch_5.pth")

In [None]:
learn.fit_one_cycle(3, lr)
learn.save("ctransfer_epoch_6_8.pth")

In [None]:
lr = 10e-06
learn.fit_one_cycle(1, lr)
learn.save("ctransfer_epoch_9.pth")