In [1]:
#pra conseguir algum erro util da porra do torch
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [2]:
import random
import time
import datetime


import torch
import torch.nn as nn
from torchsummary import summary
from transformers import AutoModel, AutoConfig, AutoTokenizer, get_constant_schedule_with_warmup, BertForSequenceClassification

from torch.utils.data import TensorDataset, DataLoader, SequentialSampler, RandomSampler



import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from skmultilearn.skmultilearn.model_selection.iterative_stratification import iterative_train_test_split, IterativeStratification
from sklearn.metrics import classification_report, average_precision_score, f1_score

# Utils

In [3]:
class EarlyStopper:
    def __init__(self, patience=4, min_delta=0,is_loss=True):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.is_loss = is_loss
        if is_loss:
            self.min_validation_loss = np.inf
        else:
            self.min_validation_loss = -np.inf

    def early_stop(self, validation_loss):
        if self.is_loss:
            if validation_loss < self.min_validation_loss:
                self.min_validation_loss = validation_loss
                self.counter = 0
                print("NEW LOWEST LOSS ",self.min_validation_loss)
            elif validation_loss >= (self.min_validation_loss + self.min_delta):
                self.counter += 1
                if self.counter >= self.patience:
                    return True
            return False
        
        else:
            if validation_loss > self.min_validation_loss:
                self.min_validation_loss = validation_loss
                self.counter = 0
                print("NEW HIGHEST SCORE ",self.min_validation_loss)
            elif validation_loss <= (self.min_validation_loss + self.min_delta):
                self.counter += 1
                if self.counter >= self.patience:
                    return True
            return False

In [4]:
def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

# Hiperparâmetros

In [5]:
seed_val = 2023
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
if torch.cuda.is_available():
  torch.cuda.manual_seed_all(seed_val)
  device = 0

In [6]:
learning_rate = 5e-5
num_train_epochs = 20
print_each_n_step = 100
batch_size= 16
warmup_proportion = 0.2
apply_scheduler = True
max_length = 128

# Dados

## Tarefa primária

In [7]:
def process_data(df):
    df['Alteração na eficiência/funcionalidade'] = (df['Alteração na eficiência/funcionalidade'] == 1) | (df["Alteração da funcionalidade"] == 1) | (df["Alteração na eficiência"] == 1)
    df = df.drop(["Postagem com possível perfil depressivo","Alteração na eficiência",
         "Alteração da funcionalidade","*",'Agitação/inquietação','Sintoma obsessivo e compulsivo','Déficit de atenção/Memória',
              'Perda/Diminuição do prazer/ Perda/Diminuição da libido'],axis=1)
    df['Neutro'] =  df.iloc[:,1:].sum(axis=1) == 0
    df = df.replace({True:1,False:0})

    df = df.reset_index(drop=True)
    return df

In [8]:
train_df = process_data(pd.read_csv("data/segredos_sentenças_multitask_train_clean.csv",index_col=0))
test_df = process_data(pd.read_csv("data/segredos_sentenças_multitask_test_clean.csv",index_col=0))


In [9]:

symptom_num = train_df.iloc[:,1:].shape[1]
target_names_primary = list(train_df.iloc[:,1:].columns)

## Tarefa auxiliar

In [10]:
go_emotions_path = "data/goemotions"

auxiliary_train = pd.read_csv(f"{go_emotions_path}/train.tsv",sep='\t')
auxiliary_val = pd.read_csv(f"{go_emotions_path}/dev.tsv",sep='\t')

auxiliary_test = pd.read_csv(f"{go_emotions_path}/test.tsv",sep='\t')


In [11]:
def change_label_encoding(row):
    #mudar encoding para ser multirrótulo (preferível para o cálculo da entropia)
    labels = row['labels'].replace(" ","").split(",")
    new_row = {}
    for emotion in emotion_dict:
        if str(emotion_dict[emotion]) in labels:
            new_row[emotion] = True
        else:
            new_row[emotion] = False
    return new_row

In [12]:
emotion_dict = {
    "admiração": 0,
    "diversão": 1,
    "raiva": 2,
    "aborrecimento": 3,
    "aprovação": 4,
    "zelo": 5,
    "confusão": 6,
    "curiosidade": 7,
    "desejo": 8,
    "decepção": 9,
    "desaprovação": 10,
    "nojo": 11,
    "constrangimento": 12,
    "entusiasmo": 13,
    "medo": 14,
    "gratidão": 15,
    "luto": 16,
    "alegria": 17,
    "amor": 18,
    "nervosismo": 19,
    "otimismo": 20,
    "orgulho": 21,
    "percepção": 22,
    "alívio": 23,
    "remorso": 24,
    "tristeza": 25,
    "surpresa": 26,
    "neutro": 27
}


In [13]:
emotion_labels_test = pd.DataFrame(list(auxiliary_test.apply(change_label_encoding,axis=1)))
emotion_labels_train = pd.DataFrame(list(auxiliary_train.apply(change_label_encoding,axis=1)))
emotion_labels_val = pd.DataFrame(list(auxiliary_val.apply(change_label_encoding,axis=1)))


In [14]:
emotion_num = len(emotion_dict)
target_names_auxiliary = list(auxiliary_train.columns)

