In [1]:
import torch

# If there's a GPU available...
if torch.cuda.is_available():
    # Tell PyTorch to use the GPU.
    device = torch.device("cuda")

    print("There are %d GPU(s) available." % torch.cuda.device_count())

    print("We will use the GPU:", torch.cuda.get_device_name(0))

# If not...
elif torch.backends.mps.is_available():
    device = torch.device("mps")

    print("Using mps backend")
else:
    print("No GPU available, using the CPU instead.")
    device = torch.device("cpu")

There are 1 GPU(s) available.
We will use the GPU: NVIDIA A100-SXM4-80GB


In [2]:
from transformers import AutoTokenizer, AutoProcessor
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import requests
from PIL import Image
import numpy as np
from io import BytesIO
from diffusers import (
    KandinskyV22Pipeline,
    KandinskyV22PriorEmb2EmbPipeline,
    KandinskyV22PriorPipeline,
)
from diffusers.utils import load_image
from torchvision.transforms import ToPILImage

In [10]:
from transformers import CLIPTextModelWithProjection, CLIPVisionModelWithProjection, BertForSequenceClassification
import torch.nn as nn
from transformers import AdamW
import os


class T2IModel(nn.Module):
    def __init__(self):
        super(T2IModel, self).__init__()
        self.text_model = BertForSequenceClassification.from_pretrained(
            'bert-base-uncased'
        )
        self.vision_model = CLIPVisionModelWithProjection.from_pretrained(
            "kandinsky-community/kandinsky-2-2-prior", subfolder="image_encoder"
        )
        self.fc = nn.Linear(2560, 1280)

    def initialize_optimizer(self):
        params = (
            list(self.fc.parameters())
        )
        optimizer = AdamW(params, lr=1e-4)
        return optimizer

    def forward(self, input_imgs, instructions):
        text_embeds = self.text_model(instructions).text_embeds
        vision_embeds = self.vision_model(input_imgs).image_embeds
        x = torch.cat((vision_embeds, text_embeds), 1)
        x = self.fc(x)
        return x

    def output_embedding(self, target_images):
        target_image_output = self.vision_model(target_images)
        target_image_embeds = target_image_output.image_embeds
        return target_image_embeds

    def custom_loss(self, output_embeddings, target_embeddings):
        mse_loss = nn.MSELoss()
        loss = mse_loss(output_embeddings, target_embeddings)

        return loss

    def save_model(self, output_dir="../model_save/", filename="model_checkpoint.pt"):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        file_path = output_dir + filename
        print("Saving model to %s" % file_path)

        torch.save(model.state_dict(), file_path)

    def get_cos(self, input1, input2):
        cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
        similarity = cos(input1, input2)
        avg = torch.sum(similarity) / len(similarity)
        return avg

    def metrics(self, input1, input2):
        cos = self.get_cos(input1, input2)
        return [cos]

    def visualization(self, input_img, instruction, filename,negative_instruction = ""):
        output_embeddings = self.forward(input_img, instruction)
        neg_image_embed = self.forward(input_img, negative_instruction)

        pipe = KandinskyV22Pipeline.from_pretrained(
            "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
        )
        pipe.to(device)
        image = pipe(
            image_embeds = output_embeddings,
            negative_image_embeds = neg_image_embed,
            height = 768,
            width = 768,
            num_inference_steps=100,
        ).images

        image[0].save(filename)


model = T2IModel()
model.to(device=device)

