# **Dependencies:**

In [None]:
!pip install sentencepiece
!pip install git+https://github.com/huggingface/transformers.git@cae78c46
!pip install diffusers
!pip install tokenizers==0.12.1
!pip install datasets
!pip install accelerate
!pip install evaluate
!pip install gradio==4.12.0
!pip install gradio_client==0.8.0
!pip install -i https://download.pytorch.org/whl/cu118 torch==2.0 torchvision==0.15 torchaudio==2.0

# **Checkpoints/pretrained Embeddings**



In [None]:
!python -c "import huggingface_hub; huggingface_hub.snapshot_download(repo_id='tsujuifu/ml-mgie', repo_type='model', local_dir='_ckpt', local_dir_use_symlinks=False)"

For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder.
Fetching 14 files:   0% 0/14 [00:00<?, ?it/s]
LLaVA-7B-v1/config.json: 100% 834/834 [00:00<00:00, 5.33MB/s]

.gitattributes: 100% 1.52k/1.52k [00:00<00:00, 11.0MB/s]
Fetching 14 files:   7% 1/14 [00:00<00:07,  1.77it/s]
LLaVA-7B-v1/generation_config.json: 100% 132/132 [00:00<00:00, 971kB/s]

LLaVA-7B-v1/pytorch_model.bin.index.json: 100% 26.9k/26.9k [00:00<00:00, 73.0MB/s]

LLaVA-7B-v1/special_tokens_map.json: 100% 435/435 [00:00<00:00, 3.25MB/s]

