# Setup

In [38]:
from PIL import Image
import requests
from transformers import AutoProcessor, LlavaForConditionalGeneration, Blip2Model, Blip2Processor, Blip2ForConditionalGeneration, Blip2Config
import time
import torch
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss
from typing import Any, Optional, Tuple, Union
from dataclasses import dataclass
import pandas as pd

# Load Model Weights

## LLaVA

In [None]:
# model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True,).to(0)
# processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

## BLIP2

In [39]:
# Download BLIP2 weights
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16)
# model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# Inference

In [None]:
prompt = "<image>\nUSER: What is in the image?\nASSISTANT:"
label_text = '</s> The image is of two cats laying on a couch with remotes on the couch\n'
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
# url = "https://i.redd.it/s2fi2wr7ibe21.jpg"
image = Image.open(requests.get(url, stream=True).raw)

In [None]:
display(image)

In [None]:
# inputs = processor(image, return_tensors="pt").to(0, torch.float16)
inputs = processor(image, prompt, return_tensors="pt").to(0, torch.float16)
label_input_ids = processor.tokenizer.encode(label_text, return_tensors="pt")
label_input_ids

In [None]:
configuration = Blip2Config()
configuration.text_config.vocab_size
labels = torch.randint(0, configuration.text_config.vocab_size, size=(1, 11))
labels

In [None]:
# Generate
start_time = time.time()
generated_ids = model.generate(**inputs, max_length=200, do_sample=False)
processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
end_time = time.time()
print(f'Took {end_time - start_time} seconds')
processor.decode(generated_ids[0])

In [None]:
# Forward
print(label_input_ids.shape)
print(inputs.input_ids.shape)
start_time = time.time()
generated_ids = model(pixel_values=inputs.pixel_values, input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, labels=label_input_ids)
end_time = time.time()
print(f'Took {end_time - start_time} seconds')
generated_ids.keys()

In [None]:
generated_ids.loss.shape
logits = generated_ids.logits
labels = label_input_ids
print(f'LOGITS: {logits.shape}')
print(f'{inputs.input_ids.shape}')
print(f'{label_input_ids.shape}')
logits = logits[:, -label_input_ids.size(1):, :]
print(f'NEW LOGITS: {logits.shape}')

# From CODE
shift_logits = logits[..., :-1, :].contiguous()
print(f'SHIFT LOGITS: {shift_logits.shape}')
print(shift_logits)
shift_labels = labels[..., 1:].contiguous().to(logits.device)
print(f'SHIFT LABELS: {shift_logits.shape}')
print(shift_labels)
print(shift_labels.shape)
print(processor.batch_decode(shift_labels, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])

final_logit_layer = logits[:,-1,:]
print(final_logit_layer.shape)
max_token_prob = F.softmax(final_logit_layer).argmax()
prediction = processor.batch_decode(max_token_prob.unsqueeze(0), skip_special_tokens=True, clean_up_tokenization_spaces=False)

In [None]:
# One loss
yb = ['</s> The image is of two cats laying on a couch with remotes on the couch\n', ' The image is of two cats laying on a couch with remotes on the couch\n', '</s>The image is of two cats laying on a couch with remotes on the couch\n', 'Swagger', ' The', '</s> The']

for label_text in yb: 
    inputs = processor(image, prompt, return_tensors="pt").to(0, torch.float16)
    label_input_ids = processor.tokenizer.encode(label_text, return_tensors="pt")
    print(inputs.input_ids.shape)
    generated_ids = model(pixel_values=inputs.pixel_values, input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, labels=label_input_ids)

    final_logit_layer = generated_ids.logits[:,-1,:]
    max_token_prob = F.softmax(final_logit_layer).argmax()
    prediction = processor.batch_decode(max_token_prob.unsqueeze(0), skip_special_tokens=True, clean_up_tokenization_spaces=False)
    print(f'Prediction {prediction} vs. {label_text} has loss: {generated_ids.loss}')

In [None]:
# One loss
yb = ['<image>\nUSER: What is in the image?\nASSISTANT:</s> The image is of two cats laying on a couch with remotes on the couch\n', '</s> The image is of two cats laying on a couch with remotes on the couch\n', ' The image is of two cats laying on a couch with remotes on the couch\n', '</s>The image is of two cats laying on a couch with remotes on the couch\n', 'Swagger', ' The', '</s> The']

