In [7]:
import sys
sys.path.insert(0, '/home/afigueroa')

In [21]:
from transformers import AutoModel, AutoConfig
from tfIntegration.baseModel.models.base import BaseClassificationModel
import torch
import torch.nn as nn
import pytorch_lightning as pl



In [26]:
from transformers import AutoModel, AutoConfig
from tfIntegration.baseModel.models.base import BaseClassificationModel
import torch
import torch.nn as nn
import pytorch_lightning as pl

class CombinedModel(nn.Module):
    def __init__(self, base_model_checkpoint, ernie_model_path, dropout=0.05, n_classes=2):
        super(CombinedModel, self).__init__()

        # Load the PyTorch trained base model
        self.base_model = BaseClassificationModel() # Initialize your base model class
        checkpoint = torch.load(base_model_checkpoint) # Load the PyTorch checkpoint
        self.base_model.load_state_dict(checkpoint['state_dict']) # Load the state dict into your base model


        # Load the pre-trained ERNIE model
        ernie_config = AutoConfig.from_pretrained(ernie_model_path)
        self.ernie_model = AutoModel.from_pretrained(ernie_model_path, config=ernie_config)

        # Set the hidden size based on one of the models
        self.hidden_size = ernie_config.hidden_size

        # (combined) model classification head
        self.head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(self.hidden_size*2, self.hidden_size),
            nn.Tanh(),
            nn.Dropout(dropout),
            nn.Linear(self.hidden_size, n_classes)
        )

    def forward(self, x):
        # Pass input through base model
        base_output = self.base_model(*x)

        # Pass input through ERNIE model
        ernie_output = self.ernie_model(*x)

        # Concatenate the outputs
        combined_output = torch.cat((base_output[0][:, 0, :], ernie_output[0][:, 0, :]), dim=1)

        # Pass through final classification head
        return self.head(combined_output)

In [27]:
base_model_checkpoint = "../baseModel/trained_model/base_model.ckpt"
ernie_model_path = "./models/"
model = CombinedModel(base_model_checkpoint, ernie_model_path)
