In [None]:
import numpy as np
import pandas as pd
import re
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import RobertaTokenizer, RobertaForSequenceClassification, RobertaModel
from torch.optim import Adam, lr_scheduler
import torch.nn.functional as F
import sys
sys.path.append('../')

import os
import argparse
import json
import torch.nn as nn

from util import *
from losses import SupConLoss
from augment import *

from torch.utils.data.dataset import ConcatDataset
# from torch_model import SupConRobertaNet, SupConMultiRobertaNet
from torch.utils.data.sampler import RandomSampler

from torch_model import SupConRobertaNet, SupConMultiRobertaNet


In [2]:
# !pip install transformers

In [3]:
# df = pd.read_csv('data/files/diags_id.csv')
# len(df[df.columns[10]].value_counts())

In [4]:
class PetDataset(Dataset):
    def __init__(self, df):
        self.df = df
        self.SE_index = [ i for i, c in enumerate(df.columns) if "SE" in c][0]
        self.label_index = [ i for i, c in enumerate(df.columns) if "label_id" in c][0]
        self.Num_class = len(df[df.columns[self.label_index]].value_counts())
        if len([ i for i, c in enumerate(df.columns) if "task_id" in c]) > 0:
            self.task_index = [ i for i, c in enumerate(df.columns) if "task_id" in c][0]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        text = self.df.iloc[idx, self.SE_index]
        label = self.df.iloc[idx, self.label_index]
        task_id = self.df.iloc[idx, self.task_index]
        return text, label, task_id
    

In [5]:
import random

task_ids = ['diags_id_1st','diags_id_2nd','diags_id_3rd','diags_id_4th',
            'disease_id_1st','disease_id_2nd','symptoms_id_1st','symptoms_id_2nd','symptoms_id_3rd']
random.shuffle(task_ids)
task_label_dict = {}
dataset_all = []
# df_all = []
# files = [ f for f in os.listdir('data/files') if 'id' in f]
files = [ f + '.csv' for f in  task_ids]
for file in files :
    print(file)
    df = pd.read_csv('data/files' + '/' + file)
    label_index = [ i for i, c in enumerate(df.columns) if "label_id" in c][0]
    print(len(df[df.columns[label_index]].value_counts()))
    df['task_id'] = file.split('.')[0]
    task_label_dict[file.split('.')[0]] = len(df[df.columns[label_index]].value_counts())
    df.dropna(subset=['SE'], inplace=True)
    dataset_all.append(PetDataset(df))

concat_dataset = ConcatDataset(dataset_all)


diags_id_1st.csv
9
symptoms_id_1st.csv
7
diags_id_2nd.csv
11
disease_id_1st.csv
5
diags_id_3rd.csv
30
symptoms_id_3rd.csv
21
disease_id_2nd.csv
10
symptoms_id_2nd.csv
12
diags_id_4th.csv
40


In [14]:
print(task_label_dict)

{'diags_id_2nd': 11, 'disease_id_2nd': 10, 'diags_id_3rd': 30, 'symptoms_id_3rd': 21, 'symptoms_id_2nd': 12, 'disease_id_1st': 5, 'symptoms_id_1st': 7, 'diags_id_1st': 9, 'diags_id_4th': 40}