In [15]:
emotion_labels_train

Unnamed: 0,admiração,diversão,raiva,aborrecimento,aprovação,zelo,confusão,curiosidade,desejo,decepção,...,amor,nervosismo,otimismo,orgulho,percepção,alívio,remorso,tristeza,surpresa,neutro
0,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,True
1,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,True
2,False,False,True,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
3,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
4,False,False,False,True,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
43405,False,False,False,False,False,False,False,False,False,False,...,True,False,False,False,False,False,False,False,False,False
43406,False,False,False,False,False,False,True,False,False,False,...,False,False,False,False,False,False,False,False,False,False
43407,False,False,False,True,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
43408,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False


# Conjunto de logits

In [16]:
logits_df = pd.read_csv("features/both tasks facebook + reddit + goemotions.csv",index_col = 0)


In [17]:
logits_reddit = logits_df.loc[train_df.shape[0]: logits_df.shape[0] - emotion_labels_train.shape[0]]


In [18]:

from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
#pegando só uma parte dos dados do goemotions. Não é necessário treinar no conjunto completo
strat = MultilabelStratifiedKFold(n_splits=10, shuffle=True, random_state=seed_val).split(auxiliary_train,emotion_labels_train)

for train_idx, test_idx in strat:
    
    logits_emotions = logits_df.tail(emotion_labels_train.shape[0]).reset_index(drop=True)
    emotion_labels_train = emotion_labels_train.loc[test_idx]
    auxiliary_train = auxiliary_train.loc[test_idx]
    logits_emotions = logits_df.loc[test_idx]
    break


In [19]:
logits_symptoms = logits_df.head(train_df.shape[0])



# Modelo

In [20]:
class Mixup(torch.nn.Module):
    def __init__(self,mixup_alpha=1):
        
        super(Mixup,self).__init__()
    
        self.mixup_alpha = mixup_alpha
        
        
    def mixup(self,batch_ids,batch_labels,alpha=1):
        '''Returns mixed inputs, pairs of targets, and lambda'''
        if alpha > 0:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = 1

        batch_size = batch_ids.size()[0]

        index = torch.randperm(batch_size).cuda()


        mixed_x = lam * batch_ids + (1 - lam) * batch_ids[index, :]
        #using bigger input mask
        #mixed_masks = batch_masks | batch_masks[index,:]
        y_a, y_b = batch_labels, batch_labels[index]
        return mixed_x, y_a, y_b, lam
    
    def forward(self,outputs,labels):
       
        if self.training:
            outputs, labels_a, labels_b, lamb = self.mixup(outputs,labels,self.mixup_alpha)

            return outputs,labels_a,labels_b,lamb
        else:
            return outputs


In [21]:
class BertSoftSharingLayer(nn.Module):
    def __init__(self,encoders):
        super(BertSoftSharingLayer,self).__init__()

        self.encoders = nn.ModuleList(encoders)
        self.n_tasks = len(encoders)
        #linear combination weights between tasks
        #self.alpha = nn.Parameter(torch.eye(len(encoders),len(encoders),requires_grad=True))
        #self.alpha = nn.Parameter(torch.full((len(encoders),len(encoders)),0.5,requires_grad=True))
        
        self.alpha_prim = nn.Parameter(torch.tensor(0.5,requires_grad=True))
        self.beta_prim = nn.Parameter(torch.tensor(0.5,requires_grad=True))
        
        self.alpha_aux = nn.Parameter(torch.tensor(0.5,requires_grad=True))
        self.beta_aux = nn.Parameter(torch.tensor(0.5,requires_grad=True))
        
        self.alphas = [self.alpha_prim,self.alpha_aux]
        self.betas = [self.beta_prim,self.beta_aux]

    def prepare_for_task(self,task_idx):
        #notar que os alphas não foram tocados. Estes são sempre treinados
        for n in range(self.n_tasks):
            if n != task_idx:
                for param in self.encoders[n].parameters():
                    param.requires_grad=False
                    self.alphas[n].requires_grad=False
                    self.betas[n].requires_grad=False
            else:
                for param in self.encoders[n].parameters():
                    param.requires_grad=True
                    self.alphas[n].requires_grad=True
                    self.betas[n].requires_grad=True

        
    def forward(self,input_arr,attention_mask):
        outputs = []
        for i,encoder in enumerate(self.encoders):
            outputs.append(encoder(input_arr[i],attention_mask = attention_mask)[0])
        
        outputs[0] = outputs[0] * self.alpha_prim + outputs[1] * self.beta_prim
        outputs[1] = outputs[1] * self.alpha_aux + outputs[0] * self.beta_aux

        outputs = torch.stack(outputs)

        #linear_combination = torch.matmul(self.alpha,outputs.reshape(self.n_tasks,torch.prod(torch.tensor(outputs.size()[1:]))))
        #return linear_combination.reshape(outputs.size())
        return outputs
        
        

