In [11]:
import pandas as pd
import numpy as np
import openai
from sklearn.base import BaseEstimator, TransformerMixin
from sentence_transformers import SentenceTransformer
from gensim.utils import simple_preprocess
from gensim.models.doc2vec import Doc2Vec, TaggedDocument

class Textcol2Mat(BaseEstimator, TransformerMixin):
    def __init__(
        self
        , type = 'openai'
        , openai_client = None
        , embedding_model_openai = 'text-embedding-3-large'
        , embedding_model_st = 'neuml/pubmedbert-base-embeddings'
        , embedding_model_doc2vec = 'PV-DM'
        , doc2vec_epochs = 40
        , doc2vec_vector_size = 10
        , colsep = ' || '
        , return_cols_prefix = 'X_'
    ):
        if not (type in ('openai', 'st', 'doc2vec')):
            raise ValueError('Invalid embedding type')
        if not (embedding_model_doc2vec in ('PV-DM', 'PV-DBOW')):
            raise ValueError('Doc2Vec model must be one of "PV-DM" or "PV-DBOW"')
        
        self.type = type
        self.openai_client = openai_client
        self.embedding_model_openai = embedding_model_openai
        self.embedding_model_st = embedding_model_st
        self.embedding_model_doc2vec = embedding_model_doc2vec
        self.colsep = colsep
        self.return_cols_prefix = return_cols_prefix
        self.doc2vec_epochs = doc2vec_epochs
        self.doc2vec_vector_size = doc2vec_vector_size
        
        pass
    
    def _fit_doc2vec(self, X, y = None):
        if not (X.dtypes == 'object').all():
            raise TypeError('All columns of X must be of string (object) type')
        
        Xstr = X.fillna('').astype(str).apply(
            lambda row: self.colsep.join([f'{col}: {row[col]}' for col in X.columns])
            , axis=1
        ).tolist()
        
        corpus = [TaggedDocument(words = simple_preprocess(doc), tags=[str(i)]) for i, doc in enumerate(Xstr)]
        alg = 1 if self.embedding_model_doc2vec == 'PV-DM' else 0
        model = Doc2Vec(corpus, vector_size = self.doc2vec_vector_size, window=2, min_count=1, epochs = self.doc2vec_epochs, dm = alg)
        self.doc2vec_model = model
        return self
    
    def fit(self, X, y = None):
        if self.type == 'doc2vec':
            return self._fit_doc2vec(X, y)
        else:
            return self

    def fit_transform(self, X, y = None):
        if self.type == 'docv2ec':
            self._fit_doc2vec(X, y)
        return self.transform(X)
    
    def transform(self, X):
        
        if not (X.dtypes == 'object').all():
            raise TypeError('All columns of X must be of string (object) type')
        
        Xstr = X.fillna('').astype(str).apply(
            lambda row: self.colsep.join([f'{col}: {row[col]}' for col in X.columns])
            , axis=1
        ).tolist()
        
        if self.type == 'openai':
            arr = self._transform_openai(Xstr)
        elif self.type == 'st':
            arr = self._transform_st(Xstr)
        elif self.type == 'doc2vec':
            arr = self._transform_doc2vec(Xstr)
        else:
            raise ValueError('Invalid embedding type')
        return pd.DataFrame(arr, columns = [self.return_cols_prefix + str(i) for i in range(arr.shape[1])])

    def _transform_doc2vec(self, X):
        out = [self.doc2vec_model.infer_vector(simple_preprocess(doc)) for doc in X]
        return np.array(out)

    def _transform_st(self, X):
        model = SentenceTransformer(self.embedding_model_st)
        return model.encode(X) 
    
    def _transform_openai(self, X):
        if not self.openai_client:
            raise ValueError('Invalid OpenAI client')
        
        ret = self.openai_client.embeddings.create(
            input = X
            , model = self.embedding_model_openai
        )
        ret = np.array([ret.data[n].embedding for n in range(len(ret.data))])
        return ret


In [2]:
dfText = pd.read_csv('C:/Users/alire/OneDrive/data/statman_bitbucket/aki/LLM/March2024/all_text_columns.csv')
dfText.head()

Unnamed: 0,project_id,operation_no,diagnosis,prevproc,comorbidity,operation
0,PR-00000001,1,cardiac conduit complication; pulmonary atresi...,replacement of cardiac conduit; rv to pulmonar...,22q11 microdeletion with full digeorge sequenc...,replacement of cardiac conduit
1,PR-00000002,2,cardiac conduit failure; aortic regurgitation ...,no previous procedure,no pre-procedural risk factors,replacement of cardiac conduit; 'annuloplasty'...
2,PR-00000003,3,tetralogy of fallot; pulmonary regurgitation,no previous procedure,no pre-procedural risk factors,absent pulmonary valve syndrome (fallot type) ...
3,PR-00000004,4,aortic regurgitation; congenital anomaly of ao...,rv to pulmonary artery conduit construction; p...,no pre-procedural risk factors,aortic root replacement: valve sparing technique
4,PR-00000005,5,cardiac conduit failure; common arterial trunk...,rv to pulmonary artery conduit construction; r...,lung disease; renal failure; 22q11 microdeleti...,replacement of cardiac conduit; pacemaker syst...


In [3]:
from dotenv import load_dotenv
import os

# Load environment variables from .env file
load_dotenv()
api_key = os.getenv('OPENAI_API_KEY')

from openai import OpenAI
client = OpenAI(api_key=api_key)

In [20]:
obj = Textcol2Mat(
    type = 'st'
    , openai_client = client
    , embedding_model_openai = 'text-embedding-3-small'
    , doc2vec_epochs = 100
    , doc2vec_vector_size = 5 
).fit(dfText[['diagnosis', 'operation']])

In [22]:
obj.transform(dfText[['diagnosis', 'operation']]).head()



Unnamed: 0,X_0,X_1,X_2,X_3,X_4,X_5,X_6,X_7,X_8,X_9,...,X_758,X_759,X_760,X_761,X_762,X_763,X_764,X_765,X_766,X_767
0,-0.865562,-0.373688,-0.658263,-0.38451,0.312594,0.6718,-0.517263,0.092669,0.050474,0.331842,...,-0.704268,0.512846,0.129162,0.374641,0.288703,-0.386418,0.210517,0.99338,0.087783,0.043976
1,-0.781574,-0.28228,-0.714436,0.034094,0.244921,0.393129,-0.284776,0.746228,0.06371,0.409943,...,-0.737984,0.682594,0.50257,0.19444,0.282918,-0.546767,-0.172537,0.779971,0.610607,-0.020707
2,-0.104579,-0.580223,-0.187845,0.177381,0.145783,-0.113661,-0.311051,0.205705,-0.430501,0.020169,...,-0.472311,0.56324,-0.144658,0.454121,-0.266685,-0.342979,-0.272786,0.103753,0.232315,0.134914
3,-0.453933,-0.616355,-0.702995,0.200006,0.174352,-0.047446,-0.179576,0.785386,0.239604,-0.009016,...,-0.131017,0.803993,0.434759,0.173214,-0.090122,0.314879,-0.197242,0.014069,0.455173,0.030475
4,-1.145213,0.216133,-0.312247,0.031895,0.344814,0.028535,-0.524198,0.372269,-0.44491,-0.049628,...,-0.236367,0.105341,0.729378,0.768905,0.225522,-0.641583,0.188161,0.409608,-0.138511,-0.113123


10