# Libraries

In [None]:
import os
import warnings
import random
import sqlite3
import numpy as np
import pandas as pd
from tqdm import tqdm
import wandb

from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    average_precision_score,
)

import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForSequenceClassification,
    BertConfig,
    BertForSequenceClassification,
    get_cosine_schedule_with_warmup,
)

seed = 42
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
warnings.filterwarnings("ignore")

# Hyperparamters

In [None]:
file_path = "./"

feature_list = ["smile", "target", "enzyme"]
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)

cfg = {
    "model_name": "bert-base-cased",
    # "model_name": "microsoft/deberta-v3-base",
    "num_classes": 100,
    "epo_num": 2,
    "lr": 5e-5,
    "patience": 0,  # 0 represent no early stop
    "batch_size": 16,
    "max_len": 256,
    "use_amp": False,  # Cause gradient underflow -> NaN
    "forzen": False,
}

# Data

#### large dataset

In [None]:
events = pd.read_csv(f"{file_path}/data/events.csv", index_col=0)
df_drug = pd.read_csv(f"{file_path}/data/drugs.csv", index_col=0)
# events = pd.read_csv(f'{file_path}/events.csv', index_col=0)
# df_drug = pd.read_csv(f'{file_path}/drugs.csv', index_col=0)

# events.shape, df_drug.shape
display(events.head(1))
display(df_drug.head(1))

#### small dataset

In [None]:
# conn = sqlite3.connect(f"{file_path}/event.db")
# df_drug = pd.read_sql("select * from drug;", conn)
# extraction = pd.read_sql("select * from extraction;", conn)
# mechanism = extraction["mechanism"]
# action = extraction["action"]
# drugA = extraction["drugA"]
# drugB = extraction["drugB"]

In [None]:
# extraction["label_text"] = extraction.mechanism + " " + extraction.action

# extraction["label"] = LabelEncoder().fit_transform(extraction["label_text"])
# extraction["label_text"] = extraction["label_text"].apply(str.lower)
# extraction = extraction.drop(["index"], axis=1)
# df_drug = df_drug.drop(["id", "index", "pathway"], axis=1)
# df_drug = df_drug.set_index("name")

# # ## check number of classes
# # extraction['label'].nunique()
# # ## check number of drugs
# # # df_drug.index

# display(df_drug.head(2))
# display(extraction.head(2))

# Preprocess

In [None]:
class DDI_Dataset(Dataset):
    def __init__(self, ev_df, drug_df, tokenizer, max_len=256):
        self.events = ev_df
        self.drugs = drug_df
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return self.events.shape[0]

    def __getitem__(self, index):
        d_a, d_b, labels = self.events.iloc[index, [2, 3, -1]]

        ## without prompt
        # d_a_seq = d_a + "," + ','.join(self.drugs.loc[d_a].values)
        # d_b_seq = d_b + "," + ','.join(self.drugs.loc[d_b].values)

        ## use different modal
        # modal = 'smile'
        # d_a_seq = d_a + f", the drug {d_a}'s {modal} information is: " + self.drugs.loc[d_a].smile
        # d_b_seq = d_b + f", the drug {d_b}'s {modal} information is: " + self.drugs.loc[d_b].smile

        # text = f'{d_a_seq + " " + self.tokenizer.sep_token + " " + d_b_seq}'

        ## use prompt
        text = f"The drug {d_a} interacts with the drug {d_b}. \
        The drug {d_a}'s information is: {', '.join(self.drugs.loc[d_a].values)}. \
        The drug {d_b}'s information is: {', '.join(self.drugs.loc[d_b].values)}."

        # print(text)

        encode_dict = self.tokenizer.encode_plus(
            text=text,
            add_special_tokens=True,
            max_length=self.max_len,
            truncation=True,
            padding="max_length",
            return_attention_mask=True,
            return_tensors="pt",
        )
        ids = encode_dict["input_ids"].squeeze(0)
        masks = encode_dict["attention_mask"].squeeze(0)

        return {"ids": ids, "masks": masks, "labels": labels}

# Model

In [None]:
model_name = cfg["model_name"]

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name, num_labels=cfg["num_classes"]
)


