In [1]:
from transformers import AutoModelWithLMHead,BertForSequenceClassification, AutoTokenizer, AutoModel,AutoModelForMaskedLM,AutoModelForSequenceClassification
import torch
from torch import nn
import json
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split,StratifiedShuffleSplit
from torch.utils.data import DataLoader,TensorDataset
from transformers import Trainer, TrainingArguments
import pickle
from sklearn.metrics import confusion_matrix,classification_report
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score,roc_curve
import pandas as pd
import matplotlib.pyplot as plt

from transformers import AdamW,get_scheduler


In [2]:
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
new_tokens = ["interstitial", "fibrosis", "tubular", "atrophy","antibody","T-cell"]
tokenizer.add_tokens(new_tokens)
base_kidneyBert = AutoModel.from_pretrained("./mlm_results_largeData_extended_tokenizer/checkpoint-1100")

Some weights of the model checkpoint at ./mlm_results_largeData_extended_tokenizer/checkpoint-1100 were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at ./mlm_results_largeData_extended_tokenizer/checkpoint-1100 and are

In [3]:
base_kidneyBert.to("cuda")

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(29002, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [4]:
class ABMRQA(nn.Module):
    def __init__(self):
        super(ABMRQA, self).__init__()
        self.qa_outputs = nn.Linear(768, 2)
        
    def forward(self, outputs):
        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()
        
        return (start_logits,end_logits)
    
class TCMRQA(nn.Module):
    def __init__(self):
        super(TCMRQA, self).__init__()
        self.qa_outputs = nn.Linear(768, 2)
        
    def forward(self, outputs):
        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()
        
        return (start_logits,end_logits)
    
class IFTAQA(nn.Module):
    def __init__(self):
        super(IFTAQA, self).__init__()
        self.qa_outputs = nn.Linear(768, 2)
        
    def forward(self, outputs):
        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()
        
        return (start_logits,end_logits)
    

In [5]:
def cal_loss(start_positions,end_positions,start_logits,end_logits):
    total_loss = None
#     print(start_positions,end_positions,start_logits.shape,end_logits.shape)
    start_positions = start_positions.squeeze(-1)
    end_positions = end_positions.squeeze(-1)
    
    loss_fct = nn.CrossEntropyLoss()
    start_loss = loss_fct(start_logits, start_positions)
    end_loss = loss_fct(end_logits, end_positions)
#     print(start_loss,end_loss,start_positions,end_positions)
    total_loss = (start_loss + end_loss) / 2
    return total_loss

In [6]:
class RenalDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels,task_name=None):
        self.encodings = encodings
        self.answers = labels
        self.task_name = task_name

    def __getitem__(self, idx):
        inputs = {key: val[idx] for key, val in self.encodings.items()}
        answer = self.answers[idx]
        offsets = inputs.pop("offset_mapping")
        input_ids = inputs["input_ids"]
        cls_index = list(input_ids).index(tokenizer.cls_token_id)

        token_type_ids = inputs["token_type_ids"]
        
#         print("Asd",answer)

        if answer[1] == 0:
            inputs["start_positions"] = cls_index
            inputs["end_positions"] = cls_index
        else:
            start_char = answer[0]
            end_char = answer[1]

            token_start_index = 0
            while token_type_ids[token_start_index] != 1:
                token_start_index += 1

            token_end_index = len(input_ids) - 1
            while offsets[token_end_index][1] == 0:
                token_end_index -= 1
                
#             print(offsets[token_start_index][0] , start_char,answer)

#             print(token_start_index,token_end_index)
#             print(offsets[token_start_index], offsets[token_end_index])

            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                inputs["start_positions"] = cls_index
                inputs["end_positions"] = cls_index
            else:
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
#                     print(offsets[token_start_index],token_start_index)
                    token_start_index += 1
                inputs["start_positions"] = token_start_index - 1

                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                inputs["end_positions"] = token_end_index + 1
        inputs["start_positions"] = torch.tensor(inputs["start_positions"])
        inputs["end_positions"] = torch.tensor(inputs["end_positions"])
#         inputs["labels"] = (inputs["start_positions"],inputs["end_positions"])
#         print(inputs["start_positions"],inputs["end_positions"])
        return inputs
        

    def __len__(self):
        return len(self.answers)
    
import difflib

def get_overlap_ratio(s1, s2):
    s = difflib.SequenceMatcher(None, s1, s2)
    pos_a, pos_b, size = s.find_longest_match(0, len(s1), 0, len(s2)) 
#     print(s1,s2,s1[pos_a:pos_a+size])
    return size/len(s2),len(s1[pos_a:pos_a+size].split())/len(s2.split())

def compute_metrics(pred,test_ans,test_ids):     
        
    answer_start_scores, answer_end_scores = pred
    answer_start = np.argmax(answer_start_scores, axis=1)  # get the most likely beginning of answer with the argmax of the score
    answer_end = np.argmax(answer_end_scores, axis=1)+1
    
    
    total = 0
    correct,correct_with_info = 0,0
    overlap_ratio_char = []
    overlap_ratio_word = []
    for s,e,t,id in zip(answer_start,answer_end,test_ans,test_ids):
        total += 1
        pred_ans = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(id[s:e]))
#         print("qweretw",pred_ans,t,s,e)
        if s == 0 and e == 1 and t == "":
            correct += 1
        elif not (s == 0 and e == 1) and t!="":
            if pred_ans.lower().replace('\n', ' ')==t.lower():
                correct_with_info += 1
            char_ratio,word_ratio = get_overlap_ratio(pred_ans.lower().replace('\n', ' '),t.lower())
            overlap_ratio_char.append(char_ratio)
            overlap_ratio_word.append(word_ratio)
    
    result_dict =  {"accuracy": (correct+correct_with_info)/total,"accuracy_info": correct_with_info/total,\
            "overlap_ratio_char":np.mean(overlap_ratio_char),"overlap_ratio_word":np.mean(overlap_ratio_word)} 
    print(result_dict)
    return result_dict

In [7]:
def gen_datasets(q,train_text,test_text,tokenizer=tokenizer):
    train_q = [q for i in range(len(train_text))]
    test_q = [q for i in range(len(test_text))]

    train_encodings = tokenizer(train_q,train_text,padding="max_length", truncation=True, 
                                return_tensors="pt",max_length=512,return_offsets_mapping=True)
    test_encodings = tokenizer(test_q,test_text,padding="max_length", truncation=True, 
                                return_tensors="pt",max_length=512,return_offsets_mapping=True)
    train_dataset = RenalDataset(train_encodings, train_labels)
    test_dataset = RenalDataset(test_encodings, test_labels)
    return train_dataset,test_dataset

In [8]:
batch_size = 12

