In [1]:
from StableDiffuser import StableDiffuser
from finetuning import FineTunedModel
import torch
from tqdm import tqdm
import datetime
import torchvision
import torchvision.transforms as T
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data import sampler
from torch.utils.data import TensorDataset
from PIL import Image
from diffusers import AutoencoderKL

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def create_model(num_artists):
    import torchvision
    # transfer learning on top of ResNet (only replacing final FC layer)
    # model_conv = torchvision.models.resnet18(pretrained=True)
    model_conv = torchvision.models.resnet18(pretrained=True)
    # Parameters of newly constructed modules have requires_grad=True by default
    for param in model_conv.parameters():
        param.requires_grad = False
    num_ftrs = model_conv.fc.in_features
    model_conv.fc = nn.Linear(num_ftrs, num_artists)
    # load the pre-trained weights
    model_conv.load_state_dict(torch.load('./detector/artist/artist_ckp/state_dict.dat.von_gogh'))
    return model_conv

In [3]:
from StableDiffuser import StableDiffuser
from finetuning import FineTunedModel
import torch
from tqdm import tqdm
import datetime
import torchvision
import torchvision.transforms as T
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data import sampler
from torch.utils.data import TensorDataset
from PIL import Image
from diffusers import AutoencoderKL

In [4]:

class RGBConverter(nn.Module):
    def __init__(self):
        super(RGBConverter, self).__init__()
        # Magic number used in the detector
        mean_resnet = np.array([0.485, 0.456, 0.406])
        std_resnet = np.array([0.229, 0.224, 0.225])
        self.val_transform = T.Compose([T.Resize(224), T.Normalize(mean_resnet, std_resnet)])
    
    def toRGB(self, RGBA, background=(255,255,255)):
        _, D, R, C = RGBA.shape
        if D == 3:
            return RGBA
        RGB = torch.zeros((1, 3, R, C), dtype=torch.float32)
        R, G, B, A = RGBA[0].split(1, dim=0)
        A = A.float() / 255
        RGB[0, 0,:,:] = R.squeeze() * A.squeeze() + (1 - A.squeeze()) * background[0]
        RGB[0, 1,:,:] = G.squeeze() * A.squeeze() + (1 - A.squeeze()) * background[1]
        RGB[0, 2,:,:] = B.squeeze() * A.squeeze() + (1 - A.squeeze()) * background[2]
        return RGB

    def forward(self, input):
        min = torch.min(input.detach())
        max = torch.max(input.detach())
        input = (input-min)/(max-min)*255
        #input = self.toRGB(input)
        input = self.val_transform(input.squeeze())
        return input

In [5]:
import torch.optim as optim

class ArtModel(nn.Module):

    def __init__(self):
        super(ArtModel, self).__init__()
        self.rgb = RGBConverter()
        self.vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to("cuda:1")
        self.classifier = create_model(5)
        self.classifier.eval()
        for param in self.classifier.parameters():
            param.requires_grad = False
    def forward(self, x):
        x = self.vae.decode(1 / self.vae.config.scaling_factor * x).sample
        x = self.rgb(x)
        x = self.classifier(x.unsqueeze(0))
        return x

class ObjectDetector():
    def __init__(self):
        self.dtype = torch.FloatTensor
        if (torch.cuda.is_available()):
            self.dtype = torch.cuda.FloatTensor
        # transfer learning on top of ResNet (only replacing final FC layer)
        self.model = ArtModel()
        
        self.model.to("cuda:1")
        self.device = "cuda:1"
        # self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)

    def get_input_grad(self, x): 
        x_var = Variable(x.type(self.dtype).to(self.device), requires_grad=True)
        resnet_output = self.model(x_var)
        prob = resnet_output[0][2]
        prob.backward()
        return x_var.grad
        

In [6]:
prompt="Alfred Sisley"
modules = ".*attn2$"
iterations=200
negative_guidance=1
lr=0.015
save_path="tmp/test"
freeze_modules=[]

In [7]:
# kwargs = dict(locals())
# print(f"train kwargs: {kwargs}")
print("BEGIN TRAIN")  
nsteps = 50

diffuser = StableDiffuser(scheduler='DDIM').to('cuda:1')
diffuser.train()


finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
# finetuner = FineTunedModel.from_checkpoint(diffuser, "models/vangogh.pt")

params = list(finetuner.parameters())
criteria = torch.nn.MSELoss()

print("Begin pbar")
pbar = tqdm(range(iterations))

with torch.no_grad():
    # neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
    positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)

# del diffuser.vae
# del diffuser.text_encoder
# del diffuser.tokenizer
del diffuser.safety_checker

torch.cuda.empty_cache()

BEGIN TRAIN


