In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
from torch.utils.data import Dataset, DataLoader
import pdb
import torch
from torch import cuda
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import LabelEncoder
from tqdm.auto import tqdm
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import precision_recall_fscore_support
import datasets
from functools import partial
from ast import literal_eval
from datetime import datetime
import gc


import warnings
warnings.filterwarnings('ignore')

In [None]:
config = {'model_name': '/kaggle/input/deberta-v3-base/deberta-v3-base/',
         'max_length': 512,
         'train_batch_size':4,
         'valid_batch_size':8,
         'epochs':3,
         'learning_rate':2e-05,
         'max_grad_norm':10,
          'warmup':0.1,
          "grad_acc":8,
          "model_save_path":"deberta-trained",
          "folds":5,
          "seed":42,
          'num_proc' : 2,
          'dropout':0.2,
          "no_unlabelled":1000,
         'device': 'cuda' if cuda.is_available() else 'cpu'}

In [None]:
df_features = pd.read_csv("/kaggle/input/nbme-score-clinical-patient-notes/features.csv")
df_patients = pd.read_csv("/kaggle/input/nbme-score-clinical-patient-notes/patient_notes.csv")
df_train = pd.read_csv("/kaggle/input/nbme-score-clinical-patient-notes/train.csv")

In [None]:
df_patients.head()

In [None]:
df_patients['pn_num'].nunique(), df_patients['case_num'].nunique()

In [None]:
df_train.head()

In [None]:
df_patients.head()

In [None]:
# The following is necessary if you want to use the fast tokenizer for deberta v2 or v3
import shutil
from pathlib import Path

transformers_path = Path("/opt/conda/lib/python3.7/site-packages/transformers")

input_dir = Path("../input/deberta-v2-3-fast-tokenizer")

convert_file = input_dir / "convert_slow_tokenizer.py"
conversion_path = transformers_path / convert_file.name

if conversion_path.exists():
    conversion_path.unlink()

shutil.copy(convert_file, transformers_path)
deberta_v2_path = transformers_path / "models" / "deberta_v2"

for filename in [
    "tokenization_deberta_v2.py",
    "tokenization_deberta_v2_fast.py",
    "deberta__init__.py",
]:
    if str(filename).startswith("deberta"):
        filepath = deberta_v2_path / str(filename).replace("deberta", "")
    else:
        filepath = deberta_v2_path / filename
    if filepath.exists():
        filepath.unlink()

    shutil.copy(input_dir / filename, filepath)

In [None]:
from transformers.models.deberta_v2.tokenization_deberta_v2_fast import DebertaV2TokenizerFast
from transformers import Trainer, TrainingArguments
from transformers import AutoTokenizer, AutoModelForTokenClassification


tokenizer = DebertaV2TokenizerFast.from_pretrained(config['model_name'])

In [None]:
df_train.head()

In [None]:
df_patients.head()

In [None]:
df_features.head()

In [None]:
def pre_process_data(df_train):
    print(f"before converting annotations of type :{type(df_train.annotation[0])}, {df_train.annotation[0]}, location of type: {type(df_train.location[0])}, {df_train.location[0]}")
    df_train['anno_list'] = [literal_eval(x) for x in df_train.annotation]
    df_train['loc_list'] = [literal_eval(x) for x in df_train.location]
    print(f"after converting annotations of type :{type(df_train.annotation[0])}, {df_train.annotation[0]}, location of type: {type(df_train.location[0])}, {df_train.location[0]}")
    print(f"column names of df_train : {df_train.columns}")
    merged = df_train.merge(df_patients, how='left')
    print(f"column names of df_train after merging with patietns: {merged.columns}")
    merged = merged.merge(df_features, how='left')
    print(f"column names of df_train after merging with features: {merged.columns}")
    return merged

In [None]:
merged = pre_process_data(df_train)
merged.shape

In [None]:
merged.head()

In [None]:
# incorrect annotations
merged.loc[338, "anno_list"] =  '["father heart attack"]'
merged.loc[338, "loc_list"] =  '["764 783"]'

merged.loc[621, "anno_list"] =  '["for the last 2-3 months", "over the last 2 months"]'
merged.loc[621, "loc_list"] =  '["77 100", "398 420"]'

