In [None]:
import pandas as pd
import json
import numpy as np
from transformers import AutoTokenizer, AutoModel, get_cosine_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader
# from torchmetrics.functional.classification import auroc
# from torch.amp.autocast_mode import autocast
# from torch.amp.grad_scaler import GradScaler
import torch.nn as nn
import torch
import torch.nn.functional as F
# from tqdm import tqdm
# from sklearn.metrics import f1_score, classification_report
# from pytorch_lamb import Lamb
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score

torch.set_float32_matmul_precision('medium')
MAX_SEQ_LEN = 223 + 2 # Old value: 195. Decide to keep as original
def read_file_json(fileName):
    with open(fileName, 'r', encoding='utf-8') as f:
        return json.load(f)
data = read_file_json("vimq_data/char2index.json")
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
batch_size = 32 # Will run value: 67
vocab = []
for char, _ in data.items():
    vocab.append(char)
# a b c d => 0 234 5555 5555 2 1 1 1 1 1 1 1
def get_token(tokenizer, words, max_seq_len = MAX_SEQ_LEN):
    inputs_id = [tokenizer.cls_token_id]
    for word in words:
        word_token = tokenizer.encode(word)
        inputs_id += word_token[1: (len(word_token) - 1)]
    inputs_id.append(tokenizer.sep_token_id)
    attention_mask = [1] * len(inputs_id)
    if len(inputs_id) > max_seq_len:
        inputs_id = inputs_id[:max_seq_len]
        attention_mask = attention_mask[:max_seq_len]
    else:
        inputs_id += [tokenizer.pad_token_id] * (max_seq_len - len(inputs_id))
        attention_mask += [0] * (max_seq_len - len(attention_mask))
    return inputs_id, attention_mask
def load_tokenizer():
    return AutoTokenizer.from_pretrained("vinai/phobert-base-v2")
def clean_text(text: str):
    text = str(text).lower()
    cleaned_text = ''
    for char in text:
        if char in vocab:
            cleaned_text += char
    return cleaned_text
def cleanify_sentence(sentence):
    cleaned_text = ' '.join([clean_text(text) for text in sentence.split(" ")])
    # cleaned_text = cleaned_text.rsplit(".", 1)[0]
    return cleaned_text

In [None]:
data = pd.read_csv("vim-med/ViMedical_Disease.csv")
data["Question"] = data["Question"].apply(cleanify_sentence)

In [None]:
### Find the longest ones to get hyper param --> Don't run in main
list_questions = data["Question"].values.tolist()
longest = max(list_questions, key=len)
print(longest)
print(len(longest))

In [None]:
all_labels = data["Disease"].drop_duplicates().to_list()
labels_to_id = {label : i for i, label in enumerate(all_labels)}
id_to_labels = {i : label for i, label in enumerate(all_labels)}
data["Disease"] = data["Disease"].map(labels_to_id)

In [None]:
# split 80% 20% for validation
IDS_ALL_UNIQUE = data.shape[0]
np.random.seed(42)
train_idx = np.random.choice(np.arange(IDS_ALL_UNIQUE), int(0.8 * IDS_ALL_UNIQUE), replace=False)
valid_idx = np.setdiff1d(np.arange(IDS_ALL_UNIQUE), train_idx)
np.random.seed(None)

df_train = data.loc[train_idx].reset_index(drop=True)
df_val = data.loc[valid_idx].reset_index(drop=True)

print(f"Full dataset: {data.shape}")
print(f"Train dataset: {df_train.shape}")
print(f"Validate dataset: {df_val.shape}")

In [None]:
tokenizer = load_tokenizer()
def tokenize_data(df, tokenizer_tool):
    list_input_ids = []
    list_attention_mask = []
    
    for text in df["Question"]:
        inputs_id, attention_mask = get_token(tokenizer_tool, text)
        list_input_ids.append(inputs_id)
        list_attention_mask.append(attention_mask)
    df["input_ids"] = list_input_ids
    df["attention_masks"] = list_attention_mask
    return df
df_train_tokenized = tokenize_data(df_train, tokenizer)
df_val_tokenized = tokenize_data(df_val, tokenizer)

