In [1]:
from transformers import AutoModelWithLMHead,BertForSequenceClassification, AutoTokenizer, AutoModel,AutoModelForMaskedLM,AutoModelForSequenceClassification
import torch
from torch import nn
import json
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split,StratifiedShuffleSplit
from torch.utils.data import DataLoader,TensorDataset
from transformers import Trainer, TrainingArguments
import pickle
from sklearn.metrics import confusion_matrix,classification_report
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score,roc_curve
import pandas as pd
import matplotlib.pyplot as plt

from transformers import AdamW,get_scheduler


In [2]:
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
base_kidneyBert = AutoModel.from_pretrained("./mlm_results_largeData/checkpoint-1100")

Some weights of the model checkpoint at ./mlm_results_largeData/checkpoint-1100 were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at ./mlm_results_largeData/checkpoint-1100 and are newly initialized: ['bert.pooler.dens

In [3]:
base_kidneyBert.to("cuda")

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [4]:
class RejClassifier(nn.Module):
    def __init__(self):
        super(RejClassifier, self).__init__()
        
#         self.base_model = base_model
        self.dropout = nn.Dropout()
        self.linear = nn.Linear(768, 256) # output features from bert is 768 and 2 is ur number of labels
        self.linear2 = nn.Linear(256, 2)
        self.relu = nn.ReLU()
        
    def forward(self, outputs):
        #input_ids, attention_mask = inputs["input_ids"],inputs["attention_mask"]
        #print("a")
        #outputs = self.base_model(input_ids, attention_mask=attention_mask)
#         del inputs["labels"]
#         outputs = self.base_model(**inputs)
        #print("b")
        
        #print("c",outputs,outputs.shape)
        outputs = self.linear(outputs[1])
        outputs = self.dropout(self.relu(outputs))
        outputs = self.linear2(outputs)
        #print("d",outputs,outputs.shape)
        
        return outputs
    
class IFTAClassifier(nn.Module):
    def __init__(self):
        super(IFTAClassifier, self).__init__()
        
#         self.base_model = base_model
        self.dropout = nn.Dropout()
        self.linear = nn.Linear(768, 256) # output features from bert is 768 and 2 is ur number of labels
        self.linear2 = nn.Linear(256, 4)
        self.relu = nn.ReLU()
        
    def forward(self, outputs):
        #input_ids, attention_mask = inputs["input_ids"],inputs["attention_mask"]
#         del inputs["labels"]
#         outputs = self.base_model(**inputs)
        #print("e",outputs,outputs.shape)
        #outputs = self.base_model(input_ids, attention_mask=attention_mask)
#         outputs = self.dropout(outputs[1])
        #print("f",outputs,outputs.shape)
#         outputs = self.linear(outputs)
        #print("g",outputs,outputs.shape)
        outputs = self.linear(outputs[1])
        outputs = self.dropout(self.relu(outputs))
        outputs = self.linear2(outputs)
        
        return outputs

In [5]:
class RenalDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels,task_name=None):
        self.encodings = encodings
        self.labels = labels
        self.task_name = task_name

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        #item["task_name"] = self.task_name
        return item

    def __len__(self):
        return len(self.labels)
def compute_metrics(p):    
    pred, labels = p
    #pred = np.argmax(pred, axis=1)
    accuracy = accuracy_score(y_true=labels, y_pred=pred)
    recall = recall_score(y_true=labels, y_pred=pred,average="micro")
    precision = precision_score(y_true=labels, y_pred=pred,average="micro")
    f1 = f1_score(y_true=labels, y_pred=pred,average="micro")
    print("accuracy: {}, precision: {}, recall: {}, f1: {}".format(accuracy,precision,recall,f1))
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1} 

In [6]:
def gen_datasets(q,train_text,test_text,tokenizer=tokenizer):
    train_q = [q for i in range(len(train_text))]
    test_q = [q for i in range(len(test_text))]

    train_encodings = tokenizer(train_q,train_text,padding="max_length", truncation=True, 
                                return_tensors="pt",max_length=512)
    test_encodings = tokenizer(test_q,test_text,padding="max_length", truncation=True, 
                                return_tensors="pt",max_length=512)
    train_dataset = RenalDataset(train_encodings, train_labels)
    test_dataset = RenalDataset(test_encodings, test_labels)
    return train_dataset,test_dataset

In [7]:
batch_size = 14

In [8]:
data = pd.read_csv("data.csv")
inputs1 = data["train_rej"].tolist()
label1 = data["isRejection"].tolist()
label = [l for i,l in zip(inputs1,label1) if str(i)!="nan"]
inputs = [i for i in inputs1 if str(i)!="nan"]
train_text, test_text, train_labels, test_labels = train_test_split(
    inputs, label,random_state = 1,stratify=label,test_size=0.2)
# train_encodings = tokenizer(train_text,padding="max_length", truncation=True, 
#                             return_tensors="pt",max_length=512)
# test_encodings = tokenizer(test_text,padding="max_length", truncation=True, 
#                             return_tensors="pt",max_length=512)
# train_dataset = RenalDataset(train_encodings, train_labels,task_name="isrej")
# test_dataset = RenalDataset(test_encodings, test_labels,task_name="isrej")
q_rej = "Is there any rejection?"
train_dataset,test_dataset = gen_datasets(q_rej,train_text,test_text)
isrej_train_loader = torch.utils.data.DataLoader(train_dataset,batch_size = batch_size)
isrej_test_loader = torch.utils.data.DataLoader(test_dataset,batch_size = batch_size)

In [9]:
data = pd.read_csv("data.csv")
inputs1 = data["train_ifta"].tolist()
label1 = data["IFTA"].tolist()
label2 = [l for i,l in zip(inputs1,label1) if str(i)!="nan"]
label = [0 if l in ["nosig","minimal","noinfo"] else (1 if l=="mild" else (2 if l=="moderate" else 3)) for l in label2]
inputs = [i for i in inputs1 if str(i)!="nan"]
train_text, test_text, train_labels, test_labels = train_test_split(
    inputs, label,random_state = 1,stratify=label,test_size=0.2)