In [15]:
class BatchSchedulerSampler(torch.utils.data.sampler.Sampler):
    """
    iterate over tasks and provide a random batch per task in each mini-batch
    """
    def __init__(self, dataset, batch_size):
        self.dataset = dataset
        self.batch_size = batch_size
        self.number_of_datasets = len(dataset.datasets)
        self.largest_dataset_size = max([len(cur_dataset) for cur_dataset in dataset.datasets])    

    def __len__(self):
            return self.batch_size * math.ceil(self.largest_dataset_size / self.batch_size) * len(self.dataset.datasets)

    def __iter__(self):
        samplers_list = []
        sampler_iterators = []
        for dataset_idx in range(self.number_of_datasets):
            cur_dataset = self.dataset.datasets[dataset_idx]
            sampler = RandomSampler(cur_dataset)
            samplers_list.append(sampler)
            cur_sampler_iterator = sampler.__iter__()
            sampler_iterators.append(cur_sampler_iterator)   
            
        push_index_val = [0] + self.dataset.cumulative_sizes[:-1]
        step = self.batch_size * self.number_of_datasets
        samples_to_grab = self.batch_size
        # for this case we want to get all samples in dataset, this force us to resample from the smaller datasets
        epoch_samples = self.largest_dataset_size * self.number_of_datasets

        final_samples_list = []  # this is a list of indexes from the combined dataset
        for _ in range(0, epoch_samples, step):
            for i in range(self.number_of_datasets):
                cur_batch_sampler = sampler_iterators[i]
                cur_samples = []
                for _ in range(samples_to_grab):
                    try:
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]
                        cur_samples.append(cur_sample)
                    except StopIteration:
                        # got to the end of iterator - restart the iterator and continue to get samples
                        # until reaching "epoch_samples"
                        sampler_iterators[i] = samplers_list[i].__iter__()
                        cur_batch_sampler = sampler_iterators[i]
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]
                        cur_samples.append(cur_sample)
                final_samples_list.extend(cur_samples)

        return iter(final_samples_list)            

In [16]:
# test_df.head()

In [6]:
def model_eval(test_df, ContraMutliFlag=2) :
    model.eval()

    test_dataset = PetDataset(test_df)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=2)

    total_loss = 0
    total_len = 0
    total_correct = 0

    for text, label, task_id in test_loader:
        #   encoded_list = [tokenizer.encode(t, add_special_token=True) for t in text]
        encoded_list = [tokenizer.encode(t, max_length=512, truncation=True) for t in text]
        padded_list = [e[:512] + [0] * (512-len(e[:512])) for e in encoded_list]
        sample = torch.tensor(padded_list)
        sample, label = sample.to(device), label.to(device)
        label = torch.tensor(label)
#         outputs = model(ContraMutliFlag=3, task_id=(task_id[0]), sample=sample)
        outputs = model(ContraMutliFlag=ContraMutliFlag, task_id='downstream', sample=sample)
        logits = outputs

        pred = torch.argmax(F.softmax(logits), dim=1)
        correct = pred.eq(label)
        total_correct += correct.sum().item()
        total_len += len(label)

    print('Test accuracy: ', total_correct / total_len) 

In [18]:
task_label_dict

{'diags_id_2nd': 11,
 'disease_id_2nd': 10,
 'diags_id_3rd': 30,
 'symptoms_id_3rd': 21,
 'symptoms_id_2nd': 12,
 'disease_id_1st': 5,
 'symptoms_id_1st': 7,
 'diags_id_1st': 9,
 'diags_id_4th': 40}

In [7]:
train_df = pd.read_csv('files/train3.csv')
train_df['task_id'] = 'downstream'
train_dataset = PetDataset(train_df)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
donwstream_class_num = train_dataset.Num_class
print(donwstream_class_num)

18


In [8]:
device = torch.device("cuda")
# device = torch.device('cpu')
# pretrained_path = './pretrained_without_wiki'
pretrained_path = './pretrained_without_wiki/'
tokenizer = RobertaTokenizer.from_pretrained(pretrained_path, do_lower_case=False)
# donwstream_class_num = task_label_dict['diags_id']
model = SupConMultiRobertaNet(path=pretrained_path, 
                              embedding_dim=768,
                              feat_dim=64,
                              task_label_dict=task_label_dict, 
                              num_class=donwstream_class_num)
model.to(device)
criterion = SupConLoss(temperature=0.07)
criterion1 = torch.nn.CrossEntropyLoss()
criterion1 = criterion1.to(device)

In [21]:
# model_eval(test_df, ContraMutliFlag=2)

In [9]:
optimizer = Adam(model.parameters(), lr=0.00008)
scheduler = lr_scheduler.LambdaLR(
    optimizer=optimizer, lr_lambda=lambda epoch: 1 / ((epoch/4) + 1)
)

In [10]:
BATCH_SIZE = 8
MAX_SEQ_LEN = 512

In [24]:
# basic dataloader
import sys
# dataloader = DataLoader(dataset=concat_dataset,
#                          batch_size=BATCH_SIZE,
#                          shuffle=True)

