In [None]:
import pandas as pd
import numpy as np
from sklearn.metrics import classification_report, f1_score, accuracy_score
from sklearn.model_selection import train_test_split
import collections
import time
import pathlib
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import random

import torch
from torch import nn, cuda
from torch.utils.data import DataLoader
import torch.nn.functional as F
from scipy.stats import spearmanr, pearsonr

import transformers
from transformers import BertModel, BertTokenizerFast, DistilBertModel, DistilBertTokenizerFast
from transformers import get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, ConcatDataset
import copy

device = 'cuda' if cuda.is_available() else 'cpu'
if device == 'cuda':
    torch.cuda.empty_cache()
print(device)

In [None]:
# print(DATA_PATH)

# TASK = 'relnorel'
# TASK = 'attsup'
TASK = 'premiseclaim'

dataframe = None

if (TASK == 'relnorel') or (TASK == 'attsup'):
    DATA_PATH = pathlib.Path('/kaggle/input/ukp-sentential-am-corpus')
    dataframe = pd.read_csv(DATA_PATH / 'ukp-data.csv')
    
    if TASK == 'attsup':
        dataframe = dataframe.loc[dataframe.label != 'NoArgument']    
elif TASK == 'premiseclaim':
    DATA_PATH = pathlib.Path('/kaggle/input/argumentannotatedessays')
    dataframe = pd.read_csv(DATA_PATH / 'data.csv')
    dataframe = dataframe[dataframe.label != 'majorclaim']
    dataframe = dataframe.rename(columns={'text': 'sentence'})
    
    
# if TASK == 'attsup':
#     dataframe = dataframe.loc[dataframe.label != 'NoArgument']

print(len(dataframe.topic.unique()))
print(collections.Counter(dataframe.set))
print(collections.Counter(dataframe.label))
dataframe.head()

In [None]:
"""
Dataset class
"""
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [None]:
class BERTClassifier(nn.Module):
    def __init__(self, num_labels = 2):
        super(BERTClassifier, self).__init__()
        self.num_labels = num_labels
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.classifier = nn.Linear(768, self.num_labels)
        self.dropout = torch.nn.Dropout(0.1)

    def forward(self, input_ids, attention_mask, token_type_ids):
        x = self.bert(input_ids, 
                      token_type_ids = token_type_ids,
                      attention_mask = attention_mask)
        hidden = self.dropout(nn.LeakyReLU()(x.pooler_output))
        return self.classifier(hidden)

In [None]:
# Establecer la semilla aleatoria
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def prepare_input(df, labels, tokenizer, max_len = 256, shuffle = True, bs = 32):
    encodings = tokenizer(list(df['sentence'].values),
                          list(df['topic'].values),
                          truncation=True, 
                          padding='max_length', 
                          max_length=max_len)
    dst = CustomDataset(encodings, labels)
    dataloader = DataLoader(dst, batch_size = bs, shuffle = shuffle)
    return dataloader

# save model to disc
def save_model_state(model_state, PATH):
    torch.save(model_state, PATH)
    
# load and returns a previously saved model
def load_model(PATH):
    model = BERTClassifier()
    model.load_state_dict(torch.load(PATH))
    return model

def get_batch_predictions(model, batch, loss_fn = None):
    # mini-batch predictions
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    input_type_ids = batch['token_type_ids'].to(device)
    labels = batch['labels'].type(torch.LongTensor).to(device)  # BS x1
            
    # get predictions for the mini-batch
    outputs = model(input_ids, attention_mask, input_type_ids) # outputs BS x 2
            
    # calculate mini-batch loss and accuracy
    loss = None
    if loss_fn is not None:
        loss = loss_fn(outputs, labels)
    
    return outputs, labels, loss

# returns number of correct predictions for a given batch of samples
def batch_hits(logits, labels):
    return (torch.flatten((logits > 0.0).float()) == torch.flatten(labels)).sum().item()