# train_encodings = tokenizer(train_text,padding="max_length", truncation=True, 
#                             return_tensors="pt",max_length=512)
# test_encodings = tokenizer(test_text,padding="max_length", truncation=True, 
#                             return_tensors="pt",max_length=512)
# train_dataset = RenalDataset(train_encodings, train_labels,task_name = "ifta")
# test_dataset = RenalDataset(test_encodings, test_labels,task_name = "ifta")
q_ifta = "What is the grade of interstitial fibrosis and tubular atrophy?"
train_dataset,test_dataset = gen_datasets(q_ifta,train_text,test_text)
ifta_train_loader = torch.utils.data.DataLoader(train_dataset,batch_size = batch_size)
ifta_test_loader = torch.utils.data.DataLoader(test_dataset,batch_size = batch_size)



In [10]:
model_isrej = RejClassifier()
model_ifta = IFTAClassifier()

In [11]:
optimizer_isrej = AdamW(list(base_kidneyBert.parameters())+list(model_isrej.parameters()), lr=5e-5)
optimizer_ifta = AdamW(list(base_kidneyBert.parameters())+list(model_ifta.parameters()), lr=5e-5)
num_epochs = 20
num_training_steps = num_epochs * min(len(isrej_train_loader),len(ifta_train_loader))
lr_scheduler_isrej = get_scheduler("linear", optimizer=optimizer_isrej, num_warmup_steps=30, num_training_steps=num_training_steps)
lr_scheduler_ifta = get_scheduler("linear", optimizer=optimizer_ifta, num_warmup_steps=30, num_training_steps=num_training_steps)






In [12]:
device = "cuda"
model_isrej.to(device)
model_ifta.to(device)

IFTAClassifier(
  (dropout): Dropout(p=0.5, inplace=False)
  (linear): Linear(in_features=768, out_features=256, bias=True)
  (linear2): Linear(in_features=256, out_features=4, bias=True)
  (relu): ReLU()
)

In [13]:
# model_renal = AutoModelForSequenceClassification.from_pretrained("./mlm_results_largeData/checkpoint-1100",num_labels=4)
# model_renal

In [15]:
def get_pred(model,dataloader):
    pred,labels = [],[]
    model.eval() 
    base_kidneyBert.eval()
    with torch.no_grad(): 
        for batch in dataloader:
           
            batch = {k: v.to(device) for k, v in batch.items()}
            cur_labels = batch["labels"].to("cpu").flatten().tolist()
            del batch["labels"]
            outputs = model(base_kidneyBert(**batch))
            predicted = torch.argmax(outputs, axis=1).to("cpu").flatten().tolist() #torch.max(outputs.data,1) 
            pred+=predicted
            labels+=cur_labels
            
            
#             if pred == 0:
#                 pred = predicted
#                 labels = cur_labels
#             else:
#                 pred = torch.cat((pred,predicted),-1)
#                 labels = torch.cat((labels,cur_labels),-1)
#             total += labels.size(0)
            
#             correct += (predicted == labels).sum().item()
            
#     print('the accuracy is {:.4f}'.format(correct/total))
    return pred,labels

In [16]:
from tqdm.auto import tqdm
import copy

progress_bar = tqdm(range(num_training_steps))

best_f1_isrej = 0
best_f1_ifta = 0
step = 0
try:
    for epoch in range(num_epochs):
        for batch_isrej,batch_ifta in zip(isrej_train_loader,ifta_train_loader):
            model_isrej.train()
            model_ifta.train()
            base_kidneyBert.train()
            step+=1
            #print(step)
            #print(1)
            
            
            batch_isrej = {k: v.to(device) for k, v in batch_isrej.items()}
            labels = batch_isrej["labels"]
#             labels = copy.deepcopy(batch_isrej["labels"]).to(device)
            #print(2)
            del batch_isrej["labels"]
            feat = base_kidneyBert(**batch_isrej)
#             print(feat[0].shape,feat[1].shape)
            outputs = model_isrej(feat)
            #print(3)
            #print(4)
            loss_fct = torch.nn.CrossEntropyLoss().to(device)
            #print(outputs,outputs.shape)
            #print(labels,labels.shape)
            loss_isrej = loss_fct(outputs.view(-1, 2), labels.view(-1))
            #print(6)
            loss_isrej.backward()
            #print(7)

            optimizer_isrej.step()
            lr_scheduler_isrej.step()
            optimizer_isrej.zero_grad()
#             for k, v in batch_isrej.items():
#                 v.detach()
# #                 del v
#             batch_isrej.clear()
                
#             batch_isrej = 0
            
#             torch.cuda.empty_cache()
            
            

            batch_ifta = {k: v.to(device) for k, v in batch_ifta.items()}
            labels = batch_ifta["labels"]
            del batch_ifta["labels"]
            feat = base_kidneyBert(**batch_ifta)
            outputs = model_ifta(feat)
#             outputs = model_ifta(batch_ifta)
            #print(10)
            
            #print(11)
            loss_fct = torch.nn.CrossEntropyLoss().to(device)
            #print(12)
            #print("q",outputs,outputs.shape)
            #print("z",labels,labels.shape)
            loss_ifta = loss_fct(outputs.view(-1, 4), labels.view(-1))
            loss_ifta.backward()

            optimizer_ifta.step()
            lr_scheduler_ifta.step()
            optimizer_ifta.zero_grad()
            progress_bar.update(1)
    
            if step % 300 == 0:
                #test the accuracy

                print('STEP:{}, EPOCHS : {}/{}'.format(step,epoch+1,num_epochs),
                      'Loss : {:.4f},{:.4f}'.format(loss_isrej,loss_ifta))
                
                res_isrej = compute_metrics(get_pred(model_isrej,isrej_test_loader))
                res_ifta = compute_metrics(get_pred(model_ifta,ifta_test_loader))
                if res_isrej["f1"] > best_f1_isrej and res_ifta["f1"] > best_f1_ifta:
                    best_f1_isrej = res_isrej["f1"]
                    best_f1_ifta = res_ifta["f1"]
                    base_kidneyBert.save_pretrained(f"./fine_both/f1_{best_f1_isrej}_{best_f1_ifta}")
                    torch.save(model_isrej_.state_dict(),f"./fine_both/model_isrej_{best_f1_isrej}.pth")
                    torch.save(model_ifta.state_dict(),f"./fine_both/model_ifta_{best_f1_ifta}.pth")
