In [91]:
import pandas as pd
import numpy as np
import openai
from sklearn.base import BaseEstimator, TransformerMixin
from sentence_transformers import SentenceTransformer
from gensim.utils import simple_preprocess
from gensim.models.doc2vec import Doc2Vec, TaggedDocument

class Textcol2Vec(BaseEstimator, TransformerMixin):
    def __init__(
        self
        , type = 'openai'
        , openai_client = None
        , embedding_model_openai = 'text-embedding-3-large'
        , embedding_model_st = 'neuml/pubmedbert-base-embeddings'
        , colsep = ' || '
        , return_cols_prefix = 'X_'
    ):
        if not (type in ('openai', 'st', 'doc2vec')):
            raise ValueError('Invalid embedding type')
        
        self.type = type
        self.openai_client = openai_client
        self.embedding_model_openai = embedding_model_openai
        self.embedding_model_st = embedding_model_st
        self.colsep = colsep
        self.return_cols_prefix = return_cols_prefix
        
        pass
    
    def _fit_doc2vec(self, X, y = None):
        if not (X.dtypes == 'object').all():
            raise TypeError('All columns of X must be of string (object) type')
        
        Xstr = X.fillna('').astype(str).apply(
            lambda row: self.colsep.join([f'{col}: {row[col]}' for col in X.columns])
            , axis=1
        ).tolist()
        
        corpus = [TaggedDocument(words = simple_preprocess(doc), tags=[str(i)]) for i, doc in enumerate(Xstr)]
        model = Doc2Vec(corpus, vector_size=5, window=2, min_count=1, epochs=40)
        self.doc2vec_model = model
        return self
    
    def fit(self, X, y = None):
        if self.type == 'doc2vec':
            return self._fit_doc2vec(X, y)
        else:
            return self

    def fit_transform(self, X, y = None):
        if self.type == 'docv2ec':
            self._fit_doc2vec(X, y)
        return self.transform(X)
    
    def transform(self, X):
        
        if not (X.dtypes == 'object').all():
            raise TypeError('All columns of X must be of string (object) type')
        
        Xstr = X.fillna('').astype(str).apply(
            lambda row: self.colsep.join([f'{col}: {row[col]}' for col in X.columns])
            , axis=1
        ).tolist()
        
        if self.type == 'openai':
            return self._transform_openai(Xstr)
        elif self.type == 'st':
            return self._transform_st(Xstr)
        elif self.type == 'doc2vec':
            return self._transform_doc2vec(Xstr)
        else:
            raise ValueError('Invalid embedding type')

    def _transform_doc2vec(self, X):
        out = [self.doc2vec_model.infer_vector(simple_preprocess(doc)) for doc in X]
        return np.array(out)

    def _transform_st(self, X):
        model = SentenceTransformer(self.embedding_model_st)
        return model.encode(X) 
    
    def _transform_openai(self, X):
        if not self.openai_client:
            raise ValueError('Invalid OpenAI client')
        
        ret = self.openai_client.embeddings.create(
            input = X
            , model = self.embedding_model_openai
        )
        ret = np.array([ret.data[n].embedding for n in range(len(ret.data))])
        return pd.DataFrame(ret, columns = [self.return_cols_prefix + str(i) for i in range(ret.shape[1])])


In [5]:
dfText = pd.read_csv('C:/Users/alire/OneDrive/data/statman_bitbucket/aki/LLM/March2024/all_text_columns.csv')
dfText.head()

Unnamed: 0,project_id,operation_no,diagnosis,prevproc,comorbidity,operation
0,PR-00000001,1,cardiac conduit complication; pulmonary atresi...,replacement of cardiac conduit; rv to pulmonar...,22q11 microdeletion with full digeorge sequenc...,replacement of cardiac conduit
1,PR-00000002,2,cardiac conduit failure; aortic regurgitation ...,no previous procedure,no pre-procedural risk factors,replacement of cardiac conduit; 'annuloplasty'...
2,PR-00000003,3,tetralogy of fallot; pulmonary regurgitation,no previous procedure,no pre-procedural risk factors,absent pulmonary valve syndrome (fallot type) ...
3,PR-00000004,4,aortic regurgitation; congenital anomaly of ao...,rv to pulmonary artery conduit construction; p...,no pre-procedural risk factors,aortic root replacement: valve sparing technique
4,PR-00000005,5,cardiac conduit failure; common arterial trunk...,rv to pulmonary artery conduit construction; r...,lung disease; renal failure; 22q11 microdeleti...,replacement of cardiac conduit; pacemaker syst...


In [37]:
from dotenv import load_dotenv
import os

# Load environment variables from .env file
load_dotenv()
api_key = os.getenv('OPENAI_API_KEY')

from openai import OpenAI
client = OpenAI(api_key=api_key)

In [92]:
obj = Textcol2Vec(
    type = 'doc2vec'
    , openai_client = client
    , embedding_model_openai = 'text-embedding-3-small'
).fit(dfText[['diagnosis', 'operation']])

In [94]:
obj.transform(dfText[['diagnosis', 'operation']][:5]).shape

(5, 5)

In [88]:
obj._transform_doc2vec(dfText[['diagnosis', 'operation']][:5])#.shape

Unnamed: 0,diagnosis,operation
0,cardiac conduit complication; pulmonary atresi...,replacement of cardiac conduit
1,cardiac conduit failure; aortic regurgitation ...,replacement of cardiac conduit; 'annuloplasty'...
2,tetralogy of fallot; pulmonary regurgitation,absent pulmonary valve syndrome (fallot type) ...
3,aortic regurgitation; congenital anomaly of ao...,aortic root replacement: valve sparing technique
4,cardiac conduit failure; common arterial trunk...,replacement of cardiac conduit; pacemaker syst...
