In [1]:
import pandas as pd
from transformers import pipeline
from os.path import join
from collections import Counter
from sklearn.metrics import classification_report
from sklearn.base import BaseEstimator, ClassifierMixin
from pathlib import Path
import random
from collections import defaultdict
from itertools import chain, groupby
from typing import Any, List, Optional, Union
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB, BernoulliNB, ComplementNB
import joblib
import numpy as np
import torch
from sentence_transformers import InputExample, SentenceTransformer, losses
from sentence_transformers.losses import BatchHardTripletLossDistanceFunction as LossDistances
from sklearn.naive_bayes import MultinomialNB
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
import lightgbm as lgb
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool

from time import sleep

import sys
import json
from glob import glob
from sklearn import preprocessing
import matplotlib.pyplot as plt
import scikitplot as skplt
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm
StrOrPath = Union[Path, str]

In [2]:
class Params():
    def __init__(self):
        pass

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
torch.__version__

'1.10.0+cu113'

In [5]:
FACTUALITY_INT = {'Uncommitted':1, 'Fact':2,
                  'Probable':3, 'Possible':4, 'Counterfact':5,
                  'Doubtful':6, 'Conditional':7}
FACTUALITY_INT_REV = {FACTUALITY_INT[i]:i for i in FACTUALITY_INT}
MARGIN_LOSSES = ['', '']
DISTANCE_LOSSES = ['BatchHardSoftMarginTripletLoss']
MARG_DIST_LOSSES = ['BatchAllTripletLoss', 'BatchHardTripletLoss',
                    'BatchSemiHardTripletLoss', 'ContrastiveLoss', 'TripletLoss']
random_state = 1234
FACT_DIR = '/home/pc/Desktop/AdilStuff/Projects/SemRepMed/semmed_data/FactualityData'

In [6]:
random.seed(random_state)
np.random.seed(random_state)
torch.manual_seed(random_state)

<torch._C.Generator at 0x7f5227e4d350>

## Utils

In [66]:
def write_json_lines(file_name,dict_data):
    json_string = json.dumps(dict_data)
    with open(file_name, 'a') as f:
        f.write(json_string+"\n")
        
def read_json_lines(file_name):
    lines = []
    with open(file_name) as file_in:
        for line in file_in:
            lines.append(json.loads(line))
    return lines
def infer_fact(in_file_name, out_file_name):
    args = Params()
    # paraphrase-multilingual-mpnet-base-v2 paraphrase-albert-small-v2 all-mpnet-base-v2 all-MiniLM-L6-v2 
    # all-MiniLM-L12-v2 
    args.model_path = join(FACT_DIR, 'MODELS')
    pip = FullPipe(args,
        x_train= None,
        y_train= None,
        x_test= None,
        y_test= None, mode = 'inference')
    
    with open(in_file_name) as file_in:
        pbar = tqdm(total = 375866742)
        for line in file_in:
            lines = []
            data = json.loads(line)
            PREDICATION_AUX_ID = data['PREDICATION_AUX_ID']
            SENTENCE, FORMATED_SENTENCE = data['SENTENCE'], data['FORMATED_SENTENCE']
            label = pip.predict([FORMATED_SENTENCE])
            label = label[1][0]
            res_data = {}
            res_data['PREDICATION_AUX_ID'], res_data['LABEL'] = PREDICATION_AUX_ID, label
            write_json_lines(file_name = out_file_name, dict_data = res_data)
            pbar.update()
    return res_data

def infer_fact_csv(in_file_name):
    args = Params()
    # paraphrase-multilingual-mpnet-base-v2 paraphrase-albert-small-v2 all-mpnet-base-v2 all-MiniLM-L6-v2 
    # all-MiniLM-L12-v2 
    args.model_path = join(FACT_DIR, 'MODELS')
    pip = FullPipe(args,
        x_train= None,
        y_train= None,
        x_test= None,
        y_test= None, mode = 'inference')
    all_paths = glob(join(in_file_name, 'form_sent_*.csv'))
    print('file count: ', len(all_paths))
    for f_file in tqdm(all_paths):
        fil_num = f_file.split('/')[-1].split('.')[0].split('_')[-1]
        df = pd.read_csv(f_file, compression = 'gzip')
        FORMATED_SENTENCE = list(df['FORMATED_SENTENCE'])
        PREDICATION_AUX_ID = list(df['PREDICATION_AUX_ID'])
        label = pip.predict(FORMATED_SENTENCE)
        label = label[1]
        df_res = pd.DataFrame([{'PREDICATION_AUX_ID':k,'label':v} for k,v in zip(PREDICATION_AUX_ID, label)])
        df_res.to_csv(join(in_file_name, 'labeled_sent_'+fil_num+'.csv'), index = False, compression = 'gzip')
        