In [22]:
class BertSharedParametersModel(nn.Module):
    def __init__(self,models,attention_n,hidden_size,num_labels,dropout_rate=0.3,mixup_alphas=[1,1],mixup_layer=12):
        super(BertSharedParametersModel,self).__init__()
        self.n_tasks = len(models)
        self.embeddings = nn.ModuleList([model.bert.embeddings for model in models])
        soft_sharing_layers = []
        #assumindo que os modelos têm o mesmo n° de camadas de atenção
        for i in range(attention_n):
            attention_layers = [model.bert.encoder.layer[i] for model in models]
            soft_sharing_layers.append(BertSoftSharingLayer(attention_layers))
        self.soft_sharing_layers = nn.ModuleList(soft_sharing_layers)
        self.poolers = nn.ModuleList([model.bert.pooler for model in models])
        self.dropouts = nn.ModuleList([model.dropout for model in models])
        self.classification_heads = nn.ModuleList([model.classifier for model in models])
        self.mixups = nn.ModuleList([Mixup(alpha) for alpha in mixup_alphas])
        self.mixup_layer = mixup_layer
        for i,_ in enumerate(self.dropouts):
            self.dropouts.p = dropout_rate
            
            
    def prepare_for_task(self,task_idx):
        for n in range(self.n_tasks):
            if n == task_idx:
                for param in self.embeddings[n].parameters():
                    param.requires_grad= True
                
                for i in range (len(self.soft_sharing_layers)):
                    self.soft_sharing_layers[i].prepare_for_task(task_idx)
                
                for param in self.poolers[n].parameters():
                    param.requires_grad= True
                
                for param in self.classification_heads[n].parameters():
                    param_requires_grad = True
            else:
                
                for param in self.embeddings[n].parameters():
                    param.requires_grad= False
                
                for param in self.poolers[n].parameters():
                    param.requires_grad= False
                
                for param in self.classification_heads[n].parameters():
                    param_requires_grad = False

    
    
    def forward(self,input_arr,attention_mask,labels,task_idx):
        
        self.prepare_for_task(task_idx)

        
        
        extended_attention_mask: extended_attention_mask = attention_mask[:, None, None, :]
        extended_attention_mask = extended_attention_mask.to(dtype=torch.float)  
        extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(torch.float).min
        outputs = []
        for i, embedding in enumerate(self.embeddings):
           
            outputs.append(embedding(input_arr))
            
        outputs = torch.stack(outputs)
        

        for i,sharing_layer in enumerate(self.soft_sharing_layers):
            outputs = sharing_layer(outputs,attention_mask=extended_attention_mask)
            if i == (self.mixup_layer-1) and self.training:
                outputs_cur_task,labels_a,labels_b,lamb = self.mixups[task_idx](outputs[task_idx],labels)
                outputs[task_idx] = outputs_cur_task
        outputs = self.poolers[task_idx](outputs[task_idx])
        
        
        outputs = self.dropouts[task_idx](outputs)
            

        outputs = self.classification_heads[task_idx](outputs)
        if self.training:
            return outputs,labels_a,labels_b,lamb
        
        return outputs


In [23]:
#combined_model.prepare_for_task(5)
#for param in combined_model.soft_sharing_layers[2].encoders[0].parameters():
#    print(param.requires_grad)

In [24]:
#primary_model_path = "models/BERT baselines"
primary_model_path = "neuralmind/bert-base-portuguese-cased"
#auxiliary_model_path = "models/go_emotions"
auxiliary_model_path = "neuralmind/bert-base-portuguese-cased"

tokenizer_path = "neuralmind/bert-base-portuguese-cased"

In [25]:
primary_model = BertForSequenceClassification.from_pretrained(primary_model_path,num_labels=symptom_num)
auxiliary_model = BertForSequenceClassification.from_pretrained(auxiliary_model_path,num_labels=emotion_num)

primary_config = AutoConfig.from_pretrained(primary_model_path)
auxiliary_config = AutoConfig.from_pretrained(auxiliary_model_path)

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

Some weights of the model checkpoint at neuralmind/bert-base-portuguese-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the

In [26]:
assert primary_config.num_hidden_layers == auxiliary_config.num_hidden_layers
combined_model = BertSharedParametersModel([primary_model,auxiliary_model],primary_config.num_hidden_layers,
                                          primary_config.hidden_size,[symptom_num,emotion_num])

In [27]:
summary(combined_model,input_size=(768,),depth=5,batch_dim=2, dtypes=[torch.IntTensor]) 

Layer (type:depth-idx)                                  Param #
├─ModuleList: 1-1                                       --
|    └─BertEmbeddings: 2-1                              --
|    |    └─Embedding: 3-1                              22,881,792
|    |    └─Embedding: 3-2                              393,216
|    |    └─Embedding: 3-3                              1,536
|    |    └─LayerNorm: 3-4                              1,536
|    |    └─Dropout: 3-5                                --
|    └─BertEmbeddings: 2-2                              --
|    |    └─Embedding: 3-6                              22,881,792
|    |    └─Embedding: 3-7                              393,216
|    |    └─Embedding: 3-8                              1,536
|    |    └─LayerNorm: 3-9                              1,536
|    |    └─Dropout: 3-10                               --
├─ModuleList: 1-2                                       --
|    └─BertSoftSharingLayer: 2-3                        --
|    |    └─M