#                 compute_metrics(get_pred(model_ifta,ifta_test_loader))
except Exception as e:
    print("Exception",e)

  0%|          | 0/3920 [00:00<?, ?it/s]

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:300, EPOCHS : 2/20 Loss : 0.1552,0.5154
accuracy: 0.9445255474452555, precision: 0.9445255474452555, recall: 0.9445255474452555, f1: 0.9445255474452555


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


accuracy: 0.710948905109489, precision: 0.710948905109489, recall: 0.710948905109489, f1: 0.7109489051094892


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:600, EPOCHS : 4/20 Loss : 0.1199,0.9246
accuracy: 0.962043795620438, precision: 0.962043795620438, recall: 0.962043795620438, f1: 0.962043795620438


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


accuracy: 0.7445255474452555, precision: 0.7445255474452555, recall: 0.7445255474452555, f1: 0.7445255474452555


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:900, EPOCHS : 5/20 Loss : 0.0424,0.3793
accuracy: 0.9518248175182482, precision: 0.9518248175182482, recall: 0.9518248175182482, f1: 0.9518248175182482


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


accuracy: 0.710948905109489, precision: 0.710948905109489, recall: 0.710948905109489, f1: 0.7109489051094892


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:1200, EPOCHS : 7/20 Loss : 0.0124,0.5633
accuracy: 0.9635036496350365, precision: 0.9635036496350365, recall: 0.9635036496350365, f1: 0.9635036496350365


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


accuracy: 0.7489051094890511, precision: 0.7489051094890511, recall: 0.7489051094890511, f1: 0.7489051094890511


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:1500, EPOCHS : 8/20 Loss : 0.0013,0.4945
accuracy: 0.9576642335766423, precision: 0.9576642335766423, recall: 0.9576642335766423, f1: 0.9576642335766423


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


accuracy: 0.7445255474452555, precision: 0.7445255474452555, recall: 0.7445255474452555, f1: 0.7445255474452555


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:1800, EPOCHS : 10/20 Loss : 0.0488,0.4588
accuracy: 0.9445255474452555, precision: 0.9445255474452555, recall: 0.9445255474452555, f1: 0.9445255474452555


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


accuracy: 0.708029197080292, precision: 0.708029197080292, recall: 0.708029197080292, f1: 0.708029197080292


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:2100, EPOCHS : 11/20 Loss : 0.0332,0.3511
accuracy: 0.9605839416058394, precision: 0.9605839416058394, recall: 0.9605839416058394, f1: 0.9605839416058394


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


accuracy: 0.7635036496350365, precision: 0.7635036496350365, recall: 0.7635036496350365, f1: 0.7635036496350364


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:2400, EPOCHS : 13/20 Loss : 0.0017,0.3229
accuracy: 0.9547445255474453, precision: 0.9547445255474453, recall: 0.9547445255474453, f1: 0.9547445255474453


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


accuracy: 0.7547445255474453, precision: 0.7547445255474453, recall: 0.7547445255474453, f1: 0.7547445255474453


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:2700, EPOCHS : 14/20 Loss : 0.0013,0.1532
accuracy: 0.9518248175182482, precision: 0.9518248175182482, recall: 0.9518248175182482, f1: 0.9518248175182482


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


accuracy: 0.7211678832116788, precision: 0.7211678832116788, recall: 0.7211678832116788, f1: 0.7211678832116789


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:3000, EPOCHS : 16/20 Loss : 0.0007,0.4339
accuracy: 0.9518248175182482, precision: 0.9518248175182482, recall: 0.9518248175182482, f1: 0.9518248175182482


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


accuracy: 0.7635036496350365, precision: 0.7635036496350365, recall: 0.7635036496350365, f1: 0.7635036496350364


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:3300, EPOCHS : 17/20 Loss : 0.0012,0.0808
accuracy: 0.945985401459854, precision: 0.945985401459854, recall: 0.945985401459854, f1: 0.945985401459854


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


accuracy: 0.7532846715328467, precision: 0.7532846715328467, recall: 0.7532846715328467, f1: 0.7532846715328468


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:3600, EPOCHS : 19/20 Loss : 0.0002,0.0347
accuracy: 0.9576642335766423, precision: 0.9576642335766423, recall: 0.9576642335766423, f1: 0.9576642335766423


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


accuracy: 0.7766423357664234, precision: 0.7766423357664234, recall: 0.7766423357664234, f1: 0.7766423357664234


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:3900, EPOCHS : 20/20 Loss : 0.0002,0.0380
accuracy: 0.9562043795620438, precision: 0.9562043795620438, recall: 0.9562043795620438, f1: 0.9562043795620438


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


accuracy: 0.7737226277372263, precision: 0.7737226277372263, recall: 0.7737226277372263, f1: 0.7737226277372263


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


In [17]:
base_kidneyBert.save_pretrained(f"./fine_both/final")
torch.save(model_isrej.state_dict(),f"./fine_both/model_isrej_final.pth")
torch.save(model_ifta.state_dict(),f"./fine_both/model_ifta_final.pth")

In [25]:
model_isrej.load_state_dict(torch.load("./fine_both/model_isrej_final.pth"))

<All keys matched successfully>

