In [43]:
import os
import gc
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer, AutoModelForTokenClassification,
    AutoConfig
)

# GPU setup
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("CUDA Available:", torch.cuda.is_available())

# Configs
VERSION = 26
MODEL_NAME = 'google/bigbird-roberta-base'
DATA_DIR = './data'
MODEL_DIR = f'{DATA_DIR}/bigbird'

CONFIG = {
    'model_name': MODEL_NAME,
    'max_length': 1024,
    'train_batch_size': 4,
    'valid_batch_size': 4,
    'epochs': 5,
    'learning_rates': [2.5e-5, 2.5e-5, 2.5e-6, 2.5e-6, 2.5e-7],
    'max_grad_norm': 10,
    'device': DEVICE,
}

COMPUTE_VAL_SCORE = len(os.listdir(f'{DATA_DIR}/test')) <= 5


CUDA Available: True


In [44]:
config = {
    'model_name': 'google/bigbird-roberta-base',  # or your model path
    'max_length': 1024,
    'train_batch_size': 4,
    'valid_batch_size': 4,
    'epochs': 5,
    'learning_rates': [2.5e-5, 2.5e-5, 2.5e-6, 2.5e-6, 2.5e-7],
    'max_grad_norm': 10,
    'device': 'cuda' 
}

In [45]:
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)

    # Load and save tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, add_prefix_space=True)
    tokenizer.save_pretrained(MODEL_DIR)

    # Load and save config with custom label count
    config_model = AutoConfig.from_pretrained(MODEL_NAME)
    config_model.num_labels = 15
    config_model.save_pretrained(MODEL_DIR)

    # Load and save model
    model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME, config=config_model)
    model.save_pretrained(MODEL_DIR)


In [46]:
train_df = pd.read_csv(f'{DATA_DIR}/train.csv')

def read_texts(folder, limit=100):
    filenames = os.listdir(folder)[:limit]
    texts = [open(os.path.join(folder, f)).read() for f in filenames]
    ids = [f.replace('.txt', '') for f in filenames]
    return pd.DataFrame({'id': ids, 'text': texts})

train_text_df = read_texts(f'{DATA_DIR}/train')
test_text_df = read_texts(f'{DATA_DIR}/test')


In [47]:
from ast import literal_eval

NER_PATH = f'{MODEL_DIR}/train_NER.csv'

if os.path.exists(NER_PATH):
    train_text_df = pd.read_csv(NER_PATH)
    train_text_df['entities'] = train_text_df['entities'].apply(literal_eval)
else:
    all_entities = []
    for _, row in tqdm(train_text_df.iterrows(), total=len(train_text_df)):
        tokens = row['text'].split()
        labels = ['O'] * len(tokens)
        for _, ann in train_df[train_df['id'] == row['id']].iterrows():
            idxs = list(map(int, ann['predictionstring'].split()))
            labels[idxs[0]] = f'B-{ann["discourse_type"]}'
            for idx in idxs[1:]:
                labels[idx] = f'I-{ann["discourse_type"]}'
        all_entities.append(labels)
    train_text_df['entities'] = all_entities
    train_text_df.to_csv(NER_PATH, index=False)


In [48]:
output_labels = ['O', 'B-Lead', 'I-Lead', 'B-Position', 'I-Position', 'B-Claim', 'I-Claim', 
                 'B-Counterclaim', 'I-Counterclaim', 'B-Rebuttal', 'I-Rebuttal', 'B-Evidence', 
                 'I-Evidence', 'B-Concluding Statement', 'I-Concluding Statement']

ids_to_labels = {i: label for i, label in enumerate(output_labels)}
output_labels = ['O', 'B-Lead', 'I-Lead', 'B-Position', 'I-Position', 'B-Claim', 'I-Claim', 
                 'B-Counterclaim', 'I-Counterclaim', 'B-Rebuttal', 'I-Rebuttal', 'B-Evidence', 
                 'I-Evidence', 'B-Concluding Statement', 'I-Concluding Statement']

