In [1]:
import pandas as pd
from transformers import BertTokenizer

# Load datasets
prompts_test = pd.read_csv("../data/prompts_test.csv")
prompts_train = pd.read_csv("../data/prompts_train.csv")
summaries_test = pd.read_csv("../data/summaries_test.csv")
summaries_train = pd.read_csv("../data/summaries_train.csv")

# Drop student_id column from summaries_train and summaries_test
summaries_train = summaries_train.drop(columns=['student_id'])
summaries_test = summaries_test.drop(columns=['student_id'])
summaries_train = summaries_train[0:32]

id_mapping = {id_val: idx for idx, id_val in enumerate(prompts_train['prompt_id'].unique())}

summaries_train['prompt_id'] = summaries_train['prompt_id'].replace(id_mapping)
summaries_test['prompt_id'] = summaries_test['prompt_id'].replace(id_mapping)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize the 'text' column
texts = summaries_train['text'].tolist()
tokens = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=128)

input_ids = tokens["input_ids"]
attention_mask = tokens["attention_mask"]


class CustomBERTModel(nn.Module):
    def __init__(self, config, num_prompt_classes):
        super(CustomBERTModel, self).__init__()
        
        self.bert = BertModel(config)
        
        # Classification head for prompts
        self.prompt_classifier = Linear(config.hidden_size, num_prompt_classes)  # replace num_prompt_classes with your actual number
        
        # Regression head for wording
        self.wording_regressor = Linear(config.hidden_size, 1)
        
        # Regression head for content
        self.content_regressor = Linear(config.hidden_size, 1)
        
        # Regression head for combined wording & content
        self.combined_regressor = Linear(config.hidden_size, 2)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]
        
        prompt_class = self.prompt_classifier(pooled_output)
        wording_value = self.wording_regressor(pooled_output)
        content_value = self.content_regressor(pooled_output)
        combined_values = self.combined_regressor(pooled_output)
        
        return prompt_class, wording_value, content_value, combined_values

In [2]:
from transformers import BertModel, BertConfig
from torch.nn import Linear, CrossEntropyLoss, MSELoss
from torch.optim import AdamW
import torch.nn as nn

class CustomBERTModel(nn.Module):
    def __init__(self, config, num_prompt_classes, hidden_size=256):
        super(CustomBERTModel, self).__init__()

        self.bert = BertModel(config)

        # Classification head for prompts
        self.prompt_classifier_1 = nn.Linear(config.hidden_size, hidden_size)
        self.prompt_classifier_2 = nn.Linear(hidden_size, num_prompt_classes)

        # Regression head for wording
        self.wording_regressor_1 = nn.Linear(config.hidden_size, hidden_size)
        self.wording_regressor_2 = nn.Linear(hidden_size, 1)

        # Regression head for content
        self.content_regressor_1 = nn.Linear(config.hidden_size, hidden_size)
        self.content_regressor_2 = nn.Linear(hidden_size, 1)

        # Regression head for combined wording & content
        self.combined_regressor_1 = nn.Linear(config.hidden_size, hidden_size)
        self.combined_regressor_2 = nn.Linear(hidden_size, 2)

        # Activation and dropout layers
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)  # you can adjust the dropout rate if needed

    def forward(self, input_ids, attention_mask=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]  # [CLS] representation
        
        # Classification head for prompts
        prompt_output = self.prompt_classifier_1(pooled_output)
        prompt_output = self.relu(prompt_output)
        prompt_output = self.dropout(prompt_output)
        prompt_output = self.prompt_classifier_2(prompt_output)
        
        # Regression head for wording
        wording_output = self.wording_regressor_1(pooled_output)
        wording_output = self.relu(wording_output)
        wording_output = self.dropout(wording_output)
        wording_output = self.wording_regressor_2(wording_output)
        
        # Regression head for content
        content_output = self.content_regressor_1(pooled_output)
        content_output = self.relu(content_output)
        content_output = self.dropout(content_output)
        content_output = self.content_regressor_2(content_output)
        
        # Regression head for combined wording & content
        combined_output = self.combined_regressor_1(pooled_output)
        combined_output = self.relu(combined_output)
        combined_output = self.dropout(combined_output)
        combined_output = self.combined_regressor_2(combined_output)
        
        return prompt_output, wording_output, content_output, combined_output


