### Loading the iPhone 16 dataset from HF

In [1]:
from datasets import load_dataset
# load and prepare dataset
ds = load_dataset("ArkaMukherjee/iphone16-dataset")

Downloading readme:   0%|          | 0.00/415 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/15.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.37M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/255 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/30 [00:00<?, ? examples/s]

### LLaVA setup

In [None]:
import os

if not os.path.isdir("LLaVA"):
    !git clone https://github.com/haotian-liu/LLaVA
else:
    print("LLaVA directory already exists. Skipping clone.")

In [None]:
import torch
from peft import LoraConfig, get_peft_model
from PIL import Image
import transformers
from transformers import AutoProcessor, Trainer, TrainingArguments, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, pipeline, logging
import torchvision.transforms as transforms

In [None]:
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model
import torch.nn as nn

#cuda_idx = 2 # edit device index that you want to track
#device = f'cuda:{cuda_idx}'

device = 'cuda'

model_path = "liuhaotian/llava-v1.6-mistral-7b"
#model_path = "liuhaotian/llava-v1.5-7b"
#model_path = "liuhaotian/llava-v1.6-vicuna-7b"

model_name = get_model_name_from_path(model_path)
print(model_name)
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=model_name,
    cache_dir='',
    use_flash_attn=True,
    device=device,
    #load_8bit = #NOT SUPPORTED
    #load_4bit = True,
)

In [None]:
print(model)

### Inference

In [None]:
# Inference
import re
import torch
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.transforms.functional import to_pil_image, to_tensor
from PIL import Image
import requests
from io import BytesIO

from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
)

In [None]:
# Common function to create prompts
def create_prompt (query, model, model_name=model_name, caption=None):
    image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
    if IMAGE_PLACEHOLDER in query:
        if model.config.mm_use_im_start_end:
            query = re.sub(IMAGE_PLACEHOLDER, image_token_se, query)
        else:
            query = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, query)
    else:
        if model.config.mm_use_im_start_end:
            query = image_token_se + "\n" + query
        else:
            query = DEFAULT_IMAGE_TOKEN + "\n" + query
            
    conv_mode = infer_conv_mode(model_name)
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], query)
    if caption is not None:
        conv.append_message(conv.roles[1], caption)
    else:
        conv.append_message(conv.roles[1], None)
    return conv.get_prompt()

# Common function to infer conversation mode
def infer_conv_mode(model_name):
    if "llama-2" in model_name.lower():
        return "llava_llama_2"
    elif "mistral" in model_name.lower():
        return "mistral_instruct"
    elif "v1.6-34b" in model_name.lower():
        return "chatml_direct"
    elif "v1" in model_name.lower():
        return "llava_v1"
    elif "mpt" in model_name.lower():
        return "mpt"
    else:
        return "llava_v0"


In [None]:
import torch
import re
import PIL.Image

def load_image(image_input):
    # Check if the input is a string (path or URL)
    if isinstance(image_input, str):
        if image_input.startswith("http") or image_input.startswith("https"):
            response = requests.get(image_input)
            image = Image.open(BytesIO(response.content)).convert("RGB")
        else:
            image = Image.open(image_input).convert("RGB")
    elif isinstance(image_input, PIL.Image.Image):
        # Input is already an Image object, return as is
        image = image_input
    else:
        raise ValueError("Unsupported image input type")
    return image

# Common function to process images
def process_and_prepare_images(image_files, image_processor, model, device):
    images = [load_image(image_file) for image_file in image_files]
    images_tensor = process_images(
        images,
        image_processor,
        model.config
    ).to(
        device,
        dtype=torch.bfloat16
    )
    image_sizes = [image.size for image in images]
    return images_tensor, image_sizes

def eval_model(tokenizer, model, image_processor, context_len, image_file, query, model_name=model_name, sep=",", temperature=1.0, num_beams=1, max_new_tokens=512):
    # Model
    disable_torch_init()
    
    # Create prompt using the common function
    prompt = create_prompt(query, model, model_name)
    
    print(f"Prompt: {prompt}")
    
    # Process images using the common function
    if isinstance(image_file, list):
        images_tensor, image_sizes = process_and_prepare_images(image_file, image_processor, model, model.device)
    elif isinstance(image_file, str):
        images_tensor, image_sizes = process_and_prepare_images([image_file], image_processor, model, model.device)
    else:
        # If image_file is neither a list nor a string, it's likely an Image object or similar; wrap it in a list
        images = [image_file]
        images_tensor, image_sizes = process_and_prepare_images(images, image_processor, model, model.device)

    images_tensor.to(model.device)
    image_sizes.to(device) if isinstance(image_sizes, torch.Tensor) else image_sizes
    
    # Tokenixe the prompt using the custom tokenizer_image_token function
    input_ids = (
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .to(model.device)
    )
    
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=images_tensor,
            image_sizes=image_sizes,
            do_sample=temperature != 1.0,
            temperature=temperature,
            # top_p=top_p,
            num_beams=num_beams,
            max_new_tokens=max_new_tokens,
            use_cache=True,
        )
    
    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=False)[0].strip()
    print(outputs)