In [9]:
# load abmr data

data = pd.read_csv("data.csv")
inputs1 = data["train_report_qa"].tolist()
label1 = data["abmr_pos_qa"].tolist()
label = [eval(l) for i,l in zip(inputs1,label1) if str(i)!="nan"]
inputs = [i for i in inputs1 if str(i)!="nan"]

label_class_help = data["abmr_class"].tolist()
label_class = [l for i,l in zip(inputs1,label_class_help) if str(i)!="nan"]


train_text, test_text, train_labels, test_labels = train_test_split(
    inputs, label,random_state = 1,stratify=label_class,test_size=0.2)


q_abmr = "How is the antibody-mediated rejection?"
train_dataset,test_dataset = gen_datasets(q_abmr,train_text,test_text)
abmr_train_loader = torch.utils.data.DataLoader(train_dataset,batch_size = batch_size,shuffle=True)
abmr_test_loader = torch.utils.data.DataLoader(test_dataset,batch_size = batch_size)

abmr_test_ans = []
for i,l in zip(test_text,test_labels):
    abmr_test_ans.append(i[l[0]:l[1]])

abmr_test_ids = torch.tensor([])
for i in abmr_test_loader:
    abmr_test_ids = torch.cat((abmr_test_ids,i["input_ids"]),0)

In [10]:
# load tcmr data

data = pd.read_csv("data.csv")
inputs1 = data["train_report_qa"].tolist()
label1 = data["tcmr_pos_qa"].tolist()
label = [eval(l) for i,l in zip(inputs1,label1) if str(i)!="nan"]
inputs = [i for i in inputs1 if str(i)!="nan"]

label_class_help = data["tcmr_class"].tolist()
label_class = [l for i,l in zip(inputs1,label_class_help) if str(i)!="nan"]


train_text, test_text, train_labels, test_labels = train_test_split(
    inputs, label,random_state = 1,stratify=label_class,test_size=0.2)


q_tcmr = "How is the t-cell-mediated rejection?"
train_dataset,test_dataset = gen_datasets(q_tcmr,train_text,test_text)
tcmr_train_loader = torch.utils.data.DataLoader(train_dataset,batch_size = batch_size,shuffle=True)
tcmr_test_loader = torch.utils.data.DataLoader(test_dataset,batch_size = batch_size)

tcmr_test_ans = []
for i,l in zip(test_text,test_labels):
    tcmr_test_ans.append(i[l[0]:l[1]])

tcmr_test_ids = torch.tensor([])
for i in tcmr_test_loader:
    tcmr_test_ids = torch.cat((tcmr_test_ids,i["input_ids"]),0)
    

In [11]:
# load ifta data

data = pd.read_csv("data.csv")
inputs1 = data["train_report_qa"].tolist()
label1 = data["ifta_pos_qa"].tolist()
label = [eval(l) for i,l in zip(inputs1,label1) if str(i)!="nan"]
inputs = [i for i in inputs1 if str(i)!="nan"]

label_class_help1 = data["IFTA"].tolist()
label_class_help2 = [l for i,l in zip(inputs1,label_class_help1) if str(i)!="nan"]
label_class = [0 if l in ["nosig","minimal","noinfo"] else (1 if l=="mild" else (2 if l=="moderate" else 3)) for l in label_class_help2]

train_text, test_text, train_labels, test_labels = train_test_split(
    inputs, label,random_state = 1,stratify=label_class,test_size=0.2)


q_ifta = "What is the grade of interstitial fibrosis and tubular atrophy?"
train_dataset,test_dataset = gen_datasets(q_ifta,train_text,test_text)
ifta_train_loader = torch.utils.data.DataLoader(train_dataset,batch_size = batch_size,shuffle=True)
ifta_test_loader = torch.utils.data.DataLoader(test_dataset,batch_size = batch_size)



ifta_test_ans = []
for i,l in zip(test_text,test_labels):
    ifta_test_ans.append(i[l[0]:l[1]])

ifta_test_ids = torch.tensor([])
for i in ifta_test_loader:
    ifta_test_ids = torch.cat((ifta_test_ids,i["input_ids"]),0)


In [12]:
model_abmr = ABMRQA()
model_tcmr = TCMRQA()
model_ifta = IFTAQA()

In [13]:
optimizer_abmr = AdamW(list(base_kidneyBert.parameters())+list(model_abmr.parameters()), lr=5e-5)
optimizer_tcmr = AdamW(list(base_kidneyBert.parameters())+list(model_tcmr.parameters()), lr=5e-5)
optimizer_ifta = AdamW(list(base_kidneyBert.parameters())+list(model_ifta.parameters()), lr=5e-5)
num_epochs = 3
num_training_steps = num_epochs * min(len(abmr_train_loader),len(tcmr_train_loader),len(ifta_train_loader))
lr_scheduler_abmr = get_scheduler("linear", optimizer=optimizer_abmr, num_warmup_steps=30, num_training_steps=num_training_steps)
lr_scheduler_tcmr = get_scheduler("linear", optimizer=optimizer_tcmr, num_warmup_steps=30, num_training_steps=num_training_steps)
lr_scheduler_ifta = get_scheduler("linear", optimizer=optimizer_ifta, num_warmup_steps=30, num_training_steps=num_training_steps)






In [14]:
device = "cuda"
model_abmr.to(device)
model_tcmr.to(device)
model_ifta.to(device)

IFTAQA(
  (qa_outputs): Linear(in_features=768, out_features=2, bias=True)
)

In [15]:
def get_pred(model,dataloader):
    start,end = [],[]
    model.eval() 
    base_kidneyBert.eval()
    with torch.no_grad(): 
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            del batch["start_positions"]
            del batch["end_positions"]
            start_logits,end_logits = model(base_kidneyBert(**batch))
            start_logits = start_logits.squeeze(-1).contiguous().tolist()
            end_logits = end_logits.squeeze(-1).contiguous().tolist()
            start+=start_logits
            end+=end_logits
            
    return (start,end)

In [16]:
from tqdm.auto import tqdm
import copy

progress_bar = tqdm(range(num_training_steps))

best_f1_abmr = 0
best_f1_tcmr = 0
best_f1_ifta = 0
step = 0

