In [2]:
import torch
import torchtext
from torchtext.data.utils import get_tokenizer
from torch.utils.data import Dataset
from torchtext.data.functional import numericalize_tokens_from_iterator

In [3]:
import docx

def getText(filename):
    doc = docx.Document(filename)
    fullText = []
    for para in doc.paragraphs:
        fullText.append(para.text)
    return '\n'.join(fullText)

In [124]:
notes = [getText("/Users/etashguha/Downloads/Sample Note.docx")]
labels = [0]

In [125]:
icd_codes = ["Unspecified atrial fibrillation","Acute rheumatic fever", "Minor rheumatic fever"]
en_tokenizer = get_tokenizer('spacy', language='en')

vocab = {}
vocab_index = 0
for note in notes:
    words = en_tokenizer(note)
    for word in words:
        if word not in vocab:
            vocab[word] = vocab_index 
            vocab_index += 1
for code in icd_codes:
    words = en_tokenizer(code)
    for word in words:
        if word not in vocab:
            vocab[word] = vocab_index 
            vocab_index += 1

In [126]:
# print(en_tokenizer("left shoulder"))
# id_iters = numericalize_tokens_from_iterator(vocab, [["left", "shoulder"], ["no", "effusion"]])
# for ids in id_iters:
#     print([num for num in ids])

class SATDataSet(Dataset):
    def __init__(self, texts, labels):
        super().__init__()
        self.data = []
        for i, text in enumerate(texts):
            id_iters = numericalize_tokens_from_iterator(vocab, [en_tokenizer(text)])
            for ids in id_iters:
                self.data.append(([num for num in ids], labels[i]))
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx][0]), torch.tensor(self.data[idx][1])
train_dataset = SATDataSet(notes, labels)



In [116]:
index_codes = []
for code in icd_codes:
    id_iters = numericalize_tokens_from_iterator(vocab, [en_tokenizer(code)])
    for ids in id_iters:
        index_codes.append([num for num in ids])

def n_gram(codes, n):
    return [n_gram_helper(code, n) for code in codes]

def n_gram_helper(code, n):
    l = []
    for i in range(len(code) - n + 1):
        l.append(code[i:i + n])
    return l


code_lengths = []
for cod in index_codes:
    code_lengths.append(len(cod))
index_codes = [n_gram(index_codes, i) for i in range(2,4)]


In [101]:
def get_code_ngrams(index_codes):
    freq_codes = []
    for icd10 in index_codes:
        d = set()
        for gram in icd10:
            
            gram = tuple(gram)
            d.add(gram)
        freq_codes.append(d)
    return freq_codes

def get_frequency_of_codes(index_codes):
    freq_codes = {}
    for icd10 in index_codes:
        for gram in icd10:
            print(gram)
            gram = tuple(gram)
            print(gram)
            if gram in freq_codes:
                freq_codes[gram] += 1
            else:
                freq_codes[gram] = 1
    return freq_codes
print(index_codes)
freq_codes = [get_frequency_of_codes(index_codes[i]) for i in range(2)]
code_ngrams = [get_code_ngrams(index_codes[i]) for i in range(2)]
print(code_ngrams)
print(freq_codes)

[[[[363, 244], [244, 245]], [[364, 365], [365, 366]], [[367, 365], [365, 366]]], [[[363, 244, 245]], [[364, 365, 366]], [[367, 365, 366]]]]
[363, 244]
(363, 244)
[244, 245]
(244, 245)
[364, 365]
(364, 365)
[365, 366]
(365, 366)
[367, 365]
(367, 365)
[365, 366]
(365, 366)
[363, 244, 245]
(363, 244, 245)
[364, 365, 366]
(364, 365, 366)
[367, 365, 366]
(367, 365, 366)
[[{(244, 245), (363, 244)}, {(364, 365), (365, 366)}, {(365, 366), (367, 365)}], [{(363, 244, 245)}, {(364, 365, 366)}, {(367, 365, 366)}]]
[{(363, 244): 1, (244, 245): 1, (364, 365): 1, (365, 366): 2, (367, 365): 1}, {(363, 244, 245): 1, (364, 365, 366): 1, (367, 365, 366): 1}]


In [102]:
for text in train_dataset:
    print(text)