In [None]:
class SentenceDataSet(Dataset):
    def __init__(self, df):
        self.input_ids = df['input_ids'].tolist()
        self.attention_masks = df['attention_masks'].tolist()
        self.labels = df['Disease'].tolist()

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return torch.tensor(self.input_ids[idx]), torch.tensor(self.attention_masks[idx]), torch.tensor(self.labels[idx])

In [None]:
SentenceDataSet(df_train_tokenized)[0][2].shape

In [None]:
class QuickDataModule(LightningDataModule):
    def __init__(self, df_train = df_train_tokenized, df_val = df_val_tokenized, batch_size = batch_size, num_worker = 2):
        super().__init__()
        self.df_train = df_train
        self.df_val = df_val
        self.batch_size = batch_size
        self.num_worker = num_worker
    def setup(self, stage = None):
        if stage in (None, "fit"):
            self.train_dataset = SentenceDataSet(self.df_train)
            self.val_dataset = SentenceDataSet(self.df_val)
        if stage == 'predict':
            self.val_dataset = SentenceDataSet(self.df_val)
    def train_dataloader(self):
        return DataLoader(dataset=self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_worker)
    def val_dataloader(self):
        return DataLoader(dataset=self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_worker)
    def predict_dataloader(self):
        return DataLoader(dataset=self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_worker)

In [None]:
### Testing
sample_data = QuickDataModule()
sample_data.setup()
dl = sample_data.train_dataloader()
print(len(dl))
example = iter(dl)
print(next(example))

In [None]:
# Version 2 - The BiLSTM
class QuickModelLSTM(LightningModule):
    def __init__(self, num_labels, tokenizer_tool,  hidden_dim, lstm_layer, num_epoch, train_size = df_train.shape[0], batch_size = batch_size):
        super(QuickModelLSTM, self).__init__()
        # self.num_labels = num_labels
        self.save_hyperparameters()
        self.roberta = AutoModel.from_pretrained("vinai/phobert-base-v2", return_dict=True)
        
        self.lstm = nn.LSTM(self.roberta.config.hidden_size, self.hparams.hidden_dim, self.hparams.lstm_layer, batch_first=True, bidirectional=True, dropout=0.5)
        self.classfication = nn.Linear(self.hparams.hidden_dim * 2, self.hparams.num_labels)
        nn.init.xavier_uniform_(self.classfication.weight)
        self.loss_fnc = nn.CrossEntropyLoss()
        self.dropout = nn.Dropout(0.5)
        self.train_metrics = MetricCollection(
            {
                "accuracy": MulticlassAccuracy(num_classes=self.hparams.num_labels, average='weighted'),
                "f1": MulticlassF1Score(num_classes=self.hparams.num_labels, average='weighted'),
            },
            prefix="train_",
        )
        self.valid_metrics = self.train_metrics.clone(prefix="valid_")
    def forward(self, input_ids, attention_mask, labels=None):
        output = self.roberta(input_ids = input_ids, attention_mask = attention_mask)
        sequence_output = output.last_hidden_state
        output, _ = self.lstm(sequence_output)
        pooled_output = torch.mean(output, 1)
        pooled_output = self.dropout(pooled_output)
        logits = self.classfication(pooled_output)
        loss = 0
        if labels is not None:
            loss = self.loss_fnc(logits.view(-1, self.hparams.num_labels), labels.view(-1))
        return loss, logits
    def training_step(self, batch, batch_idx):
        id, mask, label = batch
        loss, logits = self(id, mask, label)
        self.log('train_loss', loss, prog_bar=True, on_step=True)
        batch_value = self.train_metrics(logits, label)
        self.log_dict(batch_value, prog_bar=True, on_step=True)
        return {"loss": loss, "predictions": logits, "labels": label}
    def validation_step(self, batch, batch_idx):
        id, mask, label = batch
        loss, logits = self(id, mask, label)
        self.log('val_loss', loss, prog_bar=True)
        self.valid_metrics.update(logits, label)
        return {"loss": loss, "predictions": logits, "labels": label}
    def predict_step(self, batch, batch_idx):
        id, mask, label = batch
        _, logits = self(id, mask, label)
        return logits
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1.5e-6, weight_decay=0.001)
        # optimizer = torch.optim.AdamW(self.parameters(), lr=1e-2, weight_decay=0.01)
        total_step = int(self.hparams.train_size / self.hparams.batch_size) * self.hparams.num_epoch
        warmup_step = int(total_step * 0.1)
        scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_step, total_step)
        return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]
    def on_validation_epoch_end(self):
        self.log_dict(self.valid_metrics.compute(), prog_bar=True)
        self.valid_metrics.reset()
    def on_train_epoch_end(self):
        self.train_metrics.reset()