#### Inference on original model

In [None]:
import requests
from PIL import Image
from io import BytesIO

# Dog image URL
image_url = 'https://farm7.staticflickr.com/6119/6315804553_050a2d1f4e_z.jpg'

# Download the image and open it with PIL
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))

import matplotlib.pyplot as plt
# Display the image using matplotplib
plt.imshow(image)
plt.axis('off') # Turn off axis numbers and ticks
plt.show()

# Pass the processed image to eval_model
eval_model(
    tokenizer,
    model,
    image_processor,
    context_len,
    image, # using the processed image
    "What do you see in this picure?"
)

In [None]:
device = 'cuda'

In [None]:
from torch.nn.utils.rnn import pad_sequence

def tokenize_and_create_labels(example_batch, image_processor, tokenizer, model, device, model_name=model_name, ignore_index=-100, image_token_index=IMAGE_TOKEN_INDEX):
    pad_token_id = tokenizer.pad_token_id
    image_files = example_batch['image']
    
    images_tensor, image_sizes = process_and_prepare_images(image_files, image_processor, model, device)

    query = "What do you see in this picture?"
    
    # Tokenize the conversation without the captions to determine which tokens to ignore
    tokenized_conversations_without_caption = [
        tokenizer_image_token(create_prompt(query, model, model_name, None), tokenizer, image_token_index, return_tensors="pt")
        for _ in example_batch['captions']
    ]
    
    # Tokenize the full conversations with the captions
    tokenized_conversations_with_caption = [
        tokenizer_image_token(create_prompt(query, model, model_name, caption), tokenizer, image_token_index, return_tensors="pt")
        for caption in example_batch['captions']
    ]
    
    # Pad the tokenized conversations to the same length
    input_ids = pad_sequence([tcwc.squeeze(0) for tcwc in tokenized_conversations_with_caption], batch_first=True, padding_value=pad_token_id).to(device)
    
    # Create attention_mask (1 for real tokens and 0 for padding tokens)
    attention_mask = (input_ids != pad_token_id).long().to(device)
    
    # Create the labels tensor which is a copy of input_ids but with ignore_index for non-caption tokens
    labels = torch.full_like(input_ids, fill_value=ignore_index)
    for i, tcwc in enumerate(tokenized_conversations_without_caption):
        # Set ignore_index for the tokens corresponding to the conversation without the caption
        input_id_without_caption = tcwc.squeeze(0)
        labels[i, len(input_id_without_caption):] = input_ids[i, len(input_id_without_caption):]

    #print("Labels shape: ", labels.shape)
    
    inputs = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "images": images_tensor,
        "image_sizes": image_sizes,
        "labels": labels,
    }

    return inputs
    
# Make sure to define the function outside of the lambda to ensure it's serializable
def transform_batch(batch):
    return tokenize_and_create_labels(batch, image_processor, tokenizer, model, device, model_name)

train_ds = ds["train"] #train_dataset
eval_ds = ds["test"] #test_dataset

# Apply the transformation function to the dataset
train_ds.set_transform(transform_batch)
eval_ds.set_transform(transform_batch)

In [None]:
print(train_ds)

In [None]:
print(eval_ds)

#### Inference on entire eval dataset

In [None]:
import matplotlib.pyplot as plt

# Temporarily disable the transformation to access the original data
eval_ds.reset_format()

# Iterate over each example in the evaluation dataset
for i in range(len(eval_ds)):
    # Access the original image and caption for the current row
    image = eval_ds[i]['image']
    caption = eval_ds[i]['caption']
    
    # Display the image using matplotlib
    plt.imshow(image)
    plt.axis('off')  # Turn off axis numbers and ticks
    plt.show()

    eval_model(
        tokenizer,
        model,
        image_processor,
        context_len,
        image,
        "What do you see in this picture?"
    )

    print(f"\nCorrect caption: {caption}\n\n")

# Re-enable the transformation if needed
eval_ds.set_transform(lambda batch: tokenize_and_create_labels(batch, image_processor, tokenizer, model, device))