def finetune_model(model, train_loader, validation_loader, loss_fn, optim, scheduler, epochs, n_samples, n_eval_samples):

    best_model_state = copy.deepcopy(model.state_dict())
    best_loss = float('inf')
    progress_bar = tqdm(range(len(train_loader) * epochs))
    
    for epoch in range(1, epochs + 1):
                
        print("Epoch: {}/{}".format(epoch, epochs))
        start_time = time.time()
        
        ## TRAINING LOOP
        start_time = time.time()
        train_loss, train_hits = 0, 0
        model.train()
        
        for i, batch in enumerate(train_loader, 0):
            optim.zero_grad()
            
            outputs, labels, loss = get_batch_predictions(model, batch, loss_fn)
            train_loss += loss.item()
            train_hits += (torch.argmax(outputs, dim = 1) == labels).sum().item()
            
            # Optimization step
            loss.backward()
            
            # gradient clipping
            torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=1.0)
            
            # update learning rates using scheduler
            scheduler.step()
            optim.step()
            
            progress_bar.update(1)
          
        # calculate average loss and accuracy
        avg_train_loss = round(float(train_loss / n_samples), 4)
        train_acc = round(float(train_hits / n_samples), 3)
            
        ## VALIDATION LOOP
        model.eval()
        eval_loss = 0
        eval_predictions, eval_labels = np.array([], int), np.array([], int)
        
        with torch.no_grad():
            for batch in validation_loader:
                outputs, labels, loss = get_batch_predictions(model, batch, loss_fn)
                eval_loss += loss.item()
                
                eval_predictions = np.concatenate((eval_predictions, torch.argmax(outputs, dim = 1).int().cpu().numpy()), axis = 0)
                eval_labels = np.concatenate((eval_labels, labels.int().cpu().numpy()), axis = 0)
                
        # calculate average loss, accuracy and macro f1 score
        avg_eval_loss = round(float(eval_loss / n_eval_samples), 4)
        eval_acc = accuracy_score(eval_labels, eval_predictions)
        eval_macro_f1 = f1_score(eval_labels, eval_predictions, average = 'macro')
        
        ## PRINT INFO
        print(f"Training   ==> Accuracy: {train_acc} | Loss: {avg_train_loss}")
        print(f"Validation ==> Accuracy: {eval_acc}  | Loss: {avg_eval_loss}  | Macro-F1: {eval_macro_f1}")
        
        ## SAVE BEST MODEL
        ## Save model state when validation loss decreases
        if eval_loss < best_loss:
            best_model_state = copy.deepcopy(model.state_dict())
            best_loss = eval_loss
            print("Validation loss has decreased... saving model at epoch {}".format(epoch))
    
    return best_model_state

def test_model(model_state, dataloader, get_logits = False, classes = None, verbose = True):
    
    model = BERTClassifier()
    model.load_state_dict(model_state)
    model.to(device)
    model.eval()
    
    predictions, test_labels = np.array([], int), np.array([], int)
    logits = np.zeros((0, 2), dtype=np.float32)
    
    with torch.no_grad():
        for batch in dataloader:
            outputs, labels, loss = get_batch_predictions(model, batch)
            
            logits = np.concatenate((logits, outputs.cpu().numpy()), axis = 0)
            predictions = np.concatenate((predictions, torch.argmax(outputs, dim = 1).int().cpu().numpy()), axis = 0)
            test_labels = np.concatenate((test_labels, labels.int().cpu().numpy()), axis = 0)
            
    if verbose:
        if classes is not None:
            print(classification_report(test_labels, predictions, target_names = classes))
        else:
            print(classification_report(test_labels, predictions))
    
    if get_logits:
        return predictions, logits
    else:
        return predictions, None

# Finetuning

In [None]:
MAX_LEN = 256
EPOCHS = 4
EXPERIMENTS = 5
LR = 2e-5

SAVE_MODELS = True

df = dataframe.copy() # copy of dataframe
    
for exp in range(EXPERIMENTS):
    
    # Establecer la semilla aleatoria para cada corrida
    seed = exp*2023 + 1
    set_seed(seed)
    
    # TOKENIZER
    tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

    # SEPARATE DATA
    X_train = df.loc[df.set == 'train'].sample(frac = 1).reset_index(drop = True)