In [None]:
# Version 1
class QuickModel(LightningModule):
    def __init__(self, num_labels, tokenizer_tool, num_epoch, train_size = df_train.shape[0], batch_size = batch_size):
        super(QuickModel, self).__init__()
        # self.num_labels = num_labels
        self.save_hyperparameters()
        self.roberta = AutoModel.from_pretrained("vinai/phobert-base-v2", return_dict=True)
        self.hidden = nn.Linear(self.roberta.config.hidden_size, self.roberta.config.hidden_size)
        self.classfication = nn.Linear(self.roberta.config.hidden_size, self.hparams.num_labels)
        nn.init.xavier_uniform_(self.hidden.weight)
        nn.init.xavier_uniform_(self.classfication.weight)
        self.loss_fnc = nn.CrossEntropyLoss()
        self.dropout = nn.Dropout()
    def forward(self, input_ids, attention_mask, labels=None):
        output = self.roberta(input_ids = input_ids, attention_mask = attention_mask)
        pooled_output = torch.mean(output.last_hidden_state, 1) # torch.concat [batchsize, seq_length, hidden_size* số lớp phobert]
        pooled_output = self.hidden(pooled_output)
        pooled_output = self.dropout(pooled_output)
        pooled_output = F.relu(pooled_output)
        logits = self.classfication(pooled_output)
        loss = 0
        if labels is not None:
            loss = self.loss_fnc(logits.view(-1, self.hparams.num_labels), labels.view(-1))
        return loss, logits
    def training_step(self, batch, batch_idx):
        id, mask, label = batch
        loss, logits = self(id, mask, label)
        self.log('train_loss', loss, prog_bar=True, on_step=True)
        return {"loss": loss, "predictions": logits, "labels": label}
    def validation_step(self, batch, batch_idx):
        id, mask, label = batch
        loss, logits = self(id, mask, label)
        self.log('val_loss', loss, prog_bar=True)
        return {"loss": loss, "predictions": logits, "labels": label}
    def predict_step(self, batch, batch_idx):
        id, mask, label = batch
        _, logits = self(id, mask, label)
        return logits
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1.5e-6, weight_decay=0.001)
        # optimizer = torch.optim.AdamW(self.parameters(), lr=1e-2, weight_decay=0.01)
        total_step = int(self.hparams.train_size / self.hparams.batch_size) * self.hparams.num_epoch
        warmup_step = int(total_step * 0.1)
        scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_step, total_step)
        return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]

In [None]:
modelSaver = ModelCheckpoint(monitor='val_loss',dirpath="work_progress/", filename='BiLSTM-Epoch-{epoch}-Val_loss-{val_loss:.4f}', save_top_k=1)

In [None]:
# _, pred = torch.max(output.logits, dim=1)
# print(pred.cpu().numpy())
# print(label.cpu().numpy())
train_dataset = SentenceDataSet(df_train_tokenized)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
example = iter(train_loader)
id, mask, label = next(example)

In [None]:
model = QuickModelLSTM(len(all_labels), tokenizer, 128, 2, 3)
loss, logits = model(id, mask, label)
preds = torch.argmax(logits, dim=1)
print(preds)

In [None]:
epochs = 3
data_module = QuickDataModule()
data_module.setup()
trainer = Trainer(max_epochs=epochs, num_sanity_val_steps=50, default_root_dir="my_board/", fast_dev_run=True, callbacks=[modelSaver])
model = QuickModel(len(all_labels), tokenizer, epochs)
trainer.fit(model, datamodule=data_module)

In [None]:
epochs = 3
data_module = QuickDataModule()
data_module.setup()
trainer = Trainer(max_epochs=epochs, num_sanity_val_steps=50, default_root_dir="my_board/", fast_dev_run=True, callbacks=[modelSaver])
model = QuickModelLSTM(len(all_labels), tokenizer, 128, 2, epochs)
trainer.fit(model, datamodule=data_module)