merged.loc[655, "anno_list"] =  '["no heat intolerance", "no cold intolerance"]'
merged.loc[655, "loc_list"] =  '["285 292;301 312", "285 287;296 312"]'

merged.loc[1262, "anno_list"] =  '["mother thyroid problem"]'
merged.loc[1262, "loc_list"] =  '["551 557;565 580"]'

merged.loc[1265, "anno_list"] =  '[\'felt like he was going to "pass out"\']'
merged.loc[1265, "loc_list"] =  '["131 135;181 212"]'

merged.loc[1396, "anno_list"] =  '["stool , with no blood"]'
merged.loc[1396, "loc_list"] =  '["259 280"]'

merged.loc[1591, "anno_list"] =  '["diarrhoe non blooody"]'
merged.loc[1591, "loc_list"] =  '["176 184;201 212"]'

merged.loc[1615, "anno_list"] =  '["diarrhea for last 2-3 days"]'
merged.loc[1615, "loc_list"] =  '["249 257;271 288"]'

merged.loc[1664, "anno_list"] =  '["no vaginal discharge"]'
merged.loc[1664, "loc_list"] =  '["822 824;907 924"]'

merged.loc[1714, "anno_list"] =  '["started about 8-10 hours ago"]'
merged.loc[1714, "loc_list"] =  '["101 129"]'

merged.loc[1929, "anno_list"] =  '["no blood in the stool"]'
merged.loc[1929, "loc_list"] =  '["531 539;549 561"]'

merged.loc[2134, "anno_list"] =  '["last sexually active 9 months ago"]'
merged.loc[2134, "loc_list"] =  '["540 560;581 593"]'

merged.loc[2191, "anno_list"] =  '["right lower quadrant pain"]'
merged.loc[2191, "loc_list"] =  '["32 57"]'

merged.loc[2553, "anno_list"] =  '["diarrhoea no blood"]'
merged.loc[2553, "loc_list"] =  '["308 317;376 384"]'

merged.loc[3124, "anno_list"] =  '["sweating"]'
merged.loc[3124, "loc_list"] =  '["549 557"]'

merged.loc[3858, "anno_list"] =  '["previously as regular", "previously eveyr 28-29 days", "previously lasting 5 days", "previously regular flow"]'
merged.loc[3858, "loc_list"] =  '["102 123", "102 112;125 141", "102 112;143 157", "102 112;159 171"]'

merged.loc[4373, "anno_list"] =  '["for 2 months"]'
merged.loc[4373, "loc_list"] =  '["33 45"]'

merged.loc[4763, "anno_list"] =  '["35 year old"]'
merged.loc[4763, "loc_list"] =  '["5 16"]'

merged.loc[4782, "anno_list"] =  '["darker brown stools"]'
merged.loc[4782, "loc_list"] =  '["175 194"]'

merged.loc[4908, "anno_list"] =  '["uncle with peptic ulcer"]'
merged.loc[4908, "loc_list"] =  '["700 723"]'

merged.loc[6016, "anno_list"] =  '["difficulty falling asleep"]'
merged.loc[6016, "loc_list"] =  '["225 250"]'

merged.loc[6192, "anno_list"] =  '["helps to take care of aging mother and in-laws"]'
merged.loc[6192, "loc_list"] =  '["197 218;236 260"]'

merged.loc[6380, "anno_list"] =  '["No hair changes", "No skin changes", "No GI changes", "No palpitations", "No excessive sweating"]'
merged.loc[6380, "loc_list"] =  '["480 482;507 519", "480 482;499 503;512 519", "480 482;521 531", "480 482;533 545", "480 482;564 582"]'

merged.loc[6562, "anno_list"] =  '["stressed due to taking care of her mother", "stressed due to taking care of husbands parents"]'
merged.loc[6562, "loc_list"] =  '["290 320;327 337", "290 320;342 358"]'

merged.loc[6862, "anno_list"] =  '["stressor taking care of many sick family members"]'
merged.loc[6862, "loc_list"] =  '["288 296;324 363"]'

merged.loc[7022, "anno_list"] =  '["heart started racing and felt numbness for the 1st time in her finger tips"]'
merged.loc[7022, "loc_list"] =  '["108 182"]'

merged.loc[7422, "anno_list"] =  '["first started 5 yrs"]'
merged.loc[7422, "loc_list"] =  '["102 121"]'