#     X_val = df.loc[df.set == 'val'].sample(frac = 1).reset_index(drop = True)
    X_test = df.loc[df.set == 'test'].reset_index(drop = True)
    if TASK == 'premiseclaim':
        X_train, X_val = train_test_split(X_train, test_size=0.15, random_state=42)
    else:
        X_val = df.loc[df.set == 'val'].sample(frac = 1).reset_index(drop = True)
    
    train_labels, val_labels, test_labels = None, None, None
    
    print(collections.Counter(X_train.label))
    print(collections.Counter(X_val.label))
    print(collections.Counter(X_test.label))
    break
    
    if TASK == 'relnorel':
        train_labels = np.asarray([1 if x != 'NoArgument' else 0 for x in X_train.label.values])
        val_labels = np.asarray([1 if x != 'NoArgument' else 0 for x in X_val.label.values])
        test_labels = np.asarray([1 if x != 'NoArgument' else 0 for x in X_test.label.values])
        
    elif TASK == 'attsup':
        train_labels = np.asarray([1 if x != 'Argument_against' else 0 for x in X_train.label.values])
        val_labels = np.asarray([1 if x != 'Argument_against' else 0 for x in X_val.label.values])
        test_labels = np.asarray([1 if x != 'Argument_against' else 0 for x in X_test.label.values])
        
    elif TASK == 'premiseclaim':
        train_labels = np.asarray([1 if x == 'claim' else 0 for x in X_train.label.values])
        val_labels = np.asarray([1 if x == 'claim' else 0 for x in X_val.label.values])
        test_labels = np.asarray([1 if x == 'claim' else 0 for x in X_test.label.values])
        
    else:
        print("Invalid Task")
        break

    train_samples = len(train_labels)
    eval_samples = len(val_labels)
    
    # TRAINING DATA
    train_loader = prepare_input(X_train, train_labels, tokenizer, bs = 32, shuffle = True)
    validation_loader = prepare_input(X_val, val_labels, tokenizer, bs = 32, shuffle = True)
    test_loader = prepare_input(X_test, test_labels, tokenizer, bs = 32, shuffle = False)
    
    # INITIALIZE MODEL
    model = BERTClassifier()
    model.to(device)
    
    # OPTIMIZER
    optim = torch.optim.AdamW(model.parameters(), lr=LR)
    
    ### Scheduler
    num_train_optimization_steps = EPOCHS * len(train_loader)
    print(num_train_optimization_steps)
    scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps=300, num_training_steps=num_train_optimization_steps)
    
    # LOSS FUNCTION
    loss_fn = nn.CrossEntropyLoss()    
    
    # training model
    best_model_state = finetune_model(model, 
                                train_loader, 
                                validation_loader, 
                                loss_fn, 
                                optim, 
                                scheduler, 
                                EPOCHS, 
                                train_samples, 
                                eval_samples)
    
    test_predictions, test_logits = test_model(best_model_state, test_loader, get_logits = True)
    test_logits = torch.from_numpy(test_logits)
    softmax_tensor = F.softmax(test_logits, dim=1)
    df_res = X_test.copy()
    
    if TASK == 'relnorel':
        df_res['prediction'] = ['arg' if x == 1 else 'noarg' for x in test_predictions]
        
    elif TASK == 'attsup':
        df_res['prediction'] = ['against' if x == 1 else 'for' for x in test_predictions]
        
    elif TASK == 'premiseclaim':
        df_res['prediction'] = ['claim' if x == 1 else 'premise' for x in test_predictions]
    
    df_res['logits'] = [list(row) for row in softmax_tensor.numpy()]
    
    if TASK != 'premiseclaim':
        df_res.to_csv(f"binary_{TASK}_ukp_exp{exp}.csv", index = False)
    else:
        df_res.to_csv(f"binary_{TASK}_pe_exp{exp}.csv", index = False)
    
    if SAVE_MODELS:
        # saving the best model
        best_model_name = f'trained_model_ukp_{TASK}_exp{exp}.pt'
        if TASK == 'premiseclaim':
            best_model_name = f'trained_model_pe_{TASK}_exp{exp}.pt'
        save_model_state(best_model_state, best_model_name)
        print("Best model saved... {}".format(best_model_name))

# Inferencia con dataset con perturbaciones

In [None]:
def prepare_input_for_test(df, column_name, labels, tokenizer, max_len = 256, shuffle = True, bs = 32):
    encodings = tokenizer(list(df[column_name].values),
                          list(df['topic'].values),
                          truncation=True, 
                          padding='max_length', 
                          max_length=max_len)
    dst = CustomDataset(encodings, labels)
    dataloader = DataLoader(dst, batch_size = bs, shuffle = shuffle)
    return dataloader

DATA_PATH = pathlib.Path('/kaggle/input/ukp-sentential-am-corpus')
print(DATA_PATH)

## Instancias C1