### LoRA setup

In [None]:
from peft import LoraConfig, get_peft_model
import torch

config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=[
        "q_proj", "k_proj", "v_proj",
        # "fc1", "fc2", # for llama,
        "mm_projector" #for mistral, train instead "mm_projector"
        "up_proj", "down_proj", "gate_proj" #optionally train more linarly
    ],
    lora_dropout=0.05,
    bias="none",
)
model = get_peft_model(model, config)
model.to(device)
#model = torch.nn.DataParallel(model)

In [None]:
model.print_trainable_parameters()

In [None]:
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Assuming train_ds is your training dataset prepared as a PyTorch Dataset object
batch_size = 4  # Specify the batch size you want to use
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

# Assuming train_loader is your DataLoader instance for the training dataset
for batch in train_loader:
    batch_size = len(batch)
    print(batch.keys())  # Print the dictionary keys to see what data is included in a batch
    
    # If 'images' is a key, this indicates that images are being loaded
    if 'images' in batch:
        print("Images are included in the DataLoader.")
        print(f"Batch 'images' shape: {batch['images'].shape}")  # Print the shape of the images tensor
        
    # Similarly, check for other expected keys, like 'input_ids' and 'attention_mask'
    if 'input_ids' in batch and 'attention_mask' in batch:
        # Print the first row of input_ids to check for out-of-range token IDs
        input_ids_first_row = batch['input_ids'][1]
        print(f"First row of 'input_ids': \n{input_ids_first_row.tolist()}")

        # # Check if any token IDs are out of range
        # vocab_size = tokenizer.vocab_size
        # out_of_range_tokens = [token_id for token_id in input_ids_first_row if token_id >= vocab_size]
        # if out_of_range_tokens:
        #     print(f"Out-of-range token IDs: {out_of_range_tokens}")

        # # Decode the first row of input_ids to text, if all token IDs are in range
        # if not out_of_range_tokens:
        #     decoded_inputs = tokenizer.decode(input_ids_first_row, skip_special_tokens=False)
        #     print(f"Decoded input tokens: {decoded_inputs}")
        # else:
        #     print("Cannot decode input_ids due to out-of-range token IDs.")
            
        print("Text inputs are included in the DataLoader.")
        print(f"Batch 'input_ids' shape: {batch['input_ids'].shape}")
        print(f"Batch 'attention_mask' shape: {batch['attention_mask'].shape}")
        
        # # Decode the first row of input_ids to text
        # decoded_inputs = tokenizer.decode(batch['input_ids'][0], skip_special_tokens=False)
        # print(f"Decoded input tokens: {decoded_inputs}")
        
        # Print the first row of labels, replacing ignore_index with the string '[IGNORE]'
        labels = batch['labels'][1].tolist()
        labels_str = ['[IGNORE]' if label == -100 else str(label) for label in labels]
        print(f"Labels: {labels_str}")
        print(len(labels))
        
        # Print the first row of the attention_mask
        attention_mask_str = batch['attention_mask'][1].tolist()
        print(f"Attention mask: {attention_mask_str}")
    
    # Optionally, display an image from the batch to visually confirm loading
    if 'images' in batch:
        image_tensor = batch['images'][1]
        print(f"First Row Image Data type: {image_tensor.dtype}")
        print(f"First Row Image Shape: {image_tensor.shape}")
        print(f"First Row Image Value range: [{image_tensor.min()}, {image_tensor.max()}]")
    
    break  # Only check the first batch

In [None]:
from torch.nn import CrossEntropyLoss
ignore_index = -100

# subclass trainer
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
#        labels = inputs.pop("labels")
        
#        outputs = model(**inputs)
#        logits = outputs.logits
#        print("Shape of logits: ", logits.shape)
#        print("Shape of labels: ", labels.shape)
#        loss = loss_eval(logits, labels)

#        return (loss, outputs) if return_outputs else loss

        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        outputs = model(**inputs)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            unwrapped_model = self.accelerator.unwrap_model(model)
            if _is_peft_model(unwrapped_model):
                model_name = unwrapped_model.base_model.model._get_name()
            else:
                model_name = unwrapped_model._get_name()
            if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        #print("Shape of logits: ", logits.shape)
        #print("Shape of labels: ", labels.shape)
        negative_loss = -loss
        return (negative_loss, outputs) if return_outputs else negative_loss

In [None]:
import torch.nn as nn
import torch

#cuda_idx = 2 # edit device index that you want to track
#device = f'cuda:{cuda_idx}'

device = 'cuda'

output_model_name=f"{model_name}-dogs-unlearned"