# train from scratch
# config = BertConfig()
# model = BertForSequenceClassification(config)

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, num_classes, alpha=0.25, gamma=2):
        super(FocalLoss, self).__init__()
        self.num_classes = num_classes
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, logits, targets):
        probs = F.softmax(logits, dim=1)
        targets_one_hot = F.one_hot(targets, num_classes=self.num_classes)
        pt = torch.sum(probs * targets_one_hot, dim=1) + 1e-6
        focal_loss = -self.alpha * (1 - pt) ** self.gamma * torch.log(pt)
        return focal_loss.mean()

# Training

In [None]:
# wandb_key = "****"
# wandb.login(key=wandb_key)

run = wandb.init(
    project="DDI",
    config=cfg,
    dir=f"{file_path}",
    name="w/ each",
)

In [None]:
from sklearn.preprocessing import label_binarize


def evaluate_metrics(pred_probs, labels):
    pred_probs = np.concatenate(pred_probs, axis=0)
    labels = np.concatenate(labels, axis=0)
    print(f"evaluate: {pred_probs.shape}, {labels.shape}")
    assert pred_probs.shape[0] == labels.shape[0]

    # 获得预测的类别
    predicted_labels = pred_probs.argmax(axis=1)

    # 计算accuracy
    accuracy = accuracy_score(labels, predicted_labels)

    # 计算precision、recall、F1-score
    precision_mi = precision_score(labels, predicted_labels, average="micro")
    recall_mi = recall_score(labels, predicted_labels, average="micro")
    f1_mi = f1_score(labels, predicted_labels, average="micro")
    precision_ma = precision_score(labels, predicted_labels, average="macro")
    recall_ma = recall_score(labels, predicted_labels, average="macro")
    f1_ma = f1_score(labels, predicted_labels, average="macro")

    # 计算AUC和AUPR
    auc_score = roc_auc_score(labels, pred_probs, average="macro", multi_class="ovr")

    # 初始化一个列表来存储每个类别的AUPR分数
    aupr_scores = []
    # 对每个类别计算AUPR
    for class_idx in range(pred_probs.shape[1]):
        class_labels = (labels == class_idx).astype(int)
        class_probs = pred_probs[:, class_idx]
        aupr = average_precision_score(class_labels, class_probs)
        aupr_scores.append(aupr)
    # 平均AUPR分数
    average_aupr = sum(aupr_scores) / len(aupr_scores)

    each_eval_type = 6
    result_eve = np.zeros((100, each_eval_type), dtype=float)
    y_one_hot = label_binarize(labels, classes=range(100))
    pred_one_hot = label_binarize(predicted_labels, classes=range(100))
    for i in range(100):
        result_eve[i, 0] = accuracy_score(
            y_one_hot.take([i], axis=1).ravel(), pred_one_hot.take([i], axis=1).ravel()
        )
        result_eve[i, 1] = average_precision_score(
            y_one_hot.take([i], axis=1).ravel(),
            pred_one_hot.take([i], axis=1).ravel(),
            average=None,
        )
        result_eve[i, 2] = roc_auc_score(
            y_one_hot.take([i], axis=1).ravel(),
            pred_one_hot.take([i], axis=1).ravel(),
            average=None,
        )
        result_eve[i, 3] = f1_score(
            y_one_hot.take([i], axis=1).ravel(),
            pred_one_hot.take([i], axis=1).ravel(),
            average="binary",
        )
        result_eve[i, 4] = precision_score(
            y_one_hot.take([i], axis=1).ravel(),
            pred_one_hot.take([i], axis=1).ravel(),
            average="binary",
        )
        result_eve[i, 5] = recall_score(
            y_one_hot.take([i], axis=1).ravel(),
            pred_one_hot.take([i], axis=1).ravel(),
            average="binary",
        )

    result_all = {
        "accuracy": accuracy,
        "precision_micro": precision_mi,
        "precision_macro": precision_ma,
        "recall_micro": recall_mi,
        "recall_macro": recall_ma,
        "f1_micro": f1_mi,
        "f1_macro": f1_ma,
        "auc_score": auc_score,
        "aupr_score": average_aupr,
    }
    return result_all, result_eve

In [None]:
# Training configuration

