# Libraries

In [None]:
import os
import warnings
import math
import random
import sqlite3
import numpy as np
import pandas as pd
from tqdm import tqdm
import wandb
import matplotlib.pyplot as plt

from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.utils import resample

import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForSequenceClassification,
    BertConfig,
    BertForSequenceClassification,
    DebertaConfig,
    DebertaForSequenceClassification,
    DebertaV2Config,
    DebertaV2ForSequenceClassification,
    get_cosine_schedule_with_warmup,
)

seed = 42
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
warnings.filterwarnings("ignore")

# Hyperparamters

In [None]:
file_path = "/data-new/yangzihao/DDI/"
# file_path= '/kaggle/input/ddidatasets'
# file_path = "./"

feature_list = ["smile", "target", "enzyme"]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cfg = {
    "model_name": "bert-base-cased",
    # "model_name": "microsoft/deberta-v3-base",
    "epo_num": 5,
    "lr": 5e-5,
    "patience": 0,  # 0 represent no early stop
    "batch_size": 16,
    "max_len": 256,
    "use_amp": False,   # Cause gradient underflow -> NaN
    "forzen": False,
}

### Early stopping

In [None]:
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model, path=None):
        # print("val_loss={}".format(val_loss))
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            # self.save_checkpoint(val_loss, model, path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            # print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            # self.save_checkpoint(val_loss, model, path)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, path):
        if self.verbose:
            print(f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...")
        torch.save(model.state_dict(), path + "/" + "model_checkpoint.pth")
        self.val_loss_min = val_loss

# Data

#### Dataset_big

In [None]:
events = pd.read_csv(f'{file_path}/fusion_data/events.csv', index_col=0)
df_drug = pd.read_csv(f'{file_path}/fusion_data/drugs.csv', index_col=0)
# events = pd.read_csv(f'{file_path}/events.csv', index_col=0)
# df_drug = pd.read_csv(f'{file_path}/drugs.csv', index_col=0)

# events.mechanism = events.mechanism + " " + events.action
# counts = events.mechanism.value_counts()

# events['label'] = LabelEncoder().fit_transform(events.mechanism)
# events = events.drop(['index'], axis=1)
# df_drug = df_drug.drop(['id', 'index'], axis=1)
# df_drug = df_drug.set_index('name')

# events.shape, df_drug.shape
display(events.head(1))
display(df_drug.head(1))

#### Dataset_small

In [None]:
# conn = sqlite3.connect(f"{file_path}/fusion_data/event.db")
# df_drug = pd.read_sql("select * from drug;", conn)
# extraction = pd.read_sql("select * from extraction;", conn)
# mechanism = extraction["mechanism"]
# action = extraction["action"]
# drugA = extraction["drugA"]
# drugB = extraction["drugB"]

# Preprocess

sampling classes equally

In [None]:
# # 统计不同类别的数量
# label_counts = events["label"].value_counts()

# # 找到数量最少的类别
# min_label = label_counts.idxmin()
# min_count = label_counts.min()

# # 对每个类别按最小数量采样
# sample_data = pd.concat(
#     [
#         resample(
#             events[events["label"] == i],
#             n_samples=min_count,
#             replace=False,
#         )
#         for i in events["label"].unique()
#     ]
# )

In [None]:
class DDI_Dataset(Dataset):
    def __init__(self, ev_df, drug_df, tokenizer, max_len=256):
        self.events = ev_df
        self.drugs = drug_df
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return self.events.shape[0]

    def __getitem__(self, index):
        d_a, d_b, labels = self.events.iloc[index]
        # d_a_seq = d_a + "," + ','.join(self.drugs.loc[d_a].values)
        # d_b_seq = d_b + "," + ','.join(self.drugs.loc[d_b].values)
        
        # use different modal
        # d_a_seq = d_a + f", the drug {d_a}'s chemical form is: " + self.drugs.loc[d_a].target
        # d_b_seq = d_b + f", the drug {d_b}'s chemical form is: " + self.drugs.loc[d_b].target
        # text = f'{d_a_seq + " " + self.tokenizer.sep_token + " " + d_b_seq}'

        # use prompt
        text = f"The drug {d_a} interacts with the drug {d_b}. \
            The drug {d_a}'s information is: {','.join(self.drugs.loc[d_a].values)}. \
            The drug {d_b}'s information is: {','.join(self.drugs.loc[d_b].values)}."
        
        # print(text)

        encode_dict = self.tokenizer.encode_plus(
            text=text,
            add_special_tokens=True,
            max_length=self.max_len,
            truncation=True,
            padding="max_length",
            return_attention_mask=True,
            return_tensors="pt",
        )
        ids = encode_dict["input_ids"].squeeze(0)
        masks = encode_dict["attention_mask"].squeeze(0)
        
        return {"ids": ids, "masks": masks, "labels": labels}

# Model

In [None]:
model_name = cfg['model_name']

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=100)


# scratch bert
# config = BertConfig()
# model = BertForSequenceClassification(config)

