In [1]:
import torch
import torchvision
from transformers import AutoTokenizer
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
import random
random.seed(0)
from PIL import Image

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
# !huggingface-cli login --token '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'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

from phi import PhiForCausalLM, PhiConfig
from phi.modeling_phi import *
from seq2seq_autoencoder import Seq2SeqAutoEncoderModel
from transformers.utils import ModelOutput


class PhiModel(PhiPreTrainedModel):
    """Phi model."""

    _keys_to_ignore_on_load_missing = [""]
    _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]

    def __init__(self, config: PhiConfig) -> None:
        super().__init__(config)

        self.embd = Embedding(config)
        self.h = nn.ModuleList([ParallelBlock(config, block_idx=i) for i in range(config.n_layer)])
        self.gradient_checkpointing = False
        self.post_init()

    def get_input_embeddings(self) -> nn.Embedding:
        return self.embd.wte

    def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
        self.embd.wte = new_embeddings

    def forward(
        self,
        input_ids: torch.LongTensor,
        segment_tokens: Optional[torch.FloatTensor] = None,
        segment_position_ids: Optional[list] = None,
        past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
        attention_mask: Optional[torch.BoolTensor] = None,
    ) -> torch.FloatTensor:
        hidden_states = self.embd(input_ids)

        if segment_tokens is not None:
            assert segment_position_ids is not None
            for i in range(len(segment_tokens)):
                if len(segment_tokens[i]) != 0:
                    hidden_states[i, segment_position_ids[i]] += segment_tokens[i]

        for layer in self.h:
            hidden_states = layer(
                hidden_states,
                past_key_values=past_key_values,
                attention_mask=attention_mask,
            )

        return hidden_states


class SegmentHead(nn.Module):

    def __init__(self, d_llm, d_segment_latent) -> None:
        super().__init__()

        self.ln = nn.LayerNorm(d_llm)
        self.linear_to_latent = nn.Linear(d_llm, d_segment_latent)
        self.linear_to_bbox = nn.Linear(d_llm, 4)

    def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
        hidden_states = self.ln(hidden_states)
        latents = self.linear_to_latent(hidden_states).to(torch.float32)
        bboxes = self.linear_to_bbox(hidden_states).to(torch.float32)

        return latents, bboxes


class RegreessionLoss(nn.Module):

    def __init__(self, shift_labels: bool = True) -> None:
        super().__init__()

        self.shift_labels = shift_labels
        self.loss_fct = nn.MSELoss()

    def forward(self, prediction: torch.FloatTensor, target: torch.LongTensor) -> torch.FloatTensor:
        if self.shift_labels:
            prediction = prediction[..., :-1, :].contiguous()
            target = target[..., 1:].contiguous()

        loss = self.loss_fct(prediction, target)

        return loss


@dataclass
class MultimodalCausalLMOutputWithPast(ModelOutput):

    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    predicted_segment_latents: Optional[torch.FloatTensor] = None
    predicted_segment_bboxes: Optional[torch.FloatTensor] = None
    