def process_record(line):
    print('process')
    data = json.loads(line)
    PREDICATION_AUX_ID = data['PREDICATION_AUX_ID']
    SENTENCE, FORMATED_SENTENCE = data['SENTENCE'], data['FORMATED_SENTENCE']
    print(PREDICATION_AUX_ID)

def infer_fact_parallel(in_file_name):
    numthreads = 8
    numlines = 100
#     lines = open(in_file_name).readlines()
#     r = process_map(_foo, lines, max_workers=numthreads)
    pool = ThreadPool(3)
    file = open(in_file_name)
    print('map')
    res = pool.map(process_record,file)

## Model

In [67]:
class SentTrans():
    def __init__(self, args, x_train, y_train, x_test, y_test):
        self.args = args
        self.x_train = x_train
        self.y_train = y_train
        self.x_test = x_test
        self.y_test = y_test
        if hasattr(args, 'train_iter'):
            self.train_iter = args.train_iter
        if hasattr(args, 'warmup_steps'):
            self.warmup_steps = args.warmup_steps
        if hasattr(args, 'data_iter'):
            self.data_iter = args.data_iter
        if hasattr(args, 'n_neg'):
            self.n_neg = args.n_neg
        if hasattr(args, 'loss_margin'):
            self.loss_margin = args.loss_margin
        if hasattr(args, 'loss_name'):
            self.loss_name = args.loss_name
            loss =  get_loss(args.loss_name, args)
        else:
            self.loss_name = None
            loss = None
        if hasattr(args, 'loss_distance'):
            self.loss_distance = get_distance(args.loss_distance)
        if hasattr(args, 'model'):
            self.model = SentenceTransformer(args.model)
            self.model = self.model.to(device)
        elif hasattr(args, 'model_path'):
            self.model = SentenceTransformer(args.model_path)
            self.model = self.model.to(device)
        if loss is not None:
            if self.loss_name in MARG_DIST_LOSSES:
                self.loss = loss(self.model, self.loss_distance, self.loss_margin)
            elif self.loss_name in DISTANCE_LOSSES:
                self.loss = loss(self.model, self.loss_distance)
            else:
                self.loss = loss(self.model)
        if self.loss_name is not None:
            if 'Triplet' in self.loss_name:
    #             train_examples = weighted_generate_multiple_sentence_triples(x_train, y_train, self.data_iter)
                train_examples = mult_neg_weighted_generate_multiple_sentence_triples(x_train, y_train, self.n_neg, self.data_iter)
            else:
    #             train_examples = weighted_generate_multiple_sentence_pairs(self.x_train, self.y_train, self.data_iter)
                train_examples = mult_neg_weighted_generate_multiple_sentence_pairs(self.x_train, self.y_train, self.n_neg,self.data_iter)
    #         train_examples = generate_multiple_sentence_pairs(self.x_train, self.y_train, self.data_iter)
    #         train_examples = generate_multiple_sentence_triples(x_train, y_train, self.data_iter)

            self.train_dataloader = DataLoader(
                    train_examples,
                    shuffle=True,
                    batch_size=args.batch_size,
                    generator=torch.Generator(device='cpu'),
            )