tensor([  0,   1,   2,   3,   4,   3,   5,   6,   7,   8,   9,  10,  11,   5,
         12,  13,  14,  15,  16,   2,   3,  17,  18,  19,  20,  21,  22,  23,
         24,  25,  26,  27,  28,   2,   3,  29,  18,  30,  31,  32,  33,   2,
          3,  17,  20,  21,  34,  35,   3,  36,  37,  38,  39,   2,   3,  40,
         41,  42,  43,  44,  45,   2,   3,  46,  47,  48,  49,  25,  50,  19,
         51,  52,  53,  54,  55,   5,  56,  57,  58,  52,  25,  59,   2,   3,
         60,  61,  38,  62,  63,  25,  47,  51,  64,  65,  57,  58,  51,  52,
         25,  66,  32,  65,  67,  68,  51,  69,   2,   3,  70,  47,  71,  56,
         72,  14,  73,  74,  75,  65,  57,  58,   2,   3,  29,  76,  35,   7,
          9,  14,  53,  54,   2,  77,  17,  76,  35,  78,  79,  80,  81,   2,
          3,  17,  76,  35,  82,  83,  80,  84,  85,   2,   3,  17,  76,  35,
         86,  87,  88,  80,  89,   2,  90,  91,  14,  92,  93,   3,  94,  95,
         96,  14,  97,  98,   6,  99,  24,  25, 100, 101, 102, 1

In [140]:
class icd_pred(torch.nn.Module):
    def __init__(self, num_words, word_dim=10, k=5, dilation=2,conv_dim=4, num_labels=128, label_dim=12):
        super().__init__()
        self.embed = torch.nn.Embedding(num_words, word_dim)
        self.dconv = torch.nn.Conv1d(word_dim, conv_dim, k, dilation=2)
        self.conv = torch.nn.Conv1d(word_dim, conv_dim, k)
        self.u = torch.nn.Parameter(torch.rand(conv_dim, num_labels))
        self.smax = torch.nn.Softmax(dim=2)
        self.reduce = torch.nn.Linear(conv_dim * 2 + 2,1)
        
    def forward(self, sentence):
        text_embedded = self.embed(sentence)
        text_embedded = text_embedded.unsqueeze(dim=0).transpose(1,2)
        #dilated conv
        H = self.dconv(text_embedded).transpose(1,2)
        alpha = self.smax(torch.matmul(H, self.u))
        H = H.squeeze(dim=0).transpose(0,1)
        alpha = alpha.squeeze(dim=0)
        dilated_m = torch.matmul(H, alpha)
        
        #normal conv
        H = self.conv(text_embedded).transpose(1,2)
        alpha = self.smax(torch.matmul(H, self.u))
        H = H.squeeze(dim=0).transpose(0,1)
        alpha = alpha.squeeze(dim=0)
        normal_m = torch.matmul(H, alpha)
        
        scores = [self.get_ngram_scores(sentence, i) for i in range(2,4)]
        scores = torch.tensor(scores)
        final_scores = torch.cat([normal_m, dilated_m, scores]).transpose(0,1)
        
        final_scores = self.reduce(final_scores)
        return final_scores.squeeze()
    def get_ngram_scores(self, sentence, n):
        gram_freq_text = self.getfreqgrams(sentence, n)
        listofscores =[]
        for i, code in enumerate(code_ngrams[n - 2]):
            score = 0
            for gram in code:
                if gram in gram_freq_text:
                    score += gram_freq_text[gram] * n/code_lengths[i] * len(code_ngrams[n - 2 ])/freq_codes[n - 2][gram]
            listofscores.append(score)
        return listofscores
            
    def getfreqgrams(self, sentence, n):
        sentence = sentence.tolist()
        freq_gram = {}
        n_grams = n_gram_helper(sentence, n)
        
        for n_gram in n_grams:
            n_gram = tuple(n_gram)
            if n_gram in freq_gram:
                freq_gram[n_gram] += 1
            else:
                freq_gram[n_gram] = 1
        return freq_gram
    
    def n_gram_helper(self, code, n):
        l = []
        for i in range(len(code) - n + 1):
            l.append(code[i:i + n])
        return l
        
model = icd_pred(len(vocab), num_labels=3)
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters())
for _ in range(100):
    for text, label in train_dataset:
        optim.zero_grad()
        answers = model(text)
        answers = answers.unsqueeze(dim=0)
        label = label.unsqueeze(dim=0)

        loss = loss_fn(answers, label)
        loss.backward()
        print(loss.item())
        optim.step()

1.6689286894688848e-06
5.960462772236497e-07
2.3841855067985307e-07
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
