In [1]:
import pandas as pd
import re
import os
from torch.utils.data import Dataset
from google.colab import drive
drive.mount('/content/drive')

In [None]:
def remove(text):
    pattern = r'^[^a-zA-Z]+'
    result = re.sub(pattern, '', text)
    return result

def remove_around(text):
    pattern_start = r'^[^a-zA-Z]+'
    pattern_end = r'[^a-zA-Z]+$'
    result = re.sub(pattern_start, '', text)
    result = re.sub(pattern_end, '', result)
    return result

class GenerationDatasetforPatent(Dataset):
    def __init__(self, dl: list):
        self.x = []
        self.text_len = []
        self.init_data(dl)
        self.length = len(self.x)

    def init_data(self, dl):
        for inst in dl:
            ## label
            self.x.append(inst.replace('\t','\n\n').strip())
            self.text_len.append(len(inst.split()))

    def __getitem__(self, index: int) -> dict:
        ## add BOS and EOS special token
        txt = self.x[index]
        if not txt.endswith('.'):
            x = '<|endoftext|> ' + txt
        else:
            x = '<|endoftext|> ' + txt + ' <|endoftext|>'

        return {'x': str(x)}

    def __len__(self):
        return self.length

    ## call for direct input
    @staticmethod
    def from_file(file_path: str):
        with open(file_path, 'r') as f:
            dl = f.readlines()
        return GenerationDatasetforPatent(dl)

In [3]:
data = pd.read_csv('/content/drive/MyDrive/innovae-revision/data/google-patents/raw/ai-patents-date-cpc-text.csv')
data = data.drop([260199]).reset_index(drop = True)
data = data.dropna()
data = data.sample(frac=1, replace=False, random_state=1)

In [None]:
claims = [remove(i.replace('\t', ' ').replace('\n', ' ')) for i in data['primary_claim'].tolist()]
years = [str(i).replace('\t', ' ').replace('\n', ' ')[:4] for i in data['priority_date'].tolist()]
titles = [remove_around(i.replace('\t', ' ').replace('\n', ' ')) for i in data['title'].tolist()]

In [None]:
train_idx = int(len(data)*0.9)
val_idx = int(len(data)*0.95)
test_idx = len(data)
print(train_idx, val_idx, test_idx)

In [None]:
os.chdir('/content/drive/MyDrive/innovae-revision/innovae-adavae/adavae/data/optimus_dataset/patent_claim')
with open('train.txt','w') as f:
  for idx in range(train_idx):
    f.write(f'Year: {years[idx]}\tTitle: {titles[idx]}\tClaim: {claims[idx]}\n')

with open('valid.txt','w') as f:
  for idx in range(train_idx, val_idx):
    f.write(f'Year: {years[idx]}\tTitle: {titles[idx]}\tClaim: {claims[idx]}\n')

with open('test.txt','w') as f:
  for idx in range(val_idx, test_idx):
    f.write(f'Year: {years[idx]}\tTitle: {titles[idx]}\tClaim: {claims[idx]}\n')