In [12]:
from transformers import BartTokenizer, BartForConditionalGeneration
import textwrap
import glob
import itertools
import json
import pandas as pd
import time
import numpy as np
import sys

In [2]:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
condi_gen = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")

In [93]:
class Summary():
    def __init__(self,text,l,u,beams,stop,skip):
        """
        l - lower bound of the string as a float below 1
        u - upper bound of the string as a float below 1 always greater than l
        beams -  number of beams as an int
        stop - true or false for early stopping
        skip - true or false for skipping special characters
        full_stops - array with all the locations of the full stops
        """
        self.text = text
        self.output = ""
        self.l = l
        self.u = u
        self.beams = beams
        self.stop = stop
        self.skip = skip
        self.name = str(self.l) +str("_")+ str(self.u) +str("_") + str(self.beams) +str("_") + str(self.stop) +str("_") + str(self.skip)
        self.full_stops = []
        
    def new_model(self):
        # if string is under 200 characters, no need for summarisation, if the summary is over 1000
        # summarise again
        if len(self.text)<=200:
            return self.text

        input_tokens = tokenizer.batch_encode_plus([self.text], return_tensors = "pt", max_length = 1024, truncation =True)["input_ids"]
        num_token = input_tokens.shape[1]
        min_ = int(self.l*num_token)
        max_ = int(self.u*num_token)

        encoded_ids = condi_gen.generate(input_tokens, max_length = max_, min_length = min_, num_beams = self.beams,
                                    early_stopping = self.stop)
        summary = tokenizer.decode(encoded_ids.squeeze(), skip_special_tokens = self.skip)
        self.output = textwrap.fill(summary,max_)
    
    def loc_full(self):
        """ 
        Get all locations of full stops
        """
        symbols = ["."]
        i = 0
        #get all locations of the punctuations
        for char in self.output:
            if char in symbols:
                self.full_stops.append((char, i))
            i+=1
    
    def cut(self):
        """
            Cut the summary to the last full stop
        """
        if len(self.full_stops) == 0:
            self.output = self.text
            return
        self.output = self.output[0:self.full_stops[-1][1]]
    
    def remove_newline(self):
        self.output = self.output.replace('\n',' ')
    
    def get_output(self):
        return self.output
    
    def get_text(self):
        return self.text
    
    def set_text(self, new_text):
        self.text = new_text

def get_filenames(files):
    """ returns file names in the folder"""
    file_names = []
    for file in files:
        file_names.append(glob.glob(file))
    file_names = list(itertools.chain.from_iterable(file_names))
    return file_names

In [94]:
#only run this once
df = pd.DataFrame(columns = ["input", "output", "output_length", "input_length", "percentage_decrease", "model_name", "time"])

In [95]:
#keep running this to fill the dataframe with different hypyer-parameters
path = ["Json files/*.json"]
json_files = get_filenames(path)

i=0
with open('jsons2.txt', 'a') as g:
    while i < len(json_files):
        f = open(json_files[i])
        dicts = json.load(f)
        for vals in dicts.values():
            for val in vals.items():
                if val[0] == "PROBLEM DESCRIPTION" or val[0] == "TARGET CONDITION" or val[0] == "CURRENT CONDITION" or val[0] == "ROOT CAUSE ANALYSIS" or val[0] == "COUNTERMEASURES" or val[0] == "EFFECT CONFIRMATION" or val[0] == "FOLLOW UP ACTION":
                    start = time.time()
                    sum_ = Summary(val[1],0.3,0.6,2,True,True)
                    if len(val[1])<5:
                        continue
                    # Get the intial summary    
                    sum_.new_model()
                    # find all the full stops of the summary
                    sum_.loc_full()
                    # cut the summary to the last full stop
                    sum_.cut()
                    # remove new line characters
                    sum_.remove_newline()
                    a = sum_.get_output()
                    # if the summary is above 1000 characters, run the model on the summary again
                    if len(a)>=1000:
                        sum_.set_text(a)
                        sum_.new_model()
                        sum_.loc_full()
                        sum_.remove_newline()
                        a = sum_.get_output()
                    end =  time.time()
                    df.loc[len(df)] = [val[1], a, len(a), len(val[1]), (len(val[1])-len(a))/len(val[1]),  sum_.name, end-start]
        i+=1

In [96]:
df = df.sort_values(by=['input_length'])

In [97]:
df.to_csv('table2.csv')

In [92]:
df

Unnamed: 0,input,output,output_length,input_length,percentage_decrease,model_name,time
36,All parts produced within 70°+25° specification,,0,47,1.0,0.3_0.6_2_True_True,0.0
6,Reference 8D ES191216231144 for follow up actions,,0,49,1.0,0.3_0.6_2_True_True,0.0
22,- Product functionality remains post air-air t...,,0,66,1.0,0.3_0.6_2_True_True,0.0
34,"Spray angle ranging from 36,7° to 69,0°O PV pa...",,0,70,1.0,0.3_0.6_2_True_True,0.0
19,Yokoten : history deliver to further project v...,,0,80,1.0,0.3_0.6_2_True_True,0.0
45,Expand the visual defect lesson learn into PFM...,,0,100,1.0,0.3_0.6_2_True_True,0.0
18,The select lever position information via UDS ...,,0,104,1.0,0.3_0.6_2_True_True,0.0
47,Sporadic CAN frame dropouts existing in Pre-SW...,,0,107,1.0,0.3_0.6_2_True_True,0.0
2,Twenty four of twenty-four injectors successfu...,,0,109,1.0,0.3_0.6_2_True_True,0.0
41,Confirm the root cause in supplier side and im...,,0,129,1.0,0.3_0.6_2_True_True,0.0


In [65]:
1000/0.6

1666.6666666666667