Layer (type:depth-idx)                                  Param #
├─ModuleList: 1-1                                       --
|    └─BertEmbeddings: 2-1                              --
|    |    └─Embedding: 3-1                              22,881,792
|    |    └─Embedding: 3-2                              393,216
|    |    └─Embedding: 3-3                              1,536
|    |    └─LayerNorm: 3-4                              1,536
|    |    └─Dropout: 3-5                                --
|    └─BertEmbeddings: 2-2                              --
|    |    └─Embedding: 3-6                              22,881,792
|    |    └─Embedding: 3-7                              393,216
|    |    └─Embedding: 3-8                              1,536
|    |    └─LayerNorm: 3-9                              1,536
|    |    └─Dropout: 3-10                               --
├─ModuleList: 1-2                                       --
|    └─BertSoftSharingLayer: 2-3                        --
|    |    └─M

# Dataloaders

In [28]:
from torch.utils.data.dataset import ConcatDataset

def get_dataloader(texts,labels,idx,batch_size,shuffle):
    
    input_ids = []
    attention_masks = []
    for text in texts:
        encoded = tokenizer.encode_plus(text,return_attention_mask=True,add_special_tokens=True,max_length = max_length,
    padding="max_length",truncation=True)
        input_ids.append(encoded['input_ids'])
        attention_masks.append(encoded['attention_mask'])
    

    input_ids = torch.tensor(input_ids)
    attention_masks = torch.tensor(attention_masks)
    lbs = torch.tensor(labels.to_numpy(),dtype=torch.float32)
    task_idx = torch.zeros(lbs.size()).fill_(idx)
    dataset = TensorDataset(input_ids,attention_masks,lbs,task_idx)
    
    return torch.utils.data.DataLoader(dataset=dataset,batch_size=batch_size,shuffle=shuffle)


def fill_unlabeled(symptoms,df,symptom_num):
    data = []
    for row in range(df.shape[0] - symptoms.shape[0]):
        data.append([-1] * symptom_num)
    data = pd.DataFrame(data)
    data.columns = symptoms.columns
    symptoms = pd.concat([symptoms,data],axis=0)
    return symptoms.reset_index(drop=True)


def get_logit_dataloader(texts,logits,labels,idx,batch_size,shuffle):
    
    input_ids = []
    attention_masks = []
    for text in texts:
        encoded = tokenizer.encode_plus(text,return_attention_mask=True,add_special_tokens=True,max_length = max_length,
    padding="max_length",truncation=True)
        input_ids.append(encoded['input_ids'])
        attention_masks.append(encoded['attention_mask'])
    
    labels = fill_unlabeled(labels,logits,labels.shape[1])
    input_ids = torch.tensor(input_ids)
    attention_masks = torch.tensor(attention_masks)
    logits = torch.tensor(logits.to_numpy(),dtype=torch.float32)
    lbs = torch.tensor(labels.to_numpy(),dtype=torch.float32)
    task_idx = torch.zeros(lbs.size()).fill_(idx)
    dataset = TensorDataset(input_ids,attention_masks,logits,lbs,task_idx)
    
    return torch.utils.data.DataLoader(dataset=dataset,batch_size=batch_size,shuffle=shuffle)
    

In [29]:
temp = pd.concat([logits_symptoms,logits_reddit])
prim_train_dataloader = get_logit_dataloader(temp.text,temp.iloc[:,:symptom_num],train_df.iloc[:,1:],0,batch_size,shuffle=True)

temp = pd.concat([logits_emotions,logits_reddit])
aux_train_dataloader = get_logit_dataloader(temp.text,temp.iloc[:,symptom_num+1:symptom_num+emotion_num+1],emotion_labels_train,1,batch_size,shuffle=True)
train_dataloaders = [prim_train_dataloader,aux_train_dataloader]

prim_test_dataloader = get_dataloader(test_df.text,test_df.iloc[:,1:],0,batch_size,shuffle=True)
aux_test_dataloader = get_dataloader(auxiliary_val.text,emotion_labels_val,1,batch_size,shuffle=True)
test_dataloaders = [prim_test_dataloader,aux_test_dataloader]

# Treinamento

In [30]:
def focal_loss(p,targets,gamma=3,alpha=0.8):
    bce = torch.nn.functional.binary_cross_entropy(p,targets,reduction='none')
    p = torch.where(targets == 1,p,1-p)
    alpha = targets * alpha + (1-targets) * (1 - alpha)
    loss = (bce * alpha * (1 - p) ** gamma)
    return loss.mean()

In [31]:
def kl_loss(student,teacher,T=8):
    loss = torch.nn.KLDivLoss(reduction="batchmean")
    student = torch.nn.functional.log_softmax(student/T)
    teacher = torch.nn.functional.softmax(teacher/T)

    
    return loss(student,teacher)

In [32]:
def focal_kl_loss(student,teacher,targets,weight=0.5):
    kl = kl_loss(student,teacher)
    student = torch.masked_select(student,targets!=-1)
    targets = torch.masked_select(targets,targets!=-1)
    if targets.size(0) > 0:

        focal = focal_loss(sigmoid(student),targets)
    else:
        focal = kl
    return focal * weight + kl * (1-weight)

In [33]:
def naive_combination(p_loss,a_loss,alpha=0.5):
    return p_loss * alpha + a_loss * (1-alpha)

In [34]:
def mixup_criterion(criterion_params_a,criterion_params_b,lamb,criterion):
    return criterion(*criterion_params_a) * lamb + criterion(*criterion_params_b) * (1-lamb)

