In [1]:
from load import *

splitter = load_idxsplit()

x_train, x_val, x_test = load_x(splitter).values()
y_train, y_val, y_test = load_y(splitter).values()

diag = load_diagnoses(splitter)
left_diag_train, left_diag_val, left_diag_test = diag["left"].values()
right_diag_train, right_diag_val, right_diag_test = diag["right"].values()

diag_mask = load_diagnosis_masks(splitter)
left_diag_mask_train, left_diag_mask_val, left_diag_mask_test = diag_mask["left"].values()
right_diag_mask_train, right_diag_mask_val, right_diag_mask_test = diag_mask["right"].values()

images = load_images(splitter)
left_fundus_images_train, left_fundus_images_val, left_fundus_images_test = images["left"].values()
right_fundus_images_train, right_fundus_images_val, right_fundus_images_test = images["right"].values()

In [2]:
sentences = left_diag_train + right_diag_train

max_len = max([len(sentence) for sentence in sentences])

uniq_words = set()
for sentence in sentences:
    uniq_words.update(sentence)
print("Unique words: ", len(uniq_words))

from m3care.cs224n.vocab import VocabEntry

vocab = VocabEntry()
for word in uniq_words:
  vocab.add(word)

Unique words:  135


In [3]:
import torch
# device = torch.device("cuda:0" if torch.cuda.is_available() == True else 'cpu')
# print("available device: {}".format(device))
device = 'cpu'

In [4]:
from m3care.model import M3Care

model = M3Care(
    input_dim = x_train.shape[-1],
    hidden_dim = 128,
    embed_size = 128,
    output_dim = 3,
    keep_prob = 0.5,
    vocab=vocab,
    device=device
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

In [5]:
import random

RAND_SEED = 42
np.random.seed(RAND_SEED)
random.seed(RAND_SEED)
torch.manual_seed(RAND_SEED)
torch.cuda.manual_seed(RAND_SEED)
torch.backends.cudnn.deterministic = True

epochs = 100
batch_size = 512

total_train_loss = list()
total_val_loss = list()

global_best = 0
_global_best = 0

history = list()

fold_count = 0
fold_train_loss = []
fold_valid_loss = []

best_auc_scores = 0
best_ave_auc_micro = 0
best_ave_auc_macro = 0
best_coverage_error = 0
best_label_ranking_loss = 0


_best_auc_scores = 0
_best_ave_auc_micro = 0
_best_ave_auc_macro = 0
_best_coverage_error = 0
_best_label_ranking_loss = 0

In [6]:
def get_loss(y_pred, y_true, weight=None):
    loss = torch.nn.BCEWithLogitsLoss(weight=weight)
    return loss(y_pred, y_true)

In [7]:
import math

def batch_iter(x, y, left_fundus, right_fundus, left_diag, left_diag_mask, right_diag, right_diag_mask, \
               batch_size, shuffle=False):
    """ Yield batches of source and target sentences reverse sorted by length (largest to smallest).
    @param data (list of (src_sent, tgt_sent)): list of tuples containing source and target sentence
    @param batch_size (int): batch size
    @param shuffle (boolean): whether to randomly shuffle the dataset
    """
    batch_num = math.ceil(len(x) / batch_size)  # 向下取整
    index_array = list(range(len(x)))

    if shuffle:
        np.random.shuffle(index_array)

    for i in range(batch_num):
        indices = index_array[i * batch_size: (i + 1) * batch_size]  # fetch out all the induces

        examples = []
        for idx in indices:
            examples.append((x[idx], y[idx], left_fundus[idx], right_fundus[idx], left_diag[idx], \
                             left_diag_mask[idx], right_diag[idx], right_diag_mask[idx]))

        batch_x = [e[0] for e in examples]
        batch_y = [e[1] for e in examples]
        batch_left_fundus = [e[2] for e in examples]
        batch_right_fundus = [e[3] for e in examples]
        batch_left_diag = [e[4] for e in examples]
        batch_left_diag_mask = [e[5] for e in examples]
        batch_right_diag = [e[6] for e in examples]
        batch_right_diag_mask = [e[7] for e in examples]


        yield batch_x, batch_y, batch_left_fundus, batch_right_fundus, batch_left_diag, \
        batch_right_diag,  [batch_left_diag_mask, batch_right_diag_mask]

In [8]:
for each_epoch in range(epochs):

    epoch_loss = []
    counter_batch = 0
    model.train()
    for step, (batch_x, batch_y, batch_left_fundus_images, batch_right_fundus_images, batch_left_diag, \
    batch_right_diag, l_r_masks) in enumerate(
            batch_iter(x_train, y_train, left_fundus_images_train, right_fundus_images_train, left_diag_train,
                        left_diag_mask_train, right_diag_train, right_diag_mask_train, \
                       batch_size, shuffle=True)):
        optimizer.zero_grad()

        batch_x = torch.tensor(batch_x, dtype=torch.float32).to(device)
        batch_y = torch.tensor(batch_y, dtype=torch.float32).to(device).squeeze(-1)

        batch_left_fundus_images = torch.tensor(batch_left_fundus_images, dtype=torch.float32).to(device)
        batch_right_fundus_images = torch.tensor(batch_right_fundus_images, dtype=torch.float32).to(device)

        opt, sum_of_diff= model(batch_x, batch_left_fundus_images, batch_right_fundus_images, \
                    batch_left_diag, batch_right_diag, l_r_masks)



  batch_x = torch.tensor(batch_x, dtype=torch.float32).to(device)


ENCODE#1:  torch.Size([10, 512])
ENCODE#2:  torch.Size([512, 10, 128])
ENCODE#3:  torch.Size([512, 1, 10])
Layer 0: EncoderLayer(
  (self_attn): MultiHeadedAttention(
    (linears): ModuleList(
      (0-3): 4 x Linear(in_features=128, out_features=128, bias=True)
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (feed_forward): PositionwiseFeedForward(
    (w_1): Linear(in_features=128, out_features=512, bias=True)
    (w_2): Linear(in_features=512, out_features=128, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (sublayer): ModuleList(
    (0-1): 2 x SublayerConnection(
      (norm): LayerNorm()
      (dropout): Dropout(p=0.5, inplace=False)
    )
  )
)
MHA STARTED
NEWMASKSIZE:  torch.Size([512, 1, 1, 10])
Q, K, V found
Q torch.Size([512, 8, 10, 16]) K torch.Size([512, 8, 10, 16]) V torch.Size([512, 8, 10, 16])
PAID ATTENTION
CONCATTING
XOUT: torch.Size([512, 10, 128])
ENCODE#4:  torch.Size([512, 10, 128])
TRANSFORMEROUTPUT: {enc_hiddens.shape}
ENCODE#1:  torch.S

KeyboardInterrupt: 