In [18]:
pred_labels,test_labels = get_pred(model_isrej,isrej_test_loader)
print(confusion_matrix(test_labels,pred_labels))
print(classification_report(test_labels,pred_labels))
pred_labels,test_labels = get_pred(model_ifta,ifta_test_loader)
print(confusion_matrix(test_labels,pred_labels))
print(classification_report(test_labels,pred_labels))

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


[[638   9]
 [ 21  17]]
              precision    recall  f1-score   support

           0       0.97      0.99      0.98       647
           1       0.65      0.45      0.53        38

    accuracy                           0.96       685
   macro avg       0.81      0.72      0.75       685
weighted avg       0.95      0.96      0.95       685



  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


[[324  43   0   0]
 [ 37 164  23   0]
 [  3  31  38   0]
 [  3   4  12   3]]
              precision    recall  f1-score   support

           0       0.88      0.88      0.88       367
           1       0.68      0.73      0.70       224
           2       0.52      0.53      0.52        72
           3       1.00      0.14      0.24        22

    accuracy                           0.77       685
   macro avg       0.77      0.57      0.59       685
weighted avg       0.78      0.77      0.77       685



In [26]:
from tqdm.auto import tqdm
import copy

progress_bar = tqdm(range(num_training_steps))

best_f1 = 0
step = 0
try:
    for epoch in range(num_epochs):
        for batch_isrej,batch_ifta in zip(isrej_train_loader,ifta_train_loader):
            model_isrej.train()
            model_ifta.train()
            base_kidneyBert.train()
            step+=1
            #print(step)
            #print(1)
            
            
            batch_isrej = {k: v.to(device) for k, v in batch_isrej.items()}
            labels = batch_isrej["labels"]
#             labels = copy.deepcopy(batch_isrej["labels"]).to(device)
            #print(2)
            del batch_isrej["labels"]
            feat = base_kidneyBert(**batch_isrej)
#             print(feat[0].shape,feat[1].shape)
            outputs = model_isrej(feat)
            #print(3)
            #print(4)
            loss_fct = torch.nn.CrossEntropyLoss().to(device)
            #print(outputs,outputs.shape)
            #print(labels,labels.shape)
            loss_isrej = loss_fct(outputs.view(-1, 2), labels.view(-1))
            #print(6)
            loss_isrej.backward()
            #print(7)

            optimizer_isrej.step()
            lr_scheduler_isrej.step()
            optimizer_isrej.zero_grad()

#             batch_ifta = {k: v.to(device) for k, v in batch_ifta.items()}
#             labels = batch_ifta["labels"]
#             del batch_ifta["labels"]
#             feat = base_kidneyBert(**batch_ifta)
#             outputs = model_ifta(feat)
# #             outputs = model_ifta(batch_ifta)
#             #print(10)
            
#             #print(11)
#             loss_fct = torch.nn.CrossEntropyLoss().to(device)
#             #print(12)
#             #print("q",outputs,outputs.shape)
#             #print("z",labels,labels.shape)
#             loss_ifta = loss_fct(outputs.view(-1, 4), labels.view(-1))
#             loss_ifta.backward()

#             optimizer_ifta.step()
#             lr_scheduler_ifta.step()
#             optimizer_ifta.zero_grad()
            progress_bar.update(1)
    
            loss_ifta = 0
            if step % 300 == 0:
                #test the accuracy

                print('STEP:{}, EPOCHS : {}/{}'.format(step,epoch+1,num_epochs),
                      'Loss : {:.4f},{:.4f}'.format(loss_isrej,loss_ifta))

                res = compute_metrics(get_pred(model_isrej,isrej_test_loader))
                if res["f1"] > best_f1:
                    best_f1 = res["f1"]
                    base_kidneyBert.save_pretrained(f"./fine_rej/f1{best_f1}")
                    torch.save(model_isrej.state_dict(),f"./fine_rej/model_isrej_{best_f1}.pth")
#                 compute_metrics(get_pred(model_ifta,ifta_test_loader))
except Exception as e:
    print("Exception",e)

  0%|          | 0/3440 [00:00<?, ?it/s]

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:300, EPOCHS : 2/20 Loss : 0.2560,0.0000
accuracy: 0.9445255474452555, precision: 0.9445255474452555, recall: 0.9445255474452555, f1: 0.9445255474452555


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:600, EPOCHS : 4/20 Loss : 0.2050,0.0000
accuracy: 0.9562043795620438, precision: 0.9562043795620438, recall: 0.9562043795620438, f1: 0.9562043795620438


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:900, EPOCHS : 6/20 Loss : 0.0197,0.0000
accuracy: 0.9503649635036496, precision: 0.9503649635036496, recall: 0.9503649635036496, f1: 0.9503649635036496


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:1200, EPOCHS : 7/20 Loss : 0.0123,0.0000
accuracy: 0.9635036496350365, precision: 0.9635036496350365, recall: 0.9635036496350365, f1: 0.9635036496350365


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:1500, EPOCHS : 9/20 Loss : 0.0687,0.0000
accuracy: 0.9547445255474453, precision: 0.9547445255474453, recall: 0.9547445255474453, f1: 0.9547445255474453


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:1800, EPOCHS : 11/20 Loss : 0.0141,0.0000
accuracy: 0.9635036496350365, precision: 0.9635036496350365, recall: 0.9635036496350365, f1: 0.9635036496350365


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:2100, EPOCHS : 13/20 Loss : 0.0038,0.0000
accuracy: 0.964963503649635, precision: 0.964963503649635, recall: 0.964963503649635, f1: 0.964963503649635


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:2400, EPOCHS : 14/20 Loss : 0.0004,0.0000
accuracy: 0.9664233576642336, precision: 0.9664233576642336, recall: 0.9664233576642336, f1: 0.9664233576642337


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:2700, EPOCHS : 16/20 Loss : 0.0024,0.0000
accuracy: 0.9664233576642336, precision: 0.9664233576642336, recall: 0.9664233576642336, f1: 0.9664233576642337


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:3000, EPOCHS : 18/20 Loss : 0.0005,0.0000
accuracy: 0.9693430656934306, precision: 0.9693430656934306, recall: 0.9693430656934306, f1: 0.9693430656934306


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:3300, EPOCHS : 20/20 Loss : 0.0001,0.0000
accuracy: 0.9693430656934306, precision: 0.9693430656934306, recall: 0.9693430656934306, f1: 0.9693430656934306


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