#         self.loss = loss(self.model)
        if self.x_train is not None:
            self.X_train_noFT = self.model.encode(self.x_train)
        if self.x_test is not None:
            self.X_test_noFT = self.model.encode(self.x_test)
        
    def fit(self, show_progress_bar=True):
        self.model.fit(
            train_objectives=[(self.train_dataloader, self.loss)],
            epochs=self.train_iter,
            warmup_steps=self.warmup_steps,
            show_progress_bar=show_progress_bar,
        )
    
    
    def get_train_test_features(self):
        return self.model.encode(self.x_train), self.model.encode(self.x_test)
    
    def get_embeddings(self, x):
        return self.model.encode(x)
    
    def plot_(self):
        plt.figure(figsize=(20,10))

        #Plot X_train_noFit
        X_embedded = TSNE(init='pca', n_components=2).fit_transform(np.array(self.X_train_noFT))
        plt.subplot(221)
        plt.title('X_train No Fit')

        for i, t in enumerate(set(np.array(self.y_train))):
            idx = np.array(self.y_train) == t
            plt.scatter(X_embedded[idx, 0], X_embedded[idx, 1], label=FACTUALITY_INT_REV[t])   

        plt.legend(bbox_to_anchor=(1, 1));

        #Plot X_eval noFit
        X_embedded = TSNE(init='pca',n_components=2).fit_transform(np.array(self.X_test_noFT))
        plt.subplot(223)
        plt.title('X_test No Fit')

        for i, t in enumerate(set(np.array(self.y_test))):
            idx = np.array(self.y_test) == t
            plt.scatter(X_embedded[idx, 0], X_embedded[idx, 1], label=FACTUALITY_INT_REV[t])   

        plt.legend(bbox_to_anchor=(1, 1));

        X_train, X_test = self.get_train_test_features()
        #Plot X_train SetFit
        X_embedded = TSNE(init='pca',n_components=2).fit_transform(np.array(X_train))

        plt.subplot(222)
        plt.title('X_train SetFit')

        for i, t in enumerate(set(np.array(self.y_train))):
            idx = np.array(self.y_train) == t
            plt.scatter(X_embedded[idx, 0], X_embedded[idx, 1], label=FACTUALITY_INT_REV[t])   

        plt.legend(bbox_to_anchor=(1, 1));

        #Plot X_eval SetFit
        X_embedded = TSNE(init='pca',n_components=2).fit_transform(np.array(X_test))
        plt.subplot(224)
        plt.title('X_test SetFit')

        for i, t in enumerate(set(np.array(self.y_test))):
            idx = np.array(self.y_test) == t
            plt.scatter(X_embedded[idx, 0], X_embedded[idx, 1], label=FACTUALITY_INT_REV[t])   

        plt.legend(bbox_to_anchor=(1, 1))
        plt.savefig('embedding_distribution.pdf', bbox_inches = 'tight')
        plt.show()

In [68]:
class ClassificationHead():
    def __init__(self, args, x_train, y_train, x_test, y_test):
        self.args = args
        self.x_train = x_train
        self.y_train = y_train
        self.x_test = x_test
        self.y_test = y_test
        self.classifier = args.classifier
        self.transformation = args.transformation
        self.classifier = get_classifier_head(ch_name = self.classifier)
        self.class_fitted = False
        
    def transform(self, x):
        if self.transformation == 'normalize':
            return preprocessing.normalize(x, norm='l2')
        else:
            pass
    
    def fit(self):
        self.classifier.fit(self.x_train, self.y_train)
        self.class_fitted = True
        
    def fit_transform(self):
        self.x_train = self.transform(self.x_train)
        self.x_test = self.transform(self.x_test)
        self.fit()
        
    def predict(self, x):
        preds = self.classifier.predict(x)
        return preds
    
    def report_test(self, output_dict = False):
        if self.class_fitted:
            test_peds = self.predict(self.x_test)
            report = classification_report(self.y_test, test_peds,
                                            labels = list(FACTUALITY_INT.values()),
                                            output_dict=output_dict,
                                            target_names=list(FACTUALITY_INT.keys()),
                                            zero_division = True)
            print(report)
    def confusion_matrix(self):
        if self.class_fitted:
            predictions = self.predict(self.x_test)
#             cm = confusion_matrix(self.y_test, predictions, display_labels=list(FACTUALITY_INT.values()),
#                                   normalize='true')
#             disp = ConfusionMatrixDisplay.from_predictions(confusion_matrix=cm,
#                                           display_labels=list(FACTUALITY_INT.keys()))
            disp = ConfusionMatrixDisplay.from_predictions(
                self.y_test, predictions,labels = list(FACTUALITY_INT.values()),
                display_labels=list(FACTUALITY_INT.keys()), xticks_rotation = 45, normalize= 'true'
            )
#             disp.plot()
            plt.savefig('confusion_matrix.pdf', bbox_inches = 'tight')
            plt.show()
#             skplt.metrics.plot_confusion_matrix(self.y_test, predictions, x_tick_rotation=45,
#                                                 labels=list(FACTUALITY_INT.keys()), normalize=True)
    

In [69]:
class FullPipe():
    def __init__(self, args, x_train, y_train, x_test, y_test, mode='train'):
        random.seed(random_state)
        np.random.seed(random_state)
        torch.manual_seed(random_state)
        if mode == 'train':
            self.strans = SentTrans(args, x_train, y_train, x_test, y_test)
            self.args = args
            self.strans_fitted = False
            self.class_model_fitted = False
        elif mode == 'inference':
            self.strans = SentTrans(args, None, None, None, None)
            self.class_model = joblib.load(Path(args.model_path) / "classifier.pkl")
            self.strans_fitted = True
            self.class_model_fitted = True
            
    def fit(self):
        self.strans.fit()
        self.strans_fitted = True
        x_train, x_test = self.strans.get_train_test_features()
        y_train, y_test = self.strans.y_train, self.strans.y_test
        self.class_model = ClassificationHead(self.args, x_train, y_train, x_test, y_test)
        self.class_model.fit()
        self.class_model_fitted = True
        
    def predict(self, x, y = None):
        if self.strans_fitted & self.class_model_fitted:
            x = self.strans.model.encode(x)
            if len(x.shape) == 1:
                x = x.reshape(1, -1)
