In [1]:
import torch

# If there's a GPU available...
if torch.cuda.is_available():
    # Tell PyTorch to use the GPU.
    device = torch.device("cuda")

    print("There are %d GPU(s) available." % torch.cuda.device_count())

    print("We will use the GPU:", torch.cuda.get_device_name(0))

# If not...
elif torch.backends.mps.is_available():
    device = torch.device("mps")

    print("Using mps backend")
else:
    print("No GPU available, using the CPU instead.")
    device = torch.device("cpu")

There are 1 GPU(s) available.
We will use the GPU: NVIDIA A100-SXM4-80GB


In [2]:
from transformers import AutoTokenizer, AutoProcessor
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import requests
from PIL import Image
import numpy as np
from io import BytesIO
from diffusers import (
    KandinskyV22Pipeline,
    KandinskyV22PriorEmb2EmbPipeline,
    KandinskyV22PriorPipeline,
)
from diffusers.utils import load_image
from torchvision.transforms import ToPILImage

In [36]:
from transformers import CLIPTextModelWithProjection, CLIPVisionModelWithProjection, BertModel
import torch.nn as nn
from transformers import AdamW
import os


class T2IModel(nn.Module):
    def __init__(self):
        super(T2IModel, self).__init__()
        self.text_model = BertModel.from_pretrained('bert-base-uncased')
        self.vision_model = CLIPVisionModelWithProjection.from_pretrained(
            "kandinsky-community/kandinsky-2-2-prior", subfolder="image_encoder"
        )
        # Adjust the input features of the FC layer to the combined size of text and vision outputs
        self.fc = nn.Linear(self.text_model.config.hidden_size + self.vision_model.config.projection_dim, 1280)


    def initialize_optimizer(self):
        params = (
            list(self.fc.parameters())
        )
        optimizer = AdamW(params, lr=1e-4)
        return optimizer

    def forward(self, input_imgs, input_txt, attention_mask=None):
        text_outputs = self.text_model(input_txt, attention_mask=attention_mask)
        text_embeds = text_outputs.last_hidden_state[:, 0, :]  # Use the representation of the [CLS] token

        vision_outputs = self.vision_model(input_imgs)
        vision_embeds = vision_outputs.image_embeds

        combined_embeds = torch.cat((vision_embeds, text_embeds), dim=1)
        x = self.fc(combined_embeds)
        return x

    def output_embedding(self, target_images):
        target_image_output = self.vision_model(target_images)
        target_image_embeds = target_image_output.image_embeds
        return target_image_embeds

    def custom_loss(self, output_embeddings, target_embeddings):
        mse_loss = nn.MSELoss()
        loss = mse_loss(output_embeddings, target_embeddings)

        return loss

    def save_model(self, output_dir="../model_save/", filename="model_checkpoint.pt"):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        file_path = output_dir + filename
        print("Saving model to %s" % file_path)

        torch.save(model.state_dict(), file_path)

    def get_cos(self, input1, input2):
        cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
        similarity = cos(input1, input2)
        avg = torch.sum(similarity) / len(similarity)
        return avg

    def metrics(self, input1, input2):
        cos = self.get_cos(input1, input2)
        return [cos]

    def visualization(self, input_img, instruction, instruction_attention_mask, filename, negative_instruction=None, negative_instruction_attention_mask=None):
        # Generate output embeddings with the provided attention mask
        output_embeddings = self.forward(input_img, instruction, attention_mask=instruction_attention_mask)

        # Handle the negative instruction if provided
        neg_image_embed = None
        if negative_instruction is not None and negative_instruction_attention_mask is not None:
            neg_image_embed = self.forward(input_img, negative_instruction, attention_mask=negative_instruction_attention_mask)
        else:
            # If no negative instruction is provided, we'll use a tensor of zeros as a placeholder
            neg_image_embed = torch.zeros_like(output_embeddings)

        # Initialize the pipeline for the Kandinsky V2.2 decoder
        pipe = KandinskyV22Pipeline.from_pretrained(
            "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
        )
        pipe.to(device)  # Make sure 'self.device' is defined in your model class

        # Generate the image using the pipeline
        image = pipe(
            image_embeds=output_embeddings,
            negative_image_embeds=neg_image_embed,
            height=768,
            width=768,
            num_inference_steps=100,
        ).images

        # Save the generated image
        image[0].save(filename)

model = T2IModel()
model.to(device=device)

