In [2]:
import numpy as np

In [1]:
from pdf2image import convert_from_path
pages = convert_from_path('recovery.pdf')
pages[0].save('image.jpg', 'JPEG')

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer, pipeline, BitsAndBytesConfig

# quantization_config = BitsAndBytesConfig(load_in_4bit=True)

if 'model' in locals():
    del model
model = AutoModelForCausalLM.from_pretrained( 
    # "/remote/gpu03/schiller/ExecLLM/models/huggingface/microsoft/Phi-3.5-MoE-instruct",
    "/remote/gpu03/schiller/skatr/models/huggingface/microsoft/Phi-3.5-vision-instruct",
    device_map="cuda",  
    torch_dtype=torch.bfloat16,  
    trust_remote_code=True,
    # quantization_config=quantization_config,
    attn_implementation="flash_attention_2",
)

processor = AutoProcessor.from_pretrained(
    "/remote/gpu03/schiller/skatr/models/huggingface/microsoft/Phi-3.5-vision-instruct", 
    trust_remote_code=True, 
    num_crops=4,
    local_files_only=True
) 
    
# tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-MoE-instruct", local_files_only=True) 
# tokenizer = AutoTokenizer.from_pretrained("/remote/gpu03/schiller/ExecLLM/models/huggingface/microsoft/Phi-3.5-vision-instruct", local_files_only=True) 

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.52s/it]


In [None]:
# import math
# import os
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from einops import rearrange, repeat
# from einops.layers.torch import Rearrange
# from functools import partial
# from hydra.utils import instantiate
# from torch.utils.checkpoint import checkpoint

# from src.utils import masks
# from src.utils.config import get_prev_config    

