In [4]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/ddicorpus/DDICorpus/Test/Test for DrugNER task/MedLine/21728158.xml
/kaggle/input/ddicorpus/DDICorpus/Test/Test for DrugNER task/MedLine/21887897.xml
/kaggle/input/ddicorpus/DDICorpus/Test/Test for DrugNER task/MedLine/21752788.xml
/kaggle/input/ddicorpus/DDICorpus/Test/Test for DrugNER task/MedLine/21751542.xml
/kaggle/input/ddicorpus/DDICorpus/Test/Test for DrugNER task/MedLine/21730162.xml
/kaggle/input/ddicorpus/DDICorpus/Test/Test for DrugNER task/MedLine/21888103.xml
/kaggle/input/ddicorpus/DDICorpus/Test/Test for DrugNER task/MedLine/21888227.xml
/kaggle/input/ddicorpus/DDICorpus/Test/Test for DrugNER task/MedLine/21899534.xml
/kaggle/input/ddicorpus/DDICorpus/Test/Test for DrugNER task/MedLine/21750603.xml
/kaggle/input/ddicorpus/DDICorpus/Test/Test for DrugNER task/MedLine/21777098.xml
/kaggle/input/ddicorpus/DDICorpus/Test/Test for DrugNER task/MedLine/22029224.xml
/kaggle/input/ddicorpus/DDICorpus/Test/Test for DrugNER task/MedLine/21874017.xml
/kaggle/input/dd

Import Libraries

In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import BertModel, BertTokenizer
import xml.etree.ElementTree as ET
import os
from sklearn.metrics import classification_report, confusion_matrix
import glob
from sklearn.model_selection import train_test_split

STEP 1: PARSE DRUGBANK XML

In [None]:
def parse_drugbank(drugbank_xml_path):
    """ Parses DrugBank XML to extract drug descriptions. """
    tree = ET.parse(drugbank_xml_path)
    root = tree.getroot()
    ns = {"db": "http://www.drugbank.ca"}  # Namespace for XML parsing
    drug_descriptions = {}
    
    print("Parsing DrugBank XML...")
    for drug in root.findall("db:drug", ns):
        drug_id_elem = drug.find("db:drugbank-id", ns)
        drug_name_elem = drug.find("db:name", ns)
        description_elem = drug.find("db:description", ns)
        if drug_id_elem is not None and drug_name_elem is not None:
            drug_id = drug_id_elem.text.strip()
            drug_name = drug_name_elem.text.strip().lower()
            description = description_elem.text.strip() if description_elem is not None and description_elem.text else "No description available."
            drug_descriptions[drug_name] = description
    
    print(f"Processed {len(drug_descriptions)} drugs from DrugBank")
    return drug_descriptions

STEP 2: PARSE DDI DATASET

In [None]:
def parse_ddi_files(ddi_folder_path, drug_descriptions):
    all_data = []
    skipped_entries = 0
    processed_files = 0

    xml_files = glob.glob(os.path.join(ddi_folder_path, '**/*.xml'), recursive=True)
    print(f"Found {len(xml_files)} XML files to process")

    for xml_file in xml_files:
        try:
            tree = ET.parse(xml_file)
            root = tree.getroot()
            processed_files += 1

            for sentence in root.findall("sentence"):
                text = sentence.get("text")
                if not text:
                    skipped_entries += 1
                    continue

                entities = {e.get("id"): e.get("text").lower() for e in sentence.findall("entity")}

                for pair in sentence.findall("pair"):
                    drug1 = entities.get(pair.get("e1"), "").lower()
                    drug2 = entities.get(pair.get("e2"), "").lower()
                    ddi_type = pair.get("ddi")

                    if drug1 and drug2 and ddi_type is not None:
                        interaction_type = "true" if ddi_type == "true" else "false"  # Combined interaction types

                        desc1 = drug_descriptions.get(drug1)  # Get description, None if not found
                        desc2 = drug_descriptions.get(drug2)
                        
                        # Only include if both drugs have descriptions
                        if desc1 and desc2:  
                            all_data.append({
                                "text": text,
                                "desc1": desc1,
                                "desc2": desc2,
                                "label": interaction_type,
                                "file": os.path.basename(xml_file)
                            })
                        else:
                            skipped_entries += 1 # Count as skipped if drug description is missing

                    else:
                        skipped_entries += 1

        except ET.ParseError as e:
            print(f"Error parsing file {xml_file}: {e}")
            continue

        if processed_files % 10 == 0:
            print(f"Processed {processed_files}/{len(xml_files)} files...")

    print(f"Processed {len(all_data)} valid entries, skipped {skipped_entries} invalid entries")
    return all_data




