In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW


import matplotlib.pyplot as plt

import json
from statistics import mean
import pickle


In [None]:
import utils_generic as generic
import mt_dep as mt
import model_confs as confs
import train_datamaps as train 
from train import eval_func_multi

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
tasks = ['to','as','about']

In [None]:
model_conf = confs.distilbert_conf
model_name = 'distilbert'
encoding_type = ''
num_labels = 2

# Datos ConvAI

In [None]:
with open('Datasets\ConvAI2\convai2_complete.json','r') as f:
    data = json.load(f)

In [None]:


convai_train = data['train']
convai_val = data['validation']

# Cargo clase vocabulary
with open(f'vocab_{encoding_type}.pkl','rb') as f:
    vocab = pickle.load(f)

convai_train_token = generic.tokenize_dataset_with_dependencies(convai_train,['about','to','as'],vocab,model_conf) 
convai_val_token = generic.tokenize_dataset_with_dependencies(convai_val,['about','to','as'],vocab,model_conf) 

In [None]:
convai_train_dataset = mt.DatasetMultitaskDep(convai_train_token,tasks,eval=False)
convai_val_dataset = mt.DatasetMultitaskDep(convai_val_token,tasks,eval=False)

# Datos md_gender

In [None]:
with open('Datasets\md_gender\md_complete.json','r',encoding="utf8") as f:
    md_data = json.load(f)

In [None]:
md_tokenized = generic.tokenize_dataset_with_dependencies(md_data,['about','to','as'],vocab,model_conf) 
md_dataset = mt.DatasetMultitaskDep(md_tokenized,tasks,eval=True)

# Creación dataloaders

In [None]:
dl_train =DataLoader(convai_train_dataset,batch_size=128,shuffle=True,collate_fn=mt.collate_fn)
dl_val =DataLoader(convai_val_dataset,batch_size=128,shuffle=True,collate_fn=mt.collate_fn)
dl_eval = DataLoader(md_dataset,batch_size=128,shuffle=False,collate_fn=mt.collate_fn)

# Modelo multitask

In [None]:
num_epochs = 100
learning_rate = 1e-6

global_metrics = {'about':{'recall':{'weighted_avg':[],'average':[],'female':[],'male':[]},
                            'precision':{'weighted_avg':[],'average':[],'female':[],'male':[]},
                            'f1':{'weighted_avg':[],'average':[],'female':[],'male':[]},
                            'acc':[]},
                    'to':{'recall':{'weighted_avg':[],'average':[],'female':[],'male':[]},
                            'precision':{'weighted_avg':[],'average':[],'female':[],'male':[]},
                            'f1':{'weighted_avg':[],'average':[],'female':[],'male':[]},
                            'acc':[]},
                    'as':{'recall':{'weighted_avg':[],'average':[],'female':[],'male':[]},
                            'precision':{'weighted_avg':[],'average':[],'female':[],'male':[]},
                            'f1':{'weighted_avg':[],'average':[],'female':[],'male':[]},
                            'acc':[]}}

In [None]:
model = mt.MultiWithDependencies(model_conf,vocab,num_labels=2).to(device)
optimizer = AdamW(model.parameters(), lr=learning_rate,weight_decay=0.1)
save_path=f'{model_name}_multitask_{encoding_type}_1'
p, c, e = train.train_function_multi(model,num_epochs,dl_train,optimizer,early_stop = 10,dl_val = dl_val,save_path=save_path,es_threshold=0)
torch.save(p,save_path+'_probs'+'.pt')
torch.save(c,save_path+'_corr'+'.pt')

In [None]:
for task in ['about','as','to']:
    print(task.upper())
    train.get_datamap_complete_graph(p[task],correctness_vector=c[task],num_epochs=e)

In [None]:
for task in ['about','as','to']:
    print(task.upper())
    train.get_datamap_complete_graph(p[task],correctness_vector=c[task],num_epochs=e,show_samples=True)

In [None]:
model = mt.MultiWithDependencies(model_conf,vocab,num_labels=2).to(device)
model.load_state_dict(torch.load(save_path))