In [35]:
combined_model = combined_model.to(device)

In [None]:
model_vars = [i[1] for i in combined_model.named_parameters() if i[0].find("alpha") == -1 and i[0].find("beta") == -1]
alpha_vars = [i[1] for i in combined_model.named_parameters() if i[0].find("alpha") != -1 or i[0].find("beta") != -1] 


prim_optimizer = torch.optim.AdamW([{"params":model_vars},{"params":alpha_vars,"lr":0.001}],lr=learning_rate)
sigmoid = torch.sigmoid
aux_optimizer = torch.optim.AdamW([{"params":model_vars},{"params":alpha_vars,"lr":0.001}],lr=learning_rate)


if apply_scheduler:
    num_train_steps = int(len(aux_train_dataloader) * num_train_epochs)
    num_warmup_steps = int(num_train_steps * warmup_proportion)

    prim_scheduler = get_constant_schedule_with_warmup(prim_optimizer, 
                                           num_warmup_steps = num_warmup_steps)
    
    aux_scheduler = get_constant_schedule_with_warmup(aux_optimizer, 
                                           num_warmup_steps = num_warmup_steps)
    

sigmoid = torch.sigmoid


    
train_loss_func = focal_kl_loss
loss_func = focal_loss
multitask_loss = naive_combination

best_loss_prim = np.inf

best_auc_prim = 0

best_loss_aux = np.inf

best_auc_aux = 0



def calculate_loss_batch_train(batch,model,criterion):
    
    input_ids = batch[0].to(device)
    input_masks = batch[1].to(device)
    logits = batch[2].to(device)
    labels = batch[3].to(device)

    task_idx = int(batch[4][0][0].item())
    
    #print(batch)
    
    if model.training:
        
        loss_f = lambda x,y,z: mixup_criterion(x,y,z,criterion=criterion)

        outputs,labels_a,labels_b,lamb = model(input_ids,input_masks,labels,task_idx)

        return loss_f((outputs,logits,labels_a),(outputs,logits,labels_b),lamb)
    else:
        #acho que isso aqui nao importa mais
        outputs = model(input_ids,input_masks,labels,task_idx)
        return criterion(sigmoid(outputs),labels), outputs,labels
    
    
def calculate_loss_batch(batch,model,criterion):
    
    input_ids = batch[0].to(device)
    input_masks = batch[1].to(device)
    labels = batch[2].to(device)

    task_idx = int(batch[3][0][0].item())
    
    
    if model.training:
        
        loss_f = lambda x,y,z: mixup_criterion(x,y,z,criterion=criterion)

        outputs,labels_a,labels_b,lamb = model(input_ids,input_masks,labels,task_idx)

        return loss_f((sigmoid(outputs),labels_a),(sigmoid(outputs),labels_b),lamb)
    else:
        
        outputs = model(input_ids,input_masks,labels,task_idx)
        return criterion(sigmoid(outputs),labels), outputs,labels    
    
    
    