ids_to_labels = {i: label for i, label in enumerate(output_labels)}


In [49]:
LABELS = ['O', 'B-Lead', 'I-Lead', 'B-Position', 'I-Position', 'B-Claim', 'I-Claim',
          'B-Counterclaim', 'I-Counterclaim', 'B-Rebuttal', 'I-Rebuttal',
          'B-Evidence', 'I-Evidence', 'B-Concluding Statement', 'I-Concluding Statement']

label2id = {label: i for i, label in enumerate(LABELS)}
id2label = {i: label for i, label in enumerate(LABELS)}
LABEL_ALL_SUBTOKENS = True

class NERDataset(Dataset):
    def __init__(self, df, tokenizer, max_len, return_wids=False):
        self.data = df
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.return_wids = return_wids

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data.loc[idx, 'text'].split()
        labels = self.data.loc[idx, 'entities'] if not self.return_wids else None

        encoding = self.tokenizer(text,
                                  is_split_into_words=True,
                                  padding='max_length',
                                  truncation=True,
                                  max_length=self.max_len)
        
        word_ids = encoding.word_ids()
        if not self.return_wids:
            previous_word_idx = None
            label_ids = []
            for word_idx in word_ids:
                if word_idx is None:
                    label_ids.append(-100)
                elif word_idx != previous_word_idx:
                    label_ids.append(label2id[labels[word_idx]])
                else:
                    label_ids.append(label2id[labels[word_idx]] if LABEL_ALL_SUBTOKENS else -100)
                previous_word_idx = word_idx
            encoding['labels'] = label_ids

        item = {key: torch.tensor(val) for key, val in encoding.items()}
        if self.return_wids:
            item['wids'] = torch.tensor([w if w is not None else -1 for w in word_ids])
        return item


In [50]:
np.random.seed(42)
all_ids = train_df['id'].unique()
train_ids = np.random.choice(all_ids, int(0.9 * len(all_ids)), replace=False)
valid_ids = np.setdiff1d(all_ids, train_ids)

train_data = train_text_df[train_text_df['id'].isin(train_ids)].reset_index(drop=True)
valid_data = train_text_df[train_text_df['id'].isin(valid_ids)].reset_index(drop=True)

tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
train_dataset = NERDataset(train_data, tokenizer, CONFIG['max_length'], return_wids=False)
valid_dataset = NERDataset(valid_data, tokenizer, CONFIG['max_length'], return_wids=True)

train_loader = DataLoader(train_dataset, batch_size=CONFIG['train_batch_size'], shuffle=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size=CONFIG['valid_batch_size'], shuffle=False, num_workers=0)


In [51]:
def train_one_epoch(model, optimizer, epoch):
    model.train()
    total_loss, total_acc, steps = 0, 0, 0

    for batch in train_loader:
        ids = batch['input_ids'].to(DEVICE)
        mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        optimizer.zero_grad()
        loss, logits = model(input_ids=ids, attention_mask=mask, labels=labels, return_dict=False)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['max_grad_norm'])
        optimizer.step()

        total_loss += loss.item()
        steps += 1

        # Accuracy calculation
        active = labels.view(-1) != -100
        preds = torch.argmax(logits.view(-1, model.num_labels), axis=1)
        acc = accuracy_score(labels.view(-1)[active].cpu(), preds[active].cpu())
        total_acc += acc

    print(f"Epoch {epoch+1} | Loss: {total_loss/steps:.4f} | Accuracy: {total_acc/steps:.4f}")


In [52]:
config_model = AutoConfig.from_pretrained(f'{MODEL_DIR}/config.json')
model = AutoModelForTokenClassification.from_pretrained(f'{MODEL_DIR}/pytorch_model.bin', config=config_model)
model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['learning_rates'][0])

