In [1]:
import torch
from transformers import CLIPVisionModel, CLIPImageProcessor, AutoModel, AutoTokenizer, BitsAndBytesConfig
from accelerate import Accelerator
from accelerate.utils import gather_object
from parallelformers import parallelize
import warnings
from PIL import Image
import torchvision
from torchvision import transforms
from torch.nn import DataParallel
from torch.cuda.amp import autocast
import os
import numpy as np
import bitsandbytes as bnb

d_type=torch.float16

def get_image_encoder_internlm():
    vision_tower_name="internlm/internlm-xcomposer2d5-clip"
    warnings.filterwarnings("ignore")

    vision_tower = CLIPVisionModel.from_pretrained(vision_tower_name).eval().cuda()
    vision_tower.half()
    vision_tower.requires_grad_(True)

    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs")
        vision_tower = DataParallel(vision_tower)

    img_size=560

    return vision_tower, None, None, img_size

def encode_image_internlm(image_encoder, X_adv, img_size, bs, diff_aug, orig_sizes):
    images = []
    for j in range(bs):
        orig_w, orig_h = orig_sizes[j]

        img = X_adv[j][:, :orig_h, :orig_w]
        img = __resize_img__(img)
        img = torch.nn.functional.interpolate(img.unsqueeze(0), size=(img_size, img_size), mode='bicubic')

        if diff_aug:
            img = diff_aug(img).cuda()
        else:
            img = img.cuda()


        images.append(img)

    images = torch.cat(images, dim=0).cuda()

    with torch.autocast(device_type='cuda', dtype=d_type):
        image_embeds = image_encoder(images)

    return image_embeds.last_hidden_state

def __resize_img__(img):
    """Resize the image with padding to maintain aspect ratio."""
    _, h, w = img.shape
    target_size = max(h, w)
    
    # Calculate padding
    top_padding = (target_size - h) // 2
    bottom_padding = target_size - h - top_padding
    left_padding = (target_size - w) // 2
    right_padding = target_size - w - left_padding

    # Apply padding to make the image square
    padded_img = torchvision.transforms.functional.pad(
        img, [left_padding, top_padding, right_padding, bottom_padding]
    )

    return padded_img

def get_internlm_model():
    model_name="internlm/internlm-xcomposer2d5-7b"
    d_type=torch.float16
    warnings.filterwarnings("ignore")
    # self.accelerator = Accelerator()

    quantization_config = BitsAndBytesConfig(
        load_in_8bit=True,  # Use 8-bit quantization
        llm_int8_threshold=200.0  # Adjust threshold for 8-bit quantization if necessary
    )

    model = AutoModel.from_pretrained(
        model_name, 
        torch_dtype=d_type, 
        # quantization_config=quantization_config,
        trust_remote_code=True, 
        device_map='auto'
    )
    # self.model = self.accelerator.prepare(self.model)
    # parallelize(self.model, num_gpus=2, fp16=True, verbose='detail')
    model.eval()
    
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model.tokenizer = tokenizer
    num_beams = 3
    return model, tokenizer

def Image_transform(img, hd_num=25):
    width, height = img.size
    trans = False
    if width < height:
        img = img.transpose(Image.TRANSPOSE)
        trans = True
        width, height = img.size
    ratio = (width/ height)
    scale = 1
    while scale*np.ceil(scale/ratio) <= hd_num:
        scale += 1
    scale -= 1
    scale = min(np.ceil(width / 560), scale)
    new_w = int(scale * 560)
    new_h = int(new_w / ratio)
    #print (scale, f'{height}/{new_h}, {width}/{new_w}')

    img = transforms.functional.resize(img, [new_h, new_w],)
    img = padding_336(img, 560)
    width, height = img.size
    if trans:
        img = img.transpose(Image.TRANSPOSE)

    return img

def padding_336(b, pad=336):
    width, height = b.size
    tar = int(np.ceil(height / pad) * pad)
    top_padding = 0 # int((tar - height)/2)
    bottom_padding = tar - height - top_padding
    left_padding = 0
    right_padding = 0
    b = transforms.functional.pad(b, [left_padding, top_padding, right_padding, bottom_padding], fill=[255,255,255])

    return b

def get_response_internlm(image, text_prompt, tokenizer, model):
        query = text_prompt
        with torch.no_grad():
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                print(transforms.ToTensor()(Image_transform(image)).shape)
                response, his = model.chat(tokenizer, query, transforms.ToTensor()(Image_transform(image)).unsqueeze(0).unsqueeze(0), do_sample=False, num_beams=3, use_meta=True)
        prediction = response.strip()
        return prediction

[2024-12-06 15:14:23,491] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
model, tokenizer = get_internlm_model()

Set max length to 16384


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [33]:
chat = lambda image: get_response_internlm(transforms.ToPILImage()(image), "what is in this image", tokenizer, model)

In [4]:
image = transforms.ToTensor()(Image.open("../data/task_data/Mini_MathVista_base_hamburgerFries_target/target_train/0.png").convert('RGB'))

In [36]:
chat(image)

torch.Size([3, 1120, 1120])


'The image is a compilation of various data visualization charts and graphs, each representing a different type of data representation. It includes a Bar Graph, a Pie Chart, a Line Graph, a Scatter Plot, a Histogram, a Network Graph, a Heat Map, and a Box Plot. These visualizations are commonly used in the field of data science and analytics to represent data in a way that is easy to understand and analyze.'

In [2]:
def list_image_files(directory):
        """
        Lists all JPG and PNG files in a given directory.

        Args:
            directory (str): Path to the directory.

        Returns:
            list: A list of image file paths.
        """
        image_files = []
        for root, _, files in os.walk(directory):
            for file in files:
                if file.lower().endswith(('.jpg', '.jpeg', '.png')):
                    image_files.append(os.path.join(root, file))
        return image_files

# Get image files from both directories
image_files_dir1 = list_image_files("../data/poisons/llava/Mini_MathVista_base_hamburgerFries_target/image")
image_files_dir2 = list_image_files("../data/task_data/Mini_MathVista_base_hamburgerFries_target/base_train")

In [12]:
import re

def extract_number(file_path):
    match = re.search(r'(\d+)', file_path)  # Extract the number using regex
    return int(match.group(1)) if match else float('inf')  # Return the number or inf if no number is found


image_files_dir1 = sorted(image_files_dir1, key=extract_number)
image_files_dir2 = sorted(image_files_dir2, key=extract_number)


In [None]:
import matplotlib.pyplot as plt
from PIL import Image

i = 0
for poison, base in zip(image_files_dir1, image_files_dir2):
    img1 = Image.open(poison)
    img2 = Image.open(base)

    # Set a fixed height for the images and adjust their aspect ratio
    fixed_height = 600
    img1_resized = img1.resize((int(img1.width * fixed_height / img1.height), fixed_height))
    img2_resized = img2.resize((int(img2.width * fixed_height / img2.height), fixed_height))

    # Create a figure with two subplots side by side
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))  # Adjust figsize as needed

    # Display the first image
    axes[0].imshow(img1_resized)
    axes[0].axis("off")  # Turn off axes for a cleaner look

    # Display the second image
    axes[1].imshow(img2_resized)
    axes[1].axis("off")  # Turn off axes for a cleaner look

    # Adjust layout
    plt.tight_layout()
    plt.show()



    i+=1
    if i > 50:
        break