for epoch in range(num_epochs):
    for batch_abmr,batch_tcmr,batch_ifta in zip(abmr_train_loader,tcmr_train_loader,ifta_train_loader):
        model_abmr.train()
        model_tcmr.train()
        model_ifta.train()
        base_kidneyBert.train()
        step+=1
        #print(step)
        #print(1)


        batch_abmr = {k: v.to(device) for k, v in batch_abmr.items()}
        start_positions = batch_abmr["start_positions"]
        end_positions = batch_abmr["end_positions"]

        del batch_abmr["start_positions"]
        del batch_abmr["end_positions"]
        feat = base_kidneyBert(**batch_abmr)
        start_logits,end_logits = model_abmr(feat)
        loss_abmr = cal_loss(start_positions,end_positions,start_logits,end_logits)

        loss_abmr.backward()
        optimizer_abmr.step()
        lr_scheduler_abmr.step()
        optimizer_abmr.zero_grad()



        batch_tcmr = {k: v.to(device) for k, v in batch_tcmr.items()}
        start_positions = batch_tcmr["start_positions"]
        end_positions = batch_tcmr["end_positions"]

        del batch_tcmr["start_positions"]
        del batch_tcmr["end_positions"]
        feat = base_kidneyBert(**batch_tcmr)
        start_logits,end_logits = model_tcmr(feat)
        loss_tcmr = cal_loss(start_positions,end_positions,start_logits,end_logits)

        loss_tcmr.backward()
        optimizer_tcmr.step()
        lr_scheduler_tcmr.step()
        optimizer_tcmr.zero_grad()


        batch_ifta = {k: v.to(device) for k, v in batch_ifta.items()}
        start_positions = batch_ifta["start_positions"]
        end_positions = batch_ifta["end_positions"]

        del batch_ifta["start_positions"]
        del batch_ifta["end_positions"]
        feat = base_kidneyBert(**batch_ifta)
        start_logits,end_logits = model_ifta(feat)
        loss_ifta = cal_loss(start_positions,end_positions,start_logits,end_logits)

        loss_ifta.backward()
        optimizer_ifta.step()
        lr_scheduler_ifta.step()
        optimizer_ifta.zero_grad()



        progress_bar.update(1)

        if step % 50 == 0:
            #test the accuracy

            print('STEP:{}, EPOCHS : {}/{}'.format(step,epoch+1,num_epochs),
                  'Loss : {:.4f},{:.4f},{:.4f}'.format(loss_abmr,loss_tcmr,loss_ifta))

            res_abmr = compute_metrics(get_pred(model_abmr,abmr_test_loader),abmr_test_ans,abmr_test_ids)
            res_tcmr = compute_metrics(get_pred(model_tcmr,tcmr_test_loader),tcmr_test_ans,tcmr_test_ids)
            res_ifta = compute_metrics(get_pred(model_ifta,ifta_test_loader),ifta_test_ans,ifta_test_ids)

#             base_kidneyBert.save_pretrained(f"./fine_both_qa/step_{step}")
#             torch.save(model_abmr.state_dict(),f"./fine_both_qa/model_abmr_{step}.pth")
#             torch.save(model_tcmr.state_dict(),f"./fine_both_qa/model_tcmr_{step}.pth")
#             torch.save(model_ifta.state_dict(),f"./fine_both_qa/model_ifta_{step}.pth")



  0%|          | 0/687 [00:00<?, ?it/s]

STEP:50, EPOCHS : 1/3 Loss : 0.0264,0.0296,0.2343


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