LLaVA-7B-v1/tokenizer.json:   0% 0.00/1.84M [00:00<?, ?B/s][A

README.md: 100% 1.15k/1.15k [00:00<00:00, 5.95MB/s]
LLaVA-7B-v1/tokenizer.json: 100% 1.84M/1.84M [00:00<00:00, 25.4MB/s]

pytorch_model-00001-of-00002.bin:   0% 0.00/9.98G [00:00<?, ?B/s][A

pytorch_model-00002-of-00002.bin:   0% 0.00/3.51G [00:00<?, ?B/s][A[A


LLaVA-7B-v1/tokenizer_config.json: 100% 727/727 [00:00<00:00, 3.86MB/s]



LLaVA-7B-v1/add

In [None]:
!ls _ckpt

# **Original MGIE Implementation:**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from PIL import Image
import numpy as np
import torch as T
import transformers
import diffusers
import gradio as gr
import huggingface_hub

CKPT_DIR = '/content/drive/My Drive/_ckpt'




def crop_resize(f, sz=512):
    w, h = f.size
    if w > h:
        p = (w - h) // 2
        f = f.crop([p, 0, p + h, h])
    elif h > w:
        p = (h - w) // 2
        f = f.crop([0, p, w, p + w])
    f = f.resize([sz, sz])
    return f

def remove_alter(s):
    if 'ASSISTANT:' in s: s = s[s.index('ASSISTANT:') + 10:].strip()
    if '</s>' in s: s = s[:s.index('</s>')].strip()
    if 'alternative' in s.lower(): s = s[:s.lower().index('alternative')]
    if '[IMG0]' in s: s = s[:s.index('[IMG0]')]
    s = '.'.join([s.strip() for s in s.split('.')[:2]])
    if s[-1] != '.': s += '.'
    return s.strip()

DEFAULT_IMAGE_TOKEN = '<image>'
DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
DEFAULT_IM_START_TOKEN = '<im_start>'
DEFAULT_IM_END_TOKEN = '<im_end>'
PATH_LLAVA = f'{CKPT_DIR}/LLaVA-7B-v1'

tokenizer = transformers.AutoTokenizer.from_pretrained(PATH_LLAVA)
model = LlavaLlamaForCausalLM.from_pretrained(PATH_LLAVA, low_cpu_mem_usage=True, torch_dtype=T.float16, use_cache=True).cuda()
image_processor = transformers.CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=T.float16)

tokenizer.padding_side = 'left'
tokenizer.add_tokens(['[IMG0]', '[IMG1]', '[IMG2]', '[IMG3]', '[IMG4]', '[IMG5]', '[IMG6]', '[IMG7]'], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
ckpt = T.load(f'{CKPT_DIR}/mgie_7b/mllm.pt', map_location='cpu')
model.load_state_dict(ckpt, strict=False)

mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end', False)
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)

vision_tower = model.get_model().vision_tower[0]
vision_tower = transformers.CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=T.float16, low_cpu_mem_usage=True).cuda()
model.get_model().vision_tower[0] = vision_tower
vision_config = vision_tower.config
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
vision_config.use_im_start_end = mm_use_im_start_end
if mm_use_im_start_end: vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2

_ = model.eval()

pipe = diffusers.StableDiffusionInstructPix2PixPipeline.from_pretrained('timbrooks/instruct-pix2pix', torch_dtype=T.float16).to('cuda')
pipe.set_progress_bar_config(disable=True)
pipe.unet.load_state_dict(T.load(f'{CKPT_DIR}/mgie_7b/unet.pt', map_location='cpu'))
print('--init MGIE--')

def go_mgie(img, txt, seed, cfg_txt, cfg_img):
    EMB = ckpt['emb'].cuda()
    with T.inference_mode(): NULL = model.edit_head(T.zeros(1, 8, 4096).half().to('cuda'), EMB)

    img, seed = crop_resize(Image.fromarray(img).convert('RGB')), int(seed)
    inp = img

    img = image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0]
    txt = "what will this image be like if '%s'" % (txt)
    txt = txt + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
    conv = conv_templates['vicuna_v1_1'].copy()
    conv.append_message(conv.roles[0], txt), conv.append_message(conv.roles[1], None)
    txt = conv.get_prompt()
    txt = tokenizer(txt)
    txt, mask = T.as_tensor(txt['input_ids']), T.as_tensor(txt['attention_mask'])

    with T.inference_mode():
        _ = model.cuda()
        out = model.generate(txt.unsqueeze(dim=0).cuda(), images=img.half().unsqueeze(dim=0).cuda(), attention_mask=mask.unsqueeze(dim=0).cuda(),
                             do_sample=False, max_new_tokens=96, num_beams=1, no_repeat_ngram_size=3,
                             return_dict_in_generate=True, output_hidden_states=True)
        out, hid = out['sequences'][0].tolist(), T.cat([x[-1] for x in out['hidden_states']], dim=1)[0]

        if 32003 in out: p = out.index(32003) - 1
        else: p = len(hid) - 9
        p = min(p, len(hid) - 9)
        hid = hid[p:p + 8]

        out = remove_alter(tokenizer.decode(out))
        _ = model.cuda()
        emb = model.edit_head(hid.unsqueeze(dim=0), EMB)
        res = pipe(image=inp, prompt_embeds=emb, negative_prompt_embeds=NULL,
                   generator=T.Generator(device='cuda').manual_seed(seed), guidance_scale=cfg_txt, image_guidance_scale=cfg_img).images[0]

    return res, out

with gr.Blocks() as app:
    gr.Markdown(
        """
        # MagiX: Edit Personalized Images using Gen AI by Ateeb Taser
        """
    )
    with gr.Row():
        inp, res = [gr.Image(height=384, width=384, label='Input Image', interactive=True),
                    gr.Image(height=384, width=384, label='Goal Image', interactive=True)]
    with gr.Row():
        txt, out = [gr.Textbox(label='Instruction', interactive=True),
                    gr.Textbox(label='Expressive Instruction', interactive=False)]
    with gr.Row():
        seed, cfg_txt, cfg_img = [gr.Number(value=13331, label='Seed', interactive=True),
                                  gr.Number(value=7.5, label='Text CFG', interactive=True),
                                  gr.Number(value=1.5, label='Image CFG', interactive=True)]
    with gr.Row():
        btn_sub = gr.Button('Submit')
    btn_sub.click(fn=go_mgie, inputs=[inp, txt, seed, cfg_txt, cfg_img], outputs=[res, out])

app.launch()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/4.52k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.71G [00:00<?, ?B/s]

Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.2.self_attn.out_proj.bias', 'text_model.encoder.layers.11.mlp.fc2.weight', 'text_model.encoder.layers.3.layer_norm1.bias', 'text_model.encoder.layers.10.mlp.fc2.weight', 'text_model.final_layer_norm.weight', 'text_model.encoder.layers.4.layer_norm2.weight', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.4.mlp.fc2.weight', 'text_model.encoder.layers.4.self_attn.v_proj.bias', 'text_model.encoder.layers.8.layer_norm1.bias', 'text_model.encoder.layers.10.self_attn.out_proj.bias', 'text_model.encoder.layers.1.mlp.fc2.weight', 'text_model.encoder.layers.6.self_attn.out_proj.bias', 'text_model.encoder.layers.5.self_attn.v_proj.weight', 'text_model.encoder.layers.1.self_attn.k_proj.bias', 'text_model.encoder.layers.11.mlp.fc2.bias', 'text_model.encoder.layers.4.self_attn.v_proj.weight', 'text_model.encoder.layers.1

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of LlavaLlamaForCausalLM were not initialized from the model checkpoint at /content/drive/My Drive/_ckpt/LLaVA-7B-v1 and are newly initialized: ['edit_head.mapper.encoder.layers.0.self_attn.in_proj_bias', 'edit_head.mapper.decoder.layers.0.self_attn.in_proj_bias', 'edit_head.mapper.encoder.layers.1.self_attn.out_proj.bias', 'edit_head.mapper.encoder.layers.0.linear1.bias', 'edit_head.mapper.encoder.layers.2.self_attn.in_proj_bias', 'edit_head.mapper.encoder.layers.2.norm1.weight', 'edit_head.mapper.decoder.layers.2.multihead_attn.out_proj.bias', 'edit_head.mapper.decoder.layers.1.self_attn.out_proj.weight', 'edit_head.mapper.encoder.norm.bias', 'edit_head.mapper.encoder.layers.1.linear1.bias', 'edit_head.mapper.decoder.layers.0.linear1.bias', 'edit_head.mapper.encoder.layers.0.norm1.weight', 'edit_head.mapper.decoder.norm.bias', 'edit_head.mapper.encoder.layers.3.self_attn.in_proj_bias', 'edit_head.mapper.encoder.layers.1.norm2.bias', 'edit_head.mapper.decoder.layers.3.lin

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.2.self_attn.out_proj.bias', 'text_model.encoder.layers.11.mlp.fc2.weight', 'text_model.encoder.layers.3.layer_norm1.bias', 'text_model.encoder.layers.10.mlp.fc2.weight', 'text_model.final_layer_norm.weight', 'text_model.encoder.layers.4.layer_norm2.weight', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.4.mlp.fc2.weight', 'text_model.encoder.layers.4.self_attn.v_proj.bias', 'text_model.encoder.layers.8.layer_norm1.bias', 'text_model.encoder.layers.10.self_attn.out_proj.bias', 'text_model.encoder.layers.1.mlp.fc2.weight', 'text_model.encoder.layers.6.self_attn.out_proj.bias', 'text_model.encoder.layers.5.self_attn.v_proj.weight', 'text_model.encoder.layers.1.self_attn.k_proj.bias', 'text_model.encoder.layers.11.mlp.fc2.bias', 'text_model.encoder.layers.4.self_attn.v_proj.weight', 'text_model.encoder.layers.1

model_index.json:   0%|          | 0.00/616 [00:00<?, ?B/s]

Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/492M [00:00<?, ?B/s]

(…)ature_extractor/preprocessor_config.json:   0%|          | 0.00/518 [00:00<?, ?B/s]

text_encoder/config.json:   0%|          | 0.00/617 [00:00<?, ?B/s]

tokenizer/special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

scheduler/scheduler_config.json:   0%|          | 0.00/569 [00:00<?, ?B/s]

tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

safety_checker/config.json:   0%|          | 0.00/4.91k [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

unet/config.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

tokenizer/tokenizer_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

vae/config.json:   0%|          | 0.00/553 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


--init MGIE--
Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
IMPORTANT: You are using gradio version 4.12.0, however version 4.29.0 is available, please upgrade.
--------
Running on public URL: https://c4830bfd21daa749a0.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




# **ExtendedMGIE Implementation:**

In [None]:
# Install necessary packages
!pip install sentencepiece
!pip install git+https://github.com/huggingface/transformers.git@cae78c46
!pip install diffusers
!pip install tokenizers==0.12.1
!pip install datasets
!pip install accelerate
!pip install evaluate
!pip install gradio==4.12.0
!pip install gradio_client==0.8.0
!pip install -i https://download.pytorch.org/whl/cu118 torch==2.0 torchvision==0.15 torchaudio==2.0

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import transformers
import diffusers
import gradio as gr
import huggingface_hub

CKPT_DIR = '/content/drive/My Drive/_ckpt'

# Add the necessary classes and functions
class LlavaConfig(transformers.LlamaConfig):
    model_type = "llava"
    num_pfb_blocks = 4
    mllm_dim = 4096
    latent_dim = 4096
    mm_vision_tower = "openai/clip-vit-large-patch14"
    hidden_size = 4096
    num_attention_heads = 32
    num_hidden_layers = 32
    intermediate_size = 11008
    hidden_dropout_prob = 0.1

class LlavaLlamaModel(transformers.LlamaModel):
    config_class = LlavaConfig

    def __init__(self, config):
        super().__init__(config)
        # Add any additional initialization code here

    def forward(self, *args, **kwargs):
        return super().forward(*args, **kwargs)

class LlavaLlamaForCausalLM(transformers.LlamaForCausalLM):
    config_class = LlavaConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = LlavaLlamaModel(config)
        # Add any additional initialization code here

    def forward(self, *args, **kwargs):
        return super().forward(*args, **kwargs)

transformers.AutoConfig.register("llava", LlavaConfig)
transformers.AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)

# Add new modules
class ProgressiveBlend(nn.Module):
    def __init__(self, mllm_dim, latent_dim):
        super().__init__()
        self.blend = nn.Parameter(torch.tensor(0.0))
        self.mllm_proj = nn.Linear(mllm_dim, latent_dim)

    def forward(self, mllm_feat, latent_feat):
        mllm_feat = self.mllm_proj(mllm_feat)
        blended_feat = (1 - self.blend) * latent_feat + self.blend * mllm_feat
        return blended_feat

class InstructionTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.transformer = nn.Transformer(
            d_model=config.hidden_size,
            nhead=config.num_attention_heads,
            num_encoder_layers=config.num_hidden_layers,
            dim_feedforward=config.intermediate_size,
            dropout=config.hidden_dropout_prob,
        )
        self.proj = nn.Linear(config.hidden_size, config.latent_dim)

    def forward(self, mllm_outputs):
        instruction_embeds = self.transformer(mllm_outputs)
        instruction_embeds = self.proj(instruction_embeds)
        return instruction_embeds

class CrossAttentionMasking(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.conv = nn.Conv2d(config.latent_dim, 1, kernel_size=1)

    def forward(self, mllm_embeds, latents):
        mask = self.conv(mllm_embeds)
        mask = torch.sigmoid(mask)
        latents = latents * mask
        return latents

# Modify LlavaPFBLlamaForCausalLM
class LlavaPFBLlamaForCausalLM(LlavaLlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.mllm = LlavaLlamaModel(config)
        self.pfb_blocks = nn.ModuleList([
            ProgressiveBlend(config.mllm_dim, config.latent_dim)
            for _ in range(config.num_pfb_blocks)
        ])
        self.instruction_transformer = InstructionTransformer(config)
        self.cross_attn_mask = CrossAttentionMasking(config)

    def forward(self, input_ids, attention_mask, images=None):
        mllm_outputs = self.mllm(input_ids, attention_mask, images=images)
        instruction_embeds = self.instruction_transformer(mllm_outputs.last_hidden_state)

        latents = self.unet(mllm_outputs.last_hidden_state)
        for pfb_block in self.pfb_blocks:
            latents = pfb_block(instruction_embeds, latents)
        latents = self.cross_attn_mask(instruction_embeds, latents)

        outputs = self.unet(latents)
        return outputs

# Load model and tokenizer
PATH_LLAVA = f'{CKPT_DIR}/LLaVA-7B-v1'
tokenizer = transformers.AutoTokenizer.from_pretrained(PATH_LLAVA)
config = LlavaConfig.from_pretrained(PATH_LLAVA)
model = LlavaPFBLlamaForCausalLM(config).half().cuda()
model.load_state_dict(torch.load(f'{CKPT_DIR}/mgie_7b/mllm.pt', map_location='cpu'), strict=False)
image_processor = transformers.CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16)

# Prepare model
tokenizer.padding_side = 'left'
tokenizer.add_tokens(['[IMG0]', '[IMG1]', '[IMG2]', '[IMG3]', '[IMG4]', '[IMG5]', '[IMG6]', '[IMG7]'], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))

DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
DEFAULT_IM_START_TOKEN = '<im_start>'
DEFAULT_IM_END_TOKEN = '<im_end>'

mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end', False)
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)

vision_tower = model.get_model().vision_tower[0]
vision_tower = transformers.CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda()
model.get_model().vision_tower[0] = vision_tower
vision_config = vision_tower.config
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
vision_config.use_im_start_end = mm_use_im_start_end
if mm_use_im_start_end: vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2

_ = model.eval()

# Load diffusion model
pipe = diffusers.StableDiffusionInstructPix2PixPipeline.from_pretrained('timbrooks/instruct-pix2pix', torch_dtype=torch.float16).to('cuda')
pipe.set_progress_bar_config(disable=True)
pipe.unet.load_state_dict(torch.load(f'{CKPT_DIR}/mgie_7b/unet.pt', map_location='cpu'))
print('--init MGIE--')

def crop_resize(f, sz=512):
    w, h = f.size
    if w > h:
        p = (w - h) // 2
        f = f.crop([p, 0, p + h, h])
    elif h > w:
        p = (h - w) // 2
        f = f.crop([0, p, w, p + w])
    f = f.resize([sz, sz])
    return f

def remove_alter(s):
    if 'ASSISTANT:' in s: s = s[s.index('ASSISTANT:') + 10:].strip()
    if '</s>' in s: s = s[:s.index('</s>')].strip()
    if 'alternative' in s.lower(): s = s[:s.lower().index('alternative')]
    if '[IMG0]' in s: s = s[:s.index('[IMG0]')]
    s = '.'.join([s.strip() for s in s.split('.')[:2]])
    if s[-1] != '.': s += '.'
    return s.strip()

class Conversation:
    def __init__(self, system="", roles=(), messages=(), offset=0, sep_style=None, sep=" ", sep2="</s>"):
        self.system = system
        self.roles = roles
        self.messages = list(messages)
        self.offset = offset
        self.sep_style = sep_style
        self.sep = sep
        self.sep2 = sep2 or self.sep

    def get_prompt(self):
        if self.sep_style == 'single':
            ret = self.system + self.sep
            for role, message in self.messages:
                if message:
                    ret += role + ": " + message + self.sep
                else:
                    ret += role + ":"
            return ret
        elif self.sep_style == 'two':
            seps = [self.sep, self.sep2]
            ret = self.system + seps[0]
            for i, (role, message) in enumerate(self.messages):
                if message:
                    ret += role + ": " + message + seps[i % 2]
                else:
                    ret += role + ":"
            return ret
        else:
            raise ValueError(f"Unsupported style: {self.sep_style}")

    def append_message(self, role, message):
        self.messages.append([role, message])

    def to_gradio_chatbot(self):
        ret = []
        for i, (role, msg) in enumerate(self.messages[self.offset:]):
            if i % 2 == 0:
                ret.append([msg, None])
            else:
                ret[-1][-1] = msg
        return ret

    def copy(self):
        return Conversation(
            system=self.system,
            roles=self.roles,
            messages=[x[:] for x in self.messages],
            offset=self.offset,
            sep_style=self.sep_style,
            sep=self.sep,
            sep2=self.sep2,
        )

    def dict(self):
        return dict(
            system=self.system,
            roles=self.roles,
            messages=[[x, y] for x, y in self.messages],
            offset=self.offset,
            sep=self.sep,
            sep2=self.sep2,
        )

conv_v1_2 = Conversation(
    system="A chat between a curious human and an artificial intelligence assistant. "
           "The assistant gives helpful, detailed, and polite answers to the human's questions.",
    roles=("Human", "Assistant"),
    messages=(
        ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
        ("Assistant",
            "Renewable energy sources are those that can be replenished naturally in a relatively "
            "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
            "Non-renewable energy sources, on the other hand, are finite and will eventually be "
            "depleted, such as coal, oil, and natural gas. Here are some key differences between "
            "renewable and non-renewable energy sources:\n"
            "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
            "energy sources are finite and will eventually run out.\n"
            "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
            "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
            "and other negative effects.\n"
            "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
            "have lower operational costs than non-renewable sources.\n"
            "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
            "locations than non-renewable sources.\n"
            "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
            "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
            "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
            "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
    ),
    offset=2,
    sep_style='single',
    sep="###",
)

conv_vicuna_v1_1 = Conversation(
    system="A chat between a curious user and an artificial intelligence assistant. "
    "The assistant gives helpful, detailed, and polite answers to the user's questions.",
    roles=("USER", "ASSISTANT"),
    messages=(),
    offset=0,
    sep_style='two',
    sep=" ",
    sep2="</s>",
)

conv_templates = {
    'v1': conv_v1_2,
    'vicuna_v1_1': conv_vicuna_v1_1,
}

# Modify go_mgie function
def go_mgie(img, txt, seed, cfg_txt, cfg_img):
    EMB = ckpt['emb'].cuda()
    with torch.inference_mode(): NULL = model.edit_head(torch.zeros(1, 8, 4096).half().to('cuda'), EMB)

    img, seed = crop_resize(Image.fromarray(img).convert('RGB')), int(seed)
    inp = img

    img = image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0]
    txt = "what will this image be like if '%s'" % (txt)
    txt = txt + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
    conv = conv_templates['vicuna_v1_1'].copy()
    conv.append_message(conv.roles[0], txt), conv.append_message(conv.roles[1], None)
    txt = conv.get_prompt()
    txt = tokenizer(txt)
    txt, mask = torch.as_tensor(txt['input_ids']), torch.as_tensor(txt['attention_mask'])

    with torch.inference_mode():
        _ = model.cuda()
        out = model.generate(txt.unsqueeze(dim=0).cuda(), images=img.half().unsqueeze(dim=0).cuda(), attention_mask=mask.unsqueeze(dim=0).cuda(),
                             do_sample=False, max_new_tokens=96, num_beams=1, no_repeat_ngram_size=3,
                             return_dict_in_generate=True, output_hidden_states=True)
        out, hid = out['sequences'][0].tolist(), torch.cat([x[-1] for x in out['hidden_states']], dim=1)[0]

        if 32003 in out: p = out.index(32003) - 1
        else: p = len(hid) - 9
        p = min(p, len(hid) - 9)
        hid = hid[p:p + 8]

        out = remove_alter(tokenizer.decode(out))
        _ = model.cuda()
        emb = model.instruction_transformer(hid.unsqueeze(dim=0))
        for pfb_block in model.pfb_blocks:
            emb = pfb_block(emb, emb)
        emb = model.cross_attn_mask(emb, emb)
        res = pipe(image=inp, prompt_embeds=emb, negative_prompt_embeds=NULL,
                   generator=torch.Generator(device='cuda').manual_seed(seed), guidance_scale=cfg_txt, image_guidance_scale=cfg_img).images[0]

    return res, out

# Set up Gradio interface
with gr.Blocks() as app:
    gr.Markdown(
        """
        # MagiX: Edit Personalized Images using Gen AI by Ateeb Taser
        """
    )
    with gr.Row():
        inp, res = [gr.Image(height=384, width=384, label='Input Image', interactive=True),
                    gr.Image(height=384, width=384, label='Goal Image', interactive=True)]
    with gr.Row():
        txt, out = [gr.Textbox(label='Instruction', interactive=True),
                    gr.Textbox(label='Expressive Instruction', interactive=False)]
    with gr.Row():
        seed, cfg_txt, cfg_img = [gr.Number(value=13331, label='Seed', interactive=True),
                                  gr.Number(value=7.5, label='Text CFG', interactive=True),
                                  gr.Number(value=1.5, label='Image CFG', interactive=True)]
    with gr.Row():
        btn_sub = gr.Button('Submit')
    btn_sub.click(fn=go_mgie, inputs=[inp, txt, seed, cfg_txt, cfg_img], outputs=[res, out])

# Launch the app
app.launch()

Collecting git+https://github.com/huggingface/transformers.git@cae78c46
  Cloning https://github.com/huggingface/transformers.git (to revision cae78c46) to /tmp/pip-req-build-09kko35c
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-09kko35c
[0m  Running command git checkout -q cae78c46
  Resolved https://github.com/huggingface/transformers.git to commit cae78c46
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers==4.28.0.dev0)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m65.2 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: transformers
  Building wheel for transformers (pypr