for label_text in yb: 
    inputs = processor(image, prompt, return_tensors="pt").to(0, torch.float16)
    label_input_ids = processor.tokenizer.encode(label_text, return_tensors="pt")
    print(f'{inputs.input_ids.shape} --> {label_input_ids.shape}')
    generated_ids = model(pixel_values=inputs.pixel_values, input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, labels=label_input_ids)

    final_logit_layer = generated_ids.logits[:,-1,:]
    max_token_prob = F.softmax(final_logit_layer).argmax()
    prediction = processor.batch_decode(max_token_prob.unsqueeze(0), skip_special_tokens=True, clean_up_tokenization_spaces=False)
    print(f'Prediction {prediction} vs. {label_text} has loss: {generated_ids.loss}')

In [40]:
# Matrix loss
label_texts = ['<image>\nUSER: What is in the image?\nASSISTANT:</s> The image is of two cats laying on a couch with remotes on the couch\n', '</s> The image is of two cats laying on a couch with remotes on the couch\n', ' The image is of two cats laying on a couch with remotes on the couch\n', '</s>The image is of two cats laying on a couch with remotes on the couch\n', 'Swagger', ' The', '</s> The'][:2]
prompt_inputs = [prompt] * len(label_texts)
label_input_ids = torch.stack([processor.tokenizer.encode(s, return_tensors="pt", padding="max_length", max_length=512) for s in label_texts])

inputs = processor(image, prompt_inputs, return_tensors="pt", padding="max_length", max_length=512).to(0, torch.float16)
expanded_pixel_values = inputs.pixel_values.expand(len(label_texts), -1, -1, -1)

print(inputs.pixel_values.shape)
print(expanded_pixel_values.shape)
print(inputs.input_ids.shape)
print(label_input_ids.shape)
print(inputs.attention_mask.shape)
print(label_input_ids.squeeze(1).shape)
print(f'{inputs.input_ids.shape} --> {label_input_ids.squeeze(1).shape}')
print("------")


generated_ids = custom_forward(model, pixel_values=expanded_pixel_values, input_ids=inputs.input_ids, attention_mask=inputs.attention_mask)
generated_ids.keys()

torch.Size([1, 3, 224, 224])
torch.Size([2, 3, 224, 224])
torch.Size([2, 512])
torch.Size([2, 1, 512])
torch.Size([2, 512])
torch.Size([2, 512])
torch.Size([2, 512]) --> torch.Size([2, 512])
------


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacty of 21.99 GiB of which 15.69 MiB is free. Including non-PyTorch memory, this process has 21.97 GiB memory in use. Of the allocated memory 20.11 GiB is allocated by PyTorch, and 265.30 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
# Clear GPU
del model
del processor
del generated_ids
del forward_ids
del final_layer_logits

# Test Section

In [None]:
inputs = processor(image, prompt, return_tensors="pt").to(0, torch.float16)
forward_ids = model(pixel_values=inputs.pixel_values, input_ids=inputs.input_ids, attention_mask=inputs.attention_mask)
# generated_ids = model.generate(pixel_values=inputs.pixel_values, input_ids=inputs.input_ids, max_length=200, do_sample=False)

In [None]:
final_layer_logits = forward_ids.logits[:, -1, :]
probs = F.softmax(final_layer_logits)
idx = torch.argmax(probs)
processor.decode(idx)

print(idx)

# Training?

In [None]:
from transformers import TrainingArguments
import numpy as np
import evaluate
import wandb

In [None]:
training_args = TrainingArguments(
    output_dir='./test_trainer',     # output directory
    num_train_epochs=2,              # total # of training epochs
    per_device_train_batch_size=4,   # batch size per device during training
    per_device_eval_batch_size=8,    # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
)

In [None]:
# training_args.num_train_epochs
model.__class__.__name__

# TO-DOs

- [] Compute Metrics. Are we just using NTP loss? AKA cross entropy
- [] Tune/expose hyperparameters for configuration
- [] Setup with W&Bs
- [] Build training and eval datasets.

# W & B

In [None]:
# Already handled
# !wandb login

wandb.init(
    # set the wandb project where this run will be logged
    project="coffee-bot",
    
    # track hyperparameters and run metadata
    config={
        "learning_rate": 0.02,
        "architecture": model.__class__.__name__,
        "dataset": "CIFAR-100",
        "num_train_epoch": training_args.num_train_epochs,
        "warmup_steps": training_args.warmup_steps,
        "weight_decay": training_args.weight_decay,
        "per_device_train_batch_size": training_args.per_device_train_batch_size,
        "per_device_eval_batch_size": training_args.per_device_eval_batch_size,        
    }
)