for epoch_i in range(0,num_train_epochs):

    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, num_train_epochs))
    print('Training...')
    t0 = time.time()
    
    tr_loss = 0
    pr_loss = 0
    aux_loss = 0
    
    combined_model.train()
    prim_iter = iter(prim_train_dataloader)
    aux_iter = iter(aux_train_dataloader)
    
    #no caso tô assumindo que o aux é sempre maior, porque esse é meu caso agora
    
    for step,aux_batch in enumerate(aux_iter):
        if step % print_each_n_step == 0 and not step == 0:
            elapsed = format_time(time.time() - t0)
            
            print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(aux_train_dataloader), elapsed))
        try:
            prim_batch = prim_iter.__next__()
        except StopIteration:
            prim_iter = iter(prim_train_dataloader)
            prim_batch = prim_iter.__next__()
        
        prim_loss = calculate_loss_batch_train(prim_batch,combined_model,train_loss_func)
        prim_optimizer.zero_grad()
        prim_loss.backward()
        prim_optimizer.step()
        
        if apply_scheduler:
            prim_scheduler.step()
        
        
        aux_loss = calculate_loss_batch_train(aux_batch,combined_model,train_loss_func)
        
        aux_optimizer.zero_grad()
        aux_loss.backward()
        aux_optimizer.step()
        
        if apply_scheduler:
            aux_scheduler.step()
        
        
        
        with torch.no_grad():
            batch_loss = multitask_loss(prim_loss,aux_loss)
        

            
            tr_loss += batch_loss.item()
        
            pr_loss += prim_loss.item()
        
            aux_loss += aux_loss.item()

        
    
    avg_prim_loss = pr_loss/len(aux_train_dataloader)    
    avg_train_loss = tr_loss/len(aux_train_dataloader)
    avg_aux_loss = aux_loss/len(aux_train_dataloader)    

    
    training_time = format_time(time.time() - t0)

    print("")
    print("  Average training loss: {:}".format(avg_train_loss))
    print("  Average primary loss: {:}".format(avg_prim_loss))
    print("  Average aux loss: {:}".format(avg_aux_loss))

    print("  Training epoch took: {:}".format(training_time))
    
    
    print("  Testing primary...")
    combined_model.eval()
    
    pr_loss = 0
    tt_loss = 0
    aux_loss = 0
    
    all_probs = []
    all_labels_ids = []
    
    for prim_batch in prim_test_dataloader:
        with torch.no_grad():
            prim_loss,outputs,labels = calculate_loss_batch(prim_batch,combined_model,loss_func)
            probs = sigmoid(outputs)
            pr_loss += prim_loss.item()
            
        all_probs += probs.detach().cpu()
        all_labels_ids += labels.detach().cpu()
        
    all_probs = torch.stack(all_probs).numpy()
    all_labels_ids = torch.stack(all_labels_ids).numpy()
    
    cur_auc = average_precision_score(all_labels_ids,all_probs)

    
    avg_prim_test_loss = pr_loss/len(prim_test_dataloader)
    
    if cur_auc > best_auc_prim:
        best_auc_prim = cur_auc
        
    if avg_prim_test_loss < best_loss_prim:
        best_loss = avg_prim_test_loss
        best_probs_prim = all_probs
        best_labels_prim = all_labels_ids
        
    print("  AUC primária:",cur_auc)
    print("  Average test primary loss: {:}".format(avg_prim_test_loss))

    
    
    print("  Testing auxiliary...")
    

    
    aux_loss = 0
    
    all_probs = []
    all_labels_ids = []
    
    for prim_batch in aux_test_dataloader:
        with torch.no_grad():
            prim_loss,outputs,labels = calculate_loss_batch(prim_batch,combined_model,loss_func)
            probs = sigmoid(outputs)
            aux_loss += prim_loss.item()
            
        all_probs += probs.detach().cpu()
        all_labels_ids += labels.detach().cpu()
        
    all_probs = torch.stack(all_probs).numpy()
    all_labels_ids = torch.stack(all_labels_ids).numpy()
    
    cur_auc = average_precision_score(all_labels_ids,all_probs)

    
    avg_aux_test_loss = aux_loss/len(prim_test_dataloader)  
    
    if cur_auc > best_auc_aux:
        best_auc_aux = cur_auc
        
    if avg_aux_test_loss < best_loss_aux:
        best_loss = avg_aux_test_loss
        best_probs_aux = all_probs
        best_labels_aux = all_labels_ids
        
    print("  AUC auxiliar:",cur_auc)
    avg_test_loss = multitask_loss(avg_prim_test_loss,avg_aux_test_loss)
    
    print("  Average test loss: {:}".format(avg_test_loss))
    print("  Average test aux loss: {:}".format(avg_aux_test_loss))
    
    
    
    


Training...


  student = torch.nn.functional.log_softmax(student/T)
  teacher = torch.nn.functional.softmax(teacher/T)


  Batch   100  of    386.    Elapsed: 0:01:33.
  Batch   200  of    386.    Elapsed: 0:03:06.
  Batch   300  of    386.    Elapsed: 0:04:39.

  Average training loss: 0.012773506741468925
  Average primary loss: 0.007991020633462681
  Average aux loss: 7.840731996111572e-05
  Training epoch took: 0:05:59
  Testing primary...
  AUC primária: 0.13955542741388183
  Average test primary loss: 0.010306467214282954
  Testing auxiliary...
  AUC auxiliar: 0.07120043504506529
  Average test loss: 0.04092991093771075
  Average test aux loss: 0.07155335466113855

Training...


  student = torch.nn.functional.log_softmax(student/T)
  teacher = torch.nn.functional.softmax(teacher/T)


  Batch   100  of    386.    Elapsed: 0:01:34.
  Batch   200  of    386.    Elapsed: 0:03:07.
  Batch   300  of    386.    Elapsed: 0:04:41.

  Average training loss: 0.010057473360980654
  Average primary loss: 0.004716574000267051
  Average aux loss: 7.063095836201683e-05
  Training epoch took: 0:06:01
  Testing primary...
  AUC primária: 0.2625204310389268
  Average test primary loss: 0.008594012327211083
  Testing auxiliary...
  AUC auxiliar: 0.16316033008668404
  Average test loss: 0.03613697908344274
  Average test aux loss: 0.0636799458396744

Training...


  student = torch.nn.functional.log_softmax(student/T)
  teacher = torch.nn.functional.softmax(teacher/T)


  Batch   100  of    386.    Elapsed: 0:01:34.
  Batch   200  of    386.    Elapsed: 0:03:07.
  Batch   300  of    386.    Elapsed: 0:04:41.

  Average training loss: 0.009507695911452174
  Average primary loss: 0.003993575738675391
  Average aux loss: 6.964640488149598e-05
  Training epoch took: 0:06:01
  Testing primary...
  AUC primária: 0.34516602730800533
  Average test primary loss: 0.008525640888245038
  Testing auxiliary...
  AUC auxiliar: 0.23313751177388067
  Average test loss: 0.03330399381119828
  Average test aux loss: 0.058082346734151524