#             print(x.shape)
#             print(type(x))
            preds = self.class_model.predict(x)
            return preds, [FACTUALITY_INT_REV[i] for i in list(preds)], y
        else:
            print('The models should be fitted')
    
    def evaluate(self):
        self.strans.plot_()
        self.class_model.report_test()
        self.class_model.confusion_matrix()
        
    def save(
        self,
        path: StrOrPath,
        model_name: Optional[str] = None,
        create_model_card: bool = False,
    ):
        if (not self.class_model_fitted) and (not self.strans_fitted):
            raise NotFittedError(
                "This SetFitClassifier instance is not fitted yet."
                " Call 'fit' with appropriate arguments before saving this estimator."
            )
        self.strans.model.save(str(path), self.args.model, create_model_card)
        joblib.dump(self.class_model.classifier, Path(path) / "classifier.pkl")

    def load(self, cls, path: StrOrPath):
        args.model = path
        self.strans = SentTrans(args, None, None, None, None)
        self.class_model = joblib.load(Path(path) / "classifier.pkl")
        return setfit
    

## Inference

In [70]:
# infer_fact(in_file_name = 'all_sentences.jsonl', out_file_name = 'all_sentences_facts.jsonl')

In [72]:
xx = infer_fact_csv('Format_sents')

file count:  1126785


  0%|          | 0/1126785 [00:00<?, ?it/s]

In [73]:
len(glob(join('Format_sents', 'labeled_sent_*.csv')))

1126785

In [65]:
pd.read_csv('Format_sents/labeled_sent_849503.csv',compression = 'gzip')

Unnamed: 0,PREDICATION_AUX_ID,label
0,167307222,Fact
1,167307223,Fact
2,167307224,Fact
3,167307225,Fact
4,167307226,Fact
...,...,...
96,167307319,Fact
97,167307320,Fact
98,167307321,Fact
99,167307322,Fact


In [37]:
xx[-1]

Unnamed: 0,PREDICATION_AUX_ID,label
0,167307222,Fact
1,167307223,Fact
2,167307224,Fact
3,167307225,Fact
4,167307226,Fact
...,...,...
96,167307319,Fact
97,167307320,Fact
98,167307321,Fact
99,167307322,Fact


In [38]:
xx[0]

Unnamed: 0,PREDICATION_AUX_ID,SENTENCE,FORMATED_SENTENCE,file_name
0,167307222,BACKGROUND: Two and a half years after commenc...,BACKGROUND: Two and a half years after commenc...,SENTENCE/split_990577.csv.gz
1,167307223,BACKGROUND: Two and a half years after commenc...,BACKGROUND: Two and a half years after commenc...,SENTENCE/split_990577.csv.gz
2,167307224,This subsequent outbreak provided the opportun...,This subsequent outbreak provided the opportun...,SENTENCE/split_990577.csv.gz
3,167307225,Children with rotavirus-confirmed gastroenteri...,@PREDICAT$ @OBJECT$ rotavirus-confirmed @SUBJE...,SENTENCE/split_990577.csv.gz
4,167307226,Nineteen (46%) of 41 case patients had receive...,Nineteen (46%) of 41 case patients had receive...,SENTENCE/split_990577.csv.gz
...,...,...,...,...
96,167307319,"However, the subtype PsA was more prevalent in...","However, the subtype PsA was more prevalent in...",SENTENCE/split_990577.csv.gz
97,167307320,CONCLUSION: In Sweden the prevalence of spondy...,CONCLUSION: In Sweden the prevalence of @SUBJE...,SENTENCE/split_990577.csv.gz
98,167307321,PsA was the most frequent subtype followed by ...,@SUBJECT$ was the most frequent subtype @OBJEC...,SENTENCE/split_990577.csv.gz
99,167307322,Magnetic resonance imaging of skeletal muscles...,@PREDICAT$ @OBJECT$ @SUBJECT$ in sporadic incl...,SENTENCE/split_990577.csv.gz


## Tests

In [27]:
infer_fact_parallel(in_file_name = 'all_sentences.jsonl')

map


KeyboardInterrupt: 

In [5]:
from tqdm.contrib.concurrent import process_map  # or thread_map
import time

In [6]:
def _foo(my_number):
    square = my_number * my_number
    time.sleep(1)
    return square 


In [7]:
r = process_map(_foo, range(0, 30), max_workers=6)

  0%|          | 0/30 [00:00<?, ?it/s]