# Now, update your training function
def train_model(model, input_ids, attention_mask, prompt_id_labels, wording_labels, content_labels, epochs=3):
    # Define the loss functions
    classification_criterion = CrossEntropyLoss()
    regression_criterion = MSELoss()

    # Define the optimizer
    optimizer = AdamW(model.parameters(), lr=2e-5)

    model.train()

    for epoch in range(epochs):
        for batch in range(0, len(input_ids), batch_size):  # assume batch_size is the size of your batch
            optimizer.zero_grad()
            
            # Forward pass
            prompt_ids, wording, content, combined = model(input_ids[batch:batch+batch_size], attention_mask=attention_mask[batch:batch+batch_size])

            # Compute loss
            loss1 = classification_criterion(prompt_ids, prompt_id_labels[batch:batch+batch_size])
            loss2 = regression_criterion(wording, wording_labels[batch:batch+batch_size])
            loss3 = regression_criterion(content, content_labels[batch:batch+batch_size])
            loss4 = regression_criterion(combined, torch.cat((wording_labels, content_labels), dim=1)[batch:batch+batch_size])
            
            # Total loss
            loss = loss1 +  loss2 + loss3 + loss4
            loss.backward()
            optimizer.step()
            
            print(f"Epoch {epoch+1}, Batch {batch//batch_size + 1}, Loss: {loss.item()}")

    # Save the model after training
    torch.save(model.state_dict(), "./saved_model_directory/model_weights.pth")


In [3]:
import torch

prompt_id_labels = torch.tensor(summaries_train['prompt_id'].values)
wording_labels = torch.tensor(summaries_train['wording'].values).float().unsqueeze(1)
content_labels = torch.tensor(summaries_train['content'].values).float().unsqueeze(1)


In [4]:
import os

# Set the hyperparameters
batch_size = 16  # or whatever you choose
num_prompt_classes = 4  # replace with the actual number of classes for prompt classification

# Instantiate model with BERT's configuration
config = BertConfig.from_pretrained("bert-base-uncased")
model = CustomBERTModel(config, num_prompt_classes)

# Path to the saved model weights
model_weights_path = "./saved_model_directory/model_weights.pth"

# Instantiate model with BERT's configuration
config = BertConfig.from_pretrained("bert-base-uncased")
model = CustomBERTModel(config, num_prompt_classes)

# Check if the model weights exist and load them
if os.path.exists(model_weights_path):
    model.load_state_dict(torch.load(model_weights_path))
    print("Loaded saved model weights!")

# If no saved model is found, train from scratch
train_model(
    model,
    input_ids=input_ids,
    attention_mask=attention_mask,
    prompt_id_labels=prompt_id_labels,
    wording_labels=wording_labels,
    content_labels=content_labels,
    epochs=5
)


Loaded saved model weights!
Epoch 1, Batch 1, Loss: 3.7746429443359375
Epoch 1, Batch 2, Loss: 8.208613395690918
Epoch 2, Batch 1, Loss: 3.9896764755249023
Epoch 2, Batch 2, Loss: 2.7766194343566895
Epoch 3, Batch 1, Loss: 6.492539882659912
Epoch 3, Batch 2, Loss: 3.367208480834961
Epoch 4, Batch 1, Loss: 4.4194135665893555
Epoch 4, Batch 2, Loss: 3.221442222595215
Epoch 5, Batch 1, Loss: 3.6987507343292236
Epoch 5, Batch 2, Loss: 4.050414562225342
