In [None]:
import torch
import torch.nn as nn
import transformers
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
import pandas as pd
import numpy as np
import pickle
import time
import ast

In [None]:
# Configuration
device = "cpu"
MAX_LEN = 128
TRAIN_BATCH_SIZE = 16
VALID_BATCH_SIZE = 16
EPOCHS = 10
LEARNING_RATE = 2e-5
MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"

# Load Data
df = pd.read_csv('../data/mtsamples_cleaned.csv')

# Convert string representation of list back to actual list
df['filtered_keywords'] = df['filtered_keywords'].apply(ast.literal_eval)

mlb = MultiLabelBinarizer()
y = mlb.fit_transform(df['filtered_keywords'])

with open('../models/mlb_classes.pkl', 'rb') as f:
    classes = pickle.load(f)

In [None]:
# Dataset Class
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

class PMSIDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len, targets):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.text = dataframe.transcription
        self.targets = targets
        self.max_len = max_len

    def __len__(self):
        return len(self.text)

    def __getitem__(self, index):
        text = str(self.text.iloc[index])
        text = " ".join(text.split())

        inputs = self.tokenizer.encode_plus(
            text, None, add_special_tokens=True, max_length=self.max_len,
            padding='max_length', return_token_type_ids=True, truncation=True
        )

        return {
            'ids': torch.tensor(inputs['input_ids'], dtype=torch.long),
            'mask': torch.tensor(inputs['attention_mask'], dtype=torch.long),
            'token_type_ids': torch.tensor(inputs['token_type_ids'], dtype=torch.long),
            'targets': torch.tensor(self.targets[index], dtype=torch.float)
        }


In [None]:
# Data Loaders
train_df, val_df, train_targets, val_targets = train_test_split(df, y, test_size=0.2, random_state=42)

training_set = PMSIDataset(train_df, tokenizer, MAX_LEN, train_targets)
validation_set = PMSIDataset(val_df, tokenizer, MAX_LEN, val_targets)

training_loader = DataLoader(training_set, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
validation_loader = DataLoader(validation_set, batch_size=VALID_BATCH_SIZE, shuffle=False)


In [None]:
# Model Definition
class PMSIModel(nn.Module):
    def __init__(self, n_classes):
        super(PMSIModel, self).__init__()
        self.bert = transformers.AutoModel.from_pretrained(MODEL_NAME)
        self.drop = nn.Dropout(0.3)
        self.out = nn.Linear(768, n_classes)

    def forward(self, ids, mask, token_type_ids):
        _, pooled_output = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids, return_dict=False)
        output = self.drop(pooled_output)
        return self.out(output)

model = PMSIModel(len(classes))
model.to(device)

In [None]:
# Loss Function (Weighted for Imbalance)
class_counts = y.sum(axis=0)
total_samples = len(y)
pos_weights = (total_samples - class_counts) / (class_counts + 1e-5)
pos_weights_tensor = torch.tensor(pos_weights, dtype=torch.float).to(device)

optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weights_tensor)

In [None]:
# Training Loop
def train_model(epoch):
    model.train()
    print(f"Epoch {epoch+1}/{EPOCHS}")
    total_loss = 0
    
    for i, data in enumerate(training_loader, 0):
        ids = data['ids'].to(device)
        mask = data['mask'].to(device)
        token_type_ids = data['token_type_ids'].to(device)
        targets = data['targets'].to(device)

        outputs = model(ids, mask, token_type_ids)
        loss = loss_fn(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
        if i % 50 == 0:
            print(f'Step {i}, Loss: {loss.item():.4f}')
            
    print(f"Epoch {epoch+1} Complete. Avg Loss: {total_loss / len(training_loader):.4f}")

for epoch in range(EPOCHS):
    train_model(epoch)

In [None]:
# Save Model
torch.save(model.state_dict(), '../models/pmsi_model_conf.bin')
print("Model saved to ../models/pmsi_model_conf.bin")                            