In [28]:
base_kidneyBert.save_pretrained(f"./fine_rej/final")
torch.save(model_isrej.state_dict(),f"./fine_rej/model_isrej_final.pth")

In [27]:
pred_labels,test_labels = get_pred(model_isrej,isrej_test_loader)
print(confusion_matrix(test_labels,pred_labels))
print(classification_report(test_labels,pred_labels))

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


[[641   6]
 [ 15  23]]
              precision    recall  f1-score   support

           0       0.98      0.99      0.98       647
           1       0.79      0.61      0.69        38

    accuracy                           0.97       685
   macro avg       0.89      0.80      0.84       685
weighted avg       0.97      0.97      0.97       685



In [14]:
from tqdm.auto import tqdm
import copy

progress_bar = tqdm(range(num_training_steps))

best_f1 = 0
step = 0
try:
    for epoch in range(num_epochs):
        for batch_isrej,batch_ifta in zip(isrej_train_loader,ifta_train_loader):
            model_isrej.train()
            model_ifta.train()
            base_kidneyBert.train()
            step+=1
            #print(step)
            #print(1)
            
            
#             batch_isrej = {k: v.to(device) for k, v in batch_isrej.items()}
#             labels = batch_isrej["labels"]
# #             labels = copy.deepcopy(batch_isrej["labels"]).to(device)
#             #print(2)
#             del batch_isrej["labels"]
#             feat = base_kidneyBert(**batch_isrej)
# #             print(feat[0].shape,feat[1].shape)
#             outputs = model_isrej(feat)
#             #print(3)
#             #print(4)
#             loss_fct = torch.nn.CrossEntropyLoss().to(device)
#             #print(outputs,outputs.shape)
#             #print(labels,labels.shape)
#             loss_isrej = loss_fct(outputs.view(-1, 2), labels.view(-1))
#             #print(6)
#             loss_isrej.backward()
#             #print(7)

#             optimizer_isrej.step()
#             lr_scheduler_isrej.step()
#             optimizer_isrej.zero_grad()

            batch_ifta = {k: v.to(device) for k, v in batch_ifta.items()}
            labels = batch_ifta["labels"]
            del batch_ifta["labels"]
            feat = base_kidneyBert(**batch_ifta)
            outputs = model_ifta(feat)
#             outputs = model_ifta(batch_ifta)
            #print(10)
            
            #print(11)
            loss_fct = torch.nn.CrossEntropyLoss().to(device)
            #print(12)
            #print("q",outputs,outputs.shape)
            #print("z",labels,labels.shape)
            loss_ifta = loss_fct(outputs.view(-1, 4), labels.view(-1))
            loss_ifta.backward()

            optimizer_ifta.step()
            lr_scheduler_ifta.step()
            optimizer_ifta.zero_grad()
            progress_bar.update(1)
    
            loss_isrej = 0
            if step % 300 == 0:
                #test the accuracy

                print('STEP:{}, EPOCHS : {}/{}'.format(step,epoch+1,num_epochs),
                      'Loss : {:.4f},{:.4f}'.format(loss_isrej,loss_ifta))

                res = compute_metrics(get_pred(model_ifta,ifta_test_loader))
                if res["f1"] > best_f1:
                    best_f1 = res["f1"]
                    base_kidneyBert.save_pretrained(f"./fine_ifta/f1{best_f1}")
                    torch.save(model_ifta.state_dict(),f"./fine_ifta/model_isrej_{best_f1}.pth")
#                 compute_metrics(get_pred(model_ifta,ifta_test_loader))
except Exception as e:
    print("Exception",e)

  0%|          | 0/3440 [00:00<?, ?it/s]

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:300, EPOCHS : 2/20 Loss : 0.0000,0.3606
accuracy: 0.7051094890510949, precision: 0.7051094890510949, recall: 0.7051094890510949, f1: 0.7051094890510949


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:600, EPOCHS : 4/20 Loss : 0.0000,0.4748
accuracy: 0.7284671532846715, precision: 0.7284671532846715, recall: 0.7284671532846715, f1: 0.7284671532846715


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:900, EPOCHS : 6/20 Loss : 0.0000,0.3987
accuracy: 0.7138686131386861, precision: 0.7138686131386861, recall: 0.7138686131386861, f1: 0.7138686131386861


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:1200, EPOCHS : 7/20 Loss : 0.0000,0.4292
accuracy: 0.7386861313868613, precision: 0.7386861313868613, recall: 0.7386861313868613, f1: 0.7386861313868613


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:1500, EPOCHS : 9/20 Loss : 0.0000,0.1700
accuracy: 0.7343065693430657, precision: 0.7343065693430657, recall: 0.7343065693430657, f1: 0.7343065693430656


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:1800, EPOCHS : 11/20 Loss : 0.0000,0.0980
accuracy: 0.7138686131386861, precision: 0.7138686131386861, recall: 0.7138686131386861, f1: 0.7138686131386861


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:2100, EPOCHS : 13/20 Loss : 0.0000,0.0224
accuracy: 0.7343065693430657, precision: 0.7343065693430657, recall: 0.7343065693430657, f1: 0.7343065693430656


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:2400, EPOCHS : 14/20 Loss : 0.0000,0.0817
accuracy: 0.7386861313868613, precision: 0.7386861313868613, recall: 0.7386861313868613, f1: 0.7386861313868613


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:2700, EPOCHS : 16/20 Loss : 0.0000,0.0653
accuracy: 0.7401459854014598, precision: 0.7401459854014598, recall: 0.7401459854014598, f1: 0.7401459854014598


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:3000, EPOCHS : 18/20 Loss : 0.0000,0.0056
accuracy: 0.7416058394160584, precision: 0.7416058394160584, recall: 0.7416058394160584, f1: 0.7416058394160584


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