# dataloader with BatchSchedulerSampler
dataloader = DataLoader(dataset=concat_dataset,
                         sampler=BatchSchedulerSampler(dataset=concat_dataset,
                                                       batch_size=BATCH_SIZE),
                         batch_size=BATCH_SIZE,
                         shuffle=False)


In [25]:
# def one_iteration(text, label, task_id) :    
#     encoded_list = [tokenizer.encode(t, add_special_tokens=True, max_length=512, truncation=True) for t in text]
#     padded_list = [e[:512] + [0] * (512-len(e[:512])) for e in encoded_list]
#     sample = torch.tensor(padded_list)
#     sample, label = sample.to(device), label.to(device)
#     label = torch.tensor(label)
#     outputs = model(ContraMutliFlag=2, task_id=task_id, sample=sample)
    
#     loss = criterion1(outputs, label)

#     pred = torch.argmax(F.softmax(outputs), dim=1)
#     correct = pred.eq(label)
# #     total_correct += correct.sum().item()
# #     total_len += len(labels)
# #     total_loss += loss.item()
# #     total_count += 1
    
#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step() 
    
#     return loss, correct

In [None]:
model.train()
epochs = 5
for epoch in range(epochs):
    losses = AverageMeter()
    total_loss = 0
    total_len = 0
    total_correct = 0
    total_count = 0
#     model.train()    
    for text, label, task_id in dataloader:
#         print(text)
#         print(label)
#         print(task_id[0])
#         sys.exit()
        encoded_list = [tokenizer.encode(t, add_special_tokens=True, max_length=512, truncation=True) for t in text]
        padded_list = [e[:512] + [0] * (512-len(e[:512])) for e in encoded_list]
        sample = torch.tensor(padded_list)
        sample, label = sample.to(device), label.to(device)
        label = torch.tensor(label)
        outputs = model(ContraMutliFlag=2, task_id=task_id[0], sample=sample)
        pred = torch.argmax(F.softmax(outputs), dim=1)
        correct = pred.eq(label)
        loss = criterion1(outputs, label)
        losses.update(loss.item(), BATCH_SIZE)
#         print(loss)
        
        total_correct += correct.sum().item()
        total_len += len(label)
        total_loss += loss.item()
        total_count += 1

        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 
        
        if (total_count + 1) % 100 == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'loss {loss.avg:.5f}'.format(
                   epoch, total_count + 1, len(dataloader), loss=losses))   

    model.train()
    scheduler.step()    
    print('***********************************')
#     print(tloss)     

In [27]:
PATH = 'finetune/multitasked'
torch.save(model.state_dict(), PATH)

In [28]:
# for name in (model.state_dict()) :
#     print(name)

In [29]:
for name, param in zip(model.state_dict(), model.parameters()) :
    if 'fc.' not in name :
        param.requires_grad = False
#     print(name)
#     print(param)

In [11]:
test_df = pd.read_csv('files/test3.csv')
test_df['task_id'] = 'downstream'
print(test_df.shape)
# test_df.SE = test_df.SE.apply(lambda x : np.nan if x == "Nan" else x)
test_df.dropna(subset=['SE'], inplace=True)
print(test_df.shape)
# test_df.diags_id = test_df.diags_id.apply(lambda x : int(x))

(813, 16)
(813, 16)


In [12]:
optimizer = Adam(model.parameters(), lr=0.00006)
scheduler = lr_scheduler.LambdaLR(
    optimizer=optimizer, lr_lambda=lambda epoch: 1 / ((epoch/4) + 1)
)

In [15]:
# downstream task
epochs = 3
model.train()
for epoch in range(epochs):
    losses = AverageMeter()
    total_loss = 0
    total_len = 0
    total_correct = 0
    total_count = 0
    model.train()    
    for text, label, task_id in train_loader:
#         print(task_id)
#         print(text)
        encoded_list = [tokenizer.encode(t, add_special_tokens=True, max_length=512, truncation=True) for t in text]
        padded_list = [e[:512] + [0] * (512-len(e[:512])) for e in encoded_list]
        sample = torch.tensor(padded_list)
        sample, label = sample.to(device), label.to(device)
        label = torch.tensor(label)
        outputs = model(ContraMutliFlag=3, task_id=task_id[0], sample=sample)
        pred = torch.argmax(F.softmax(outputs), dim=1)
        correct = pred.eq(label)
        loss = criterion1(outputs, label)
        losses.update(loss.item(), BATCH_SIZE)
        
        total_correct += correct.sum().item()
        total_len += len(label)
        total_loss += loss.item()
        total_count += 1

        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 
        
        if (total_count + 1) % 100 == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'loss {loss.avg:.5f}'.format(
                   epoch, total_count + 1, len(train_loader), loss=losses))   
        
    scheduler.step()
    model_eval(test_df, ContraMutliFlag=3)    



Train: [0][100/585]	loss 0.42408
Train: [0][200/585]	loss 0.42219
Train: [0][300/585]	loss 0.41543
Train: [0][400/585]	loss 0.40937
Train: [0][500/585]	loss 0.41571




Test accuracy:  0.6519065190651907
Train: [1][100/585]	loss 0.30868
Train: [1][200/585]	loss 0.32071
Train: [1][300/585]	loss 0.31875
Train: [1][400/585]	loss 0.32189
Train: [1][500/585]	loss 0.32247
Test accuracy:  0.6494464944649446
Train: [2][100/585]	loss 0.23931
Train: [2][200/585]	loss 0.24318
Train: [2][300/585]	loss 0.24904
Train: [2][400/585]	loss 0.25221
Train: [2][500/585]	loss 0.24680
Test accuracy:  0.6691266912669127


In [33]:
for name, param in zip(model.state_dict(), model.parameters()) :
#     if 'fc.' not in name :
    param.requires_grad = True
#     print(name)
#     print(param)

In [34]:
optimizer = Adam(model.parameters(), lr=0.00001)
scheduler = lr_scheduler.LambdaLR(
    optimizer=optimizer, lr_lambda=lambda epoch: 1 / ((epoch/4) + 1)
)

In [35]:
# downstream task
epochs = 10
model.train()
for epoch in range(epochs):
    losses = AverageMeter()
    total_loss = 0
    total_len = 0
    total_correct = 0
    total_count = 0
    model.train()    
    for text, label, task_id in train_loader:
#         print(task_id)
#         print(text)
        encoded_list = [tokenizer.encode(t, add_special_tokens=True, max_length=512, truncation=True) for t in text]
        padded_list = [e[:512] + [0] * (512-len(e[:512])) for e in encoded_list]
        sample = torch.tensor(padded_list)
        sample, label = sample.to(device), label.to(device)
        label = torch.tensor(label)
        outputs = model(ContraMutliFlag=3, task_id=task_id[0], sample=sample)
        pred = torch.argmax(F.softmax(outputs), dim=1)
        correct = pred.eq(label)
        loss = criterion1(outputs, label)
        losses.update(loss.item(), BATCH_SIZE)
        
        total_correct += correct.sum().item()
        total_len += len(label)
        total_loss += loss.item()
        total_count += 1

        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 
        
        if (total_count + 1) % 100 == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'loss {loss.avg:.5f}'.format(
                   epoch, total_count + 1, len(train_loader), loss=losses))   
        
    scheduler.step()
    model_eval(test_df, ContraMutliFlag=3)    



Train: [0][100/585]	loss 0.94873
Train: [0][200/585]	loss 0.98815
Train: [0][300/585]	loss 1.00950
Train: [0][400/585]	loss 0.99730
Train: [0][500/585]	loss 0.99981