In [None]:
# TASK = 'relnorel'
# TASK = 'attsup'
TASK = 'premiseclaim'

dataframe = pd.read_csv(DATA_PATH / 'data_for_classification_C1_with_category.csv')

if TASK == 'attsup':
    dataframe = dataframe.loc[dataframe.relTag != 'NoArgument']

print(len(dataframe.topic.unique()))

X_test = dataframe

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

# load model
# ...
# model_for_inference = ...
best_experiment = 0
if TASK == 'premiseclaim':
    BEST_MODEL_PATH = f'/kaggle/working/trained_model_pe_{TASK}_exp{best_experiment}.pt'
else:
    BEST_MODEL_PATH = f'/kaggle/working/trained_model_ukp_{TASK}_exp{best_experiment}.pt'

model_state_for_inference = torch.load(BEST_MODEL_PATH)

################################################
############ Predictions sobre gold ############
################################################

if TASK == 'relnorel':
    test_labels = np.asarray([1 if x != 'NoArgument' else 0 for x in X_test.relTag.values])       
elif TASK == 'attsup':
    test_labels = np.asarray([1 if x != 'Argument_against' else 0 for x in X_test.stanceTag.values])
elif TASK == 'premiseclaim':
    test_labels = np.asarray([1 for x in X_test.relTag.values])
else:
    print("Invalid Task")

test_loader = prepare_input_for_test(X_test, 'goldArgument', test_labels, tokenizer, bs = 32, shuffle = False)

test_predictions, test_logits = test_model(model_state_for_inference, test_loader, get_logits = True)

test_logits = torch.from_numpy(test_logits)
softmax_tensor = F.softmax(test_logits, dim=1)
df_res = X_test.copy()
    
if TASK == 'relnorel':
    df_res['prediction'] = ['arg' if x == 1 else 'noarg' for x in test_predictions]
elif TASK == 'attsup':
    df_res['prediction'] = ['against' if x == 1 else 'for' for x in test_predictions]
elif TASK == 'premiseclaim':
    df_res['prediction'] = ['claim' if x == 1 else 'premise' for x in test_predictions]
    
df_res['logits'] = [list(row) for row in softmax_tensor.numpy()]
df_res.to_csv(f"binary_{TASK}_c1_gold.csv", index = False)

################################################
#### Predictions sobre instancias con error ####
################################################

if TASK == 'relnorel':
    test_labels = np.asarray([1 if x != 'NoArgument' else 0 for x in X_test.relTag.values])       
elif TASK == 'attsup':
    test_labels = np.asarray([1 if x != 'Argument_against' else 0 for x in X_test.stanceTag.values])
elif TASK == 'premiseclaim':
    test_labels = np.asarray([1 for x in X_test.relTag.values])
else:
    print("Invalid Task")

test_loader = prepare_input_for_test(X_test, 'predictedArgument', test_labels, tokenizer, bs = 32, shuffle = False)

test_predictions, test_logits = test_model(model_state_for_inference, test_loader, get_logits = True)

test_logits = torch.from_numpy(test_logits)
softmax_tensor = F.softmax(test_logits, dim=1)
df_res = X_test.copy()
    
if TASK == 'relnorel':
    df_res['prediction'] = ['arg' if x == 1 else 'noarg' for x in test_predictions]
elif TASK == 'attsup':
    df_res['prediction'] = ['against' if x == 1 else 'for' for x in test_predictions]
elif TASK == 'premiseclaim':
    df_res['prediction'] = ['claim' if x == 1 else 'premise' for x in test_predictions]
    
df_res['logits'] = [list(row) for row in softmax_tensor.numpy()]
df_res.to_csv(f"binary_{TASK}_c1_with_errors.csv", index = False)

## Instancias C2

In [None]:
# TASK = 'relnorel'
# TASK = 'attsup'
TASK = 'premiseclaim'

dataframe = pd.read_csv(DATA_PATH / 'data_for_classification_C2_with_category.csv')

print(len(dataframe.topic.unique()))

X_test = dataframe

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

# load model
# ...
# model_for_inference = ...
best_experiment = 0
if TASK == 'premiseclaim':
    BEST_MODEL_PATH = f'/kaggle/working/trained_model_pe_{TASK}_exp{best_experiment}.pt'
else:
    BEST_MODEL_PATH = f'/kaggle/working/trained_model_ukp_{TASK}_exp{best_experiment}.pt'