merged.loc[8876, "anno_list"] =  '["No shortness of breath"]'
merged.loc[8876, "loc_list"] =  '["481 483;533 552"]'

merged.loc[9027, "anno_list"] =  '["recent URI", "nasal stuffines, rhinorrhea, for 3-4 days"]'
merged.loc[9027, "loc_list"] =  '["92 102", "123 164"]'

merged.loc[9938, "anno_list"] =  '["irregularity with her cycles", "heavier bleeding", "changes her pad every couple hours"]'
merged.loc[9938, "loc_list"] =  '["89 117", "122 138", "368 402"]'

merged.loc[9973, "anno_list"] =  '["gaining 10-15 lbs"]'
merged.loc[9973, "loc_list"] =  '["344 361"]'

merged.loc[10513, "anno_list"] =  '["weight gain", "gain of 10-16lbs"]'
merged.loc[10513, "loc_list"] =  '["600 611", "607 623"]'

merged.loc[11551, "anno_list"] =  '["seeing her son knows are not real"]'
merged.loc[11551, "loc_list"] =  '["386 400;443 461"]'

merged.loc[11677, "anno_list"] =  '["saw him once in the kitchen after he died"]'
merged.loc[11677, "loc_list"] =  '["160 201"]'

merged.loc[12124, "anno_list"] =  '["tried Ambien but it didnt work"]'
merged.loc[12124, "loc_list"] =  '["325 337;349 366"]'

merged.loc[12279, "anno_list"] =  '["heard what she described as a party later than evening these things did not actually happen"]'
merged.loc[12279, "loc_list"] =  '["405 459;488 524"]'

merged.loc[12289, "anno_list"] =  '["experienced seeing her son at the kitchen table these things did not actually happen"]'
merged.loc[12289, "loc_list"] =  '["353 400;488 524"]'

merged.loc[13238, "anno_list"] =  '["SCRACHY THROAT", "RUNNY NOSE"]'
merged.loc[13238, "loc_list"] =  '["293 307", "321 331"]'

merged.loc[13297, "anno_list"] =  '["without improvement when taking tylenol", "without improvement when taking ibuprofen"]'
merged.loc[13297, "loc_list"] =  '["182 221", "182 213;225 234"]'

merged.loc[13299, "anno_list"] =  '["yesterday", "yesterday"]'
merged.loc[13299, "loc_list"] =  '["79 88", "409 418"]'

merged.loc[13845, "anno_list"] =  '["headache global", "headache throughout her head"]'
merged.loc[13845, "loc_list"] =  '["86 94;230 236", "86 94;237 256"]'

merged.loc[14083, "anno_list"] =  '["headache generalized in her head"]'
merged.loc[14083, "loc_list"] =  '["56 64;156 179"]'

merged["anno_list"] = [
    literal_eval(x) if isinstance(x, str) else x for x in merged["anno_list"]
]
merged["loc_list"] = [
    literal_eval(x) if isinstance(x, str) else x for x in merged["loc_list"]
]

In [None]:
merged = merged[~merged['pn_history'].isnull()]

In [None]:
def clean_data(merged):
    print(f"before clearning: count of empty annotations :{merged.loc[merged['annotation'] == '[]'].shape}")
    merged = merged.loc[merged['annotation'] != "[]"].copy().reset_index(drop=False)
    print(f"after clearning: count of empty annotations :{merged.loc[merged['annotation'] == '[]'].shape}")
    print(f"before clearning: count of '-OR-' in feature text: {merged[merged['feature_text'].str.contains('-OR-')].shape}")
    merged['feature_text'] = merged['feature_text'].apply(lambda x:x.replace("-OR-", ';-').replace("-", " ").lower())
    print(f"after clearning: count of '-OR-' in feature text: {merged[merged['feature_text'].str.contains('-OR-')].shape}")
    print(f"before clearning: lower pn_history {merged['pn_history'].values[1]}")
    merged['pn_history'] = merged['pn_history'].apply(lambda x:x.lower())
    print(f"before clearning: lower pn_history {merged['pn_history'].values[1]}")
    return merged

In [None]:
merged = clean_data(merged)

In [None]:
# skf = StratifiedKFold(n_splits=config['folds'], random_state=config['seed'], shuffle=True)