`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.


=> Finetuning unet.down_blocks.0.attentions.0.transformer_blocks.0.attn2
=> Finetuning unet.down_blocks.0.attentions.1.transformer_blocks.0.attn2
=> Finetuning unet.down_blocks.1.attentions.0.transformer_blocks.0.attn2
=> Finetuning unet.down_blocks.1.attentions.1.transformer_blocks.0.attn2
=> Finetuning unet.down_blocks.2.attentions.0.transformer_blocks.0.attn2
=> Finetuning unet.down_blocks.2.attentions.1.transformer_blocks.0.attn2
=> Finetuning unet.up_blocks.1.attentions.0.transformer_blocks.0.attn2
=> Finetuning unet.up_blocks.1.attentions.1.transformer_blocks.0.attn2
=> Finetuning unet.up_blocks.1.attentions.2.transformer_blocks.0.attn2
=> Finetuning unet.up_blocks.2.attentions.0.transformer_blocks.0.attn2
=> Finetuning unet.up_blocks.2.attentions.1.transformer_blocks.0.attn2
=> Finetuning unet.up_blocks.2.attentions.2.transformer_blocks.0.attn2
=> Finetuning unet.up_blocks.3.attentions.0.transformer_blocks.0.attn2
=> Finetuning unet.up_blocks.3.attentions.1.transformer_blocks.0.

  0%|          | 0/200 [00:00<?, ?it/s]

In [8]:
# del detector, optimizer
detector = ObjectDetector()
optimizer = torch.optim.SGD(params, lr=lr)



In [None]:
diffuser.train()
for i in pbar:
    with torch.no_grad():
        diffuser.set_scheduler_timesteps(60)

        optimizer.zero_grad()

        diffuse_iter = torch.randint(1, nsteps-1, (1,)).item()

        latents = diffuser.get_initial_latents(1, 512, 1)
        # print("LATENT SIZE: ", latents.size())
        with finetuner:
            latents_steps, _ = diffuser.diffusion(
                latents,
                positive_text_embeddings,
                start_iteration=0,
                end_iteration=diffuse_iter,
                guidance_scale=3,
                show_progress=False,
            )
        
        # diffuser.set_scheduler_timesteps(diffuse_iter+1)
        # diffuse_iter = int(diffuse_iter / nsteps * 1000)
        # ref_latents = diffuser.predict_noise(diffuse_iter, latents_steps[0], positive_text_embeddings, guidance_scale=1)
        with finetuner:
            ref_latents = diffuser.predict_noise(diffuse_iter, latents_steps[0], positive_text_embeddings, guidance_scale=1)


    with finetuner:
        negative_latents = diffuser.predict_noise(diffuse_iter, latents_steps[0], positive_text_embeddings, guidance_scale=1)
    
    # y = torch.tensor([53]) # label of Von Gogh
    
    # dump input_x to a file
    # torch.save(input_x, "detector/artist/test_vg/input_{}.pt".format(i))
    # generated_images.append(input_x)
    
    # print(torch.norm(detector_grad))
    # print(torch.norm(neutral_latents))
    # print(torch.norm(negative_latents))
    input_x = latents_steps[0]    
    detector_grad = detector.get_input_grad(input_x)        
    loss = criteria(negative_latents.float(), ref_latents.detach().float() + 10*(detector_grad)) #loss = criteria(e_n, e_0) works the best try 5000 epochs
    loss.backward()
    print("Loss function")
    print(loss.item())
    print("Gradient Scale")
    gradient = params[0].grad
    print(torch.norm(gradient))        
    optimizer.step()


    # if i % 10 == 0 and i != 0:
    #     now_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    #     torch.save(
    #         finetuner.state_dict(), 
    #         save_path + f'_checkpoint_{i}_{now_str}.pt'
    #     )

now_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
torch.save(finetuner.state_dict(), save_path + f'_{now_str}.pt')


torch.cuda.empty_cache()

In [None]:
del finetuner, diffuser

In [None]:
diffuser = StableDiffuser(scheduler='DDIM').to('cuda:1')

In [None]:
testtuner = FineTunedModel.from_checkpoint(diffuser, "detector_art/test_checkpoint_10_20231130_192118.pt").eval().to("cuda:1")
# testtuner = FineTunedModel.from_checkpoint(diffuser, "models/vangogh.pt").eval().to("cuda:1")

In [None]:
del images

In [None]:
torch.cuda.empty_cache()

In [None]:
generator = torch.manual_seed(30)
with testtuner:
    images = diffuser(
        "Von Gohh Painting",
        n_steps=50,
        n_imgs=10,
        generator=generator,
        noise_level=0,
    )

In [None]:
images[1][0]

In [None]:
images[9][0]

In [None]:
from torchvision import transforms
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
model = create_model(5)

In [None]:
model.to("cuda:1").eval()

In [None]:
num_original_van = 0
for i in range(10):
    if (torch.argmax(model(transform(images[i][0]).unsqueeze(0).float().cuda().to("cuda:1"))) == 4):
        num_original_van += 1

In [None]:
num_original_van

In [None]:
output

In [None]:
output[0][4]