T2IModel(
  (text_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

In [38]:
optimizer = model.initialize_optimizer()

In [39]:
import torch

def load_model_from_checkpoint(model, checkpoint_path, device='cuda'):
    """
    Load a PyTorch model from a saved checkpoint.
    
    Parameters:
    - model (torch.nn.Module): The model architecture (untrained).
    - checkpoint_path (str): Path to the saved model checkpoint (.pth file).
    - device (str): Device to which the model should be loaded ('cuda' or 'cpu').

    Returns:
    - model (torch.nn.Module): Model populated with the loaded weights.
    """

    # Load the model state dictionary from the specified path
    state_dict = torch.load(checkpoint_path, map_location=device)
    
    # Load the state dictionary into the model
    model.load_state_dict(state_dict)
    
    # Move the model to the desired device
    model.to(device)
    
    return model

# Usage
loaded_model = load_model_from_checkpoint(model, 'magicbrush_kadinsky_bert_imagewithinstruction_10epochs_full_v1.pth', device='cuda')


In [30]:
from transformers import BertTokenizer
from tqdm import tqdm
import torch

# Initialize the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def test(model, dataloader, device):
    """
    Run the model on test data and compute average loss and cosine similarity.

    Args:
    - model (T2IModel): Model to test.
    - dataloader (DataLoader): DataLoader containing the test data.
    - device (str): Device on which to run the model (e.g., 'cuda' or 'cpu').

    Returns:
    - tuple: Average loss and average cosine similarity over the test data.
    """

    # Put the model in evaluation mode.
    model.eval()

    total_eval_loss = 0
    avg_cos = 0
    nb_eval_steps = 0

    with torch.no_grad():
        val_progress = tqdm(dataloader, desc="Test Set Validation", leave=False)
        for batch in val_progress:
            
            # Convert textual input to the right format
            instructions = batch[1]
            if isinstance(instructions, torch.Tensor):
                instructions = instructions.tolist()
            if not isinstance(instructions[0], str):
                instructions = [tokenizer.decode(text_input) for text_input in instructions]
                
            inputs = tokenizer(instructions, padding=True, truncation=True, return_tensors="pt")
            input_ids = inputs['input_ids'].to(device)
            attention_masks = inputs['attention_mask'].to(device)

            # Prepare image data
            input_images = batch[0].to(device)
            target_images = batch[2].to(device)

            # Forward pass
            output_embeddings = model(input_images, input_ids, attention_mask=attention_masks)
            target_img_embeddings = model.output_embedding(target_images)

            # Calculate loss
            loss = model.custom_loss(output_embeddings, target_img_embeddings)
            total_eval_loss += loss.item()
            avg_cos += model.get_cos(output_embeddings, target_img_embeddings).item()

            # Update the progress bar with the validation loss
            val_progress.set_postfix({'validation_loss': '{:.3f}'.format(loss.item()/len(batch))})
            
            # After processing the batch:
            del input_images, input_ids, attention_masks, target_images, output_embeddings, target_img_embeddings, loss
            torch.cuda.empty_cache()
        
    avg_val_loss = total_eval_loss / len(dataloader)
    print(" Test Set Validation Loss: {0:.2f}".format(avg_val_loss))

    avg_cos /= len(dataloader)
    print(" Test Set Validation cosine similarity: {0:.2f}".format(avg_cos))
    
    return avg_val_loss, avg_cos


In [31]:
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

test_dataset = torch.load('test_dataset_kandisky_bert_magicbrush.pth')
test_dataloader = DataLoader(
    test_dataset, sampler=RandomSampler(test_dataset), batch_size=32
)

In [None]:
loaded_model

In [32]:
# Assuming you have the device defined already
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

avg_loss, avg_cos_sim = test(loaded_model, test_dataloader, device)

                                                                                

 Test Set Validation Loss: 0.25
 Test Set Validation cosine similarity: 0.91




In [33]:
def compute_max_instruction_length(dataloader, tokenizer):
    """
    Compute the maximum instruction length from the dataloader batches.

    Args:
    - dataloader (DataLoader): DataLoader containing your data.
    - tokenizer: Tokenizer used to tokenize the instructions.

    Returns:
    - int: Maximum instruction length.
    """
    max_len = 0
    
    for batch in dataloader:
        instructions = batch[1]  # Assuming instructions are in position 1 in your batch
        for instruction in instructions:
            decoded_string = tokenizer.decode(instruction)
            tokens = tokenizer.tokenize(decoded_string)
            length = len(tokens)
            if length > max_len:
                max_len = length
                
    return max_len

def custom_encode(instruction, max_len, tokenizer):
    encoded_inst = tokenizer.encode_plus(
    instruction,  # Sentence to encode.
    add_special_tokens=True,  # Add '[CLS]' and '[SEP]'
    max_length=max_len + 10,  # Pad & truncate all sentences.
    pad_to_max_length=True,
    padding="max_length",
    return_tensors="pt",  # Return pytorch tensors
    )
    
    return encoded_inst["input_ids"]
    

In [41]:
from PIL import Image
from transformers import BertTokenizer, CLIPProcessor
import torch

# Assuming 'device' is already defined as 'cuda' or 'cpu'
file_suffix = "44437"
img = Image.open(file_suffix + "-input.png")

# Initialize the tokenizer and the processor
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Compute the max length of instructions in the dataloader
max_len = compute_max_instruction_length(test_dataloader, tokenizer)

# Process the image
inputs = processor(images=img, return_tensors="pt")
input_image = inputs["pixel_values"].to(device)

# Prepare the text input with padding and attention mask
instruction = "put a face mask on one of the players"
encoded_instruction = custom_encode(instruction, max_len, tokenizer).to(device)
instruction_attention_mask = torch.zeros(encoded_instruction.shape, dtype=torch.long).to(device)
instruction_attention_mask[encoded_instruction != tokenizer.pad_token_id] = 1

# If 'instruction_1' is an empty or alternative string, prepare it similarly
instruction_1 = ""
encoded_instruction_1 = custom_encode(instruction_1, max_len, tokenizer).to(device)
instruction_1_attention_mask = torch.zeros(encoded_instruction_1.shape, dtype=torch.long).to(device)
instruction_1_attention_mask[encoded_instruction_1 != tokenizer.pad_token_id] = 1

# Visualize the output
loaded_model.visualization(
    input_img=input_image,
    instruction=encoded_instruction,  
    instruction_attention_mask=instruction_attention_mask,  
    filename=file_suffix + "-output_generated_bert.png",
    negative_instruction=encoded_instruction_1,  
    negative_instruction_attention_mask=instruction_1_attention_mask
)


Loading pipeline components...:   0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
str_check = "Change the stop sign to say \"GO\""
str_check