# merged["fold"] = -1

# for fold, (_, val_idx) in enumerate(skf.split(merged, y=merged["case_num"])):
#     merged.loc[val_idx, "fold"] = fold
    
# counts = merged.groupby(["fold", "pn_num"], as_index=False).count()

# # If the number of rows is the same as the number of 
# # unique pn_num, then each pn_num is only in one fold.
# # Also if all the counts=1
# print(counts.shape, counts.pn_num.nunique(), counts.case_num.unique(), merged['pn_num'].nunique())
# merged['fold'].value_counts()

In [None]:
merged.head()

In [None]:
labelled_patients = merged['pn_num'].unique()
labelled_patients[:10], merged['pn_num'].nunique()

In [None]:
ul_patients = df_patients[~df_patients['pn_num'].isin(labelled_patients)][:config['no_unlabelled']]
ul_patients.shape

In [None]:
df_features.head()

In [None]:
ul_final = []
for row in ul_patients.iterrows():
#     print(row[1]['case_num'], row[1]['pn_history'][:10])
#     print(df_features[df_features['case_num'] == row[1]['case_num']].shape)
    ul_final.extend([[row[1]['pn_num'],row[1]['pn_history'], row[1]['case_num'],  feat] for feat in df_features[df_features['case_num'] == row[1]['case_num']]['feature_text']])
#     if row[0] == 10:
#         break

In [None]:
ul_final_df = pd.DataFrame.from_records(ul_final, columns=['pn_num', 'pn_history', 'case_num', "feature_text"])

In [None]:
ul_final_df.tail(), ul_final_df.columns

In [None]:
print(df_patients['pn_num'].nunique())
ul_final_df['pn_num'].nunique()

In [None]:
first = merged.loc[35]

example = {"feature_text": first.feature_text,
          "pn_history": first.pn_history,
          "loc_list": first.loc_list,
          "annotation_list": first.anno_list}

for key in example.keys():
    print(key)
    print(example[key])
    print('='*10)

In [None]:
def loc_list_to_tuples(loc_list):
    to_return = []
    for loc_str in loc_list:
        loc_strs = loc_str.split(";")
        for loc in loc_strs:
            start, end = loc.split()
            to_return.append((int(start), int(end)))
    return to_return

print(example['loc_list'])
example_loc_ints = loc_list_to_tuples(example['loc_list'])
print(example_loc_ints)
for loc in example_loc_ints:
    print(example['pn_history'][loc[0] : loc[1]])

In [None]:
def tokenize_and_label(example):
    tokenized_inputs = tokenizer(example['feature_text'],
                                example['pn_history'],
                                truncation='only_second',
                                max_length = config['max_length'],
                                padding='max_length',
                                return_offsets_mapping=True,)
#                                 return_tensors='pt')
    labels = [0.0] * len(tokenized_inputs['input_ids'])
    tokenized_inputs['location'] = loc_list_to_tuples(example['loc_list'])
    tokenized_inputs['sequence_ids'] = tokenized_inputs.sequence_ids()
    
    if len(tokenized_inputs["location"]) > 0:
        for idx, (seq_id, offsets) in enumerate(
            zip(tokenized_inputs["sequence_ids"], tokenized_inputs["offset_mapping"])
        ):
            if seq_id is None or seq_id == 0:
                # don't calculate loss on question part or special tokens
                labels[idx] = -100.0
                continue

            token_start, token_end = offsets
            for label_start, label_end in tokenized_inputs["location"]:
                if (
                    token_start <= label_start < token_end
                    or token_start < label_end <= token_end
                    or label_start <= token_start < label_end
                ):
                    labels[idx] = 1.0  # labels should be float

    tokenized_inputs["labels"] = labels
    
    return tokenized_inputs

In [None]:
tokenized_inputs = tokenize_and_label(example)
tokenized_inputs.keys()

In [None]:
merged = merged[["pn_history", "feature_text", "loc_list"]]
merged.head()

