In [1]:
# Pretty print
from pprint import pprint
# Datasets load_dataset function
from datasets import load_dataset
# Transformers Autokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
# Standard PyTorch DataLoader
from torch.utils.data import DataLoader
import pandas as pd
import time

from PIL import Image as PILImage

import vertexai
from vertexai.generative_models import Image as VertexImage, GenerativeModel
import urllib.request
import http.client
import typing

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def concat_images(images):
    if len(images) == 1:
        return images[0]
    if len(images) < 1:
        return None
    max_height = 0
    total_width = 0
    for image in images:
        width, height = image.size
        print(width, height)
        max_height = max(height, max_height)
        total_width += width
    to_return = PILImage.new("RGB", (total_width, max_height))
    curr_x = 0
    for image in images:
        to_return.paste(image, (curr_x, 0))
        curr_x += image.size[0]
    # to_return.save("out.png")
    return to_return

In [5]:
df_full = pd.read_feather("data/hupd_metadata_2022-02-22.feather")

In [11]:
dataset_dict = load_dataset('HUPD/hupd',
    name='all',
    data_files='/home/ubuntu/cs477-final-project/data/hupd_metadata_2022-02-22.feather',
    icpr_label=None,
    force_extract=True,
    train_filing_start_date='2018-01-01',
    train_filing_end_date='2020-01-20',
    val_filing_start_date='2020-01-21',
    val_filing_end_date='2020-01-30',
    trust_remote_code=True,
)
df_train = dataset_dict['train'].to_pandas()
# df_train = pd.read_feather("hupd_sample_train_merged.feather") # PREVIOUS
# df_train = pd.read_feather("hupd_metadata_2022-02-22_1M PERMANENT.feather")
df_train["filing_date"] = df_train["filing_date"].astype("datetime64[ns]")
df_train.to_feather("hupd_all_train.feather")

Loading dataset with config: PatentsConfig(name='all', version=0.0.0, data_dir='data', data_files={'train': ['/home/ubuntu/cs477-final-project/data/hupd_metadata_2022-02-22.feather']}, description='Patent data from all years (2004-2018)')
Using metadata file: /home/ubuntu/.cache/huggingface/datasets/downloads/2506d351067aa50a4c90854952cb472d0d806f5f3ca541673fde056d17aec53f


Downloading data:   6%|▌         | 3.42G/61.8G [00:42<12:00, 81.0MB/s]


KeyboardInterrupt: 

In [85]:
df_join = pd.merge(df_train, df_full.astype("object"), left_on=["title", "filing_date"], right_on=["invention_title", "filing_date"], how="left")

In [86]:
df_join.columns

Index(['patent_number_x', 'decision_x', 'title', 'abstract', 'claims',
       'background', 'summary', 'description', 'cpc_label', 'ipc_label',
       'filing_date', 'patent_issue_date_x', 'date_published', 'examiner_id',
       'application_number', 'application_invention_type',
       'examiner_full_name', 'examiner_art_unit', 'uspc_class',
       'uspc_subclass', 'confirm_number', 'atty_docket_number',
       'appl_status_desc', 'appl_status_date', 'file_location',
       'file_location_date', 'earliest_pgpub_number', 'earliest_pgpub_date',
       'wipo_pub_number', 'wipo_pub_date', 'patent_number_y',
       'patent_issue_date_y', 'invention_title', 'small_entity_indicator',
       'aia_first_to_file', 'publication_number', 'date_application_produced',
       'date_application_published', 'main_cpc_label', 'cpc_labels',
       'main_ipcr_label', 'ipcr_labels', 'foreign', 'continuation',
       'decision_y', 'decision_as_of_2020'],
      dtype='object')

In [87]:
print(df_train.shape)
print(df_join.shape)


(14802, 14)
(14802, 46)


In [88]:
df_join.to_feather("hupd_sample_train_merged.feather")  

In [None]:
vertexai.init(project="cs477-final-project", location="us-central1")

In [71]:
# Gemini prompts
PROMPT_INDEX = 0
PROMPTS = [
    "This image is taken from a patent whose abstract is delimited by triple quotes. Which class in the International Patent Classification system does the following patent belong to? Answer with a 4-character IPC class.",
    "Classify the following patent, whose abstract is delimited by triple quotes, into a class in the International Patent Classification system. Use the given image from the patent to aid in your answer. Answer with a 4-character IPC class.",
    "This image is taken from a patent whose abstract is delimited by triple quotes. Determine which class in the International Patent Classification system the following patent belongs to. Answer with a 4-character IPC class.",
]

# based on input format of https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts

def load_image_from_url(image_url):
    with urllib.request.urlopen(image_url) as response:
        response = typing.cast(http.client.HTTPResponse, response)
        image_bytes = response.read()
    return VertexImage.from_bytes(image_bytes)

def generate_prompt_gemini(patent_dict):
    abstract = patent_dict["abstract"]
    messages = [
        load_image_from_url("TODO"),
        f'{PROMPTS[PROMPT_INDEX]}\n"""{abstract}"""'
        # f'{PROMPT}\n{abstract}'
        # "Classify the following patent into a class in the International Patent Classification system. Answer with a 4-character IPC class.\n" + patent_dict["abstract"]
    ]
    return messages, patent_dict["ipc_label"][:4]

In [74]:
# gemini prompt engineering
NUM_EXAMPLES = 1000
model = GenerativeModel("gemini-1.0-pro-vision")
num_correct = 0
num_total = 0
num_not_generated = 0

for i in range(NUM_EXAMPLES):
    # wait to avoid going over quotas
    time.sleep(0.5)
    try:
        messages, ground_truth = generate_prompt_gemini(df.iloc[i])
        # print(messages)
        response = model.generate_content(messages)

        # print(response)
        # print(response.candidates[0].finish_reason)
        parts = response.candidates[0].content.parts
        if len(parts) > 0:
            # print(parts[0].text)
            if parts[0].text[:4] == ground_truth:
                num_correct += 1
        else:
            num_not_generated += 1
        num_total += 1
        # print(ground_truth)
    except Exception as e:
        print(i, e)
print(num_correct, num_total, num_not_generated)



In [37]:
# print(response.candidates[0].content.parts[0].text)
print(len(response.candidates))
# print(response.candidates[0].content)
print(response.candidates[0].content.parts[0].text)

1
H04J