{'accuracy': 0.9678832116788321, 'accuracy_info': 0.0, 'overlap_ratio_char': nan, 'overlap_ratio_word': nan}
{'accuracy': 0.9781021897810219, 'accuracy_info': 0.0, 'overlap_ratio_char': nan, 'overlap_ratio_word': nan}
{'accuracy': 0.8700729927007299, 'accuracy_info': 0.8277372262773722, 'overlap_ratio_char': 0.968720629699248, 'overlap_ratio_word': 0.9917763157894737}
STEP:100, EPOCHS : 1/3 Loss : 0.0237,0.0166,0.3116
{'accuracy': 0.9678832116788321, 'accuracy_info': 0.0, 'overlap_ratio_char': nan, 'overlap_ratio_word': nan}
{'accuracy': 0.9781021897810219, 'accuracy_info': 0.0, 'overlap_ratio_char': nan, 'overlap_ratio_word': nan}
{'accuracy': 0.927007299270073, 'accuracy_info': 0.8671532846715329, 'overlap_ratio_char': 0.9853976073187897, 'overlap_ratio_word': 0.9967159277504105}
STEP:150, EPOCHS : 1/3 Loss : 0.0083,0.0121,0.1903
{'accuracy': 0.9678832116788321, 'accuracy_info': 0.0, 'overlap_ratio_char': nan, 'overlap_ratio_word': nan}
{'accuracy': 0.9781021897810219, 'accuracy_info

In [17]:
res_abmr = compute_metrics(get_pred(model_abmr,abmr_test_loader),abmr_test_ans,abmr_test_ids)
res_tcmr = compute_metrics(get_pred(model_tcmr,tcmr_test_loader),tcmr_test_ans,tcmr_test_ids)
res_ifta = compute_metrics(get_pred(model_ifta,ifta_test_loader),ifta_test_ans,ifta_test_ids)

{'accuracy': 0.9664233576642336, 'accuracy_info': 0.0, 'overlap_ratio_char': 0.3634197501369981, 'overlap_ratio_word': 0.46075757575757575}
{'accuracy': 0.9781021897810219, 'accuracy_info': 0.0, 'overlap_ratio_char': 0.3540271774444162, 'overlap_ratio_word': 0.4444444444444445}
{'accuracy': 0.9445255474452555, 'accuracy_info': 0.8759124087591241, 'overlap_ratio_char': 0.987538118695754, 'overlap_ratio_word': 0.9917898193760263}


In [16]:
from tqdm.auto import tqdm
import copy

progress_bar = tqdm(range(num_training_steps))

best_f1_abmr = 0
best_f1_tcmr = 0
best_f1_ifta = 0
step = 0

for epoch in range(num_epochs):
    for batch_abmr,batch_tcmr,batch_ifta in zip(abmr_train_loader,tcmr_train_loader,ifta_train_loader):
        model_abmr.train()
        model_tcmr.train()
        model_ifta.train()
        base_kidneyBert.train()
        step+=1
        #print(step)
        #print(1)


        batch_abmr = {k: v.to(device) for k, v in batch_abmr.items()}
        start_positions = batch_abmr["start_positions"]
        end_positions = batch_abmr["end_positions"]

        del batch_abmr["start_positions"]
        del batch_abmr["end_positions"]
        feat = base_kidneyBert(**batch_abmr)
        start_logits,end_logits = model_abmr(feat)
        loss_abmr = cal_loss(start_positions,end_positions,start_logits,end_logits)

        loss_abmr.backward()
        optimizer_abmr.step()
        lr_scheduler_abmr.step()
        optimizer_abmr.zero_grad()



        batch_tcmr = {k: v.to(device) for k, v in batch_tcmr.items()}
        start_positions = batch_tcmr["start_positions"]
        end_positions = batch_tcmr["end_positions"]

        del batch_tcmr["start_positions"]
        del batch_tcmr["end_positions"]
        feat = base_kidneyBert(**batch_tcmr)
        start_logits,end_logits = model_tcmr(feat)
        loss_tcmr = cal_loss(start_positions,end_positions,start_logits,end_logits)

        loss_tcmr.backward()
        optimizer_tcmr.step()
        lr_scheduler_tcmr.step()
        optimizer_tcmr.zero_grad()


        batch_ifta = {k: v.to(device) for k, v in batch_ifta.items()}
        start_positions = batch_ifta["start_positions"]
        end_positions = batch_ifta["end_positions"]

        del batch_ifta["start_positions"]
        del batch_ifta["end_positions"]
        feat = base_kidneyBert(**batch_ifta)
        start_logits,end_logits = model_ifta(feat)
        loss_ifta = cal_loss(start_positions,end_positions,start_logits,end_logits)

        loss_ifta.backward()
        optimizer_ifta.step()
        lr_scheduler_ifta.step()
        optimizer_ifta.zero_grad()



        progress_bar.update(1)

        if step % 50 == 0:
            #test the accuracy

            print('STEP:{}, EPOCHS : {}/{}'.format(step,epoch+1,num_epochs),
                  'Loss : {:.4f},{:.4f},{:.4f}'.format(loss_abmr,loss_tcmr,loss_ifta))

            res_abmr = compute_metrics(get_pred(model_abmr,abmr_test_loader),abmr_test_ans,abmr_test_ids)
            res_tcmr = compute_metrics(get_pred(model_tcmr,tcmr_test_loader),tcmr_test_ans,tcmr_test_ids)
            res_ifta = compute_metrics(get_pred(model_ifta,ifta_test_loader),ifta_test_ans,ifta_test_ids)

            base_kidneyBert.save_pretrained(f"./fine_both_qa/step_{step}")
            torch.save(model_abmr.state_dict(),f"./fine_both_qa/model_abmr_{step}.pth")
            torch.save(model_tcmr.state_dict(),f"./fine_both_qa/model_tcmr_{step}.pth")
            torch.save(model_ifta.state_dict(),f"./fine_both_qa/model_ifta_{step}.pth")



  0%|          | 0/4580 [00:00<?, ?it/s]

STEP:300, EPOCHS : 2/20 Loss : 0.0131,0.0026,0.1571
{'accuracy': 0.9664233576642336, 'accuracy_info': 0.0, 'overlap_ratio_char': 0.3634197501369981, 'overlap_ratio_word': 0.46075757575757575}
{'accuracy': 0.9781021897810219, 'accuracy_info': 0.0, 'overlap_ratio_char': 0.49437805463739865, 'overlap_ratio_word': 0.6527777777777778}
{'accuracy': 0.927007299270073, 'accuracy_info': 0.8627737226277372, 'overlap_ratio_char': 0.9830254177453518, 'overlap_ratio_word': 0.9925864909390445}
STEP:600, EPOCHS : 3/20 Loss : 0.0005,0.0007,0.0185
{'accuracy': 0.9664233576642336, 'accuracy_info': 0.0, 'overlap_ratio_char': 0.3634197501369981, 'overlap_ratio_word': 0.46075757575757575}
{'accuracy': 0.9781021897810219, 'accuracy_info': 0.0, 'overlap_ratio_char': 0.3524027459954233, 'overlap_ratio_word': 0.5277777777777778}
{'accuracy': 0.945985401459854, 'accuracy_info': 0.8788321167883212, 'overlap_ratio_char': 0.9909394792399719, 'overlap_ratio_word': 0.9958949096880131}


KeyboardInterrupt: 

In [17]:
base_kidneyBert.save_pretrained(f"./fine_both/final")
torch.save(model_isrej.state_dict(),f"./fine_both/model_isrej_final.pth")
torch.save(model_ifta.state_dict(),f"./fine_both/model_ifta_final.pth")

In [25]:
model_isrej.load_state_dict(torch.load("./fine_both/model_isrej_final.pth"))

<All keys matched successfully>

In [18]:
pred_labels,test_labels = get_pred(model_isrej,isrej_test_loader)
print(confusion_matrix(test_labels,pred_labels))
print(classification_report(test_labels,pred_labels))
pred_labels,test_labels = get_pred(model_ifta,ifta_test_loader)
print(confusion_matrix(test_labels,pred_labels))
print(classification_report(test_labels,pred_labels))

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


[[638   9]
 [ 21  17]]
              precision    recall  f1-score   support

           0       0.97      0.99      0.98       647
           1       0.65      0.45      0.53        38

    accuracy                           0.96       685
   macro avg       0.81      0.72      0.75       685
weighted avg       0.95      0.96      0.95       685



  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


[[324  43   0   0]
 [ 37 164  23   0]
 [  3  31  38   0]
 [  3   4  12   3]]
              precision    recall  f1-score   support

           0       0.88      0.88      0.88       367
           1       0.68      0.73      0.70       224
           2       0.52      0.53      0.52        72
           3       1.00      0.14      0.24        22

    accuracy                           0.77       685
   macro avg       0.77      0.57      0.59       685
weighted avg       0.78      0.77      0.77       685



In [26]:
from tqdm.auto import tqdm
import copy

progress_bar = tqdm(range(num_training_steps))

best_f1 = 0
step = 0
try:
    for epoch in range(num_epochs):
        for batch_isrej,batch_ifta in zip(isrej_train_loader,ifta_train_loader):
            model_isrej.train()
            model_ifta.train()
            base_kidneyBert.train()
            step+=1
            #print(step)
            #print(1)
            
            
            batch_isrej = {k: v.to(device) for k, v in batch_isrej.items()}
            labels = batch_isrej["labels"]
#             labels = copy.deepcopy(batch_isrej["labels"]).to(device)
            #print(2)
            del batch_isrej["labels"]
            feat = base_kidneyBert(**batch_isrej)
#             print(feat[0].shape,feat[1].shape)
            outputs = model_isrej(feat)
            #print(3)
            #print(4)
            loss_fct = torch.nn.CrossEntropyLoss().to(device)
            #print(outputs,outputs.shape)
            #print(labels,labels.shape)
            loss_isrej = loss_fct(outputs.view(-1, 2), labels.view(-1))
            #print(6)
            loss_isrej.backward()
            #print(7)

            optimizer_isrej.step()
            lr_scheduler_isrej.step()
            optimizer_isrej.zero_grad()

#             batch_ifta = {k: v.to(device) for k, v in batch_ifta.items()}
#             labels = batch_ifta["labels"]
#             del batch_ifta["labels"]
#             feat = base_kidneyBert(**batch_ifta)
#             outputs = model_ifta(feat)
# #             outputs = model_ifta(batch_ifta)
#             #print(10)
            
#             #print(11)
#             loss_fct = torch.nn.CrossEntropyLoss().to(device)
#             #print(12)
#             #print("q",outputs,outputs.shape)
#             #print("z",labels,labels.shape)
#             loss_ifta = loss_fct(outputs.view(-1, 4), labels.view(-1))
#             loss_ifta.backward()

#             optimizer_ifta.step()
#             lr_scheduler_ifta.step()
#             optimizer_ifta.zero_grad()
            progress_bar.update(1)
    
            loss_ifta = 0
            if step % 300 == 0:
                #test the accuracy

                print('STEP:{}, EPOCHS : {}/{}'.format(step,epoch+1,num_epochs),
                      'Loss : {:.4f},{:.4f}'.format(loss_isrej,loss_ifta))

                res = compute_metrics(get_pred(model_isrej,isrej_test_loader))
                if res["f1"] > best_f1:
                    best_f1 = res["f1"]
                    base_kidneyBert.save_pretrained(f"./fine_rej/f1{best_f1}")
                    torch.save(model_isrej.state_dict(),f"./fine_rej/model_isrej_{best_f1}.pth")
#                 compute_metrics(get_pred(model_ifta,ifta_test_loader))
except Exception as e:
    print("Exception",e)

  0%|          | 0/3440 [00:00<?, ?it/s]

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:300, EPOCHS : 2/20 Loss : 0.2560,0.0000
accuracy: 0.9445255474452555, precision: 0.9445255474452555, recall: 0.9445255474452555, f1: 0.9445255474452555


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:600, EPOCHS : 4/20 Loss : 0.2050,0.0000
accuracy: 0.9562043795620438, precision: 0.9562043795620438, recall: 0.9562043795620438, f1: 0.9562043795620438


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:900, EPOCHS : 6/20 Loss : 0.0197,0.0000
accuracy: 0.9503649635036496, precision: 0.9503649635036496, recall: 0.9503649635036496, f1: 0.9503649635036496


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:1200, EPOCHS : 7/20 Loss : 0.0123,0.0000
accuracy: 0.9635036496350365, precision: 0.9635036496350365, recall: 0.9635036496350365, f1: 0.9635036496350365


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:1500, EPOCHS : 9/20 Loss : 0.0687,0.0000
accuracy: 0.9547445255474453, precision: 0.9547445255474453, recall: 0.9547445255474453, f1: 0.9547445255474453


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:1800, EPOCHS : 11/20 Loss : 0.0141,0.0000
accuracy: 0.9635036496350365, precision: 0.9635036496350365, recall: 0.9635036496350365, f1: 0.9635036496350365


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:2100, EPOCHS : 13/20 Loss : 0.0038,0.0000
accuracy: 0.964963503649635, precision: 0.964963503649635, recall: 0.964963503649635, f1: 0.964963503649635


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:2400, EPOCHS : 14/20 Loss : 0.0004,0.0000
accuracy: 0.9664233576642336, precision: 0.9664233576642336, recall: 0.9664233576642336, f1: 0.9664233576642337


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:2700, EPOCHS : 16/20 Loss : 0.0024,0.0000
accuracy: 0.9664233576642336, precision: 0.9664233576642336, recall: 0.9664233576642336, f1: 0.9664233576642337


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:3000, EPOCHS : 18/20 Loss : 0.0005,0.0000
accuracy: 0.9693430656934306, precision: 0.9693430656934306, recall: 0.9693430656934306, f1: 0.9693430656934306


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:3300, EPOCHS : 20/20 Loss : 0.0001,0.0000
accuracy: 0.9693430656934306, precision: 0.9693430656934306, recall: 0.9693430656934306, f1: 0.9693430656934306


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


In [28]:
base_kidneyBert.save_pretrained(f"./fine_rej/final")
torch.save(model_isrej.state_dict(),f"./fine_rej/model_isrej_final.pth")

In [27]:
pred_labels,test_labels = get_pred(model_isrej,isrej_test_loader)
print(confusion_matrix(test_labels,pred_labels))
print(classification_report(test_labels,pred_labels))

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


[[641   6]
 [ 15  23]]
              precision    recall  f1-score   support

           0       0.98      0.99      0.98       647
           1       0.79      0.61      0.69        38

    accuracy                           0.97       685
   macro avg       0.89      0.80      0.84       685
weighted avg       0.97      0.97      0.97       685



In [14]:
from tqdm.auto import tqdm
import copy

progress_bar = tqdm(range(num_training_steps))

best_f1 = 0
step = 0
try:
    for epoch in range(num_epochs):
        for batch_isrej,batch_ifta in zip(isrej_train_loader,ifta_train_loader):
            model_isrej.train()
            model_ifta.train()
            base_kidneyBert.train()
            step+=1
            #print(step)
            #print(1)
            
            
#             batch_isrej = {k: v.to(device) for k, v in batch_isrej.items()}
#             labels = batch_isrej["labels"]
# #             labels = copy.deepcopy(batch_isrej["labels"]).to(device)
#             #print(2)
#             del batch_isrej["labels"]
#             feat = base_kidneyBert(**batch_isrej)
# #             print(feat[0].shape,feat[1].shape)
#             outputs = model_isrej(feat)
#             #print(3)
#             #print(4)
#             loss_fct = torch.nn.CrossEntropyLoss().to(device)
#             #print(outputs,outputs.shape)
#             #print(labels,labels.shape)
#             loss_isrej = loss_fct(outputs.view(-1, 2), labels.view(-1))
#             #print(6)
#             loss_isrej.backward()
#             #print(7)

#             optimizer_isrej.step()
#             lr_scheduler_isrej.step()
#             optimizer_isrej.zero_grad()

            batch_ifta = {k: v.to(device) for k, v in batch_ifta.items()}
            labels = batch_ifta["labels"]
            del batch_ifta["labels"]
            feat = base_kidneyBert(**batch_ifta)
            outputs = model_ifta(feat)
#             outputs = model_ifta(batch_ifta)
            #print(10)
            
            #print(11)
            loss_fct = torch.nn.CrossEntropyLoss().to(device)
            #print(12)
            #print("q",outputs,outputs.shape)
            #print("z",labels,labels.shape)
            loss_ifta = loss_fct(outputs.view(-1, 4), labels.view(-1))
            loss_ifta.backward()

            optimizer_ifta.step()
            lr_scheduler_ifta.step()
            optimizer_ifta.zero_grad()
            progress_bar.update(1)
    
            loss_isrej = 0
            if step % 300 == 0:
                #test the accuracy

                print('STEP:{}, EPOCHS : {}/{}'.format(step,epoch+1,num_epochs),
                      'Loss : {:.4f},{:.4f}'.format(loss_isrej,loss_ifta))

                res = compute_metrics(get_pred(model_ifta,ifta_test_loader))
                if res["f1"] > best_f1:
                    best_f1 = res["f1"]
                    base_kidneyBert.save_pretrained(f"./fine_ifta/f1{best_f1}")
                    torch.save(model_ifta.state_dict(),f"./fine_ifta/model_isrej_{best_f1}.pth")
#                 compute_metrics(get_pred(model_ifta,ifta_test_loader))
except Exception as e:
    print("Exception",e)

  0%|          | 0/3440 [00:00<?, ?it/s]

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:300, EPOCHS : 2/20 Loss : 0.0000,0.3606
accuracy: 0.7051094890510949, precision: 0.7051094890510949, recall: 0.7051094890510949, f1: 0.7051094890510949


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:600, EPOCHS : 4/20 Loss : 0.0000,0.4748
accuracy: 0.7284671532846715, precision: 0.7284671532846715, recall: 0.7284671532846715, f1: 0.7284671532846715


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:900, EPOCHS : 6/20 Loss : 0.0000,0.3987
accuracy: 0.7138686131386861, precision: 0.7138686131386861, recall: 0.7138686131386861, f1: 0.7138686131386861


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:1200, EPOCHS : 7/20 Loss : 0.0000,0.4292
accuracy: 0.7386861313868613, precision: 0.7386861313868613, recall: 0.7386861313868613, f1: 0.7386861313868613


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:1500, EPOCHS : 9/20 Loss : 0.0000,0.1700
accuracy: 0.7343065693430657, precision: 0.7343065693430657, recall: 0.7343065693430657, f1: 0.7343065693430656


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:1800, EPOCHS : 11/20 Loss : 0.0000,0.0980
accuracy: 0.7138686131386861, precision: 0.7138686131386861, recall: 0.7138686131386861, f1: 0.7138686131386861


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:2100, EPOCHS : 13/20 Loss : 0.0000,0.0224
accuracy: 0.7343065693430657, precision: 0.7343065693430657, recall: 0.7343065693430657, f1: 0.7343065693430656


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:2400, EPOCHS : 14/20 Loss : 0.0000,0.0817
accuracy: 0.7386861313868613, precision: 0.7386861313868613, recall: 0.7386861313868613, f1: 0.7386861313868613


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:2700, EPOCHS : 16/20 Loss : 0.0000,0.0653
accuracy: 0.7401459854014598, precision: 0.7401459854014598, recall: 0.7401459854014598, f1: 0.7401459854014598


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:3000, EPOCHS : 18/20 Loss : 0.0000,0.0056
accuracy: 0.7416058394160584, precision: 0.7416058394160584, recall: 0.7416058394160584, f1: 0.7416058394160584


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:3300, EPOCHS : 20/20 Loss : 0.0000,0.0151
accuracy: 0.7386861313868613, precision: 0.7386861313868613, recall: 0.7386861313868613, f1: 0.7386861313868613


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


In [15]:
base_kidneyBert.save_pretrained(f"./fine_ifta/final")
torch.save(model_ifta.state_dict(),f"./fine_ifta/model_ifta_final.pth")

In [16]:
pred_labels,test_labels = get_pred(model_ifta,ifta_test_loader)
print(confusion_matrix(test_labels,pred_labels))
print(classification_report(test_labels,pred_labels))

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


[[320  42   3   2]
 [ 49 144  30   1]
 [  2  29  40   1]
 [  2   2  14   4]]
              precision    recall  f1-score   support

           0       0.86      0.87      0.86       367
           1       0.66      0.64      0.65       224
           2       0.46      0.56      0.50        72
           3       0.50      0.18      0.27        22

    accuracy                           0.74       685
   macro avg       0.62      0.56      0.57       685
weighted avg       0.74      0.74      0.74       685



In [15]:
from tqdm.auto import tqdm
import copy

progress_bar = tqdm(range(num_training_steps))

best_f1_isrej = 0
best_f1_ifta = 0
step = 0
try:
    for epoch in range(num_epochs):
        for batch_isrej,batch_ifta in zip(isrej_train_loader,ifta_train_loader):
            model_isrej.train()
            model_ifta.train()
            base_kidneyBert.train()
            step+=1
            #print(step)
            #print(1)
            
            
            batch_isrej = {k: v.to(device) for k, v in batch_isrej.items()}
            labels = batch_isrej["labels"]
#             labels = copy.deepcopy(batch_isrej["labels"]).to(device)
            #print(2)
            del batch_isrej["labels"]
            feat = base_kidneyBert(**batch_isrej)
#             print(feat[0].shape,feat[1].shape)
            outputs = model_isrej(feat)
            #print(3)
            #print(4)
            loss_fct = torch.nn.CrossEntropyLoss().to(device)
            #print(outputs,outputs.shape)
            #print(labels,labels.shape)
            loss_isrej = loss_fct(outputs.view(-1, 2), labels.view(-1))
            #print(6)
            loss_isrej.backward()
            #print(7)

            optimizer_isrej.step()
            lr_scheduler_isrej.step()
            optimizer_isrej.zero_grad()
            for k, v in batch_isrej.items():
                del v
                
            batch_isrej = 0
            
            torch.cuda.empty_cache()
            
            

            batch_ifta = {k: v.to(device) for k, v in batch_ifta.items()}
            labels = batch_ifta["labels"]
            del batch_ifta["labels"]
            feat = base_kidneyBert(**batch_ifta)
            outputs = model_ifta(feat)
#             outputs = model_ifta(batch_ifta)
            #print(10)
            
            #print(11)
            loss_fct = torch.nn.CrossEntropyLoss().to(device)
            #print(12)
            #print("q",outputs,outputs.shape)
            #print("z",labels,labels.shape)
            loss_ifta = loss_fct(outputs.view(-1, 4), labels.view(-1))
            loss_ifta.backward()

            optimizer_ifta.step()
            lr_scheduler_ifta.step()
            optimizer_ifta.zero_grad()
            progress_bar.update(1)
    
            if step % 300 == 0:
                #test the accuracy

                print('STEP:{}, EPOCHS : {}/{}'.format(step,epoch+1,num_epochs),
                      'Loss : {:.4f},{:.4f}'.format(loss_isrej,loss_ifta))
                
                res_isrej = compute_metrics(get_pred(model_isrej,isrej_test_loader))
                res_ifta = compute_metrics(get_pred(model_ifta,ifta_test_loader))
                if res_isrej["f1"] > best_f1_isrej and res_ifta["f1"] > best_f1_ifta:
                    best_f1_isrej = res_isrej["f1"]
                    best_f1_ifta = res_ifta["f1"]
                    base_kidneyBert.save_pretrained(f"./fine_both/f1_{best_f1_isrej}_{best_f1_ifta}")
                    torch.save(model_ifta.state_dict(),f"./fine_both/model_isrej_{best_f1_isrej}.pth")
                    torch.save(model_ifta.state_dict(),f"./fine_both/model_ifta_{best_f1_ifta}.pth")
#                 compute_metrics(get_pred(model_ifta,ifta_test_loader))
except Exception as e:
    print("Exception",e)

  0%|          | 0/3440 [00:00<?, ?it/s]

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


Exception CUDA out of memory. Tried to allocate 48.00 MiB (GPU 0; 16.00 GiB total capacity; 14.15 GiB already allocated; 0 bytes free; 14.30 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF


In [16]:
torch.cuda.empty_cache()

In [None]:
from tqdm.auto import tqdm
import copy

progress_bar = tqdm(range(num_training_steps))


step = 0
try:
    for epoch in range(num_epochs):
        for batch_isrej,batch_ifta in zip(isrej_train_loader,ifta_train_loader):
            model_isrej.train()
            model_ifta.train()
            base_kidneyBert.train()
            step+=1
            #print(step)
            #print(1)
            
            
            batch_isrej = {k: v.to(device) for k, v in batch_isrej.items()}
            labels = batch_isrej["labels"]
#             labels = copy.deepcopy(batch_isrej["labels"]).to(device)
            #print(2)
            del batch_isrej["labels"]
            feat = base_kidneyBert(**batch_isrej)
#             print(feat[0].shape,feat[1].shape)
            outputs = model_isrej(feat)
            #print(3)
            #print(4)
            loss_fct = torch.nn.CrossEntropyLoss().to(device)
            #print(outputs,outputs.shape)
            #print(labels,labels.shape)
            loss_isrej = loss_fct(outputs.view(-1, 2), labels.view(-1))
            #print(6)
            loss_isrej.backward()
            #print(7)

            optimizer_isrej.step()
            #lr_scheduler_isrej.step()
            optimizer_isrej.zero_grad()

            batch_ifta = {k: v.to(device) for k, v in batch_ifta.items()}
            labels = batch_ifta["labels"]
            del batch_ifta["labels"]
            feat = base_kidneyBert(**batch_ifta)
            outputs = model_ifta(feat)
#             outputs = model_ifta(batch_ifta)
            #print(10)
            
            #print(11)
            loss_fct = torch.nn.CrossEntropyLoss().to(device)
            #print(12)
            #print("q",outputs,outputs.shape)
            #print("z",labels,labels.shape)
            loss_ifta = loss_fct(outputs.view(-1, 4), labels.view(-1))
            loss_ifta.backward()

            optimizer_ifta.step()
            lr_scheduler_ifta.step()
            optimizer_ifta.zero_grad()
            progress_bar.update(1)
    
            loss_isrej = 0
            if step % 300 == 0:
                #test the accuracy

                print('STEP:{}, EPOCHS : {}/{}'.format(step,epoch+1,num_epochs),
                      'Loss : {:.4f},{:.4f}'.format(loss_isrej,loss_ifta))

#                 compute_metrics(get_pred(model_isrej,isrej_test_loader))
                compute_metrics(get_pred(model_ifta,ifta_test_loader))
except Exception as e:
    print("Exception",e)

  0%|          | 0/6850 [00:00<?, ?it/s]

  


In [22]:
from tqdm.auto import tqdm
import copy

progress_bar = tqdm(range(num_training_steps))


step = 0
try:
    for epoch in range(num_epochs):
        for batch_isrej,batch_ifta in zip(isrej_train_loader,ifta_train_loader):
            model_isrej.train()
            model_ifta.train()
            base_kidneyBert.train()
            step+=1
            #print(step)
            #print(1)
            
            
#             batch_isrej = {k: v.to(device) for k, v in batch_isrej.items()}
#             labels = batch_isrej["labels"]
# #             labels = copy.deepcopy(batch_isrej["labels"]).to(device)
#             #print(2)
#             del batch_isrej["labels"]
#             feat = base_kidneyBert(**batch_isrej)
# #             print(feat[0].shape,feat[1].shape)
#             outputs = model_isrej(feat)
#             #print(3)
#             #print(4)
#             loss_fct = torch.nn.CrossEntropyLoss().to(device)
#             #print(outputs,outputs.shape)
#             #print(labels,labels.shape)
#             loss_isrej = loss_fct(outputs.view(-1, 2), labels.view(-1))
#             #print(6)
#             loss_isrej.backward()
#             #print(7)

#             optimizer_isrej.step()
#             #lr_scheduler_isrej.step()
#             optimizer_isrej.zero_grad()

            batch_ifta = {k: v.to(device) for k, v in batch_ifta.items()}
            labels = batch_ifta["labels"]
            del batch_ifta["labels"]
            feat = base_kidneyBert(**batch_ifta)
            outputs = model_ifta(feat)
#             outputs = model_ifta(batch_ifta)
            #print(10)
            
            #print(11)
            loss_fct = torch.nn.CrossEntropyLoss().to(device)
            #print(12)
            #print("q",outputs,outputs.shape)
            #print("z",labels,labels.shape)
            loss_ifta = loss_fct(outputs.view(-1, 4), labels.view(-1))
            loss_ifta.backward()

            optimizer_ifta.step()
            lr_scheduler_ifta.step()
            optimizer_ifta.zero_grad()
            progress_bar.update(1)
    
            loss_isrej = 0
            if step % 300 == 0:
                #test the accuracy

                print('STEP:{}, EPOCHS : {}/{}'.format(step,epoch+1,num_epochs),
                      'Loss : {:.4f},{:.4f}'.format(loss_isrej,loss_ifta))

#                 compute_metrics(get_pred(model_isrej,isrej_test_loader))
                compute_metrics(get_pred(model_ifta,ifta_test_loader))
except Exception as e:
    print("Exception",e)

  0%|          | 0/6850 [00:00<?, ?it/s]

  


STEP:300, EPOCHS : 1/10 Loss : 0.0000,0.5015
accuracy: 0.7167883211678832, precision: 0.7167883211678832, recall: 0.7167883211678832, f1: 0.7167883211678832


  


STEP:600, EPOCHS : 1/10 Loss : 0.0000,0.6111
accuracy: 0.6875912408759124, precision: 0.6875912408759124, recall: 0.6875912408759124, f1: 0.6875912408759124


  


STEP:900, EPOCHS : 2/10 Loss : 0.0000,1.2040
accuracy: 0.7299270072992701, precision: 0.7299270072992701, recall: 0.7299270072992701, f1: 0.72992700729927


  


STEP:1200, EPOCHS : 2/10 Loss : 0.0000,0.4732
accuracy: 0.7094890510948905, precision: 0.7094890510948905, recall: 0.7094890510948905, f1: 0.7094890510948906


  


STEP:1500, EPOCHS : 3/10 Loss : 0.0000,0.6230
accuracy: 0.7124087591240876, precision: 0.7124087591240876, recall: 0.7124087591240876, f1: 0.7124087591240876


  


STEP:1800, EPOCHS : 3/10 Loss : 0.0000,0.6581
accuracy: 0.7299270072992701, precision: 0.7299270072992701, recall: 0.7299270072992701, f1: 0.72992700729927


  


STEP:2100, EPOCHS : 4/10 Loss : 0.0000,0.9544
accuracy: 0.7197080291970803, precision: 0.7197080291970803, recall: 0.7197080291970803, f1: 0.7197080291970803


  


STEP:2400, EPOCHS : 4/10 Loss : 0.0000,0.6450
accuracy: 0.6963503649635037, precision: 0.6963503649635037, recall: 0.6963503649635037, f1: 0.6963503649635037


  


STEP:2700, EPOCHS : 4/10 Loss : 0.0000,0.0922
accuracy: 0.7401459854014598, precision: 0.7401459854014598, recall: 0.7401459854014598, f1: 0.7401459854014598


  


STEP:3000, EPOCHS : 5/10 Loss : 0.0000,0.6349
accuracy: 0.7445255474452555, precision: 0.7445255474452555, recall: 0.7445255474452555, f1: 0.7445255474452555


  


STEP:3300, EPOCHS : 5/10 Loss : 0.0000,0.6144
accuracy: 0.7401459854014598, precision: 0.7401459854014598, recall: 0.7401459854014598, f1: 0.7401459854014598


  


STEP:3600, EPOCHS : 6/10 Loss : 0.0000,0.2996
accuracy: 0.7065693430656934, precision: 0.7065693430656934, recall: 0.7065693430656934, f1: 0.7065693430656934


  


STEP:3900, EPOCHS : 6/10 Loss : 0.0000,0.2006
accuracy: 0.7357664233576642, precision: 0.7357664233576642, recall: 0.7357664233576642, f1: 0.7357664233576642


  


STEP:4200, EPOCHS : 7/10 Loss : 0.0000,0.4186
accuracy: 0.7313868613138687, precision: 0.7313868613138687, recall: 0.7313868613138687, f1: 0.7313868613138687


  


STEP:4500, EPOCHS : 7/10 Loss : 0.0000,0.7996
accuracy: 0.7474452554744525, precision: 0.7474452554744525, recall: 0.7474452554744525, f1: 0.7474452554744525


  


STEP:4800, EPOCHS : 8/10 Loss : 0.0000,0.3397
accuracy: 0.7328467153284671, precision: 0.7328467153284671, recall: 0.7328467153284671, f1: 0.7328467153284671


  


STEP:5100, EPOCHS : 8/10 Loss : 0.0000,0.9969
accuracy: 0.7445255474452555, precision: 0.7445255474452555, recall: 0.7445255474452555, f1: 0.7445255474452555


  


STEP:5400, EPOCHS : 8/10 Loss : 0.0000,0.1258
accuracy: 0.7635036496350365, precision: 0.7635036496350365, recall: 0.7635036496350365, f1: 0.7635036496350364


  


STEP:5700, EPOCHS : 9/10 Loss : 0.0000,0.1847
accuracy: 0.7518248175182481, precision: 0.7518248175182481, recall: 0.7518248175182481, f1: 0.7518248175182483


  


STEP:6000, EPOCHS : 9/10 Loss : 0.0000,0.0526
accuracy: 0.743065693430657, precision: 0.743065693430657, recall: 0.743065693430657, f1: 0.743065693430657


  


STEP:6300, EPOCHS : 10/10 Loss : 0.0000,0.1960
accuracy: 0.743065693430657, precision: 0.743065693430657, recall: 0.743065693430657, f1: 0.743065693430657


  


STEP:6600, EPOCHS : 10/10 Loss : 0.0000,0.0521
accuracy: 0.743065693430657, precision: 0.743065693430657, recall: 0.743065693430657, f1: 0.743065693430657


  


## isRejction one task

In [27]:
pred_labels,test_labels = get_pred(model_isrej,isrej_test_loader)
print(confusion_matrix(test_labels,pred_labels))
print(classification_report(test_labels,pred_labels))

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


[[641   6]
 [ 15  23]]
              precision    recall  f1-score   support

           0       0.98      0.99      0.98       647
           1       0.79      0.61      0.69        38

    accuracy                           0.97       685
   macro avg       0.89      0.80      0.84       685
weighted avg       0.97      0.97      0.97       685



## isRejction multi tasks

In [26]:
pred_labels,test_labels = get_pred(model_isrej,isrej_test_loader)
print(confusion_matrix(test_labels,pred_labels))
print(classification_report(test_labels,pred_labels))

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


[[638   9]
 [ 21  17]]
              precision    recall  f1-score   support

           0       0.97      0.99      0.98       647
           1       0.65      0.45      0.53        38

    accuracy                           0.96       685
   macro avg       0.81      0.72      0.75       685
weighted avg       0.95      0.96      0.95       685



## IFTA one task

In [16]:
pred_labels,test_labels = get_pred(model_ifta,ifta_test_loader)
print(confusion_matrix(test_labels,pred_labels))
print(classification_report(test_labels,pred_labels))

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


[[320  42   3   2]
 [ 49 144  30   1]
 [  2  29  40   1]
 [  2   2  14   4]]
              precision    recall  f1-score   support

           0       0.86      0.87      0.86       367
           1       0.66      0.64      0.65       224
           2       0.46      0.56      0.50        72
           3       0.50      0.18      0.27        22

    accuracy                           0.74       685
   macro avg       0.62      0.56      0.57       685
weighted avg       0.74      0.74      0.74       685



## IFTA multi tasks

In [27]:
pred_labels,test_labels = get_pred(model_ifta,ifta_test_loader)
print(confusion_matrix(test_labels,pred_labels))
print(classification_report(test_labels,pred_labels))

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


[[324  43   0   0]
 [ 37 164  23   0]
 [  3  31  38   0]
 [  3   4  12   3]]
              precision    recall  f1-score   support

           0       0.88      0.88      0.88       367
           1       0.68      0.73      0.70       224
           2       0.52      0.53      0.52        72
           3       1.00      0.14      0.24        22

    accuracy                           0.77       685
   macro avg       0.77      0.57      0.59       685
weighted avg       0.78      0.77      0.77       685