T2IModel(
  (text_model): CLIPTextModelWithProjection(
    (text_model): CLIPTextTransformer(
      (embeddings): CLIPTextEmbeddings(
        (token_embedding): Embedding(49408, 1280)
        (position_embedding): Embedding(77, 1280)
      )
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-31): 32 x CLIPEncoderLayer(
            (self_attn): CLIPAttention(
              (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
              (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
              (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
              (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
            )
            (layer_norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
            (mlp): CLIPMLP(
              (activation_fn): GELUActivation()
              (fc1): Linear(in_features=1280, out_features=5120, bias=True)
              (fc2): Linear(in_features=5120, out_f

In [4]:
optimizer = model.initialize_optimizer()



In [11]:
import torch

def load_model_from_checkpoint(model, checkpoint_path, device='cuda'):
    """
    Load a PyTorch model from a saved checkpoint.
    
    Parameters:
    - model (torch.nn.Module): The model architecture (untrained).
    - checkpoint_path (str): Path to the saved model checkpoint (.pth file).
    - device (str): Device to which the model should be loaded ('cuda' or 'cpu').

    Returns:
    - model (torch.nn.Module): Model populated with the loaded weights.
    """

    # Load the model state dictionary from the specified path
    state_dict = torch.load(checkpoint_path, map_location=device)
    
    # Load the state dictionary into the model
    model.load_state_dict(state_dict)
    
    # Move the model to the desired device
    model.to(device)
    
    return model

# Usage
loaded_model = load_model_from_checkpoint(model, 'magicbrush_kadinsky_imagewithinstruction_10epochs_full_v1.pth', device='cuda')


In [None]:
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm

def test(model, dataloader, device):
    """
    Run the model on test data and compute average loss and cosine similarity.

    Args:
    - model (T2IModel): Model to test.
    - dataloader (DataLoader): DataLoader containing the test data.
    - device (str): Device on which to run the model (e.g., 'cuda' or 'cpu').

    Returns:
    - tuple: Average loss and average cosine similarity over the test data.
    """

    # Put the model in evaluation mode.
    model.eval()
    
    total_eval_loss = 0
    avg_cos = 0
    nb_eval_steps = 0

    with torch.no_grad():
        val_progress = tqdm(dataloader, desc="Test Set Validation", leave=False)
        for batch in val_progress:
            model.zero_grad()

            # Forward pass
            input_images = batch[0].to(device=device)
            instructions = batch[1].to(device=device) 
            target_images = batch[2].to(device=device)
            output_embeddings = model(input_images, instructions)
            
            # Generating target embeddings
            target_img_embeddings = model.output_embedding(target_images.to(device=device))

            # Calculate loss
            loss = model.custom_loss(output_embeddings, target_img_embeddings)
            total_eval_loss += loss.item()
            avg_cos += model.get_cos(output_embeddings, target_img_embeddings)

            # Update the progress bar with the validation loss
            val_progress.set_postfix({'validation_loss': '{:.3f}'.format(loss.item()/len(batch))})
            
            # After processing the batch:
            del input_images, instructions, target_images, output_embeddings, target_img_embeddings, loss
            torch.cuda.empty_cache()
        
    avg_val_loss = total_eval_loss / len(dataloader)
    print(" Test Set Validation Loss: {0:.2f}".format(avg_val_loss))

    avg_cos /= len(dataloader)
    print(" Test Set Validation cosine similarity: {0:.2f}".format(avg_cos))
    
    
    
    
    return avg_val_loss, avg_cos




In [6]:
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

test_dataset = torch.load('test_dataset_magicbrush.pth')
test_dataloader = DataLoader(
    test_dataset, sampler=RandomSampler(test_dataset), batch_size=32
)

In [None]:
loaded_model

In [None]:
# Assuming you have the device defined already
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

avg_loss, avg_cos_sim = test(loaded_model, test_dataloader, device)

In [12]:
def compute_max_instruction_length(dataloader, tokenizer):
    """
    Compute the maximum instruction length from the dataloader batches.

    Args:
    - dataloader (DataLoader): DataLoader containing your data.
    - tokenizer: Tokenizer used to tokenize the instructions.

    Returns:
    - int: Maximum instruction length.
    """
    max_len = 0
    
    for batch in dataloader:
        instructions = batch[1]  # Assuming instructions are in position 1 in your batch
        for instruction in instructions:
            decoded_string = tokenizer.decode(instruction)
            tokens = tokenizer.tokenize(decoded_string)
            length = len(tokens)
            if length > max_len:
                max_len = length
                
    return max_len

def custom_encode(instruction, max_len, tokenizer):
    encoded_inst = tokenizer.encode_plus(
    instruction,  # Sentence to encode.
    add_special_tokens=True,  # Add '[CLS]' and '[SEP]'
    max_length=max_len + 10,  # Pad & truncate all sentences.
    pad_to_max_length=True,
    padding="max_length",
    return_tensors="pt",  # Return pytorch tensors
    )
    
    return encoded_inst["input_ids"]
    

In [16]:
from PIL import Image
fileSuffix = "6197"
img = Image.open(fileSuffix + "-input.png")
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

input_image = processor(images=img, return_tensors="pt")["pixel_values"]

max_len = compute_max_instruction_length(test_dataloader, tokenizer)

instruction = "It should be a bus not a truck"
instruction_1 = ""

input_image = input_image.to(device)
instruction = custom_encode(instruction,max_len,tokenizer).to(device)
instruction_1 = custom_encode(instruction_1,max_len,tokenizer).to(device)

model.visualization(input_image, instruction, fileSuffix + "-output_generated.png",instruction_1)

Loading pipeline components...:   0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
str_check = "Change the stop sign to say \"GO\""
str_check