In [1]:
import torch

# If there's a GPU available...
if torch.cuda.is_available():
    # Tell PyTorch to use the GPU.
    device = torch.device("cuda")

    print("There are %d GPU(s) available." % torch.cuda.device_count())

    print("We will use the GPU:", torch.cuda.get_device_name(0))

# If not...
elif torch.backends.mps.is_available():
    device = torch.device("mps")

    print("Using mps backend")
else:
    print("No GPU available, using the CPU instead.")
    device = torch.device("cpu")

There are 1 GPU(s) available.
We will use the GPU: NVIDIA A100-SXM4-80GB


In [2]:
from transformers import AutoTokenizer, AutoProcessor
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import requests
from PIL import Image
import numpy as np
from io import BytesIO
from diffusers import (
    KandinskyV22Pipeline,
    KandinskyV22PriorEmb2EmbPipeline,
    KandinskyV22PriorPipeline,
)
from diffusers.utils import load_image
from torchvision.transforms import ToPILImage

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from transformers import CLIPTextModelWithProjection, CLIPVisionModelWithProjection, BertModel
import torch.nn as nn
from transformers import AdamW
import os


class T2IModel(nn.Module):
    def __init__(self):
        super(T2IModel, self).__init__()
        self.text_model = CLIPTextModelWithProjection.from_pretrained(
            "kandinsky-community/kandinsky-2-2-prior", subfolder="text_encoder"
        )
        self.vision_model = CLIPVisionModelWithProjection.from_pretrained(
            "kandinsky-community/kandinsky-2-2-prior", subfolder="image_encoder"
        )
        # Adjust the input features of the FC layer to the combined size of text and vision outputs
        self.fc = nn.Linear(self.text_model.config.hidden_size + self.vision_model.config.projection_dim, 1280)

        # Initialize the pipeline for the Kandinsky V2.2 decoder
        self.pipe = KandinskyV22Pipeline.from_pretrained(
            "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
        )

    def initialize_optimizer(self):
        params = (
            list(self.fc.parameters())
        )
        optimizer = AdamW(params, lr=1e-4)
        return optimizer

    def forward(self, input_imgs, input_txt, attention_mask=None):
        text_outputs = self.text_model(input_txt, attention_mask=attention_mask)
        text_embeds = text_outputs.last_hidden_state[:, 0, :]  # Use the representation of the [CLS] token

        vision_outputs = self.vision_model(input_imgs)
        vision_embeds = vision_outputs.image_embeds

        combined_embeds = torch.cat((vision_embeds, text_embeds), dim=1)
        x = self.fc(combined_embeds)
        return x

    def output_embedding(self, target_images):
        target_image_output = self.vision_model(target_images)
        target_image_embeds = target_image_output.image_embeds
        return target_image_embeds

    def custom_loss(self, output_embeddings, target_embeddings):
        mse_loss = nn.MSELoss()
        loss = mse_loss(output_embeddings, target_embeddings)

        return loss

    def save_model(self, output_dir="../model_save/", filename="model_checkpoint.pt"):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        file_path = output_dir + filename
        print("Saving model to %s" % file_path)

        torch.save(model.state_dict(), file_path)

    def get_cos(self, input1, input2):
        cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
        similarity = cos(input1, input2)
        avg = torch.sum(similarity) / len(similarity)
        return avg

    def metrics(self, input1, input2):
        cos = self.get_cos(input1, input2)
        return [cos]

    def visualization(self, input_img, instruction, instruction_attention_mask, filename, negative_instruction=None, negative_instruction_attention_mask=None):
        # Generate output embeddings with the provided attention mask
        output_embeddings = self.forward(input_img, instruction, attention_mask=instruction_attention_mask)

        # Handle the negative instruction if provided
        neg_image_embed = None
        if negative_instruction is not None and negative_instruction_attention_mask is not None:
            neg_image_embed = self.forward(input_img, negative_instruction, attention_mask=negative_instruction_attention_mask)
        else:
            # If no negative instruction is provided, we'll use a tensor of zeros as a placeholder
            neg_image_embed = torch.zeros_like(output_embeddings)

        self.pipe.to(device)
        # Generate the image using the pipeline
        image = self.pipe(
            image_embeds=output_embeddings,
            negative_image_embeds=neg_image_embed,
            height=768,
            width=768,
            num_inference_steps=100,
        ).images

        # Save the generated image
        image[0].save(filename)