class PhiForMultimodal(PhiForCausalLM):

    def __init__(
            self,
            config: PhiConfig,
            w_segment_loss: float = 1.0,
            w_bbox_loss: float = 1.0,
            ) -> None:
        super().__init__(config)

        self.transformer = PhiModel(config)
        self.lm_head = CausalLMHead(config)
        self.lm_loss = CausalLMLoss()
        self.segment_loss = RegreessionLoss(shift_labels=False)
        self.seqae_loaded = False

        self.w_segment_loss=w_segment_loss
        self.w_bbox_loss=w_bbox_loss

        self.post_init()

    def load_seqae(self, seqae_path: str, freeze_seqae_encoder: bool = False, freeze_seqae_decoder: bool = False):
        self.seqae = Seq2SeqAutoEncoderModel.from_pretrained(seqae_path).to(self.device)
        if freeze_seqae_encoder:
            for param in self.seqae.encoder.parameters():
                param.requires_grad = False
        if freeze_seqae_decoder:
            for param in self.seqae.decoder.parameters():
                param.requires_grad = False

        self.visual_token_embedding = torch.nn.Linear(self.seqae.config.d_latent+4, self.config.n_embd, bias=False).to(self.device)
        self.segment_head = SegmentHead(self.config.n_embd, self.seqae.config.d_latent).to(self.device)
        self.seqae_loaded = True

    def preprocess_segments(self, segment_masks, images):

        def resize(segment):
            # segment['patch'] is PIL Image object
            w, h = segment['patch'].size
            if h * w > (self.seqae.config.data_seq_length):
                ratio_to_maxlength = np.sqrt(self.seqae.config.data_seq_length / (h * w))
                h = int(h * ratio_to_maxlength)
                w = int(w * ratio_to_maxlength)
                segment['patch'] = torchvision.transforms.Resize([h, w])(segment['patch'])
                segment['mask'] = torchvision.transforms.Resize([h, w])(segment['mask'][None, :, :])[0]

            return segment
        
        def encode_to_sequence(segment):
            # segment['patch'] is torch tensor with shape (C, H, W)
            h, w = segment['patch'].shape[1:]
            sequence = []
            for i in range(h):
                for j in range(w):
                    pixel_data = segment['patch'][:, i, j] / 255.0
                    is_rightmost = 1 if j == w - 1 else 0
                    is_non_masked = int(segment['mask'][i, j])
                    sequence.append(pixel_data.tolist() + [is_rightmost, is_non_masked])
            sequence = np.array(sequence) 

            # pad the sequence to max_seq_length with zeros
            if len(sequence) < self.seqae.config.data_seq_length:
                sequence = np.concatenate((sequence, np.zeros((self.seqae.config.data_seq_length - len(sequence), self.seqae.config.input_channels))))

            # add the query place holder to the end of the sequence
            sequence = np.concatenate((sequence, np.zeros((self.seqae.config.num_queries, self.seqae.config.input_channels))))
            # add one all zero column to the start 
            sequence = np.concatenate((np.zeros((1, sequence.shape[1])), sequence), axis=0)

            return torch.from_numpy(sequence)
    
        segment_sequences = []
        bboxes = []
        for segment_mask in segment_masks:
            image = images[segment_mask["image_index"]]

            mask = segment_mask["mask"]
            bbox = segment_mask["bbox"]
            bbox[2] = 1 if bbox[2] == 0 else bbox[2]
            bbox[3] = 1 if bbox[3] == 0 else bbox[3]
            x, y, w, h = bbox

            segment = {
                "patch": image.crop((x, y, x + w, y + h)),
                "mask": mask[y:y+h, x:x+w],
            }

            segment = resize(segment)
            segment['patch'] = torchvision.transforms.ToTensor()(segment['patch'])
            segment['patch'] = segment['patch'] * segment['mask'][None, :, :]
            segment_sequence = encode_to_sequence(segment)
            segment_sequence = segment_sequence.to(self.device)
            segment_sequences.append(segment_sequence)
            bboxes.append(bbox)

        segment_sequences = torch.stack(segment_sequences, dim=0).to(self.device, dtype=self.dtype)
        bboxes = torch.tensor(bboxes).to(self.device, dtype=self.dtype)
        return segment_sequences, bboxes

    def forward(
        self,
        input_ids: torch.LongTensor,
        segment_masks: Optional[list] = None,
            # A list of batch_size number of list of dicts, whose key includes 
            #   - "image_index": int
            #   - "mask": binary mask of torch.Tensor with shape (H, W)
            #   - "bbox": bounding box of the object in the image (x, y, w, h)
            # if no segment token in certain sample, the corresponding list should be empty
        images: Optional[list] = None,
            # A list of PIL Images objects
        seqae_batch_size: Optional[int] = 16,
            # batch size for seqae inference, -1 for full segments inference for each sample
        past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
        attention_mask: Optional[torch.BoolTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        
        if not self.seqae_loaded:
            raise ValueError("Please load seqae model first.")
        
        if segment_masks is not None:
            batch_segment_tokens = []
            batch_segment_bboxes = []
            batch_segment_latents = []
            for sample_segment_masks in segment_masks:
                if len(sample_segment_masks)!=0:
                    segment_sequences, bboxes = self.preprocess_segments(sample_segment_masks, images)

                    if seqae_batch_size != -1:
                        segment_latents = []
                        for i in range(0, len(segment_sequences), seqae_batch_size):
                            segment_latents.append(self.seqae.encode(segment_sequences[i:i+seqae_batch_size]))
                        segment_latents = torch.cat(segment_latents, dim=0)
                    else:
                        segment_latents = self.seqae.encode(segment_sequences)

                    latents_and_bboxes = torch.cat((segment_latents, bboxes), dim=1)
                    segment_tokens = self.visual_token_embedding(latents_and_bboxes)
                    batch_segment_tokens.append(segment_tokens)
                    batch_segment_bboxes.append(bboxes)
                    batch_segment_latents.append(segment_latents)
                else:
                    batch_segment_tokens.append([])
                    batch_segment_bboxes.append([])
                    batch_segment_latents.append([])
                
            # find the position of <|seg|> in input_ids
            batch_segment_position_ids = []
            for i in range(len(input_ids)):
                seg_pos = torch.where(input_ids[i] == self.special_token_id_mappinmg["<|seg|>"])[0].tolist()
                assert len(seg_pos) == len(segment_masks[i]), f"number of <|seg|> in input_ids ({len(seg_pos)}) does not match number of segments ({len(segment_masks[i])})"
                batch_segment_position_ids.append(seg_pos)
        else:
            batch_segment_tokens = None
            batch_segment_bboxes = None
            batch_segment_latents = None
            batch_segment_position_ids = None
            
        hidden_states = self.transformer(
            input_ids, 
            segment_tokens=batch_segment_tokens,
            segment_position_ids=batch_segment_position_ids,
            past_key_values=past_key_values, 
            attention_mask=attention_mask
            )
        lm_logits = self.lm_head(hidden_states)

        # LLM Transformer last hidden state predicts next text token
        loss = None
        lm_loss = None
        if labels is not None:
            lm_loss = self.lm_loss(lm_logits, labels)

        # LLM Transformer last hidden state predicts segment latent and bbox
        segment_loss = 0
        bbox_loss = 0
        if segment_masks is not None:
            predicted_segment_latents = []
            predicted_segment_bboxes = []
            for i in range(len(batch_segment_tokens)):
                if len(batch_segment_tokens[i]) != 0:
                    segment_latents, segment_bboxes = self.segment_head(hidden_states[i, batch_segment_position_ids[i]])
                    segment_loss += self.segment_loss(segment_latents, batch_segment_latents[i])
                    bbox_loss += self.segment_loss(segment_bboxes, batch_segment_bboxes[i])

                    predicted_segment_latents.append(segment_latents)
                    predicted_segment_bboxes.append(segment_bboxes)
                else:
                    predicted_segment_latents.append([])
                    predicted_segment_bboxes.append([])
        else:
            predicted_segment_latents = None
            predicted_segment_bboxes = None

        loss = lm_loss + self.w_segment_loss * segment_loss + self.w_bbox_loss * bbox_loss

        return MultimodalCausalLMOutputWithPast(
            loss=loss, 
            logits=lm_logits, 
            past_key_values=past_key_values,
            hidden_states=hidden_states,
            predicted_segment_latents=predicted_segment_latents,
            predicted_segment_bboxes=predicted_segment_bboxes,
            )


model = PhiForMultimodal.from_pretrained(
    "microsoft/phi-2",
    w_segment_loss=1.0,
    w_bbox_loss=1.0,
    ).to(DEVICE)

model.load_seqae('/home/dchenbs/workspace/Seq2Seq-AutoEncoder/runs/Nov28_20-50-04_host19-SA1B-[327MB-16queries-1024]-[lr1e-05-bs16x1step-8gpu]/checkpoints/checkpoint_ep2_step3200k')

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


In [3]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
tokenizer.add_tokens(["<|startofimage|>", "<|endofimage|>", "<|seg|>"])
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

model.resize_token_embeddings(len(tokenizer))
model.special_token_id_mappinmg = {
    "<|startofimage|>": tokenizer.convert_tokens_to_ids("<|startofimage|>"),
    "<|endofimage|>": tokenizer.convert_tokens_to_ids("<|endofimage|>"),
    "<|seg|>": tokenizer.convert_tokens_to_ids("<|seg|>"),
    "<|endoftext|>": tokenizer.convert_tokens_to_ids("<|endoftext|>"),
    "[PAD]": tokenizer.convert_tokens_to_ids("[PAD]"),
}


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:

from segmenter import Segmenter
checkpoint = "/home/dchenbs/workspace/cache/sam_vit_b_01ec64.pth"
segmenter = Segmenter(checkpoint, device=DEVICE)

def segment_one_image(image_path, visualize=False):

    image = Image.open(image_path).convert('RGB')
    image_np = np.array(image)
    segment_masks = segmenter(image_np)

    def visualized_masks(masks, image):
        canvas = np.ones_like(image) * 255
        masks = sorted(masks, key=lambda x: x['area'], reverse=True)
        for mask in masks:
            average_color = np.mean(image[mask['mask'] == 1], axis=0)
            canvas[mask['mask'] == 1] = average_color

            # visualize segment boundary
            contours, _ = cv2.findContours(mask['mask'].astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(canvas, contours, -1, (200, 200, 200), 1)
        return canvas
    
    if visualize:
        plt.figure(figsize=(20, 10))
        plt.subplot(1, 2, 1)
        plt.imshow(image_np)
        plt.axis('off')

        canvas = visualized_masks(segment_masks, image_np)
        plt.subplot(1, 2, 2)
        plt.imshow(canvas)
        plt.axis('off')

        plt.tight_layout()
        plt.show()

    for segment_mask in segment_masks:
        segment_mask['image_index'] = 0
        segment_mask['mask'] = torch.from_numpy(segment_mask['mask'])
        
    return segment_masks, [image]

img_dir = '/home/dchenbs/workspace/datasets/coco2017/images/val2017'
img_path = os.path.join(img_dir, random.choice(os.listdir(img_dir)))
segment_masks, images = segment_one_image(img_path, visualize=False)

Loading SAM model vit_b from /home/dchenbs/workspace/cache/sam_vit_b_01ec64.pth


In [5]:
inputs = tokenizer([
    'Hellow world!',
    'This is a segment token: <|startofimage|><|seg|><|endofimage|><|endoftext|>',
    'This is an image of a cat with 4 tokens: <|startofimage|><|seg|><|seg|><|seg|><|seg|><|endofimage|>!<|endoftext|>',
    ], return_tensors="pt", return_attention_mask=False, padding=True).to(DEVICE)

print(tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=False))
print(tokenizer.decode(inputs['input_ids'][1], skip_special_tokens=False))
print(tokenizer.decode(inputs['input_ids'][2], skip_special_tokens=False))

print(inputs, inputs['input_ids'].shape)

print(model.special_token_id_mappinmg)
print('\n\n\n')
output = model(
    input_ids=inputs['input_ids'],
    segment_masks=[[], segment_masks[:1], segment_masks[:4]],
    images=images[:4],
    labels=inputs['input_ids'],
    )
print(output)

# outputs = model.generate(**inputs, max_length=200)
# text = tokenizer.batch_decode(outputs)[0]
# print(text)



Hellow world![PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]
This is a segment token: <|startofimage|><|seg|><|endofimage|><|endoftext|>[PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]
This is an image of a cat with 4 tokens: <|startofimage|><|seg|><|seg|><|seg|><|seg|><|endofimage|>!<|endoftext|>
{'input_ids': tensor([[   39,  5037,   995,     0, 50298, 50298, 50298, 50298, 50298, 50298,
         50298, 50298, 50298, 50298, 50298, 50298, 50298, 50298, 50298, 50298],
        [ 1212,   318,   257, 10618, 11241,    25,   220, 50295, 50297, 50296,
         50256, 50298, 50298, 50298, 50298, 50298, 50298, 50298, 50298, 50298],
        [ 1212,   318,   281,  2939,   286,   257,  3797,   351,   604, 16326,
            25,   220, 50295, 50297, 50297, 50297, 50297, 50296,     0, 50256]],
       device='cuda:0')} torch.Size([3, 20])
{'<|startofimage|>': 50295, '<|endofimage|>': 50296, '<|seg|>': 50297, '<|endoftext|>': 50256, '[PAD]': 50298}








MultimodalCausalLMOutputWithPast(loss=tensor(20469.7910, device='cuda:0', grad_fn=<AddBackward0>), logits=tensor([[[ 5.2736,  7.9187,  4.2404,  ..., -2.9310, -2.9312, -2.9305],
         [10.1601, 10.7517,  3.4917,  ..., -2.3451, -2.3445, -2.3443],
         [17.2196, 14.6495,  8.3551,  ..., -0.6489, -0.6487, -0.6492],
         ...,
         [10.8129, 12.5818,  9.4468,  ..., -0.3477, -0.3487, -0.3476],
         [10.9866, 12.8867,  9.6272,  ..., -0.2689, -0.2699, -0.2688],
         [12.3153, 13.7971, 11.0963,  ..., -0.0814, -0.0826, -0.0816]],

        [[ 6.4830,  6.1644,  3.4055,  ..., -4.0035, -4.0029, -4.0029],
         [ 6.0642,  7.8241,  3.4634,  ..., -2.5955, -2.5946, -2.5961],
         [ 3.9298,  7.1021,  1.9519,  ..., -1.8224, -1.8218, -1.8224],
         ...,
         [ 8.6108, 12.4696,  6.3324,  ...,  0.8407,  0.8396,  0.8404],
         [ 8.0109, 12.4024,  6.1988,  ...,  0.7150,  0.7140,  0.7150],
         [ 9.1022, 13.4377,  7.7072,  ...,  0.5867,  0.5861,  0.5863]],

        [[