model_path = f'{MODEL_DIR}/bigbird_v{VERSION}.pt'
if not os.path.exists(model_path):
    for epoch in range(CONFIG['epochs']):
        for g in optimizer.param_groups:
            g['lr'] = CONFIG['learning_rates'][epoch]
        train_one_epoch(model, optimizer, epoch)
        torch.cuda.empty_cache()
        gc.collect()
    torch.save(model.state_dict(), model_path)
else:
    model.load_state_dict(torch.load(model_path))
    print("Model loaded from checkpoint.")


Model loaded from checkpoint.


In [53]:
from torch.utils.data import Dataset
import torch

class dataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len, get_wids):
        self.len = len(dataframe)
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.get_wids = get_wids

    def __getitem__(self, index):
        # GET TEXT AND WORD LABELS 
        text = self.data.text[index]        
        word_labels = self.data.entities[index] if not self.get_wids else None

        # TOKENIZE TEXT
        encoding = self.tokenizer(text.split(),
                             is_split_into_words=True,
                             padding='max_length', 
                             truncation=True, 
                             max_length=self.max_len)
        word_ids = encoding.word_ids()  

        # CREATE TARGETS
        if not self.get_wids:
            previous_word_idx = None
            label_ids = []
            for word_idx in word_ids:                            
                if word_idx is None:
                    label_ids.append(-100)
                elif word_idx != previous_word_idx:              
                    label_ids.append(labels_to_ids[word_labels[word_idx]])
                else:
                    if LABEL_ALL_SUBTOKENS:
                        label_ids.append(labels_to_ids[word_labels[word_idx]])
                    else:
                        label_ids.append(-100)
                previous_word_idx = word_idx
            encoding['labels'] = label_ids

        # CONVERT TO TORCH TENSORS
        item = {key: torch.as_tensor(val) for key, val in encoding.items()}
        if self.get_wids: 
            word_ids2 = [w if w is not None else -1 for w in word_ids]
            item['wids'] = torch.as_tensor(word_ids2)

        return item

    def __len__(self):
        return self.len


In [54]:
test_names, test_contents = [], []
for f in list(os.listdir('./data/test')[:1000]):
    test_names.append(f.replace('.txt', ''))
    test_contents.append(open('./data/test/' + f, 'r').read())
test_texts = pd.DataFrame({'id': test_names, 'text': test_contents})

test_texts_set = dataset(test_texts, tokenizer, config['max_length'], get_wids=True)

test_params = {
    'batch_size': config['valid_batch_size'],
    'shuffle': False,
    'num_workers': 0,
    'pin_memory': True
}

test_texts_loader = DataLoader(test_texts_set, **test_params)


In [55]:
def inference(batch):
    ids = batch["input_ids"].to(config['device'])
    mask = batch['attention_mask'].to(config['device'])
    outputs = model(ids, attention_mask=mask, return_dict=False)
    all_preds = torch.argmax(outputs[0], axis=-1).cpu().numpy() 

    predictions = []
    for k, text_preds in enumerate(all_preds):
        token_preds = [ids_to_labels[i] for i in text_preds]
        prediction = []
        word_ids = batch['wids'][k].numpy()
        previous_word_idx = -1
        for idx, word_idx in enumerate(word_ids):                            
            if word_idx == -1:
                pass
            elif word_idx != previous_word_idx:              
                prediction.append(token_preds[idx])
                previous_word_idx = word_idx
        predictions.append(prediction)

    return predictions

def get_predictions(df, loader):
    model.eval()
    y_pred2 = []

    for batch in loader:
        labels = inference(batch)
        y_pred2.extend(labels)

    final_preds2 = []
    for i in range(len(df)):
        idx = df.id.values[i]
        pred = y_pred2[i]
        preds = []
        j = 0
        while j < len(pred):
            cls = pred[j]
            if cls == 'O': 
                j += 1
                continue
            else: 
                cls = cls.replace('B', 'I')  # unify B- and I-
            end = j + 1
            while end < len(pred) and pred[end] == cls:
                end += 1
            if cls != 'O' and cls != '' and end - j > 7:
                final_preds2.append((idx, cls.replace('I-', ''), ' '.join(map(str, range(j, end)))))
            j = end

    oof = pd.DataFrame(final_preds2)
    oof.columns = ['id','class','predictionstring']
    return oof