In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics
)

# I'M NOT CRAZY

In [None]:
inputs = processor(image, prompt, return_tensors="pt").to(0, torch.float16)
print(inputs.input_ids)

test = torch.tensor([[81.0]]).to(0, torch.float16)

inputs.input_ids = inputs.input_ids.to(torch.float16)

print(test)
print(inputs.input_ids)
torch.cat((inputs.input_ids, test), dim=1)
# torch.cat((inputs.input_ids, test, dim=1)

# NEW

In [None]:
pred = None
idx = None
curr_prompt = prompt
input_ids = None
res_ids = None

# While pred != '\n'

start_time = time.time()

# Get image_embedding and project
input_ids, language_model_inputs, vision_outputs = image_encoding(model=model, processor=processor, curr_prompt=curr_prompt)

while idx != 50118:

    # Forward pass
    forward_ids = forward_from_image_query_output(model=model, language_model_inputs=language_model_inputs, input_ids=input_ids)

    # Get the final layer of logits
    final_layer_logits = forward_ids['logits'][:, -1, :]

    # Get prediction
    probs = F.softmax(final_layer_logits)
    idx = torch.argmax(probs)

    # Convert to shape
    idx = torch.tensor([[idx.item()]]).to('cuda')

    if res_ids is None:
        res_ids = idx
    else:
        res_ids = torch.cat([res_ids, idx], dim = 1)
        
    input_ids = torch.cat([input_ids, idx], dim = 1)
    # print(processor.batch_decode(res_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])

end_time = time.time()
print(res_ids)
print(f'Took {end_time - start_time} seconds')
processor.batch_decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

# Utils