Training...


  student = torch.nn.functional.log_softmax(student/T)
  teacher = torch.nn.functional.softmax(teacher/T)


  Batch   100  of    386.    Elapsed: 0:01:34.
  Batch   200  of    386.    Elapsed: 0:03:07.
  Batch   300  of    386.    Elapsed: 0:04:41.

  Average training loss: 0.008950795627252194
  Average primary loss: 0.003419344441921793
  Average aux loss: 9.765219147084281e-05
  Training epoch took: 0:06:01
  Testing primary...
  AUC primária: 0.35969524449839774
  Average test primary loss: 0.007757307733145525
  Testing auxiliary...
  AUC auxiliar: 0.26442542455291934
  Average test loss: 0.031096901538131654
  Average test aux loss: 0.05443649534311778

Training...


  student = torch.nn.functional.log_softmax(student/T)
  teacher = torch.nn.functional.softmax(teacher/T)


  Batch   100  of    386.    Elapsed: 0:01:34.
  Batch   200  of    386.    Elapsed: 0:03:07.
  Batch   300  of    386.    Elapsed: 0:04:41.

  Average training loss: 0.008313133301686284
  Average primary loss: 0.0028196725389049173
  Average aux loss: 6.705034320475534e-05
  Training epoch took: 0:06:01
  Testing primary...
  AUC primária: 0.3593313398670501
  Average test primary loss: 0.008682322253490676
  Testing auxiliary...
  AUC auxiliar: 0.257725982776983
  Average test loss: 0.031352426156067766
  Average test aux loss: 0.05402253005864485

Training...


  student = torch.nn.functional.log_softmax(student/T)
  teacher = torch.nn.functional.softmax(teacher/T)


  Batch   100  of    386.    Elapsed: 0:01:34.
  Batch   200  of    386.    Elapsed: 0:03:07.
  Batch   300  of    386.    Elapsed: 0:04:41.

  Average training loss: 0.007580795414610693
  Average primary loss: 0.002314415940993171
  Average aux loss: 5.8877460105577484e-05
  Training epoch took: 0:06:01
  Testing primary...
  AUC primária: 0.36483255600805076
  Average test primary loss: 0.00988547848361843
  Testing auxiliary...
  AUC auxiliar: 0.2554741995524733
  Average test loss: 0.03385278141512623
  Average test aux loss: 0.057820084346634035

Training...


  student = torch.nn.functional.log_softmax(student/T)
  teacher = torch.nn.functional.softmax(teacher/T)


  Batch   100  of    386.    Elapsed: 0:01:34.
  Batch   200  of    386.    Elapsed: 0:03:07.
  Batch   300  of    386.    Elapsed: 0:04:40.

  Average training loss: 0.00711793399662011
  Average primary loss: 0.002103492419007403
  Average aux loss: 5.35353938175831e-05
  Training epoch took: 0:06:01
  Testing primary...
  AUC primária: 0.35516625266710083
  Average test primary loss: 0.010875128670859168
  Testing auxiliary...
  AUC auxiliar: 0.26089614055128957
  Average test loss: 0.03481859676452037
  Average test aux loss: 0.05876206485818158

Training...


  student = torch.nn.functional.log_softmax(student/T)
  teacher = torch.nn.functional.softmax(teacher/T)


  Batch   100  of    386.    Elapsed: 0:01:33.
  Batch   200  of    386.    Elapsed: 0:03:07.
  Batch   300  of    386.    Elapsed: 0:04:40.

  Average training loss: 0.006822537853976065
  Average primary loss: 0.0019729913396632943
  Average aux loss: 6.839061825303361e-05
  Training epoch took: 0:06:00
  Testing primary...
  AUC primária: 0.3574536435245241
  Average test primary loss: 0.010619720439690183
  Testing auxiliary...
  AUC auxiliar: 0.27502114019497786
  Average test loss: 0.035711761298237964
  Average test aux loss: 0.06080380215678575

Training...


  student = torch.nn.functional.log_softmax(student/T)
  teacher = torch.nn.functional.softmax(teacher/T)


  Batch   100  of    386.    Elapsed: 0:01:33.
  Batch   200  of    386.    Elapsed: 0:03:07.
  Batch   300  of    386.    Elapsed: 0:04:40.

  Average training loss: 0.006595407833388664
  Average primary loss: 0.0019058523968964757
  Average aux loss: 5.306803359417245e-05
  Training epoch took: 0:06:00
  Testing primary...
  AUC primária: 0.3408742032872066
  Average test primary loss: 0.012165521350601371
  Testing auxiliary...
  AUC auxiliar: 0.27526094388394234
  Average test loss: 0.037292262310829924
  Average test aux loss: 0.06241900327105848

Training...


  student = torch.nn.functional.log_softmax(student/T)
  teacher = torch.nn.functional.softmax(teacher/T)


  Batch   100  of    386.    Elapsed: 0:01:33.
  Batch   200  of    386.    Elapsed: 0:03:07.
  Batch   300  of    386.    Elapsed: 0:04:40.

  Average training loss: 0.006317428558090113
  Average primary loss: 0.0018576079654722074
  Average aux loss: 5.829547444591299e-05
  Training epoch took: 0:06:00
  Testing primary...
  AUC primária: 0.3578953295543679
  Average test primary loss: 0.011597895933280014
  Testing auxiliary...
  AUC auxiliar: 0.2818376654157165
  Average test loss: 0.03748705370773404
  Average test aux loss: 0.06337621148218806