STEP:3300, EPOCHS : 20/20 Loss : 0.0000,0.0151
accuracy: 0.7386861313868613, precision: 0.7386861313868613, recall: 0.7386861313868613, f1: 0.7386861313868613


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


In [15]:
base_kidneyBert.save_pretrained(f"./fine_ifta/final")
torch.save(model_ifta.state_dict(),f"./fine_ifta/model_ifta_final.pth")

In [16]:
pred_labels,test_labels = get_pred(model_ifta,ifta_test_loader)
print(confusion_matrix(test_labels,pred_labels))
print(classification_report(test_labels,pred_labels))

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


[[320  42   3   2]
 [ 49 144  30   1]
 [  2  29  40   1]
 [  2   2  14   4]]
              precision    recall  f1-score   support

           0       0.86      0.87      0.86       367
           1       0.66      0.64      0.65       224
           2       0.46      0.56      0.50        72
           3       0.50      0.18      0.27        22

    accuracy                           0.74       685
   macro avg       0.62      0.56      0.57       685
weighted avg       0.74      0.74      0.74       685



In [15]:
from tqdm.auto import tqdm
import copy

progress_bar = tqdm(range(num_training_steps))

best_f1_isrej = 0
best_f1_ifta = 0
step = 0
try:
    for epoch in range(num_epochs):
        for batch_isrej,batch_ifta in zip(isrej_train_loader,ifta_train_loader):
            model_isrej.train()
            model_ifta.train()
            base_kidneyBert.train()
            step+=1
            #print(step)
            #print(1)
            
            
            batch_isrej = {k: v.to(device) for k, v in batch_isrej.items()}
            labels = batch_isrej["labels"]
#             labels = copy.deepcopy(batch_isrej["labels"]).to(device)
            #print(2)
            del batch_isrej["labels"]
            feat = base_kidneyBert(**batch_isrej)
#             print(feat[0].shape,feat[1].shape)
            outputs = model_isrej(feat)
            #print(3)
            #print(4)
            loss_fct = torch.nn.CrossEntropyLoss().to(device)
            #print(outputs,outputs.shape)
            #print(labels,labels.shape)
            loss_isrej = loss_fct(outputs.view(-1, 2), labels.view(-1))
            #print(6)
            loss_isrej.backward()
            #print(7)

            optimizer_isrej.step()
            lr_scheduler_isrej.step()
            optimizer_isrej.zero_grad()
            for k, v in batch_isrej.items():
                del v
                
            batch_isrej = 0
            
            torch.cuda.empty_cache()
            
            

            batch_ifta = {k: v.to(device) for k, v in batch_ifta.items()}
            labels = batch_ifta["labels"]
            del batch_ifta["labels"]
            feat = base_kidneyBert(**batch_ifta)
            outputs = model_ifta(feat)
#             outputs = model_ifta(batch_ifta)
            #print(10)
            
            #print(11)
            loss_fct = torch.nn.CrossEntropyLoss().to(device)
            #print(12)
            #print("q",outputs,outputs.shape)
            #print("z",labels,labels.shape)
            loss_ifta = loss_fct(outputs.view(-1, 4), labels.view(-1))
            loss_ifta.backward()

            optimizer_ifta.step()
            lr_scheduler_ifta.step()
            optimizer_ifta.zero_grad()
            progress_bar.update(1)
    
            if step % 300 == 0:
                #test the accuracy

                print('STEP:{}, EPOCHS : {}/{}'.format(step,epoch+1,num_epochs),
                      'Loss : {:.4f},{:.4f}'.format(loss_isrej,loss_ifta))
                
                res_isrej = compute_metrics(get_pred(model_isrej,isrej_test_loader))
                res_ifta = compute_metrics(get_pred(model_ifta,ifta_test_loader))
                if res_isrej["f1"] > best_f1_isrej and res_ifta["f1"] > best_f1_ifta:
                    best_f1_isrej = res_isrej["f1"]
                    best_f1_ifta = res_ifta["f1"]
                    base_kidneyBert.save_pretrained(f"./fine_both/f1_{best_f1_isrej}_{best_f1_ifta}")
                    torch.save(model_ifta.state_dict(),f"./fine_both/model_isrej_{best_f1_isrej}.pth")
                    torch.save(model_ifta.state_dict(),f"./fine_both/model_ifta_{best_f1_ifta}.pth")
#                 compute_metrics(get_pred(model_ifta,ifta_test_loader))
except Exception as e:
    print("Exception",e)

  0%|          | 0/3440 [00:00<?, ?it/s]

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