In [None]:
def convert_to_dataset(merged, func_):
    dataset = datasets.Dataset.from_pandas(merged)
    print(f"keys before applying tokenization: {dataset[0].keys()}")
    dataset_mapped = dataset.map(func_, num_proc=config['num_proc'])
    if 'labels' in dataset_mapped.features:
        dataset_mapped.set_format(type='torch', columns=['input_ids', 'attention_mask','token_type_ids', 'labels','offset_mapping', 'sequence_ids'], output_all_columns=False)
    else:
        dataset_mapped.set_format(type='torch', columns=['input_ids', 'attention_mask', 'token_type_ids','offset_mapping', 'sequence_ids'], output_all_columns=False)
    # dataset_mapped = dataset_mapped.remove_columns(['pn_history',"feature_text","loc_list", "token_type_ids","offset_mapping", "location_int", "sequence_ids"])
    print(f"keys after applying tokenization: {dataset_mapped[0].keys()}")
    return dataset_mapped

In [None]:
dataset_mapped = convert_to_dataset(merged[:], tokenize_and_label)

In [None]:
dataset_mapped = dataset_mapped.train_test_split(test_size=0.2)

In [None]:
dataset_mapped

In [None]:
def collate_fn(examples):
    return tokenizer.pad(examples, return_tensors='pt')

In [None]:
from torch.utils.data import DataLoader

l_dataloader_train = DataLoader(dataset_mapped['train'], batch_size=config['train_batch_size'],shuffle=True, collate_fn=collate_fn)
l_dataloader_test = DataLoader(dataset_mapped['test'], batch_size=config['train_batch_size'],shuffle=True, collate_fn=collate_fn)

In [None]:
l_dataloader_train, l_dataloader_test

In [None]:
def predict_tokenize(example, tokenizer=tokenizer):
    tokens = tokenizer(example['feature_text'],
                                example['pn_history'],
                                truncation='only_second',
                                max_length = config['max_length'],
                                padding='max_length',
                                return_offsets_mapping=True)
    tokens['sequence_ids'] = tokens.sequence_ids()
    return tokens

In [None]:
ul_mapped = convert_to_dataset(ul_final_df, predict_tokenize)

In [None]:
ul_mapped[0]['input_ids'].shape, ul_mapped[0]['token_type_ids'].shape

In [None]:
print(ul_mapped)
ul_dataloader = DataLoader(ul_mapped, batch_size=config['train_batch_size'], shuffle=True, collate_fn=collate_fn)

In [None]:
device = config['device']

In [None]:
import pdb

In [None]:
from transformers import AutoModel

class ClassificationModel(torch.nn.Module):
    def __init__(self, model, num_labels):
        super().__init__()
        self.dberta = AutoModel.from_pretrained(model)
        self.dropout = torch.nn.Dropout(p=config['dropout'])
        self.linear1 = torch.nn.Linear(768, 512)
        self.classifier = torch.nn.Linear(512, 1)
    
    def forward(self, x):
        output = self.dberta(input_ids=x["input_ids"], attention_mask=x["attention_mask"], token_type_ids=x["token_type_ids"])
#         print(output[0].shape)
#         pdb.set_trace()
        out = self.linear1(self.dropout(output[0]))
        logits = self.classifier(out)
        return logits.squeeze(-1)

In [None]:
model = ClassificationModel(config['model_name'], num_labels=1).to(device)
loss_fct = torch.nn.BCEWithLogitsLoss(reduction="none")

In [None]:

# del model
# import gc
# gc.collect()

In [None]:
from torch import optim
from torch import nn
optimizer = optim.AdamW(model.parameters(), lr=config['learning_rate'])

In [None]:
def train_model(model, data_loader, optimizer, criterion):
    model.train()
    train_loss = []
    for batch in tqdm(data_loader):
        batch = {key: val.to(device) for key, val in batch.items()}
        labels = batch['labels'].to(device)
        logits = model(batch)
#         pdb.set_trace()
        loss = criterion(logits, labels)
        # since, we have
        loss = torch.masked_select(loss, labels > -1.0).mean()
        train_loss.append(loss.item() * labels.size(0))
        loss.backward()
        # clip the the gradients to 1.0. It helps in preventing the exploding gradient problem
        # it's also improve f1 accuracy slightly
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

    return sum(train_loss)/len(train_loss)

def eval_model(model, dataloader, criterion):
        model.eval()
        valid_loss = []
        preds = []
        offsets = []
        seq_ids = []
        valid_labels = []
        input_ids = []
        for batch in tqdm(dataloader):