model_state_for_inference = torch.load(BEST_MODEL_PATH)

################################################
#### Predictions sobre instancias con error ####
################################################

if TASK == 'relnorel':
    test_labels = np.asarray([1 if x != 'NoArgument' else 0 for x in X_test.relTag.values])       
elif TASK == 'attsup':
    test_labels = np.asarray([1 if x != 'Argument_against' else 0 for x in X_test.stanceTag.values])
elif TASK == 'premiseclaim':
    test_labels = np.asarray([1 for x in X_test.relTag.values])
else:
    print("Invalid Task")

test_loader = prepare_input_for_test(X_test, 'predictedArgument', test_labels, tokenizer, bs = 32, shuffle = False)

test_predictions, test_logits = test_model(model_state_for_inference, test_loader, get_logits = True)

test_logits = torch.from_numpy(test_logits)
softmax_tensor = F.softmax(test_logits, dim=1)
df_res = X_test.copy()
    
if TASK == 'relnorel':
    df_res['prediction'] = ['arg' if x == 1 else 'noarg' for x in test_predictions]
elif TASK == 'attsup':
    df_res['prediction'] = ['against' if x == 1 else 'for' for x in test_predictions]
elif TASK == 'premiseclaim':
    df_res['prediction'] = ['claim' if x == 1 else 'premise' for x in test_predictions]
    
df_res['logits'] = [list(row) for row in softmax_tensor.numpy()]
df_res.to_csv(f"binary_{TASK}_c2_with_errors.csv", index = False)

## Instancias C3 y C4

In [None]:
# TASK = 'relnorel'
# TASK = 'attsup'
TASK = 'premiseclaim'

dataframe = pd.read_csv(DATA_PATH / 'data_for_classification_C34_with_category.csv')

print(len(dataframe.topic.unique()))

X_test = dataframe

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

# load model
# ...
# model_for_inference = ...
best_experiment = 0
# BEST_MODEL_PATH = f'/kaggle/working/trained_model_ukp_{TASK}_exp{best_experiment}.pt'
if TASK == 'premiseclaim':
    BEST_MODEL_PATH = f'/kaggle/working/trained_model_pe_{TASK}_exp{best_experiment}.pt'
else:
    BEST_MODEL_PATH = f'/kaggle/working/trained_model_ukp_{TASK}_exp{best_experiment}.pt'
model_state_for_inference = torch.load(BEST_MODEL_PATH)

################################################
#### Predictions sobre instancias con error ####
################################################

if TASK == 'relnorel':
    test_labels = np.asarray([1 if x != 'NoArgument' else 0 for x in X_test.relTag.values])       
elif TASK == 'attsup':
    test_labels = np.asarray([1 if x != 'Argument_against' else 0 for x in X_test.stanceTag.values])
elif TASK == 'premiseclaim':
    test_labels = np.asarray([1 for x in X_test.stanceTag.values])
else:
    print("Invalid Task")

test_loader = prepare_input_for_test(X_test, 'predictedArgument', test_labels, tokenizer, bs = 32, shuffle = False)

test_predictions, test_logits = test_model(model_state_for_inference, test_loader, get_logits = True)

test_logits = torch.from_numpy(test_logits)
softmax_tensor = F.softmax(test_logits, dim=1)
df_res = X_test.copy()
    
if TASK == 'relnorel':
    df_res['prediction'] = ['arg' if x == 1 else 'noarg' for x in test_predictions]
elif TASK == 'attsup':
    df_res['prediction'] = ['against' if x == 1 else 'for' for x in test_predictions]
elif TASK == 'premiseclaim':
    df_res['prediction'] = ['claim' if x == 1 else 'premise' for x in test_predictions]
    
df_res['logits'] = [list(row) for row in softmax_tensor.numpy()]
df_res.to_csv(f"binary_{TASK}_c34_with_errors.csv", index = False)

## Instancias C1 SPLIT

In [None]:
# TASK = 'relnorel'
# TASK = 'attsup'
TASK = 'premiseclaim'

dataframe = pd.read_csv(DATA_PATH / 'data_for_classification_C1_SPLIT_with_category.csv')

dataframe['predictedArgument1'] = dataframe['predictedArgument1'].fillna('')
dataframe['predictedArgument2'] = dataframe['predictedArgument2'].fillna('')

print(len(dataframe.topic.unique()))

X_test = dataframe

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