model = T2IModel()
model.to(device=device)

Loading pipeline components...: 100%|█████████████| 3/3 [00:01<00:00,  2.34it/s]


T2IModel(
  (text_model): CLIPTextModelWithProjection(
    (text_model): CLIPTextTransformer(
      (embeddings): CLIPTextEmbeddings(
        (token_embedding): Embedding(49408, 1280)
        (position_embedding): Embedding(77, 1280)
      )
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-31): 32 x CLIPEncoderLayer(
            (self_attn): CLIPAttention(
              (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
              (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
              (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
              (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
            )
            (layer_norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
            (mlp): CLIPMLP(
              (activation_fn): GELUActivation()
              (fc1): Linear(in_features=1280, out_features=5120, bias=True)
              (fc2): Linear(in_features=5120, out_f

In [4]:
import torch

def load_model_from_checkpoint(model, checkpoint_path, device='cuda'):
    """
    Load a PyTorch model from a saved checkpoint.
    
    Parameters:
    - model (torch.nn.Module): The model architecture (untrained).
    - checkpoint_path (str): Path to the saved model checkpoint (.pth file).
    - device (str): Device to which the model should be loaded ('cuda' or 'cpu').

    Returns:
    - model (torch.nn.Module): Model populated with the loaded weights.
    """

    # Load the model state dictionary from the specified path
    state_dict = torch.load(checkpoint_path, map_location=device)
    
    # Load the state dictionary into the model
    model.load_state_dict(state_dict)
    
    # Move the model to the desired device
    model.to(device)
    
    return model

# Usage
loaded_model = load_model_from_checkpoint(model, '/scratch/nkusumba/magicbrush_kadinsky_imagewithinstruction_10epochs_full_v1.pth', device='cuda')


In [5]:
def compute_max_instruction_length(dataloader, tokenizer):
    """
    Compute the maximum instruction length from the dataloader batches.

    Args:
    - dataloader (DataLoader): DataLoader containing your data.
    - tokenizer: Tokenizer used to tokenize the instructions.

    Returns:
    - int: Maximum instruction length.
    """
    max_len = 0
    
    for batch in dataloader:
        instructions = batch[1]  # Assuming instructions are in position 1 in your batch
        for instruction in instructions:
            decoded_string = tokenizer.decode(instruction)
            tokens = tokenizer.tokenize(decoded_string)
            length = len(tokens)
            if length > max_len:
                max_len = length
                
    return max_len

def custom_encode(instruction, tokenizer):
    encoded_inst = tokenizer.encode_plus(
        instruction,  # Sentence to encode.
        add_special_tokens=True,  # Add '[CLS]' and '[SEP]'
        return_tensors="pt",  # Return pytorch tensors
    )
    
    return encoded_inst["input_ids"]
    

In [None]:
import os
import json

file_path = '/scratch/nkusumba/test/edit_sessions.json'

with open(file_path, 'r') as file:
    json_data = json.load(file)

# Printing each key and its corresponding value
dic = {}
for key, value in json_data.items():
    dic[key] = value[0]['instruction']

images_path = '/scratch/nkusumba/test/images/'
os.makedirs('/scratch/nkusumba/test/outputs/', exist_ok=True)

tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")


print('Started Evaluation')
count = -1
dir_name = []
for dirpath, dirname, filenames in os.walk(images_path):
    if count == -1:
        count = 0
        dir_name = dirname
    if count == 100:
        print('Process done!!!')
        break
    input_path = ''
    output_path = ''
    for file in filenames:
        filepath = os.path.join(dirpath, file)
        if filepath.endswith('input.png'):
            input_path = filepath
        elif filepath.endswith('output1.png'):
            output_path = filepath
    if input_path == '':
        continue
    print(f'Processing {count+1}th image')
    dir = f'/scratch/nkusumba/test/outputs/{count+1}'
    os.makedirs(dir, exist_ok=True)
    out_img = Image.open(output_path)
    out_img.save(f'{dir}/groundtruth.png')

    # Process the image
    img = Image.open(input_path)
    inputs = processor(images=img, return_tensors="pt")
    input_image = inputs["pixel_values"].to(device)
    img.save(f'{dir}/input_image.png')

    # Process the instruction
    instruction = dic[dir_name[count]]
    encoded_instruction = custom_encode(instruction, tokenizer).to(device)
    instruction_attention_mask = torch.zeros(encoded_instruction.shape, dtype=torch.long).to(device)
    instruction_attention_mask[encoded_instruction != tokenizer.pad_token_id] = 1
    with open(f'{dir}/instruction.txt', 'w') as f:
        f.write(instruction)

    # If 'instruction_1' is an empty or alternative string, prepare it similarly
    instruction_1 = ""
    encoded_instruction_1 = custom_encode(instruction_1, tokenizer).to(device)
    instruction_1_attention_mask = torch.zeros(encoded_instruction_1.shape, dtype=torch.long).to(device)
    instruction_1_attention_mask[encoded_instruction_1 != tokenizer.pad_token_id] = 1

    # Visualize the output
    loaded_model.visualization(
        input_img=input_image,
        instruction=encoded_instruction,  
        instruction_attention_mask=instruction_attention_mask,  
        filename=f'{dir}/output.png',
        negative_instruction=encoded_instruction_1,  
        negative_instruction_attention_mask=instruction_1_attention_mask
    )
    print(f'Finished processing {count+1}th image')
    count += 1


In [6]:
import json

json_file_path = '/scratch/nkusumba/test/edit_turns.json'
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

add_ids = []
remove_ids = []
with open(json_file_path, 'r') as file:
    data = json.load(file)
    for d in data:
        id = int(d['input'].split('-')[0])
        if 'add' in d['instruction'].lower():
            add_ids.append((id, d['instruction'], d['input'], d['output']))
        elif 'remove' in d['instruction'].lower():
            remove_ids.append((id, d['instruction'], d['input'], d['output']))


dir_path = '/scratch/nkusumba/test/images'
out_dir = '/scratch/nkusumba/test/aux_outputs'
os.makedirs(out_dir, exist_ok=True)
count = 0

inputs = add_ids[0:25] + remove_ids[0:50]

for i in inputs:
    id = i[0]
    input_img_path = f'{dir_path}/{id}/{i[2]}'
    output_img_path = f'{dir_path}/{id}/{i[3]}'
    instruction = i[1]

    res_out_dir = f'{out_dir}/{count+1}'
    os.makedirs(res_out_dir, exist_ok=True)
    
    out_img = Image.open(output_img_path)
    out_img.save(f'{res_out_dir}/groundtruth.png')

    img = Image.open(input_img_path)
    inputs = processor(images=img, return_tensors="pt")
    input_image = inputs["pixel_values"].to(device)
    img.save(f'{res_out_dir}/input_image.png')

    # Process the instruction
    encoded_instruction = custom_encode(instruction, tokenizer).to(device)
    instruction_attention_mask = torch.zeros(encoded_instruction.shape, dtype=torch.long).to(device)
    instruction_attention_mask[encoded_instruction != tokenizer.pad_token_id] = 1
    with open(f'{res_out_dir}/instruction.txt', 'w') as f:
        f.write(instruction)

    # If 'instruction_1' is an empty or alternative string, prepare it similarly
    instruction_1 = ""
    encoded_instruction_1 = custom_encode(instruction_1, tokenizer).to(device)
    instruction_1_attention_mask = torch.zeros(encoded_instruction_1.shape, dtype=torch.long).to(device)
    instruction_1_attention_mask[encoded_instruction_1 != tokenizer.pad_token_id] = 1

    # Visualize the output
    loaded_model.visualization(
        input_img=input_image,
        instruction=encoded_instruction,  
        instruction_attention_mask=instruction_attention_mask,  
        filename=f'{res_out_dir}/output.png',
        negative_instruction=encoded_instruction_1,  
        negative_instruction_attention_mask=instruction_1_attention_mask
    )
    print(f'Finished processing {count+1}th image')
    count += 1


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 18.05it/s]


Finished processing 1th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.06it/s]


Finished processing 2th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.14it/s]


Finished processing 3th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.32it/s]


Finished processing 4th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 17.61it/s]


Finished processing 5th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.85it/s]


Finished processing 6th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.81it/s]


Finished processing 7th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.02it/s]


Finished processing 8th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 18.03it/s]


Finished processing 9th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.23it/s]