#             batch = {key: val.to(device) for key, val in batch.items()}
            batch['input_ids'] = batch['input_ids'].to(device)
            batch['attention_mask'] = batch["attention_mask"].to(device)
            batch['token_type_ids'] = batch['token_type_ids'].to(device)
            labels = batch['labels'].to(device)
            offset_mapping = batch['offset_mapping']
            sequence_ids = batch['sequence_ids']

            logits = model(batch)
            loss = criterion(logits, labels)
            loss = torch.masked_select(loss, labels > -1.0).mean()
            valid_loss.append(loss.item() * labels.size(0))
            
            input_ids.append(batch['input_ids'].detach().cpu().numpy())
            preds.append(logits.detach().cpu().numpy())
            offsets.append(offset_mapping.numpy())
            seq_ids.append(sequence_ids.numpy())
            valid_labels.append(labels.detach().cpu().numpy())

        preds = np.concatenate(preds, axis=0)
        offsets = np.concatenate(offsets, axis=0)
        seq_ids = np.concatenate(seq_ids, axis=0)
        valid_labels = np.concatenate(valid_labels, axis=0)
        input_ids = np.concatenate(input_ids, axis=0)
        
        location_preds = get_location_predictions(preds, offsets, seq_ids, input_ids, test=False)
        score = calculate_char_cv(location_preds, offsets, seq_ids, valid_labels)

        return sum(valid_loss)/len(valid_loss), score

In [None]:
from sklearn.metrics import accuracy_score
from itertools import chain
import re

def get_location_predictions(preds, offset_mapping, sequence_ids, input_ids, test=False):
    all_predictions = []
    for pred, offsets, seq_ids, input_id in zip(preds, offset_mapping, sequence_ids, input_ids):
        pred = 1 / (1 + np.exp(-pred))
        start_idx = None
        end_idx = None
        current_preds = []
        for pred_, offset, seq_id, id_ in zip(pred, offsets, seq_ids, input_id):
            if seq_id is None or seq_id == 0:
                continue

            if pred_ > 0.5:
#                 pdb.set_trace()
                if start_idx is None:
                    start_idx = offset[0]
                    if re.match(r'^▁',tokenizer.convert_ids_to_tokens([id_])[0]):
                        start_idx += 1
                end_idx = offset[1]
            elif start_idx is not None:
                if test:
                    current_preds.append(f"{start_idx} {end_idx}")
                else:
                    current_preds.append((start_idx, end_idx))
                start_idx = None
        if test:
            all_predictions.append("; ".join(current_preds))
        else:
            all_predictions.append(current_preds)
            
    return all_predictions


def calculate_char_cv(predictions, offset_mapping, sequence_ids, labels):
    all_labels = []
    all_preds = []
    for preds, offsets, seq_ids, labels in zip(predictions, offset_mapping, sequence_ids, labels):

        num_chars = max(list(chain(*offsets)))
        char_labels = np.zeros(num_chars)

        for o, s_id, label in zip(offsets, seq_ids, labels):
            if s_id is None or s_id == 0:
                continue
            if int(label) == 1:
                char_labels[o[0]:o[1]] = 1
                
        char_preds = np.zeros(num_chars)

        for idx, (start_idx, end_idx) in enumerate(preds):
            char_preds[start_idx:end_idx] = 1
#             pdb.set_trace()
#             if (
#                 tokenizer.converinput_id[idx].isspace()
#                 and start_idx > 0
#                 and not char_preds[start_idx - 1]
#             ):
#                 char_preds[start_idx] = 0

        all_labels.extend(char_labels)
        all_preds.extend(char_preds)

    results = precision_recall_fscore_support(all_labels, all_preds, average="binary", labels=np.unique(all_preds))
    accuracy = accuracy_score(all_labels, all_preds)
    

    return {
        "Accuracy": accuracy,
        "precision": results[0],
        "recall": results[1],
        "f1": results[2]
    }

In [None]:
# for i in range(config['epochs']):
#     print(f"training epoch {i}")
#     train_loss = train_model(model, l_dataloader_train, optimizer, loss_fct)
    
#     validation_loss, score = eval_model(model, l_dataloader_test, loss_fct)
    