metrics_results = eval_func_multi(model,dl_eval,['about','to','as'])
for task, task_metrics in metrics_results.items():
    print(f'Resultados en la tarea {task.upper()}:')
    for metric, value in task_metrics.items():
        if metric=='accuracy':
            global_metrics[task]['acc'].append(value) 
        else:

            for g,v in value.items():
                global_metrics[task][metric][g].append(v)
        print(metric,metrics_results[task][metric])
    print('\n')



In [None]:
model = mt.MultiWithDependencies(model_conf,vocab,num_labels=2).to(device)
optimizer = AdamW(model.parameters(), lr=learning_rate,weight_decay=0.1)
save_path=f'{model_name}_multitask_{encoding_type}_2'
p, c, e = train.train_function_multi(model,num_epochs,dl_train,optimizer,early_stop = 10,dl_val = dl_val,save_path=save_path,es_threshold=0)
torch.save(p,save_path+'_probs'+'.pt')
torch.save(c,save_path+'_corr'+'.pt')

In [None]:
for task in ['about','as','to']:
    print(task.upper())
    train.get_datamap_complete_graph(p[task],correctness_vector=c[task],num_epochs=e)

In [None]:
for task in ['about','as','to']:
    print(task.upper())
    train.get_datamap_complete_graph(p[task],correctness_vector=c[task],num_epochs=e,show_samples=True)

In [None]:
model = mt.MultiWithDependencies(model_conf,vocab,num_labels=2).to(device)
model.load_state_dict(torch.load(save_path))


metrics_results = eval_func_multi(model,dl_eval,['about','to','as'])
for task, task_metrics in metrics_results.items():
    print(f'Resultados en la tarea {task.upper()}:')
    for metric, value in task_metrics.items():
        if metric=='accuracy':
            global_metrics[task]['acc'].append(value) 
        else:

            for g,v in value.items():
                global_metrics[task][metric][g].append(v)
        print(metric,metrics_results[task][metric])
    print('\n')



In [None]:
model = mt.MultiWithDependencies(model_conf,vocab,num_labels=2).to(device)
optimizer = AdamW(model.parameters(), lr=learning_rate,weight_decay=0.1)
save_path=f'{model_name}_multitask_{encoding_type}_3'
p, c, e = train.train_function_multi(model,num_epochs,dl_train,optimizer,early_stop = 10,dl_val = dl_val,save_path=save_path,es_threshold=0)
torch.save(p,save_path+'_probs'+'.pt')
torch.save(c,save_path+'_corr'+'.pt')

In [None]:
for task in ['about','as','to']:
    print(task.upper())
    train.get_datamap_complete_graph(p[task],correctness_vector=c[task],num_epochs=e)

In [None]:
for task in ['about','as','to']:
    print(task.upper())
    train.get_datamap_complete_graph(p[task],correctness_vector=c[task],num_epochs=e,show_samples=True)

In [None]:
model = mt.MultiWithDependencies(model_conf,vocab,num_labels=2).to(device)
model.load_state_dict(torch.load(save_path))


metrics_results = eval_func_multi(model,dl_eval,['about','to','as'])
for task, task_metrics in metrics_results.items():
    print(f'Resultados en la tarea {task.upper()}:')
    for metric, value in task_metrics.items():
        if metric=='accuracy':
            global_metrics[task]['acc'].append(value) 
        else:

            for g,v in value.items():
                global_metrics[task][metric][g].append(v)
        print(metric,metrics_results[task][metric])
    print('\n')



In [None]:
for metric in ['f1','recall','precision']:
    print(f'{metric} medio de los 3 modelos: \n')
    for task in ['about','to','as']:
        print(task.upper())
        print(f'Resultado global {metric}:',mean(global_metrics[task][metric]['average']))
        print(f'Resultado global ponderado{metric}:',mean(global_metrics[task][metric]['weighted_avg']))
        print(f'{metric} etiqueta male:',mean(global_metrics[task][metric]['male']))
        print(f'{metric} etiqueta female: ',mean(global_metrics[task][metric]['female']))
        print('\n')

print(f'Accuracy medio de los 3 modelos: \n')
for task in ['about','to','as']:
    print('\n',task.upper())
    print('Resultado global accuracy:',mean(global_metrics[task]['acc']))