In [36]:
def forward_from_image_query_output(
    model,
    language_model_inputs: torch.FloatTensor,
    input_ids: torch.FloatTensor,
    attention_mask: Optional[torch.LongTensor] = None,
    labels: Optional[torch.LongTensor] = None,
):
    output_attentions = None
    output_hidden_states = None
    return_dict = model.config.use_return_dict
    
    # step 3: use the language model, conditioned on the query outputs and the prompt
    language_model_attention_mask = torch.ones(
        language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
    )
    inputs_embeds = model.language_model.get_input_embeddings()(input_ids)
    inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)

    if attention_mask is None:
        attention_mask = torch.ones_like(input_ids)
    expected_device = language_model_attention_mask.device
    attention_mask = torch.cat([language_model_attention_mask, attention_mask.to(expected_device)], dim=1)

    outputs = model.language_model(
        inputs_embeds=inputs_embeds,
        attention_mask=attention_mask,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    logits = outputs.logits if return_dict else outputs[0]
    loss = None
    # we compute the loss here since we need to take into account the sequence length of the query embeds
    if labels is not None:
        labels = labels.to(logits.device)
        logits = logits[:, -labels.size(1) :, :]
        # Shift so that tokens < n predict n
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous().to(logits.device)

        # Flatten the tokens
        loss_fct = CrossEntropyLoss(reduction="mean")

        loss = loss_fct(shift_logits.view(-1, model.config.text_config.vocab_size), shift_labels.view(-1))

    if not return_dict:
        output = (logits, outputs)
        return ((loss,) + output) if loss is not None else output

    return {
        'loss': loss,
        'logits': logits,
        'language_model_outputs': outputs,
    }

def image_encoding(
    model,
    processor,
    image,
    curr_prompt
):  
    inputs = processor(image, curr_prompt, return_tensors="pt").to(0, torch.float16)
    print(inputs)
    pixel_values = inputs.pixel_values
    input_ids = inputs.input_ids
    
    decoder_input_ids = None
    decoder_attention_mask = None
    output_attentions = None
    output_hidden_states = None
    labels = None
    return_dict = model.config.use_return_dict

    # step 1: forward the images through the vision encoder,
    # to get image embeddings of shape (batch_size, seq_len, hidden_size)
    vision_outputs = model.vision_model(
        pixel_values=pixel_values,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    image_embeds = vision_outputs[0]
    print(f'Image Embedds: {image_embeds.shape}')

    # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
    image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)

    query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
    query_outputs = model.qformer(
        query_embeds=query_tokens,
        encoder_hidden_states=image_embeds,
        encoder_attention_mask=image_attention_mask,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    query_output = query_outputs[0]

    # step 3: use the language model, conditioned on the query outputs and the prompt
    return input_ids, model.language_projection(query_output), vision_outputs

def custom_forward(
    model,
    pixel_values: torch.FloatTensor,
    input_ids: torch.FloatTensor,
    attention_mask: Optional[torch.LongTensor] = None,
    decoder_input_ids: Optional[torch.LongTensor] = None,
    decoder_attention_mask: Optional[torch.LongTensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    labels: Optional[torch.LongTensor] = None,
    return_dict: Optional[bool] = None,
):
    vision_outputs = model.vision_model(
        pixel_values=pixel_values,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    image_embeds = vision_outputs[0]
    print(f'Image Embeds: {image_embeds.shape}')

    # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
    image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)

    query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
    query_outputs = model.qformer(
        query_embeds=query_tokens,
        encoder_hidden_states=image_embeds,
        encoder_attention_mask=image_attention_mask,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    query_output = query_outputs[0]

    print(f'Query Output: {query_output.shape}')

    # step 3: use the language model, conditioned on the query outputs and the prompt
    language_model_inputs = model.language_projection(query_output)
    language_model_attention_mask = torch.ones(
        language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
    )
    print(f'Input Ids: {input_ids.shape}')
    inputs_embeds = model.language_model.get_input_embeddings()(input_ids)
    print(f'Input Embeds: {inputs_embeds.shape}')
    print(f'language_model_inputs: {language_model_inputs.shape}')
    inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)

    if attention_mask is None:
        attention_mask = torch.ones_like(input_ids)
    expected_device = language_model_attention_mask.device
    attention_mask = torch.cat([language_model_attention_mask, attention_mask.to(expected_device)], dim=1)

    if model.config.use_decoder_only_language_model:
        outputs = model.language_model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        logits = outputs.logits if return_dict else outputs[0]
        loss = None
        # we compute the loss here since we need to take into account the sequence length of the query embeds
        if labels is not None:
            labels = labels.to(logits.device)
            logits = logits[:, -labels.size(1) :, :]
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous().to(logits.device)

            # Flatten the tokens
            loss_fct = CrossEntropyLoss(reduction="mean")

            loss = loss_fct(shift_logits.view(-1, model.config.text_config.vocab_size), shift_labels.view(-1))
    else:
        outputs = model.language_model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            labels=labels,
        )
        loss = outputs.loss if return_dict else outputs[0]
        logits = outputs.logits if return_dict else outputs[1]

    if not return_dict:
        output = (logits, vision_outputs, query_outputs, outputs)
        return ((loss,) + output) if loss is not None else output

    return {
        'loss': loss,
        'logits': logits,
        'vision_outputs': vision_outputs,
        'qformer_outputs': query_outputs,
        'language_model_outputs': outputs,
    }

# Need loss
- How to
Read in 

In [None]:
df = pd.read_csv('../datasets/full_ds.csv')
df.head()

In [None]:
tokenized_inputs = df['comment'].astype(str).apply(lambda str: str[:500]).apply(lambda str: processor.tokenizer.encode(str, padding="max_length",max_length=512))
input_ids = torch.tensor(tokenized_inputs)

In [None]:
input_ids.shape

In [None]:
def build_batch(comments):
    xb = [] # B,T,C --> [1, 512, ?]
    yb = [] # B,T,1 

    for comment in comments:
        context = []
        
        # tokenize the comment
        print(comment)
        tokenized = processor.tokenizer.encode(comment, padding="max_length", max_length=512)
        print(tokenized)
        
        # for token
        for idx, t in enumerate(tokenized):
            if idx == 0:
                continue
                
            context = tokenized[0:idx]
            label = tokenized[idx:idx + 1]

            xb.append(context)
            yb.append(label)
            
    return xb, yb

# Batch Dimensions
- Pixel Ids: torch.Size([1, 3, 224, 224])
- Input Ids: torch.Size([1, 512])
- Labels: torch.Size([1])

In [None]:
tokenized_inputs = df['comment'].astype(str).apply(lambda str: str[:500]).apply(lambda str: processor.tokenizer.encode(str, padding="max_length",max_length=512))

In [None]:
tokenized_inputs

In [None]:
xb, yb = build_batch(df['comment'].astype(str))

In [None]:
for context, image_url, label in zip(xb, yb):

    # download image
    image = Image.open(requests.get(url, stream=True).raw)

    # Get input_ids & 
    input_ids, language_model_inputs, vision_outputs = image_encoding(model=model, image=image, processor=processor, curr_prompt=curr_prompt)

    model(pixel_values=pixel_values, input_ids=input_idx, labels=label)
    