In [2]:
%matplotlib inline

import torch
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration, InstructBlipConfig
import h5py
import time
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

'''
PREPROCESING BLIP ENCODER to multiple context moment-wise state-action tokens

MODEL 9:
    Behavioral cloning Renas  transformer camera-lidar
    1. TEXT-Image camera or (camera+map concatenation) ENCODER using InstructBLIP (frozen) 
    2. TEXT-Image camera or (camera+map concatenation) DECODER using InstructBLIP (frozen) for text generation
    3. Cross-attention middle tokens to cls driving token MID TRANSFORMER
    4. (im_prompt)-(action) history-aware causal driving Transformer GPT
    Loss: cross-attention metrics going to CrossEntropyLoss 
    Similarity metric: First half of cross-attention

DATA:
    1. Behavioral cloning correct demonstrations (state-action episodes) 

    State: (image) or (im-map concatenation) (reworked h5), prompt 

    Actions in ros: position(x,y) orientation quternions (z, w)
    Actions for model are explored (im-prompt description) and set as tokens vocabulary

    2. Actions annotations
    (Im) or (Im-map), prompt
'''

DATASET = '/data/renas/pythonprogv2/phd_xiaor_project/TSA_dataset/real/2A724_may/tsa_combined.h5'
DEVICE = 'cuda:0'
PROMPT = 'Do you see green cone on the image? Answer only that question'

def look_im(im_i, generated_text):
    im_i_np = im_i.numpy()
    #im_i_np = im_i.numpy().transpose(1, 2, 0)  # Convert to HWC format for displaying
    plt.imshow(im_i_np.astype(np.uint8))
    plt.title(generated_text)
    plt.axis('off')
    plt.show()

class Renas9(torch.nn.Module):
    def __init__(self, device):
        super(Renas9, self).__init__()
        self.device = device
        self.blip_config = InstructBlipConfig.from_pretrained("Salesforce/instructblip-flan-t5-xl")
        self.d_model = self.blip_config.text_config.d_model
        
        self.processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl")
        self.processor.image_processor.do_rescale = True
        self.processor.image_processor.do_resize = True
        self.processor.image_processor.do_normalize = False

        self.blip_model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-flan-t5-xl", torch_dtype=torch.bfloat16)
        for param in self.blip_model.parameters():
            param.requires_grad = False 

        


if __name__ == '__main__':
    if torch.cuda.is_available():
        device = torch.device(DEVICE)
        for i in range(torch.cuda.device_count()):
            device_i = torch.device(f'cuda:{i}')
            print(f'Cuda Device {i}: ', device_i, torch.cuda.get_device_name(i))
    else:
        print('No CUDA devices available')
        device = torch.device('cpu')
    print('Current device: ',device)

    new_dataset_path = DATASET[:-3]+'_model9_prep.h5'
    model = Renas9(DEVICE).to(DEVICE)

    im = []
    action = []

    with h5py.File(DATASET, 'r') as hdf:
        im_group = hdf['states']
        action_group = hdf['actions']
        num_episodes = len(im_group)
        print('Dataset contains episodes: ', num_episodes)

        preprocess_timer_start = time.time()
        with h5py.File(new_dataset_path, 'w') as new_hdf:
            new_hdf_im_group = new_hdf.create_group('states')
            new_hdf_action_group = new_hdf.create_group('actions')
            for i in range(num_episodes):
                episode = 'data_'+str(i)
                for im_num in range(im_group[episode].shape[0]):
                    im_i = torch.from_numpy(im_group[episode][im_num]).float()
                    inputs = model.processor(images=im_i, text= PROMPT,return_tensors="pt")
                    inputs = {key: val.to(device) for key, val in inputs.items()}
                    batch_size = inputs['input_ids'].size(0)
                    if 'decoder_input_ids' not in inputs:
                        inputs['decoder_input_ids'] = torch.LongTensor([model.blip_config.text_config.bos_token_id]).repeat(batch_size, 1).to(inputs['input_ids'].device)
                    outputs = model.blip_model.forward(**inputs, return_dict=True)
                    #print('Last hidden state: ', outputs.language_model_outputs.encoder_last_hidden_state.shape)
                    outputs = model.blip_model.generate(
                            **inputs,
                            do_sample=True,
                            num_beams=5,
                            max_length=512,
                            min_length=1,
                            top_p=0.9,
                            repetition_penalty=2.5,
                            length_penalty=0.5,
                            temperature=1,
                    )
                    generated_text = model.processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
                    print('\n'+generated_text)
                    look_im(im_i, generated_text)
                #im_i = im_i.numpy()
                #new_hdf_im_group.create_dataset(episode, data=im_i, dtype = np.float32, compression = 'gzip')
                #a = action_group[episode][:]
                #new_hdf_action_group.create_dataset(episode, data=a, dtype = np.float32, compression = 'gzip')

    
    #print('preprocess full time: ',time.time()-preprocess_timer_start)
    