class ViT(nn.Module):
    """
    A vision transformer network.
    """

    def __init__(
        self,
        patch_shape=[7, 7, 94],
        in_shape=[1, 140, 140, 2350],
        hidden_dim=96,
        learn_pos_encoding=True,
        num_heads=4,
        mlp_ratio=2.0,
        mlp_drop=0.,
        checkpoint_grads=False,
        attn_drop=0.,
        proj_drop=0.,
        depth=4,
        use_mask_token=True
    ):
        super().__init__()

        self.patch_shape = patch_shape
        in_channels, *axis_sizes = in_shape
        dim = hidden_dim
        
        # embedding layer
        self.patch_dim = math.prod(patch_shape) * in_channels
        self.embedding = nn.Linear(patch_dim, dim)

        # position encoding
        fourier_dim = dim // 6 # sin/cos features for each dim
        w = torch.arange(fourier_dim) / (fourier_dim - 1)
        w = (1. / (10_000 ** w)).repeat(3)
        self.pos_encoding_freqs = nn.Parameter(
            w.log() if learn_pos_encoding else w, requires_grad=learn_pos_encoding
        )
        self.init_pos_grid(axis_sizes)

        # transformer stack
        self.blocks = nn.ModuleList([
            Block(
                dim, num_heads, mlp_ratio=mlp_ratio, mlp_drop=mlp_drop,
                checkpoint_grads=checkpoint_grads, attn_drop=attn_drop,
                proj_drop=proj_drop
            ) for _ in range(depth)
        ])

        # norm layer
        self.out_norm = nn.LayerNorm(dim, eps=1e-6)

        # optionally initialize a task head, input pooling, or mask token
        # if use_head:
        #     self.init_head(cfg.head)
        # if adapt_res:
        #     self.init_adaptor(cfg.adaptor)
        # if use_input_conv:
        #     self.init_input_conv(cfg.input_conv)
        self.use_mask_token = use_mask_token
        if self.use_mask_token:
            self.mask_token = nn.Parameter(torch.randn(dim))

    def init_pos_grid(self, axis_sizes):
        self.num_patches = [s // p for s, p in zip(axis_sizes, self.patch_shape)]
        for i, n in enumerate(self.num_patches): # axis values for each dim
            self.register_buffer(f'grid_{i}', torch.arange(n)*(2*math.pi/n))

    def pos_encoding(self): # TODO: Simplify for fixed dim=3
        grids = [getattr(self, f'grid_{i}') for i in range(3)]
        coords = torch.meshgrid(*grids, indexing='ij')

        if self.cfg.learn_pos_encoding:
            freqs = self.pos_encoding_freqs.exp().chunk(3)
        else:
            freqs = self.pos_encoding_freqs.chunk(3)

        features = [
            trig_fn(x.flatten()[:,None] * w[None, :])
            for (x, w) in zip(coords, freqs) for trig_fn in (torch.sin, torch.cos)
        ]
        return torch.cat(features, dim=1)

    def forward(self, x, mask=None):
        """
        Forward pass of ViT.
        :param x   : tensor of spatial inputs with shape (batch_size, channels, *axis_sizes)
        :param mask: a tensor of patch indices that should be masked out of `x`.
        """

        if hasattr(self, 'adaptor'):
            x = self.adaptor(x)
        
        if hasattr(self, 'input_conv'):
            x = self.input_conv(x)
        else:
            # patchify input
            # x -> (batch_size, number_of_patches, voxels_per_patch)
            x = self.to_patches(x)
            
            # embed
            # x -> (batch_size, number_of_patches, embedding_dim)
            if hasattr(self, 'extra_proj'):
                x = self.extra_proj(x)
            x = self.embedding(x)

        # apply mask and position encoding
        if self.use_mask_token:
            if mask is not None:
                x = self.apply_mask_tokens(x, mask)
            x = x + self.pos_encoding()
        else:
            # x -> (batch_size, number_of_masked_patches, embedding_dim)
            x = x + self.pos_encoding()
            if mask is not None:
                x = masks.gather_tokens(x, mask)
        
        # process patches with transformer blocks
        for block in self.blocks:
            x = block(x)
        x = self.out_norm(x)

        if hasattr(self, 'head'):
            # aggregate patch features and apply task head
            # x -> (batch_size, out_channels)
            x = torch.mean(x, axis=1)
            x = self.head(x)

        return x

    def to_patches(self, x):
        x = rearrange(
            x, 'b c (x p1) (y p2) (z p3) -> b (x y z) (p1 p2 p3 c)',
            **dict(zip(('p1', 'p2', 'p3'), self.patch_shape))
        )
        return x

    def apply_mask_tokens(self, x, mask_idcs):
        """
        Replaces patch embeddings in `x` with the network's mask token at indices speficied by `mask`.

        :param x   : input tensor with shape (B [batch size], T [number of patches], D [embed dim])
        :param mask: tensor with shape (B, T) containing indices in the range [0,T)
        """
        B, T = x.shape[:2]
        full_mask_token = repeat(self.mask_token, 'd -> b t d', b=B, t=T)
        # construct boolean mask
        mask = torch.zeros((B, T), device=x.device).scatter_(-1, mask_idcs, 1).bool()
        return torch.where(mask[..., None], full_mask_token, x)          


class PretrainedViT(ViT):
    """
    A class for initializing pretrained ViTs.
    """

    def __init__(
        self,
        backbone_dir = "runs/pretraining_micro/huge_775",
        drop_head = True,
        frozen = True,
        # add_head = False,
        # head_args = {
        #     '_target_': 'src.networks.MLP',
        #     'cfg': {
        #         'units': [
        #           144,
        #           144,
        #           6],
        #         'act': 'relu',
        #         'out_act': 'sigmoid',
        #         'drop': 0.}
        # },
        # adapt_res = False,
        # adapt_args = {
        #       'channels': 4,
        #       'downsample_factor': 5,
        #       'extra_proj': True,
        #       'replace_embedding': False
        # },
        # use_input_conv = False,
        # input_conv_args = {
        #     'channels': 8,
        #     'kernel1': [4,4,5],
        #     'stride1': [2,2,3],
        #     'kernel2': [3, 3, 4],
        #     'stride2': [2, 2, 3],
        #     'conv_out_dim': 640
        # },
        # interp_pos_encoding = False
    ):

        # read backbone config
        bb_dir = backbone_dir
        # bcfg = get_prev_config(bb_dir)

        # load backbone state
        model_state = torch.load(os.path.join(bb_dir, 'model.pt'))["model"]
        net_state = {
            k.replace('net.', ''): v for k,v in model_state.items() if k.startswith('net.')
        }
        
        # initialize network and load weights
        super().__init__()
        self.load_state_dict(net_state)
        
        # delete the head module used in pretraining
        if drop_head and hasattr(self, 'head'):
            del self.head

        # freeze weights and set to eval mode
        if frozen:
            for p in self.parameters():
                p.requires_grad = False
            self.eval()
            
        # init new head or input adaption if needed
        # if add_head:
        #     self.head = instantiate(cfg.head)
        # if adapt_res:
        #     init_adaptor(cfg.adaptor)
        # if use_input_conv:
        #     init_input_conv(cfg.input_conv)            
        # if interp_pos_encoding:
        #     bb.init_pos_grid(cfg.data_shape)              

In [2]:
from PIL import Image 

images = [Image.open("image.jpg")]
placeholder = f"<|image_1|>"

messages = [
    {"role": "user", "content": placeholder+"Summarize this image."},
]

prompt = processor.tokenizer.apply_chat_template(
    messages, 
    tokenize=False, 
    add_generation_prompt=True
)

inputs = processor(prompt, images, return_tensors="pt").to("cuda:0") 

generation_args = { 
    "max_new_tokens": 1000, 
    "temperature": 0.0, 
    "do_sample": False, 
} 

generate_ids = model.generate(
    **inputs, 
    eos_token_id=processor.tokenizer.eos_token_id, 
    **generation_args
)

# remove input tokens 
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(
    generate_ids, 
    skip_special_tokens=True, 
    clean_up_tokenization_spaces=False)[0] 

print(response)

The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.


The image contains a graph with two y-axes and one x-axis. The left y-axis is labeled 'Network' and ranges from 0 to 10, while the right y-axis is labeled 'Net - True' and ranges from 10^-2 to 10^0. The x-axis is labeled 'Truth' and ranges from 0 to 10. There are two sets of data points: one set is plotted against the 'Network' axis and the other against the 'Net - True' axis. The data points against the 'Network' axis form a nearly straight line, indicating a linear relationship. The data points against the 'Net - True' axis form a curve that starts high on the left and decreases towards the right, with a notable dip in the middle. The graph is annotated with 'mWDM' at the top and 'MARE=8.3e-01' at the bottom, suggesting a specific metric or model used in the analysis.


In [None]:
data_dir = '/remote/gpu02/ore/data/x2'
index = 1000
data_file = data_dir + f"/run{index}.npz"
data = np.load(data_file)
print(data['image'].shape)

In [30]:
import json

data_dict = {'image': data['image'].tolist(), 'label': data['label'].tolist()}
# data_dict = {'image': data['image'][:10,:10,550:600].tolist(), 'label': data['label'].tolist()}
data_str = json.dumps(data_dict)
# print(data_str[0:1000])

{"image": [[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.

In [31]:
messages = [
    {"role": "user", "content": data_str},
]

prompt = processor.tokenizer.apply_chat_template(
  messages, 
  tokenize=True,
  add_generation_prompt=True
)
print(len(prompt))

72728172