Exception CUDA out of memory. Tried to allocate 48.00 MiB (GPU 0; 16.00 GiB total capacity; 14.15 GiB already allocated; 0 bytes free; 14.30 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF


In [16]:
torch.cuda.empty_cache()

In [None]:
from tqdm.auto import tqdm
import copy

progress_bar = tqdm(range(num_training_steps))


step = 0
try:
    for epoch in range(num_epochs):
        for batch_isrej,batch_ifta in zip(isrej_train_loader,ifta_train_loader):
            model_isrej.train()
            model_ifta.train()
            base_kidneyBert.train()
            step+=1
            #print(step)
            #print(1)
            
            
            batch_isrej = {k: v.to(device) for k, v in batch_isrej.items()}
            labels = batch_isrej["labels"]
#             labels = copy.deepcopy(batch_isrej["labels"]).to(device)
            #print(2)
            del batch_isrej["labels"]
            feat = base_kidneyBert(**batch_isrej)
#             print(feat[0].shape,feat[1].shape)
            outputs = model_isrej(feat)
            #print(3)
            #print(4)
            loss_fct = torch.nn.CrossEntropyLoss().to(device)
            #print(outputs,outputs.shape)
            #print(labels,labels.shape)
            loss_isrej = loss_fct(outputs.view(-1, 2), labels.view(-1))
            #print(6)
            loss_isrej.backward()
            #print(7)

            optimizer_isrej.step()
            #lr_scheduler_isrej.step()
            optimizer_isrej.zero_grad()

            batch_ifta = {k: v.to(device) for k, v in batch_ifta.items()}
            labels = batch_ifta["labels"]
            del batch_ifta["labels"]
            feat = base_kidneyBert(**batch_ifta)
            outputs = model_ifta(feat)
#             outputs = model_ifta(batch_ifta)
            #print(10)
            
            #print(11)
            loss_fct = torch.nn.CrossEntropyLoss().to(device)
            #print(12)
            #print("q",outputs,outputs.shape)
            #print("z",labels,labels.shape)
            loss_ifta = loss_fct(outputs.view(-1, 4), labels.view(-1))
            loss_ifta.backward()

            optimizer_ifta.step()
            lr_scheduler_ifta.step()
            optimizer_ifta.zero_grad()
            progress_bar.update(1)
    
            loss_isrej = 0
            if step % 300 == 0:
                #test the accuracy

                print('STEP:{}, EPOCHS : {}/{}'.format(step,epoch+1,num_epochs),
                      'Loss : {:.4f},{:.4f}'.format(loss_isrej,loss_ifta))

#                 compute_metrics(get_pred(model_isrej,isrej_test_loader))
                compute_metrics(get_pred(model_ifta,ifta_test_loader))
except Exception as e:
    print("Exception",e)

  0%|          | 0/6850 [00:00<?, ?it/s]

  


In [22]:
from tqdm.auto import tqdm
import copy

progress_bar = tqdm(range(num_training_steps))


step = 0
try:
    for epoch in range(num_epochs):
        for batch_isrej,batch_ifta in zip(isrej_train_loader,ifta_train_loader):
            model_isrej.train()
            model_ifta.train()
            base_kidneyBert.train()
            step+=1
            #print(step)
            #print(1)
            
            
#             batch_isrej = {k: v.to(device) for k, v in batch_isrej.items()}
#             labels = batch_isrej["labels"]
# #             labels = copy.deepcopy(batch_isrej["labels"]).to(device)
#             #print(2)
#             del batch_isrej["labels"]
#             feat = base_kidneyBert(**batch_isrej)
# #             print(feat[0].shape,feat[1].shape)
#             outputs = model_isrej(feat)
#             #print(3)
#             #print(4)
#             loss_fct = torch.nn.CrossEntropyLoss().to(device)
#             #print(outputs,outputs.shape)
#             #print(labels,labels.shape)
#             loss_isrej = loss_fct(outputs.view(-1, 2), labels.view(-1))
#             #print(6)
#             loss_isrej.backward()
#             #print(7)

#             optimizer_isrej.step()
#             #lr_scheduler_isrej.step()
#             optimizer_isrej.zero_grad()

            batch_ifta = {k: v.to(device) for k, v in batch_ifta.items()}
            labels = batch_ifta["labels"]
            del batch_ifta["labels"]
            feat = base_kidneyBert(**batch_ifta)
            outputs = model_ifta(feat)
#             outputs = model_ifta(batch_ifta)
            #print(10)
            
            #print(11)
            loss_fct = torch.nn.CrossEntropyLoss().to(device)
            #print(12)
            #print("q",outputs,outputs.shape)
            #print("z",labels,labels.shape)
            loss_ifta = loss_fct(outputs.view(-1, 4), labels.view(-1))
            loss_ifta.backward()

            optimizer_ifta.step()
            lr_scheduler_ifta.step()
            optimizer_ifta.zero_grad()
            progress_bar.update(1)
    
            loss_isrej = 0
            if step % 300 == 0:
                #test the accuracy

                print('STEP:{}, EPOCHS : {}/{}'.format(step,epoch+1,num_epochs),
                      'Loss : {:.4f},{:.4f}'.format(loss_isrej,loss_ifta))

#                 compute_metrics(get_pred(model_isrej,isrej_test_loader))
                compute_metrics(get_pred(model_ifta,ifta_test_loader))
except Exception as e:
    print("Exception",e)

  0%|          | 0/6850 [00:00<?, ?it/s]

  


STEP:300, EPOCHS : 1/10 Loss : 0.0000,0.5015
accuracy: 0.7167883211678832, precision: 0.7167883211678832, recall: 0.7167883211678832, f1: 0.7167883211678832


  


STEP:600, EPOCHS : 1/10 Loss : 0.0000,0.6111
accuracy: 0.6875912408759124, precision: 0.6875912408759124, recall: 0.6875912408759124, f1: 0.6875912408759124


  


STEP:900, EPOCHS : 2/10 Loss : 0.0000,1.2040
accuracy: 0.7299270072992701, precision: 0.7299270072992701, recall: 0.7299270072992701, f1: 0.72992700729927


  


STEP:1200, EPOCHS : 2/10 Loss : 0.0000,0.4732
accuracy: 0.7094890510948905, precision: 0.7094890510948905, recall: 0.7094890510948905, f1: 0.7094890510948906


  


STEP:1500, EPOCHS : 3/10 Loss : 0.0000,0.6230
accuracy: 0.7124087591240876, precision: 0.7124087591240876, recall: 0.7124087591240876, f1: 0.7124087591240876


  


STEP:1800, EPOCHS : 3/10 Loss : 0.0000,0.6581
accuracy: 0.7299270072992701, precision: 0.7299270072992701, recall: 0.7299270072992701, f1: 0.72992700729927


  


STEP:2100, EPOCHS : 4/10 Loss : 0.0000,0.9544
accuracy: 0.7197080291970803, precision: 0.7197080291970803, recall: 0.7197080291970803, f1: 0.7197080291970803


  


STEP:2400, EPOCHS : 4/10 Loss : 0.0000,0.6450
accuracy: 0.6963503649635037, precision: 0.6963503649635037, recall: 0.6963503649635037, f1: 0.6963503649635037


  


STEP:2700, EPOCHS : 4/10 Loss : 0.0000,0.0922
accuracy: 0.7401459854014598, precision: 0.7401459854014598, recall: 0.7401459854014598, f1: 0.7401459854014598


  


STEP:3000, EPOCHS : 5/10 Loss : 0.0000,0.6349
accuracy: 0.7445255474452555, precision: 0.7445255474452555, recall: 0.7445255474452555, f1: 0.7445255474452555


  


STEP:3300, EPOCHS : 5/10 Loss : 0.0000,0.6144
accuracy: 0.7401459854014598, precision: 0.7401459854014598, recall: 0.7401459854014598, f1: 0.7401459854014598


  


STEP:3600, EPOCHS : 6/10 Loss : 0.0000,0.2996
accuracy: 0.7065693430656934, precision: 0.7065693430656934, recall: 0.7065693430656934, f1: 0.7065693430656934


  


STEP:3900, EPOCHS : 6/10 Loss : 0.0000,0.2006
accuracy: 0.7357664233576642, precision: 0.7357664233576642, recall: 0.7357664233576642, f1: 0.7357664233576642


  


STEP:4200, EPOCHS : 7/10 Loss : 0.0000,0.4186
accuracy: 0.7313868613138687, precision: 0.7313868613138687, recall: 0.7313868613138687, f1: 0.7313868613138687


  


STEP:4500, EPOCHS : 7/10 Loss : 0.0000,0.7996
accuracy: 0.7474452554744525, precision: 0.7474452554744525, recall: 0.7474452554744525, f1: 0.7474452554744525


  


STEP:4800, EPOCHS : 8/10 Loss : 0.0000,0.3397
accuracy: 0.7328467153284671, precision: 0.7328467153284671, recall: 0.7328467153284671, f1: 0.7328467153284671


  


STEP:5100, EPOCHS : 8/10 Loss : 0.0000,0.9969
accuracy: 0.7445255474452555, precision: 0.7445255474452555, recall: 0.7445255474452555, f1: 0.7445255474452555


  


STEP:5400, EPOCHS : 8/10 Loss : 0.0000,0.1258
accuracy: 0.7635036496350365, precision: 0.7635036496350365, recall: 0.7635036496350365, f1: 0.7635036496350364


  


STEP:5700, EPOCHS : 9/10 Loss : 0.0000,0.1847
accuracy: 0.7518248175182481, precision: 0.7518248175182481, recall: 0.7518248175182481, f1: 0.7518248175182483


  


STEP:6000, EPOCHS : 9/10 Loss : 0.0000,0.0526
accuracy: 0.743065693430657, precision: 0.743065693430657, recall: 0.743065693430657, f1: 0.743065693430657


  


STEP:6300, EPOCHS : 10/10 Loss : 0.0000,0.1960
accuracy: 0.743065693430657, precision: 0.743065693430657, recall: 0.743065693430657, f1: 0.743065693430657


  


STEP:6600, EPOCHS : 10/10 Loss : 0.0000,0.0521
accuracy: 0.743065693430657, precision: 0.743065693430657, recall: 0.743065693430657, f1: 0.743065693430657


  


## isRejction one task

In [27]:
pred_labels,test_labels = get_pred(model_isrej,isrej_test_loader)
print(confusion_matrix(test_labels,pred_labels))
print(classification_report(test_labels,pred_labels))

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


[[641   6]
 [ 15  23]]
              precision    recall  f1-score   support

           0       0.98      0.99      0.98       647
           1       0.79      0.61      0.69        38

    accuracy                           0.97       685
   macro avg       0.89      0.80      0.84       685
weighted avg       0.97      0.97      0.97       685



## isRejction multi tasks

In [26]:
pred_labels,test_labels = get_pred(model_isrej,isrej_test_loader)
print(confusion_matrix(test_labels,pred_labels))
print(classification_report(test_labels,pred_labels))

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


[[638   9]
 [ 21  17]]
              precision    recall  f1-score   support

           0       0.97      0.99      0.98       647
           1       0.65      0.45      0.53        38

    accuracy                           0.96       685
   macro avg       0.81      0.72      0.75       685
weighted avg       0.95      0.96      0.95       685



## IFTA one task

In [16]:
pred_labels,test_labels = get_pred(model_ifta,ifta_test_loader)
print(confusion_matrix(test_labels,pred_labels))
print(classification_report(test_labels,pred_labels))

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


[[320  42   3   2]
 [ 49 144  30   1]
 [  2  29  40   1]
 [  2   2  14   4]]
              precision    recall  f1-score   support

           0       0.86      0.87      0.86       367
           1       0.66      0.64      0.65       224
           2       0.46      0.56      0.50        72
           3       0.50      0.18      0.27        22

    accuracy                           0.74       685
   macro avg       0.62      0.56      0.57       685
weighted avg       0.74      0.74      0.74       685



## IFTA multi tasks

In [27]:
pred_labels,test_labels = get_pred(model_ifta,ifta_test_loader)
print(confusion_matrix(test_labels,pred_labels))
print(classification_report(test_labels,pred_labels))

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


[[324  43   0   0]
 [ 37 164  23   0]
 [  3  31  38   0]
 [  3   4  12   3]]
              precision    recall  f1-score   support

           0       0.88      0.88      0.88       367
           1       0.68      0.73      0.70       224
           2       0.52      0.53      0.52        72
           3       1.00      0.14      0.24        22

    accuracy                           0.77       685
   macro avg       0.77      0.57      0.59       685
weighted avg       0.78      0.77      0.77       685