optimizer = optim.AdamW(model.parameters(), lr=cfg["lr"], eps=1e-3)
scaler = torch.cuda.amp.GradScaler(enabled=cfg["use_amp"])
if cfg["use_amp"]:
    print("Using AMP!")

if cfg["forzen"]:
    print("Freeze the base model!")
    for param in model.base_model.parameters():
        param.requires_grad = False


criterion = FocalLoss(num_classes=cfg["num_classes"])
# criterion = nn.CrossEntropyLoss()

In [None]:
def train_fn(model, train_loader, test_loader):
    # Start training

    num_training_steps = cfg["epo_num"] * len(train_loader)
    num_warmup_steps = int(0.3 * num_training_steps)
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
    )

    for i in range(cfg["epo_num"]):
        print("Epoch {}/{}".format(i + 1, cfg["epo_num"]))
        print("-" * 10)

        train_epoch_loss = []
        valid_epoch_loss = []

        model.to(device)
        model.train()
        for step, batch in tqdm(
            enumerate(train_loader), total=len(train_loader), desc="Train"
        ):
            batch = {k: v.to(device) for k, v in batch.items()}
            ids, masks, labels = batch["ids"], batch["masks"], batch["labels"]

            with torch.autocast(
                device_type="cuda", dtype=torch.float16, enabled=cfg["use_amp"]
            ):
                outputs = model(ids, masks)
                logits = outputs.logits
                loss = criterion(logits, labels)

            train_epoch_loss.append(loss.item())

            if step % 200 == 0:
                wandb.log({"train_loss": loss.item()})

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            lr_scheduler.step()
            optimizer.step()
            optimizer.zero_grad()

        total_probs = []
        total_labels = []
        model.eval()
        with torch.no_grad():
            for step, b in tqdm(
                enumerate(test_loader), total=len(test_loader), desc="Valid"
            ):
                b = {k: v.to(device) for k, v in b.items()}
                ids, masks, labels = b["ids"], b["masks"], b["labels"]

                outputs = model(ids, masks)
                logits = outputs.logits
                loss = criterion(logits, labels)

                valid_epoch_loss.append(loss.item())

                probs = F.softmax(logits, dim=1).cpu().numpy()
                total_probs.append(probs)
                total_labels.append(labels.cpu().numpy())

                if step % 200 == 0:
                    wandb.log({"valid_loss": loss.item()})

        total_probs = np.concatenate(total_probs, axis=0)
        total_labels = np.concatenate(total_labels, axis=0)

        avg_train_loss = sum(train_epoch_loss) / len(train_epoch_loss)
        avg_valid_loss = sum(valid_epoch_loss) / len(valid_epoch_loss)

        print("Training Loss: {:.4f}".format(avg_train_loss))
        print("Valid Loss: {:.4f}".format(avg_valid_loss))
        print()

    return total_probs, total_labels

In [None]:
def cross_val(events, df_drugs, tokenizer, model):
    skf = StratifiedKFold(n_splits=5)
    fold = 0
    total_pred_scores = []
    total_labels = []

    for train_index, test_index in skf.split(np.zeros(len(events)), events["label"]):
        fold += 1
        print(f"Training fold {fold} start!")

        train_dataset = DDI_Dataset(
            events.iloc[train_index], df_drugs, tokenizer, max_len=cfg["max_len"]
        )
        train_loader = DataLoader(
            train_dataset, batch_size=cfg["batch_size"], shuffle=True
        )
        test_dataset = DDI_Dataset(
            events.iloc[test_index], df_drugs, tokenizer, max_len=cfg["max_len"]
        )
        test_loader = DataLoader(
            test_dataset, batch_size=cfg["batch_size"], shuffle=False
        )

        pred_scores, labels = train_fn(model, train_loader, test_loader)

        total_pred_scores.append(pred_scores)
        total_labels.append(labels)
        break

    cv_results = evaluate_metrics(total_pred_scores, total_labels)
    print("results: ", cv_results)

    return cv_results

In [None]:
cv_results = cross_val(events, df_drug, tokenizer, model)

In [None]:
# model_save_directory = f"{file_path}/saved_models/{cfg['model_name']}-pretrained"
# tokenizer.save_pretrained(model_save_directory)
# model.save_pretrained(model_save_directory)
# wandb.finish()