In [3]:
! pip install git+https://github.com/openai/CLIP.git
! pip install pytorch_lightning

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-q7awzu6a
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-q7awzu6a
  Resolved https://github.com/openai/CLIP.git to commit a9b1bf5920416aaeaec965c25dd9e8f98c864f16
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting ftfy
  Using cached ftfy-6.1.1-py3-none-any.whl (53 kB)
Collecting torch
  Using cached torch-1.13.1-cp37-cp37m-manylinux1_x86_64.whl (887.5 MB)
Collecting torchvision
  Using cached torchvision-0.14.1-cp37-cp37m-manylinux1_x86_64.whl (24.2 MB)
Collecting wcwidth>=0.2.5
  Using cached wcwidth-0.2.6-py2.py3-none-any.whl (29 kB)
Collecting nvidia-cuda-runtime-cu11==11.7.99
  Using cached nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl (849 kB)
Collecting nvidia-cublas-cu11==11.10.3.66
  Using cached nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl (317.1 MB)
Coll

In [4]:
from pathlib import Path
from random import randint, choice

import PIL
import argparse
import clip
import torch
import json
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from pytorch_lightning import LightningDataModule
model, preprocess = clip.load("ViT-B/32")
model.load_state_dict(torch.load('./model_checkpoint/model_lr_1e-08_bs_128_20230303_005035_34', map_location=torch.device('cpu')))



<All keys matched successfully>

In [7]:
class TextImageDataModule(LightningDataModule):
    def __init__(self,
                 data: str,
                 batch_size: int,
                 num_workers=0,
                 shuffle=False,
                 custom_tokenizer=None,
                 eval=False
                 ):
        """Create a text image data module from directories with congruent text and image names.
        Args:
            data (str): Json file containing images and text pairs
            batch_size (int): The batch size of each dataloader.
            num_workers (int, optional): The number of workers in the DataLoader. Defaults to 0.
            shuffle (bool, optional): Whether or not to have shuffling behavior during sampling. Defaults to False.
            custom_tokenizer (transformers.AutoTokenizer, optional): The tokenizer to use on the text. Defaults to None.
            eval (bool, optional): Eval mode or not
        """
        super().__init__()
        self.data = data
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.shuffle = shuffle
        self.custom_tokenizer = custom_tokenizer
        if eval:
            self.drop_last = False
        else:
            self.drop_last = True
        
    
    # Used later for scirpting
    @staticmethod
    def add_argparse_args(parent_parser):
        parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--data', type=str, required=True, help='json file of the text/vision pair')
        parser.add_argument('--batch_size', type=int, help='size of the batch')
        parser.add_argument('--num_workers', type=int, default=0, help='number of workers for the dataloaders')
        parser.add_argument('--shuffle', type=bool, default=False, help='whether to use shuffling during sampling')
        return parser
    
    def setup(self, stage=None):
        self.dataset = TextImageDataset(self.data, shuffle=self.shuffle, custom_tokenizer=not self.custom_tokenizer is None)
    
    def train_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, drop_last=self.drop_last, collate_fn=self.dl_collate_fn) # 
    def dl_collate_fn(self, batch):
        if self.custom_tokenizer is None:
            return torch.stack([row[0] for row in batch]), torch.stack([row[1] for row in batch]), [row[2] for row in batch]
        else:
            return torch.stack([row[0] for row in batch]), self.custom_tokenizer([row[1] for row in batch], padding=True, truncation=True, return_tensors="pt"), [row[2] for row in batch]

In [25]:
#INFER_JSON = './data/multiclass_test.json'
INFER_JSON = './data/multiclass_train.json'
NUM_WORKERS = 2
BATCH_SIZE = 16
DataModule = TextImageDataModule(INFER_JSON, BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False, eval=True)
DataModule.setup()
loader = DataModule.train_dataloader()

In [26]:
import os
SAVED_FOLDER = './feature_classification/'
for i, batch in enumerate(loader):
    # Every data instance is an input + label pair
    images,texts,contracts = batch 

    text_embed = model.encode_text(texts)
    image_embed = model.encode_image(images)
    for contract, t_e, i_e in zip(contracts, text_embed, image_embed):
        save_path = os.path.join(SAVED_FOLDER, contract)
        torch.save(t_e, save_path+'_text')
        torch.save(i_e, save_path+'_img')