Training...


  student = torch.nn.functional.log_softmax(student/T)
  teacher = torch.nn.functional.softmax(teacher/T)


  Batch   100  of    386.    Elapsed: 0:01:33.
  Batch   200  of    386.    Elapsed: 0:03:07.
  Batch   300  of    386.    Elapsed: 0:04:40.

  Average training loss: 0.006264767114922327
  Average primary loss: 0.0018070738480928223
  Average aux loss: 3.9167276554508135e-05
  Training epoch took: 0:06:00
  Testing primary...
  AUC primária: 0.3528528016565041
  Average test primary loss: 0.011180803575113698
  Testing auxiliary...
  AUC auxiliar: 0.24724038552728334
  Average test loss: 0.038075475399997435
  Average test aux loss: 0.06497014722488118

Training...


  student = torch.nn.functional.log_softmax(student/T)
  teacher = torch.nn.functional.softmax(teacher/T)


  Batch   100  of    386.    Elapsed: 0:01:34.
  Batch   200  of    386.    Elapsed: 0:03:07.
  Batch   300  of    386.    Elapsed: 0:04:40.

  Average training loss: 0.006127146592030755
  Average primary loss: 0.0018060937055013832
  Average aux loss: 5.696051084669307e-05
  Training epoch took: 0:06:01
  Testing primary...
  AUC primária: 0.3636049148595837
  Average test primary loss: 0.011197310969022647
  Testing auxiliary...
  AUC auxiliar: 0.2658630789327418
  Average test loss: 0.03883619212440022
  Average test aux loss: 0.0664750732797778

Training...


  student = torch.nn.functional.log_softmax(student/T)
  teacher = torch.nn.functional.softmax(teacher/T)


  Batch   100  of    386.    Elapsed: 0:01:34.
  Batch   200  of    386.    Elapsed: 0:03:07.
  Batch   300  of    386.    Elapsed: 0:04:41.

  Average training loss: 0.006032092548626924
  Average primary loss: 0.0017374955763717042
  Average aux loss: 7.023163925623521e-05
  Training epoch took: 0:06:01
  Testing primary...
  AUC primária: 0.3743527735849709
  Average test primary loss: 0.01169977577860063
  Testing auxiliary...
  AUC auxiliar: 0.27040489319647887
  Average test loss: 0.03789053309076237
  Average test aux loss: 0.06408129040292411

Training...


  student = torch.nn.functional.log_softmax(student/T)
  teacher = torch.nn.functional.softmax(teacher/T)


  Batch   100  of    386.    Elapsed: 0:01:34.
  Batch   200  of    386.    Elapsed: 0:03:07.
  Batch   300  of    386.    Elapsed: 0:04:41.

  Average training loss: 0.006025023062935417
  Average primary loss: 0.001765633027785255
  Average aux loss: 5.9270478232065216e-05
  Training epoch took: 0:06:01
  Testing primary...
  AUC primária: 0.3672477218422356
  Average test primary loss: 0.012393299151950006
  Testing auxiliary...
  AUC auxiliar: 0.27056002800056783
  Average test loss: 0.03961453987461216
  Average test aux loss: 0.06683578059727431

Training...


  student = torch.nn.functional.log_softmax(student/T)
  teacher = torch.nn.functional.softmax(teacher/T)


  Batch   100  of    386.    Elapsed: 0:01:34.
  Batch   200  of    386.    Elapsed: 0:03:07.
  Batch   300  of    386.    Elapsed: 0:04:41.

  Average training loss: 0.005896423589269307
  Average primary loss: 0.001722249578895081
  Average aux loss: 5.673586929333396e-05
  Training epoch took: 0:06:01
  Testing primary...
  AUC primária: 0.39066879255361053
  Average test primary loss: 0.011945243454801868
  Testing auxiliary...
  AUC auxiliar: 0.28587345701336914
  Average test loss: 0.040183575564714254
  Average test aux loss: 0.06842190767462664

Training...


  student = torch.nn.functional.log_softmax(student/T)
  teacher = torch.nn.functional.softmax(teacher/T)


  Batch   100  of    386.    Elapsed: 0:01:34.


In [None]:
for sharing_layer in combined_model.soft_sharing_layers:
    print(torch.tensor([[sharing_layer.alpha_prim,sharing_layer.beta_prim],[sharing_layer.beta_aux,sharing_layer.alpha_aux]]))

In [None]:
import seaborn as sns

import matplotlib.pyplot as plt



from sklearn.metrics import precision_recall_curve

preds = []

for i in range (best_probs_prim.shape[1]):

    precision,recall,thresholds = precision_recall_curve(best_labels_prim[:,i],best_probs_prim[:,i])

    f1_scores = 2*recall*precision/(recall+precision)

    cur_threshold = thresholds[np.nanargmax(f1_scores)]

    print("melhor f1 para ",train_df.columns[i+1]," ",np.nanmax(f1_scores))

    preds.append(best_probs_prim[:,i] >= cur_threshold)

    

best_preds = np.array(preds).T

best_fine_tuning_report = classification_report(best_labels_prim,best_preds,target_names=train_df.columns[1:], zero_division=0,output_dict=True)



fig, ax = plt.subplots(figsize=(15, 10))

sns.heatmap(pd.DataFrame(best_fine_tuning_report).iloc[:-1, :].T, annot=True)