# scratch debertaV3
# config = DebertaV2Config.from_pretrained('microsoft/deberta-v3-base')
# config.num_labels = 100
# model = DebertaV2ForSequenceClassification(config)
# model.init_weights()

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, num_classes, alpha=0.25, gamma=2):
        super(FocalLoss, self).__init__()
        self.num_classes = num_classes
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, logits, targets):
        probs = F.softmax(logits, dim=1)
        targets_one_hot = F.one_hot(targets, num_classes=self.num_classes)
        pt = torch.sum(probs * targets_one_hot, dim=1) + 1e-6
        focal_loss = -self.alpha * (1 - pt) ** self.gamma * torch.log(pt)
        return focal_loss.mean()

# Training

In [None]:
wandb_key = "69d3c1285c2d5121bb31b7fd541ff6b576817ead"
wandb.login(key=wandb_key)

run = wandb.init(
    project="DDI",
    config=cfg,
    dir=f"{file_path}",
    name="bert-pretrained",
)

In [None]:
# Training configuration

# sample half of the data for test
# experiment_data = events.sample(int(events.shape[0] * 0.5))
experiment_data = events
total_train, total_valid = train_test_split(
    experiment_data, test_size=0.2, random_state=42
)
print(f"All data size: {experiment_data.shape[0]}")
print(f"train size: {total_train.shape[0]}, test size: {total_valid.shape[0]}")

train_dataset = DDI_Dataset(total_train, df_drug, tokenizer, max_len=cfg["max_len"])
train_loader = DataLoader(train_dataset, batch_size=cfg["batch_size"], shuffle=True)
valid_dataset = DDI_Dataset(total_valid, df_drug, tokenizer, max_len=cfg["max_len"])
valid_loader = DataLoader(valid_dataset, batch_size=cfg["batch_size"], shuffle=True)

optimizer = optim.AdamW(model.parameters(), lr=cfg["lr"], eps=1e-3)
scaler = torch.cuda.amp.GradScaler(enabled=cfg["use_amp"])
if cfg["use_amp"]:
    print("Using AMP!")

if cfg['forzen']:
    print("Freeze the base model!")
    for param in model.base_model.parameters():
        param.requires_grad = False

num_training_steps = cfg["epo_num"] * len(train_loader)
num_warmup_steps = int(0.3 * num_training_steps)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
)

criterion = FocalLoss(num_classes=100)
# criterion = nn.CrossEntropyLoss()


def accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten().numpy()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [None]:
# Start training

for i in range(cfg["epo_num"]):
    print("Epoch {}/{}".format(i + 1, cfg["epo_num"]))
    print("-" * 10)

    train_epoch_loss = []
    train_epoch_acc = []
    valid_epoch_loss = []
    valid_epoch_acc = []
    # early_stopping = EarlyStopping(patience=cfg["patience"], verbose=True)

    model.to(device)
    model.train()
    for step, batch in tqdm(
        enumerate(train_loader), total=len(train_loader), desc="Train"
    ):
        batch = {k: v.to(device) for k, v in batch.items()}
        ids, masks, labels = batch["ids"], batch["masks"], batch["labels"]

        with torch.autocast(
            device_type="cuda", dtype=torch.float16, enabled=cfg["use_amp"]
        ):
            outputs = model(ids, masks)
            logits = outputs.logits
            loss = criterion(logits, labels)

        train_acc = accuracy(logits.detach().cpu().numpy(), labels.cpu())
        train_epoch_acc.append(train_acc)
        train_epoch_loss.append(loss.item())

        if step % 200 == 0:
            wandb.log({"train_acc": train_acc,"train_loss": loss.item()})

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        lr_scheduler.step()
        optimizer.step()
        optimizer.zero_grad()

        # early_stopping(loss.item(), model)
        # if early_stopping.early_stop:
        #     print("Early stopping")
        #     break

    # Evaluate after every epoch.
    model.eval()
    with torch.no_grad():
        for step, b in tqdm(
            enumerate(valid_loader), total=len(valid_loader), desc="Valid"
        ):
            b = {k: v.to(device) for k, v in b.items()}
            ids, masks, labels = b["ids"], b["masks"], b["labels"]

            outputs = model(ids, masks)
            logits = outputs.logits
            loss = criterion(logits, labels)

            valid_acc = accuracy(logits.detach().cpu().numpy(), labels.cpu())
            valid_epoch_acc.append(valid_acc)
            valid_epoch_loss.append(loss.item())

            if step % 200 == 0:
                wandb.log({"valid_acc": valid_acc, "valid_loss": loss.item()})

    avg_train_loss = sum(train_epoch_loss) / len(train_epoch_loss)
    avg_train_acc = sum(train_epoch_acc) / len(train_epoch_acc)
    avg_valid_loss = sum(valid_epoch_loss) / len(valid_epoch_loss)
    avg_valid_acc = sum(valid_epoch_acc) / len(valid_epoch_acc)

    print("Training Loss: {:.4f}".format(avg_train_loss))
    print("Training Acc: {:.4f}".format(avg_train_acc))
    print("Valid Loss: {:.4f}".format(avg_valid_loss))
    print("Valid Acc: {:.4f}".format(avg_valid_acc))
    print()

print()
print("Training complete!")

In [None]:
model_save_directory = f"{file_path}/saved_models/{cfg['model_name']}-pretrained"
tokenizer.save_pretrained(model_save_directory)
model.save_pretrained(model_save_directory)
wandb.finish()