In [48]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.linear_subjective = nn.Linear(768,10)  
        self.linear_medhx = nn.Linear(768, 10)
        self.combined = nn.Linear(30, 10)
        self.out = nn.Linear(10,4)

    def forward(self, subj,medhx,cont):
        nlp1 = F.relu(self.linear_subjective(subj))
        #print (nlp1.shape)
        nlp2 = F.relu(self.linear_medhx(medhx))
        #print (nlp2.shape)
        combined = torch.cat((nlp1,nlp2,cont), axis = 1)
        #print (combined.shape)
        x = self.combined(combined)
        return x



In [4]:

net = Net()
print(net)

Net(
  (linear_subjective): Linear(in_features=768, out_features=10, bias=True)
  (linear_medhx): Linear(in_features=768, out_features=10, bias=True)
  (combined): Linear(in_features=30, out_features=10, bias=True)
  (out): Linear(in_features=10, out_features=4, bias=True)
)


In [60]:
subj = torch.rand(160,768).numpy()
medhx = torch.rand(160,768).numpy()
cont = torch.rand(160,10).numpy()

In [61]:
labels = torch.randint(0,3,(160,)).numpy()

In [62]:
subj.shape, medhx.shape, cont.shape, labels.shape

((160, 768), (160, 768), (160, 10), (160,))

In [66]:
train_subj, validation_subj = train_test_split(subj, random_state = 42, test_size=0.1)
train_medhx, validation_medhx = train_test_split(medhx, random_state = 42, test_size=0.1)
train_cont, validation_cont = train_test_split(cont, random_state = 42, test_size=0.1)
train_labels, validation_labels = train_test_split(labels, random_state = 42, test_size=0.1)

In [71]:
net(torch.tensor(subj),torch.tensor(medhx),torch.tensor(cont))

tensor([[ 2.0086,  1.9166, -0.0743,  ..., -1.5197, -0.1110, -1.7344],
        [ 1.9272,  1.7760,  0.0819,  ..., -1.6422, -0.1861, -1.5834],
        [ 1.7860,  1.9397, -0.1915,  ..., -1.6480, -0.0814, -1.4577],
        ...,
        [ 1.7268,  1.7883,  0.0732,  ..., -1.3953, -0.1415, -1.5651],
        [ 1.7398,  1.9710, -0.2882,  ..., -1.3591, -0.3295, -1.5359],
        [ 1.5939,  1.8365, -0.1336,  ..., -1.4517, -0.2982, -1.5540]],
       grad_fn=<AddmmBackward>)

In [74]:
train_subj = torch.tensor(train_subj)
validation_subj = torch.tensor(validation_subj)
train_medhx = torch.tensor(train_medhx)
validation_medhx = torch.tensor(validation_medhx)
train_cont = torch.tensor(train_cont)
validation_cont = torch.tensor(validation_cont)
train_labels = torch.tensor(train_labels)
validation_labels = torch.tensor(validation_labels)

In [76]:
# Select a batch size for training. For fine-tuning BERT on a specific task, the authors recommend a batch size of 16 or 32
batch_size = 4

# Create an iterator of our data with torch DataLoader. This helps save on memory during training because, unlike a for loop, 
# with an iterator the entire dataset does not need to be loaded into memory

train_data = TensorDataset(train_subj, train_medhx, train_cont, train_labels)
trainloader = DataLoader(train_data, batch_size=batch_size)

valid_data = TensorDataset(validation_subj, validation_medhx, validation_cont, validation_labels)
validloader = DataLoader(valid_data, batch_size=batch_size)

In [77]:
#hyperparameters
lr = 1e-3
optimizer = Adam(net.parameters(), lr = lr)
epochs = 3
loss_func = nn.CrossEntropyLoss()

In [78]:
def train_model(model):
    for epoch_num in range(epochs):
        model.train()
        train_loss = 0
        for step_num, batch_data in enumerate(trainloader):
        
            cont_var, subj_notes, medhx, labels = tuple(t.to(device) for t in batch_data)
        
            optimizer.zero_grad()
        
            logits = model(cont_var, subj_notes, medhx)
        
            batch_loss = loss_func(logits, labels)
        
            train_loss += batch_loss.item()
        
            batch_loss.backward()
        

            clip_grad_norm_(parameters=model.parameters(), max_norm=1.0)
        
            optimizer.step()
        
            print('Epoch: ', epoch_num + 1)
            print("\r" + "{0}/{1} loss: {2} ".format(step_num, len(train_data) / batch_size, train_loss / (step_num + 1)))
        

In [79]:
train_model(net)

Epoch:  1
0/36.0 loss: 1.503239631652832 
Epoch:  1
1/36.0 loss: 1.2603518962860107 
Epoch:  1
2/36.0 loss: 1.2707217931747437 
Epoch:  1
3/36.0 loss: 1.3145799040794373 
Epoch:  1
4/36.0 loss: 1.3825627088546752 
Epoch:  1
5/36.0 loss: 1.3989834785461426 
Epoch:  1
6/36.0 loss: 1.3172224930354528 
Epoch:  1
7/36.0 loss: 1.4326949417591095 
Epoch:  1
8/36.0 loss: 1.378023472097185 
Epoch:  1
9/36.0 loss: 1.4096169173717499 
Epoch:  1
10/36.0 loss: 1.3842078284783796 
Epoch:  1
11/36.0 loss: 1.3899749666452408 
Epoch:  1
12/36.0 loss: 1.468306683577024 
Epoch:  1
13/36.0 loss: 1.4340163256440843 
Epoch:  1
14/36.0 loss: 1.4428147037823995 
Epoch:  1
15/36.0 loss: 1.422057244926691 
Epoch:  1
16/36.0 loss: 1.4581853396752302 
Epoch:  1
17/36.0 loss: 1.439686957332823 
Epoch:  1
18/36.0 loss: 1.4475009786455255 
Epoch:  1
19/36.0 loss: 1.4394472807645797 
Epoch:  1
20/36.0 loss: 1.4438557880265372 
Epoch:  1
21/36.0 loss: 1.4377187571742318 
Epoch:  1
22/36.0 loss: 1.4256151370380237 
Epo