In [14]:
from transformers import BertModel
from transformers import BertTokenizer
from transformers import PretrainedConfig

from sklearn.preprocessing import MultiLabelBinarizer

import json
import pandas as pd
import torch
from torch import nn
from torch.nn import BCELoss
from collections import Counter

In [11]:
path_to_data = './data'

# create dataframe from sessions.json
df = pd.read_json(f'{path_to_data}/sessions_altered.json')
df.head()

# create dictionaries for switching between symptom and id
id2sym = {}
sym2id = {}

with open(f'{path_to_data}/symptoms.json') as json_file:
    data = json.load(json_file)
    for sym in data:
        id2sym[sym['id']] = sym['name']
        sym2id[sym['name']] = sym['id']
        
        
# remove labels that have less than m occurrences
m = 0

labels_list = df['confirmed'].tolist()
labels_list = sum(labels_list, [])
c = Counter(labels_list)
for i in range(len(df)):
    to_remove = []
    
    # find labels that should be removed 
    for j in range(len(df['confirmed'][i])):
        if c[df['confirmed'][i][j]] < m:
            to_remove.append(j)
            
    # remove the labels
    shift = 0
    for j in range(len(to_remove)):
        df['confirmed'][i].pop(to_remove[j]-shift)
        shift += 1
    
        
# add column with the symptom names
sym_names = []

for syms in df['confirmed']:
    if len(syms) != 0:
        sym_names.append([id2sym[x] for x in syms])
    else:
        sym_names.append([])

df['labels'] = sym_names

# remove all rows with no confirmed labels
df = df[df['confirmed'].map(len) > 0]
df = df.reset_index(drop=True)

In [13]:
df.head()

Unnamed: 0,text,confirmed,suggested,labels
0,Slut på medicin.,"[89, 651]",[348],"[Känd astma, Känd lungsjukdom]"
1,Behöver att prata med psykolog angående använd...,"[116, 215]","[215, 348, 446]","[Nedstämdhet, Trötthet]"
2,Har fått besvärlig eksem på händerna,"[2, 141]",[141],"[Hudbesvär, Synliga hudbesvär]"
3,Muskelsvaghet och trötthet känner mig skakig o...,"[606, 215]","[12, 97, 215, 359, 518, 606]","[Muskelsvaghet, Trötthet]"
4,Svår smärta i vänsterhanden/handleden precis n...,"[682, 493]","[15, 28, 48, 54, 114, 148, 246, 313, 333, 339,...","[Smärta i handled eller fingrar, Förvärras av ..."


In [None]:
tok = BertTokenizer.from_pretrained(path_to_bert)

# train a multilabel_binarizer on the labels
labels = df['labels'].tolist()
multilab_bin = MultiLabelBinarizer()
multilab_bin.fit(labels)

class CustomDataset(Dataset):
    def __init__(self, datatframe, tokenizer, multilab_bin, max_len):
        self.tokenizer = tokenizer
        self.multilab_bin = multilab_bin
        self.data = dataframe
        self.text = self.data['text']
        self.labels = self.data['labels']
        
    def __len__(self):
        return len(self.comment_text)
    
    def __getitem__(self, index):
        text = str(self.text[index])
        text = ' '.join(text.split())
        
        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            pad_to_max_length=True,
            return_token_ids=True
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs['token_type_ids']
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long)
            'mask': torch.tensor(mask, dtype=torch.long)
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
            'labels': torch.tensor(self.multilab_bin.transform(self.labels[index]), dtype=torch.float)
        }        

In [23]:
path_to_bert = r'./bert/bert-base-swedish-cased'
config = PretrainedConfig.from_json_file(f'{path_to_bert}/config.json')

class BERTClass(nn.Module):
    def __init__(self, output_dim):
        super(BERTCLass, self).__init(config)
        self.l1 = BertModel.from_pretrained(path_to_bert)
        self.l2 = torch.nn.Dropout(config.hidden_dropout_prob)
        self.l3 = nn.linear(config.hidden_size, output_dim)
        self.sigm = nn.Sigmoid()
        
    def forward(self, ids, mask, token_type_ids):
        _, output_1 = self.l1(ids, attention_mask=mask, token_type_ids=token_type_ids)
        output_2 = self.l2(output_1)
        output_3 = self.l3(output_2)
        output = self.sigm(output_3)
        return output
        