# load model
# ...
# model_for_inference = ...
best_experiment = 0
# BEST_MODEL_PATH = f'/kaggle/working/trained_model_ukp_{TASK}_exp{best_experiment}.pt'
if TASK == 'premiseclaim':
    BEST_MODEL_PATH = f'/kaggle/working/trained_model_pe_{TASK}_exp{best_experiment}.pt'
else:
    BEST_MODEL_PATH = f'/kaggle/working/trained_model_ukp_{TASK}_exp{best_experiment}.pt'

model_state_for_inference = torch.load(BEST_MODEL_PATH)

################################################
############ Predictions sobre gold ############
################################################

if TASK == 'relnorel':
    test_labels = np.asarray([1 if x != 'NoArgument' else 0 for x in X_test.relTag.values])       
elif TASK == 'attsup':
    test_labels = np.asarray([1 if x != 'Argument_against' else 0 for x in X_test.stanceTag.values])
elif TASK == 'premiseclaim':
    test_labels = np.asarray([1 for x in X_test.stanceTag.values])
else:
    print("Invalid Task")

test_loader = prepare_input_for_test(X_test, 'goldArgument', test_labels, tokenizer, bs = 32, shuffle = False)

test_predictions, test_logits = test_model(model_state_for_inference, test_loader, get_logits = True)

test_logits = torch.from_numpy(test_logits)
softmax_tensor = F.softmax(test_logits, dim=1)
df_res = X_test.copy()
    
if TASK == 'relnorel':
    df_res['prediction'] = ['arg' if x == 1 else 'noarg' for x in test_predictions]
elif TASK == 'attsup':
    df_res['prediction'] = ['against' if x == 1 else 'for' for x in test_predictions]
elif TASK == 'premiseclaim':
    df_res['prediction'] = ['claim' if x == 1 else 'premise' for x in test_predictions]
    
df_res['logits'] = [list(row) for row in softmax_tensor.numpy()]
df_res.to_csv(f"binary_{TASK}_c1_SPLIT_gold.csv", index = False)

################################################
#### Predictions sobre instancias con error (SPLIT 1) ####
################################################

if TASK == 'relnorel':
    test_labels = np.asarray([1 if x != 'NoArgument' else 0 for x in X_test.relTag.values])       
elif TASK == 'attsup':
    test_labels = np.asarray([1 if x != 'Argument_against' else 0 for x in X_test.stanceTag.values])
elif TASK == 'premiseclaim':
    test_labels = np.asarray([1 for x in X_test.stanceTag.values])
else:
    print("Invalid Task")

test_loader = prepare_input_for_test(X_test, 'predictedArgument1', test_labels, tokenizer, bs = 32, shuffle = False)

test_predictions, test_logits = test_model(model_state_for_inference, test_loader, get_logits = True)

test_logits = torch.from_numpy(test_logits)
softmax_tensor = F.softmax(test_logits, dim=1)
df_res = X_test.copy()
    
if TASK == 'relnorel':
    df_res['prediction1'] = ['arg' if x == 1 else 'noarg' for x in test_predictions]
elif TASK == 'attsup':
    df_res['prediction1'] = ['against' if x == 1 else 'for' for x in test_predictions]
elif TASK == 'premiseclaim':
    df_res['prediction1'] = ['claim' if x == 1 else 'premise' for x in test_predictions]
    
df_res['logits1'] = [list(row) for row in softmax_tensor.numpy()]

# ---------------
test_loader = prepare_input_for_test(X_test, 'predictedArgument2', test_labels, tokenizer, bs = 32, shuffle = False)
test_predictions, test_logits = test_model(model_state_for_inference, test_loader, get_logits = True)

test_logits = torch.from_numpy(test_logits)
softmax_tensor = F.softmax(test_logits, dim=1)
    
if TASK == 'relnorel':
    df_res['prediction2'] = ['arg' if x == 1 else 'noarg' for x in test_predictions]
elif TASK == 'attsup':
    df_res['prediction2'] = ['against' if x == 1 else 'for' for x in test_predictions]
elif TASK == 'premiseclaim':
    df_res['prediction2'] = ['claim' if x == 1 else 'premise' for x in test_predictions]
    
df_res['logits2'] = [list(row) for row in softmax_tensor.numpy()]

# ------------------

df_res.to_csv(f"binary_{TASK}_c1_SPLIT_with_errors.csv", index = False)