STEP 3: DATASET CLASS 

In [None]:

class DDIDataset(Dataset):
    def __init__(self, data, label_map):
        self.data = data
        self.label_map = label_map
        
        if len(self.data) == 0:
            raise ValueError("No valid data entries found after filtering!")
        
        print(f"Created dataset with {len(self.data)} valid entries")
        # Print label distribution
        label_dist = {}
        for d in self.data:
            label_dist[d['label']] = label_dist.get(d['label'], 0) + 1
        print("Label distribution:")
        for label, count in label_dist.items():
            print(f"{label}: {count}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        try:
            return {
                'text': sample['text'],
                'desc1': sample['desc1'],
                'desc2': sample['desc2'],
                'label': torch.tensor(self.label_map[sample['label']], dtype=torch.long)
            }
        except KeyError as e:
            print(f"Error processing sample {idx}: {sample}")
            raise e


STEP 4: MODEL CLASS 

In [None]:
class DDIExtractionModel(nn.Module):
    def __init__(self, bert_model_name='dmis-lab/biobert-base-cased-v1.1', hidden_dim=256, num_classes=2): # num_classes is 2 now
        super(DDIExtractionModel, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.tokenizer = BertTokenizer.from_pretrained(bert_model_name)
        self.hidden_dim = hidden_dim
        self.cnn = nn.Conv1d(in_channels=768, out_channels=hidden_dim, kernel_size=3, padding=1)
        self.pool = nn.AdaptiveMaxPool1d(1)
        self.classifier = nn.Linear(hidden_dim * 3, num_classes)
        
    def encode_text(self, text):
        encoded_input = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=128,
            return_tensors='pt'
        )
        input_ids = encoded_input['input_ids'].to(self.bert.device)
        attention_mask = encoded_input['attention_mask'].to(self.bert.device)
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state
        hidden_states = hidden_states.permute(0, 2, 1)
        features = self.cnn(hidden_states)
        pooled_features = self.pool(features).squeeze(-1)
        return pooled_features
        
    def forward(self, text, desc1, desc2):
        h_sent = self.encode_text(text)
        h_desc1 = self.encode_text(desc1)
        h_desc2 = self.encode_text(desc2)
        combined = torch.cat([h_sent, h_desc1, h_desc2], dim=1)
        scores = self.classifier(combined)
        return scores


STEP 5: EVALUATION FUNCTION

In [None]:
def evaluate(model, data_loader, criterion, device, label_map):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    # Create reverse label map for converting indices back to labels
    rev_label_map = {v: k for k, v in label_map.items()}
    
    with torch.no_grad():
        for batch in data_loader:
            text = batch['text']
            desc1 = batch['desc1']
            desc2 = batch['desc2']
            labels = batch['label'].to(device)
            
            scores = model(text, desc1, desc2)
            loss = criterion(scores, labels)
            total_loss += loss.item()
            
            preds = torch.argmax(scores, dim=1)
            all_preds.extend([rev_label_map[p.item()] for p in preds])
            all_labels.extend([rev_label_map[l.item()] for l in labels])
    
    # Calculate metrics
    report = classification_report(all_labels, all_preds)
    conf_matrix = confusion_matrix(all_labels, all_preds)
    avg_loss = total_loss / len(data_loader)
    
    return avg_loss, report, conf_matrix, all_preds, all_labels


STEP 6: TRAINING FUNCTION 

In [None]:
def train(model, train_loader, val_loader, epochs=5, lr=2e-5, device='cuda'):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    best_val_loss = float('inf')
    best_model_state = None
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        total_train_loss = 0
        
        for i, batch in enumerate(train_loader):
            text = batch['text']
            desc1 = batch['desc1']
            desc2 = batch['desc2']
            labels = batch['label'].to(device)
            
            optimizer.zero_grad()
            scores = model(text, desc1, desc2)
            loss = criterion(scores, labels)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
            
            if i % 10 == 0:
                print(f"Epoch {epoch+1}, Batch {i}, Loss: {loss.item():.4f}")
        
        avg_train_loss = total_train_loss / len(train_loader)
        
        # Validation phase
        if val_loader:
            val_loss, val_report, _, _, _ = evaluate(model, val_loader, criterion, device, label_map)
            print(f"\nEpoch {epoch+1}")
            print(f"Training Loss: {avg_train_loss:.4f}")
            print(f"Validation Loss: {val_loss:.4f}")
            print("Validation Metrics:")
            print(val_report)
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_state = model.state_dict().copy()
                print("New best model saved!")
    
    # Restore best model if we have validation data
    if val_loader and best_model_state:
        model.load_state_dict(best_model_state)
    
    return model


STEP 7: MAIN EXECUTION

In [None]:
if __name__ == "__main__":
    # Set paths
    drugbank_xml_path = "/kaggle/input/drugbank-dataset/full database.xml"
    ddi_train_folder = "/kaggle/input/ddicorpus/DDICorpus/Train/DrugBank"
    ddi_test_folder = "/kaggle/input/ddicorpus/DDICorpus/Test/Test for DDI Extraction task/DrugBank"
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Define label mapping
    label_map = {
        "false": 0,  # "false" for no interaction
        "true": 1,   # "true" for any interaction
    }
    
    try:
        # Parse DrugBank data
        print("Parsing DrugBank data...")
        drug_descriptions = parse_drugbank(drugbank_xml_path)
        
        # Parse DDI datasets
        print("\nParsing DDI training data...")
        train_data = parse_ddi_files(ddi_train_folder, drug_descriptions)
        print("\nParsing DDI test data...")
        test_data = parse_ddi_files(ddi_test_folder, drug_descriptions)
        
        # Validate data
        if len(train_data) == 0 or len(test_data) == 0:
            raise ValueError("No valid data found in either train or test sets!")

        print(f"\nTotal training examples: {len(train_data)}")
        print(f"Total test examples: {len(test_data)}")

        # Verify we have examples of each class
        unique_labels = set(d['label'] for d in train_data + test_data)
        print(f"Unique labels found in dataset: {unique_labels}")
        if not all(label in label_map for label in unique_labels):
            print("Warning: Found labels not in label_map:", unique_labels - set(label_map.keys()))
        
        # Split training data into train and validation sets
        train_data, val_data = train_test_split(train_data, test_size=0.1, random_state=42)
        
        # Create datasets
        print("\nCreating datasets...")
        train_dataset = DDIDataset(train_data, label_map)
        val_dataset = DDIDataset(val_data, label_map)
        test_dataset = DDIDataset(test_data, label_map)
        
        # Create dataloaders
        train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=16)
        test_loader = DataLoader(test_dataset, batch_size=16)
        
        # Initialize model
        print("\nInitializing model...")
        model = DDIExtractionModel()
        
        # Train model
        print("\nStarting training...")
        model = train(model, train_loader, val_loader, epochs=5, device=device)
        
        # Evaluate on test set
        print("\nEvaluating on test set...")
        criterion = nn.CrossEntropyLoss()
        test_loss, test_report, test_conf_matrix, test_preds, test_labels = evaluate(
            model, test_loader, criterion, device, label_map
        )
        
        print("\nTest Results:")
        print(f"Test Loss: {test_loss:.4f}")
        print("\nClassification Report:")
        print(test_report)
        print("\nConfusion Matrix:")
        print(test_conf_matrix)
        
        # Save model
        torch.save({
            'model_state_dict': model.state_dict(),
            'label_map': label_map,
        }, 'ddi_model.pth')
        print("\nModel saved to ddi_model.pth")
        
    except Exception as e:
        print(f"\nAn error occurred: {str(e)}")
        raise e

Using device: cuda
Parsing DrugBank data...
Parsing DrugBank XML...
Processed 17430 drugs from DrugBank

Parsing DDI training data...
Found 572 XML files to process
Processed 10/572 files...
Processed 20/572 files...
Processed 30/572 files...
Processed 40/572 files...
Processed 50/572 files...
Processed 60/572 files...
Processed 70/572 files...
Processed 80/572 files...
Processed 90/572 files...
Processed 100/572 files...
Processed 110/572 files...
Processed 120/572 files...
Processed 130/572 files...
Processed 140/572 files...
Processed 150/572 files...
Processed 160/572 files...
Processed 170/572 files...
Processed 180/572 files...
Processed 190/572 files...
Processed 200/572 files...
Processed 210/572 files...
Processed 220/572 files...
Processed 230/572 files...
Processed 240/572 files...
Processed 250/572 files...
Processed 260/572 files...
Processed 270/572 files...
Processed 280/572 files...
Processed 290/572 files...
Processed 300/572 files...
Processed 310/572 files...
Process