#     print(f"train loss: {train_loss}, validation loss: {validation_loss}, score: {score}")
#     break

In [None]:
T1 = 1
T2 = 5
af = 3
def alpha_weight(step):
    if step < T1:
        return 0.0
    elif step > T2:
        return af
    else:
         return ((step-T1) / (T2-T1))*af

In [None]:
# for batch in l_dataloader_train:
#     batch['input_ids'] = batch['input_ids'].to(device)
#     batch['attention_mask'] = batch["attention_mask"].to(device)
#     batch['token_type_ids'] = batch['token_type_ids'].to(device)
#     res = model(batch)
#     break

In [None]:
for i in range(50):
    print(alpha_weight(i))
    break

In [None]:
%%capture
!pip install wandb --upgrade -qq

In [None]:
import wandb

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb")

wandb.login(key = secret_value_0)

In [None]:
# os.environ["WANDB_PROJECT"] = "nbme_pseudo_labelling"
# os.environ["WANDB_RUN_GROUP"] = "DEBERTA_" + datetime.now().strftime(
# "%Y-%m-%d %H:%M"
# )

In [None]:
wandb.init(project = "pseudo-labelling")
# wandb.watch(model, loss_fct, log='all', log_freq=100)

In [None]:
def semisuper_train(model, train_loader, unlabelled_loader, test_loader, epochs, optimizer, loss_fct):
#     epochs = epochs
    step = 1
    eval_loss = []
    eval_score = []
    ul_loss = []
    alpha_list = []
    for epoch in range(epochs):
        ul_batch_loss = []
        for batch_idx, ul in enumerate(unlabelled_loader):
            model.eval()
            ul = ul.to(device)
            ul_output = model(ul)
            ul_output = (ul_output.sigmoid() > 0.5).type(torch.float)
            
            model.train()
            output = model(ul)
#             pdb.set_trace()

            unlabelled_loss = (alpha_weight(step) * loss_fct(output, ul_output)).mean()
#             print(unlabelled_loss, ul_output.shape, output.shape)
            ul_batch_loss.append(unlabelled_loss)
            optimizer.zero_grad()
            unlabelled_loss.backward()
            optimizer.step()
            
            if batch_idx%(int(ul_dataloader.__len__()/3)) == 0 and batch_idx != 0:
                train_loss = train_model(model, train_loader, optimizer, loss_fct)
                step += 1
                wandb.log({"train_loss": train_loss})
        loss, score = eval_model(model, test_loader, loss_fct)
        wandb.log({"test_f1_score":score['f1']})
        wandb.log({"test_accuracy_score":score['Accuracy']})
        wandb.log({"test_loss": loss})
        wandb.log({"unlabelled_loss": unlabelled_loss.mean()})
        print('Epoch: {} : Alpha Weight : {:.5f} | Test Loss : {:.3f} '.format(epoch, alpha_weight(step), loss))
        print(f"Epoch score : {score}")
        eval_loss.append(loss)
        eval_score.append(score)
        alpha_list.append(alpha_weight(step))
        ul_loss.append(unlabelled_loss.mean())
#             model.trai

In [None]:
semisuper_train(model, l_dataloader_train, ul_dataloader, l_dataloader_test, config['epochs'], optimizer, loss_fct)

In [None]:
import torch.nn.functional as F

In [None]:
torch.save(model.state_dict(), "fine-tunned-weights")

In [None]:
score_dataset = convert_to_dataset(ul_final_df, predict_tokenize)

In [None]:
score_dataloder = DataLoader(score_dataset, batch_size=config['train_batch_size'], shuffle=False, collate_fn=collate_fn)

In [None]:
predictions = []
for batch in score_dataloder:
    batch['input_ids'] = batch['input_ids'].to(device)
    batch['attention_mask'] = batch["attention_mask"].to(device)
    batch['token_type_ids'] = batch['token_type_ids'].to(device)
    res = model(batch)
    predictions.extend(get_location_predictions(res.detach().cpu().numpy(), batch['offset_mapping'].numpy(), batch['sequence_ids'].numpy(), batch['input_ids'], test=True))
#     print(res)
#     break

In [None]:
ul_final_df['predictions'] = predictions

In [None]:
ul_final_df.to_csv("ul_labelled.csv", index=False)

In [None]:
ul_final_df.head()