In [56]:
sub = get_predictions(test_texts, test_texts_loader)


  torch.arange(indices.shape[0] * indices.shape[1] * num_indices_to_gather, device=indices.device)


In [57]:
from IPython.display import display, HTML

# Define colors for each class (you can customize)
HIGHLIGHT_COLORS = {
    'Lead': '#FFD700',            # gold
    'Position': '#7FFFD4',        # aquamarine
    'Claim': '#FF6347',           # tomato
    'Counterclaim': '#FFB6C1',   # lightpink
    'Rebuttal': '#87CEFA',        # lightskyblue
    'Evidence': '#98FB98',        # palegreen
    'Concluding Statement': '#FFA07A',  # lightsalmon
    'O': '#FFFFFF'                # no highlight (white)
}

def highlight_text(text, predictions):
    """
    text: str, the full original text
    predictions: list of tuples like (class, start_idx, end_idx), 
                 where start_idx and end_idx are word indices of the predicted span
    
    Returns HTML with highlighted predicted spans.
    """
    words = text.split()
    html_output = ""

    current_pos = 0
    for cls, start, end in sorted(predictions, key=lambda x: x[1]):
        # Add unhighlighted words before this span
        while current_pos < start:
            html_output += words[current_pos] + " "
            current_pos += 1
        
        # Highlight this span
        color = HIGHLIGHT_COLORS.get(cls, "#FFFF00")  # default yellow
        span_text = " ".join(words[start:end+1])
        html_output += f'<span style="background-color:{color}; padding:2px; border-radius:3px;">{span_text}</span> '
        current_pos = end + 1
    
    # Add remaining words after last span
    while current_pos < len(words):
        html_output += words[current_pos] + " "
        current_pos += 1
    
    return html_output

# Usage example:
sample_row = valid_data.loc[0]
sample_id = sample_row['id']
sample_text = sample_row['text']
sample_predictions = [
    ("Claim", 5, 6),         # highlight words 5 and 6 as Claim
    ("Evidence", 8, 9)       # highlight words 8 and 9 as Evidence
]

display(HTML(highlight_text(sample_text, sample_predictions)))


In [58]:
sample_row = valid_data.loc[0]
sample_id = sample_row['id']
sample_text = sample_row['text']

sample_preds = sub[sub['id'] == sample_id][['class', 'predictionstring']].values.tolist()


highlight_predictions(sample_text, sample_preds, title=f"Predicted Entities for ID: {sample_id}")


In [59]:

sample_row = valid_data.loc[0]
sample_id = sample_row['id']
sample_text = sample_row['text']


sample_df = pd.DataFrame({'text': [sample_text], 'entities': [None]})
sample_dataset = dataset(sample_df, tokenizer, config['max_length'], get_wids=True)
sample_loader = DataLoader(sample_dataset, batch_size=1, shuffle=False)

model.eval()
with torch.no_grad():
    for batch in sample_loader:
        preds = inference(batch)

def preds_to_spans(preds):
    spans = []
    j = 0
    while j < len(preds):
        cls = preds[j]
        if cls == 'O':
            j += 1
        else:
            cls_type = cls.replace('B-', '').replace('I-', '')
            end = j + 1
            while end < len(preds) and preds[end] == f"I-{cls_type}":
                end += 1
            spans.append((cls_type, j, end - 1))
            j = end
    return spans

sample_spans = preds_to_spans(preds[0])

# Step 5: visualize
display(HTML(highlight_text(sample_text, sample_spans)))


  torch.arange(indices.shape[0] * indices.shape[1] * num_indices_to_gather, device=indices.device)