In [None]:
def generate_prediction(my_model, dm):
    predictions = trainer.predict(my_model, datamodule=dm)
    flattened_predictions = np.stack([torch.sigmoid(torch.Tensor(p)) for batch in predictions for p in batch]).tolist()
    flattened_predictions = np.array([decode_label_disease(en_label) for en_label in flattened_predictions])
    return flattened_predictions
def get_true_label():
    encoded_labels = df_train["Disease"].apply(decode_label_disease).to_numpy()
    return encoded_labels

# val_f1 = f1_score(all_val_labels, all_val_preds, average='weighted') # Perform this

In [None]:
prediction = generate_prediction(model, data_module)

In [None]:
true_label = get_true_label()
true_name = [id_to_labels[idx] for idx in true_label]

In [None]:
f1_score(true_label, prediction, average='weighted')

In [None]:
classification_report(true_label, prediction, target_names=true_name)

In [None]:
### For software deployment
trainer.save_checkpoint("work_progress/current_model.ckpt")

## Reference to tokenizer
# model = QuickModel.load_from_checkpoint("").tokenizer_tool

### Failed version

In [None]:
class SentenceModel(nn.Module):
    def __init__(self, num_label, hidden_dim, lstm_layer):
        super(SentenceModel, self).__init__()
        self.num_labels = num_label
        self.roberta = AutoModelForSequenceClassification.from_pretrained("vinai/phobert-base-v2", num_labels=num_label)
        self.lstm = nn.LSTM(self.roberta.config.hidden_size, hidden_dim, lstm_layer, batch_first=True, bidirectional=True, dropout=0.1)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_dim * 2, num_label)
    def forward(self, input_ids, attention_mask, label):
        output = self.roberta(input_ids = input_ids, attention_mask = attention_mask, output_hidden_states=True)
        sequence_output = output.hidden_states[-1]
        output, _ = self.lstm(sequence_output)
        output = torch.mean(output, 1).to(device)
        logits = self.classifier(output)
        loss = F.cross_entropy(logits.view(-1, self.num_labels), label.view(-1))
        return loss, logits

In [None]:
train_dataset = SentenceDataSet(df_train_tokenized)
val_dataset = SentenceDataSet(df_val_tokenized)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

In [None]:
n_labels = len(all_labels)
model = SentenceModel(num_label=n_labels, hidden_dim=2048, lstm_layer=2).to(device)
optimizer = Lamb(model.parameters(), lr=1e-4)
scaler = GradScaler()
epochs = 2

patience =5
best_val_loss = float('inf')
patience_counter = 0
best_f1 = 0
best_epoch = 0
for epoch in range(epochs):
    total_loss = 0
    all_preds = []
    all_true_preds = []
    model.train()
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        id, mask, label = batch
        id, mask, label = id.to(device), mask.to(device), label.to(device)
        with autocast("cuda"):
            model_output = model(id, mask, label)
            loss = model_output[0]
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        total_loss += loss.item()
        logits = model_output[1]
        print(f"Current loss: {loss}; Total loss: {total_loss}")
        preds = torch.argmax(logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_true_preds.extend(label.cpu().numpy())
    f1 = f1_score(all_true_preds, all_preds, average="weighted")
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}, F1 Score: {f1:.4f}")
    model.eval()
    total_val_loss = 0
    all_val_preds = []
    all_val_labels = []
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Validating Epoch {epoch+1}/{epochs}"):
            id, mask, label = batch
            id, mask, label = id.to(device), mask.to(device), label.to(device)
            model_output = model(id, mask, label)
            loss = model_output[0]

            total_val_loss += loss.item()
            logits = model_output[1]
            preds = torch.argmax(logits, dim=1)
            all_val_preds.extend(preds.cpu().numpy())
            all_val_labels.extend(label.cpu().numpy())
    val_f1 = f1_score(all_val_labels, all_val_preds, average='weighted')
    avg_val_loss = total_val_loss / len(val_loader)
    print(f"Validation - Epoch {epoch+1}/{epochs} - Loss: {avg_val_loss:.4f}, F1 Score: {val_f1:.4f}\n")
    # Early stopping check
    if val_f1 > best_f1:
        best_f1 = val_f1
        patience_counter = 0
        torch.save(model.state_dict(), 'home/best_model_weights.pth')
        print(f"Saved best model weights epoch = {epoch+1}")
        best_epoch = epoch+1
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered")
            break