Finished processing 10th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.23it/s]


Finished processing 11th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.02it/s]


Finished processing 12th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.86it/s]


Finished processing 13th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.16it/s]


Finished processing 14th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.13it/s]


Finished processing 15th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.25it/s]


Finished processing 16th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 18.95it/s]


Finished processing 17th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.80it/s]


Finished processing 18th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.26it/s]


Finished processing 19th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.80it/s]


Finished processing 20th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.14it/s]


Finished processing 21th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.35it/s]


Finished processing 22th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.44it/s]


Finished processing 23th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.09it/s]


Finished processing 24th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.58it/s]


Finished processing 25th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.70it/s]


Finished processing 26th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.38it/s]


Finished processing 27th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.32it/s]


Finished processing 28th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.86it/s]


Finished processing 29th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 18.66it/s]


Finished processing 30th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.41it/s]


Finished processing 31th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.38it/s]


Finished processing 32th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.45it/s]


Finished processing 33th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.28it/s]


Finished processing 34th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.73it/s]


Finished processing 35th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 17.79it/s]


Finished processing 36th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.64it/s]


Finished processing 37th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.18it/s]


Finished processing 38th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.06it/s]


Finished processing 39th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.17it/s]


Finished processing 40th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.96it/s]


Finished processing 41th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.32it/s]