training_args = TrainingArguments(
    #max_steps=3, # Comment this out after training for TWO STEPS or THREE STEPS
    output_dir=output_model_name,
    learning_rate=1e-4,
    # fp16=True, #for non ampere gpus
    bf16=True,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=32,
    dataloader_pin_memory=False,
    save_total_limit=2,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=1,
    eval_steps=1,
    logging_steps=1,
    num_train_epochs=3,
    remove_unused_columns=False,
    push_to_hub=False,
    label_names=["labels"],
    load_best_model_at_end=True,
    report_to=None,
    optim="adamw_torch",
    #gradient_checkpointing=True,
    #gradient_checkpointing_kwargs={'use_reentrant':True}
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
)

trainer.train()

### Post fine-tuning evaluation

In [None]:
from peft import PeftModel

adapter_path = "liuhaotian/llava-v1.6-mistral-7b-dogs-unlearned/checkpoint-1"
model_path = "liuhaotian/llava-v1.6-mistral-7b"

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=model_name,
    cache_dir='',
    use_flash_attn=True,
    # load_8bit=True #NOT SUPPORTED YET WITH THIS SCRIPT
    # load_4bit=True #NOT SUPPORTED YET WITH THIS SCRIPT
)

model = PeftModel.from_pretrained(
    model,
    adapter_path,
)

In [None]:
import matplotlib.pyplot as plt

# Temporarily disable the transformation to access the original data
eval_ds.reset_format()

# Iterate over each example in the evaluation dataset
for i in range(len(eval_ds)):
    # Access the original image and caption for the current row
    image = eval_ds[i]['image']
    caption = eval_ds[i]['captions']
    
    # Display the image using matplotlib
    plt.imshow(image)
    plt.axis('off')  # Turn off axis numbers and ticks
    plt.show()

    eval_model(
        tokenizer,
        model,
        image_processor,
        context_len,
        image,
        "What do you see in this picture?"
    )

    print(f"\nCorrect caption: {caption}\n\n")

# Re-enable the transformation if needed
eval_ds.set_transform(lambda batch: ds_transforms(batch, image_processor, tokenizer, model, device))

### Unlearning with negative loss function

#### Epoch 1:

In [None]:
import torch.nn as nn
import torch

#cuda_idx = 2 # edit device index that you want to track
#device = f'cuda:{cuda_idx}'

device = 'cuda'

output_model_name=f"{model_name}-dogs-unlearned"

training_args = TrainingArguments(
    #max_steps=3, # Comment this out after training for TWO STEPS or THREE STEPS
    output_dir=output_model_name,
    learning_rate=1e-4,
    # fp16=True, #for non ampere gpus
    bf16=True,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=32,
    dataloader_pin_memory=False,
    save_total_limit=2,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=1,
    eval_steps=1,
    logging_steps=1,
    num_train_epochs=1,
    remove_unused_columns=False,
    push_to_hub=False,
    label_names=["labels"],
    load_best_model_at_end=True,
    report_to=None,
    optim="adamw_torch",
    #gradient_checkpointing=True,
    #gradient_checkpointing_kwargs={'use_reentrant':True}
)

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
)

trainer.train()

#### Epoch 1 evaluation:

In [None]:
from peft import PeftModel

adapter_path = "liuhaotian/llava-v1.6-mistral-7b-dogs-unlearned/checkpoint-1"
model_path = "liuhaotian/llava-v1.6-mistral-7b"

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=model_name,
    cache_dir='',
    use_flash_attn=True,
    # load_8bit=True #NOT SUPPORTED YET WITH THIS SCRIPT
    # load_4bit=True #NOT SUPPORTED YET WITH THIS SCRIPT
)

model = PeftModel.from_pretrained(
    model,
    adapter_path,
)

In [None]:
import matplotlib.pyplot as plt

# Temporarily disable the transformation to access the original data
eval_ds.reset_format()

# Iterate over each example in the evaluation dataset
for i in range(len(eval_ds)):
    # Access the original image and caption for the current row
    image = eval_ds[i]['image']
    caption = eval_ds[i]['captions']
    
    # Display the image using matplotlib
    plt.imshow(image)
    plt.axis('off')  # Turn off axis numbers and ticks
    plt.show()

    eval_model(
        tokenizer,
        model,
        image_processor,
        context_len,
        image,
        "What do you see in this picture?"
    )

    print(f"\nCorrect caption: {caption}\n\n")

# Re-enable the transformation if needed
eval_ds.set_transform(lambda batch: ds_transforms(batch, image_processor, tokenizer, model, device))