Test accuracy:  0.45264452644526443
Train: [1][100/585]	loss 0.97639
Train: [1][200/585]	loss 0.93414
Train: [1][300/585]	loss 0.91222
Train: [1][400/585]	loss 0.91201
Train: [1][500/585]	loss 0.91100
Test accuracy:  0.45264452644526443
Train: [2][100/585]	loss 0.85521
Train: [2][200/585]	loss 0.85377
Train: [2][300/585]	loss 0.84502
Train: [2][400/585]	loss 0.85407
Train: [2][500/585]	loss 0.85727
Test accuracy:  0.45141451414514144
Train: [3][100/585]	loss 0.84787
Train: [3][200/585]	loss 0.83428
Train: [3][300/585]	loss 0.82715
Train: [3][400/585]	loss 0.81493
Train: [3][500/585]	loss 0.83187
Test accuracy:  0.46248462484624847
Train: [4][100/585]	loss 0.80919
Train: [4][200/585]	loss 0.80257
Train: [4][300/585]	loss 0.79508
Train: [4][400/585]	loss 0.78590
Train: [4][500/585]	loss 0.78304
Test accuracy:  0.46002460024600245
Train: [5][100/585]	loss 0.78508
Train: [5][200/585]	loss 0.78564
Train: [5][300/585]	loss 0.77301
Train: [5][400/585]	loss 0.75899
Train: [5][500/585]	loss 0.7

In [36]:
optimizer = Adam(model.parameters(), lr=0.00002)
scheduler = lr_scheduler.LambdaLR(
    optimizer=optimizer, lr_lambda=lambda epoch: 1 / ((epoch/2) + 1)
)

In [37]:
# downstream task
epochs = 10
model.train()
for epoch in range(epochs):
    losses = AverageMeter()
    total_loss = 0
    total_len = 0
    total_correct = 0
    total_count = 0
    model.train()    
    for text, label, task_id in train_loader:
#         print(task_id)
#         print(text)
        encoded_list = [tokenizer.encode(t, add_special_tokens=True, max_length=512, truncation=True) for t in text]
        padded_list = [e[:512] + [0] * (512-len(e[:512])) for e in encoded_list]
        sample = torch.tensor(padded_list)
        sample, label = sample.to(device), label.to(device)
        label = torch.tensor(label)
        outputs = model(ContraMutliFlag=3, task_id=task_id[0], sample=sample)
        pred = torch.argmax(F.softmax(outputs), dim=1)
        correct = pred.eq(label)
        loss = criterion1(outputs, label)
        losses.update(loss.item(), BATCH_SIZE)
        
        total_correct += correct.sum().item()
        total_len += len(label)
        total_loss += loss.item()
        total_count += 1

        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 
        
        if (total_count + 1) % 100 == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'loss {loss.avg:.5f}'.format(
                   epoch, total_count + 1, len(train_loader), loss=losses))   
        
    scheduler.step()
    model_eval(test_df, ContraMutliFlag=3)    



Train: [0][100/585]	loss 0.75904
Train: [0][200/585]	loss 0.72472
Train: [0][300/585]	loss 0.71838
Train: [0][400/585]	loss 0.71168
Train: [0][500/585]	loss 0.70548




Test accuracy:  0.45264452644526443
Train: [1][100/585]	loss 0.67009
Train: [1][200/585]	loss 0.65710
Train: [1][300/585]	loss 0.65178
Train: [1][400/585]	loss 0.65292
Train: [1][500/585]	loss 0.65085
Test accuracy:  0.46863468634686345
Train: [2][100/585]	loss 0.61587
Train: [2][200/585]	loss 0.57734
Train: [2][300/585]	loss 0.60297
Train: [2][400/585]	loss 0.60041
Train: [2][500/585]	loss 0.59767
Test accuracy:  0.46248462484624847
Train: [3][100/585]	loss 0.58885
Train: [3][200/585]	loss 0.56309
Train: [3][300/585]	loss 0.56188
Train: [3][400/585]	loss 0.56236
Train: [3][500/585]	loss 0.56195
Test accuracy:  0.45387453874538747
Train: [4][100/585]	loss 0.52531
Train: [4][200/585]	loss 0.51730
Train: [4][300/585]	loss 0.53436
Train: [4][400/585]	loss 0.53223
Train: [4][500/585]	loss 0.53697
Test accuracy:  0.46002460024600245
Train: [5][100/585]	loss 0.53618
Train: [5][200/585]	loss 0.54476
Train: [5][300/585]	loss 0.52584
Train: [5][400/585]	loss 0.52598
Train: [5][500/585]	loss 0.5