Finished processing 42th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.86it/s]


Finished processing 43th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.15it/s]


Finished processing 44th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.72it/s]


Finished processing 45th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.06it/s]


Finished processing 46th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.18it/s]


Finished processing 47th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.18it/s]


Finished processing 48th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.10it/s]


Finished processing 49th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.12it/s]


Finished processing 50th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.01it/s]


Finished processing 51th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.92it/s]


Finished processing 52th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.16it/s]


Finished processing 53th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.99it/s]


Finished processing 54th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.94it/s]


Finished processing 55th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.52it/s]


Finished processing 56th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.26it/s]


Finished processing 57th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.31it/s]


Finished processing 58th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 18.32it/s]


Finished processing 59th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.80it/s]


Finished processing 60th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.72it/s]


Finished processing 61th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 20.00it/s]


Finished processing 62th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.99it/s]


Finished processing 63th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.93it/s]


Finished processing 64th image


100%|█████████████████████████████████████████| 100/100 [00:04<00:00, 20.02it/s]


Finished processing 65th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.93it/s]


Finished processing 66th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.97it/s]


Finished processing 67th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.94it/s]


Finished processing 68th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.84it/s]


Finished processing 69th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.81it/s]


Finished processing 70th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 20.00it/s]


Finished processing 71th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.72it/s]


Finished processing 72th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.89it/s]


Finished processing 73th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.79it/s]


Finished processing 74th image


100%|█████████████████████████████████████████| 100/100 [00:05<00:00, 19.76it/s]


Finished processing 75th image


In [7]:
!zip -r /scratch/nkusumba/test/kandinsky_aux_outputs.zip /scratch/nkusumba/test/aux_outputs

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
updating: scratch/nkusumba/test/aux_outputs/ (stored 0%)
updating: scratch/nkusumba/test/aux_outputs/50/ (stored 0%)
updating: scratch/nkusumba/test/aux_outputs/50/groundtruth.png (deflated 0%)
updating: scratch/nkusumba/test/aux_outputs/50/output.png (deflated 1%)
updating: scratch/nkusumba/test/aux_outputs/50/instruction.txt (stored 0%)
updating: scratch/nkusumba/test/aux_outputs/50/input_image.png (deflated 0%)
updating: scratch/nkusumba/test/aux_outputs/16/ (stored 0%)
updating: scratch/nkusumba/test/aux_outputs/16/groundtruth.png (deflated 0%)
updating: scratch/nkusumba/test/aux_outputs/16/input_image.png (deflated 0%)
updating: scratch/nkusumba/test/aux_outputs/16/output.png (deflated 0%)
updating: scr