# Deep Learning in Biomedicine Project

In [136]:
# model imports
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import os
from tqdm import tqdm

# backbone imports
from backbones.blocks import Conv2d_fw, BatchNorm2d_fw, init_layer, Flatten, SimpleBlock, BottleneckBlock
from backbones.resnet import ResNet, ResNet10
from backbones.fcnet import FCNet, EnFCNet

# run imports
from hydra.utils import instantiate

## 1. Block definitions

In [267]:
class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, 
                 stride=1, dilation=1, groups=1, bias=True):
        super(CausalConv1d, self).__init__()
        self.dilation = dilation
        padding = dilation * (kernel_size - 1)
        self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, stride,
                                padding, dilation, groups, bias)

    def forward(self, input):
        # Takes something of shape (N, in_channels, T),
        # returns (N, out_channels, T)
        out = self.conv1d(input)
        return out[:, :, :-self.dilation] # TODO: make this correct for different strides/padding

In [268]:
batch = torch.randn(16, 256, 11)
conv = CausalConv1d(256, 128, kernel_size=2, dilation=2)
out = conv(batch)
print(out.shape)

torch.Size([16, 128, 11])


In [269]:
class DenseBlock(nn.Module):
    def __init__(self, in_channels, dilation, filters, kernel_size=2):
        super(DenseBlock, self).__init__()
        self.casualconv1 = CausalConv1d(in_channels, filters, kernel_size, dilation=dilation)
        self.casualconv2 = CausalConv1d(in_channels, filters, kernel_size, dilation=dilation)

    def forward(self, input):
        # input is dimensions (N, in_channels, T)
        xf = self.casualconv1(input)
        xg = self.casualconv2(input)
        activations = F.tanh(xf) * F.sigmoid(xg) # shape: (N, filters, T)
        return torch.cat((input, activations), dim=1)

In [270]:
dense = DenseBlock(256, 2, 128)
out = dense(batch)
print(out.shape)

torch.Size([16, 384, 11])


In [271]:
class TCBlock(nn.Module):
    def __init__(self, in_channels, seq_length, filters):
        super(TCBlock, self).__init__()
        self.dense_blocks = nn.ModuleList([DenseBlock(in_channels + i * filters, 2 ** (i+1), filters)
                                           for i in range(int(math.ceil(math.log(seq_length, 2))))])

    def forward(self, input):
        # input is dimensions (N, T, in_channels)
        input = torch.transpose(input, 1, 2)
        for block in self.dense_blocks:
            input = block(input)
        return torch.transpose(input, 1, 2)

In [272]:
batch = batch.permute(0, 2, 1)
print(f"batch shape: {batch.shape}")

batch shape: torch.Size([16, 11, 256])


In [273]:
tc = TCBlock(256, 11, 128)
out = tc(batch)
print(f"out shape: {out.shape}")

out shape: torch.Size([16, 11, 768])


In [274]:
class AttentionBlock(nn.Module):
    def __init__(self, in_channels, key_size, value_size):
        super(AttentionBlock, self).__init__()
        self.linear_query = nn.Linear(in_channels, key_size)
        self.linear_keys = nn.Linear(in_channels, key_size)
        self.linear_values = nn.Linear(in_channels, value_size)
        self.sqrt_key_size = math.sqrt(key_size)

    def forward(self, input):
        # input is dim (N, T, in_channels) where N is the batch_size, and T is
        # the sequence length
        mask = np.array([[1 if i>j else 0 for i in range(input.shape[1])] for j in range(input.shape[1])])
        mask = torch.BoolTensor(mask)

        #import pdb; pdb.set_trace()
        keys = self.linear_keys(input) # shape: (N, T, key_size)
        query = self.linear_query(input) # shape: (N, T, key_size)
        values = self.linear_values(input) # shape: (N, T, value_size)
        temp = torch.bmm(query, torch.transpose(keys, 1, 2)) # shape: (N, T, T)
        temp.data.masked_fill_(mask, -float('inf'))
        temp = F.softmax(temp / self.sqrt_key_size, dim=1) # shape: (N, T, T), broadcasting over any slice [:, x, :], each row of the matrix
        temp = torch.bmm(temp, values) # shape: (N, T, value_size)
        return torch.cat((input, temp), dim=2) # shape: (N, T, in_channels + value_size)

In [127]:
attn = AttentionBlock(256, 128, 128)
out = attn(batch)
print(f"out shape: {out.shape}")

out shape: torch.Size([16, 11, 384])


## 2. Model definition

In [298]:
class SnailFewShot(nn.Module):
    def __init__(self, N, K, backbone, use_cuda=False):
        # N-way, K-shot
        super(SnailFewShot, self).__init__()
        self.encoder = backbone
        num_channels = 64 + N # change with actual backbone output size
        
        num_filters = int(math.ceil(math.log(N * K + 1, 2)))
        self.attention1 = AttentionBlock(num_channels, 64, 32)
        num_channels += 32
        self.tc1 = TCBlock(num_channels, N * K + 1, 128)
        num_channels += num_filters * 128
        self.attention2 = AttentionBlock(num_channels, 256, 128)
        num_channels += 128
        self.tc2 = TCBlock(num_channels, N * K + 1, 128)
        num_channels += num_filters * 128
        self.attention3 = AttentionBlock(num_channels, 512, 256)
        num_channels += 256
        print(f"final num_channels: {num_channels}")
        self.fc = nn.Linear(num_channels, N)
        self.N = N
        self.K = K
        self.use_cuda = use_cuda
        
    def forward(self, input, labels):
        x = self.encoder(input)
        # print(f"x shape after encoding: {x.shape}")
        batch_size = int(labels.size()[0] / (self.N * self.K + 1))
        last_idxs = [(i + 1) * (self.N * self.K + 1) - 1 for i in range(batch_size)]
        labels[last_idxs] = torch.Tensor(np.zeros((batch_size, labels.size()[1])))
        if self.use_cuda:
            labels[last_idxs] = labels[last_idxs].cuda()
        x = torch.cat((x, labels), 1)
        x = x.view((batch_size, self.N * self.K + 1, -1))
        x = self.attention1(x)
        x = self.tc1(x)
        x = self.attention2(x)
        x = self.tc2(x)
        x = self.attention3(x)
        x = self.fc(x)
        return x

## Dataset loading

In [276]:
n_classes = 5
n_examples = 3
n_query = 4

In [261]:
ds_info = {
    '_target_': 'datasets.cell.tabula_muris.TMSetDataset',
    'n_way': n_classes,
    'n_support': n_examples,
    'n_query': n_query,
}
#train_batch = 16
#val_batch = 16

In [262]:
train_dataset = instantiate(ds_info, mode='train')
train_loader = train_dataset.get_data_loader()

val_dataset = instantiate(ds_info, mode='val')
val_loader = val_dataset.get_data_loader()

  self.adata.obs['label'] = pd.Categorical(values=truth_labels)
  view_to_actual(adata)
  self.adata.obs['label'] = pd.Categorical(values=truth_labels)
  view_to_actual(adata)


In [214]:
def labels_to_one_hot(opt, labels):
    #if opt.cuda:
    #    labels = labels.cpu()
    labels = labels.numpy()
    unique = np.unique(labels)
    map = {label:idx for idx, label in enumerate(unique)}
    idxs = [map[labels[i]] for i in range(labels.size)]
    one_hot = np.zeros((labels.size, unique.size))
    one_hot[np.arange(labels.size), idxs] = 1
    return one_hot, idxs

In [88]:
def batch_for_few_shot(opt, x, y):
    seq_size = opt["num_cls"] * opt["num_samples"] + 1
    print(f"batch_for_few_shot: x.shape: {x.shape}, y.shape: {y.shape}, seq_size: {seq_size}")
    one_hots = []
    last_targets = []
    for i in range(opt["batch_size"]):
        if (i + 1) * seq_size > y.shape[0]:
            break
        one_hot, idxs = labels_to_one_hot(opt, y[i * seq_size: (i + 1) * seq_size])
        print(f"one_hot: {one_hot}, idxs: {idxs}")
        one_hots.append(one_hot)
        last_targets.append(idxs[-1])
    last_targets = Variable(torch.Tensor(last_targets).long())
    one_hots = [torch.Tensor(temp) for temp in one_hots]
    y = torch.cat(one_hots, dim=0)
    x, y = Variable(x), Variable(y)
    if opt["cuda"]:
        x, y = x.cuda(), y.cuda()
        last_targets = last_targets.cuda()
    return x, y, last_targets

In [83]:
def get_acc(last_model, last_targets):
    _, preds = last_model.max(1)
    acc = torch.eq(preds, last_targets).float().mean()
    return acc.item()

In [191]:
y.shape

torch.Size([5, 6])

In [None]:
def labels_to_one_hot(opt, labels):
    #if opt.cuda:
    #    labels = labels.cpu()
    labels = labels.numpy()
    unique = np.unique(labels)
    map = {label:idx for idx, label in enumerate(unique)}
    idxs = [map[labels[i]] for i in range(labels.size)]
    one_hot = np.zeros((labels.size, unique.size))
    one_hot[np.arange(labels.size), idxs] = 1
    return one_hot, idxs

In [277]:
def get_label_map(labels):
    labels = labels.numpy()
    unique = np.unique(labels)
    map = {label:idx for idx, label in enumerate(unique)}
    return map

In [278]:
support_labels = y[:, :5].flatten() # N * K labels that are fixed for each sequence
label_map = get_label_map(support_labels)
for k, v in label_map.items():
    print(f"key: {k}, value: {v}")

key: 21, value: 0
key: 23, value: 1
key: 39, value: 2
key: 48, value: 3
key: 52, value: 4


In [283]:
def get_one_hots(labels, map):
    labels = labels.numpy()
    idxs = [map[labels[i]] for i in range(labels.size)]
    one_hot = np.zeros((labels.size, len(map)))
    one_hot[np.arange(labels.size), idxs] = 1
    return one_hot, idxs

In [284]:
def seq_batch(opt, x, y):
    sequences = []

    # get support set and labels
    support_set = x[:, :opt["num_samples"], :] # N x K samples that are fixed for each sequence
    support_set = support_set.contiguous().view(-1, x.shape[2]) # flatten to get (N * K, input_dim)
    # print(f"support_set shape: {support_set.shape}")
    n_query = x.shape[1] - opt["num_samples"]
    support_labels = y[:, :opt["num_samples"]].flatten() # N * K labels that are fixed for each sequence

    # get label map
    label_map = get_label_map(support_labels)

    all_labels = []
    pred_targets = []
    for i in range(n_query):
        for j in range(opt["num_cls"]):
            query_sample = x[j, opt["num_samples"] + i, :].unsqueeze(0) # get a new query sample
            seq = torch.cat((support_set, query_sample), dim=0) # sequence of N x K + 1 samples
            # print(f"seq shape: {seq.shape}")
            sequences.append(seq) # append to list of sequences
            
            last_seq_label = y[j, opt["num_samples"] + i]
            # seq_labels = support_labels + [last_seq_label]
            seq_labels = torch.cat((support_labels, torch.Tensor([last_seq_label]).long())) # N x K + 1 labels
            # labels, idxs = labels_to_one_hot(opt, seq_labels)
            labels, idxs = get_one_hots(seq_labels, label_map)

            all_labels.append(labels)
            pred_targets.append(idxs[-1])

    # sequences are of shape (n_query * n_cls, (N * K + 1), input_dim))
    # all_labels are of shape (n_query * n_cls, (N * K + 1), num_cls)
    # pred_tagets are of shape (n_query * n_cls)

    labels = torch.Tensor(all_labels).view(-1, opt["num_cls"])
    # print(f"labels shape: {labels.shape}")
    pred_targets = torch.Tensor(pred_targets).long()
    # print(f"pred_targets shape: {pred_targets.shape}")
    # print(f"idxs: {idxs}")
    
            
    sequences = torch.stack(sequences)
   # print(f"sequences shape: {sequences.shape}")
    sequences = sequences.view(-1, x.shape[2])
    # print(f"final sequences shape: {sequences.shape}")
    return sequences, labels, pred_targets

In [285]:
y[:, :5].flatten()

tensor([48, 48, 48, 48, 48, 39, 39, 39, 39, 39, 21, 21, 21, 21, 21, 23, 23, 23,
        23, 23, 52, 52, 52, 52, 52], dtype=torch.int32)

In [286]:
opt = {
    "num_cls": 5,
    "num_samples": 3
}
seqs, labels, targets = seq_batch(opt, x, y)

In [238]:
targets

tensor([0, 4, 3, 1, 2])

In [297]:
def train(opt, tr_iter, model, optim, val_dataloader=None):
    if val_dataloader is None:
        best_state = None
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    best_acc = 0

    #best_model_path = os.path.join(opt.exp, 'best_model.pth')
    #last_model_path = os.path.join(opt.exp, 'last_model.pth')

    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(opt["epochs"]):
        print('=== Epoch: {} ==='.format(epoch))
        # tr_iter = iter(tr_dataloader)
        print("Putting model into train mode...")
        model.train()
        print("Model is ready to train!")
        if opt["cuda"]:
            model = model.cuda()
        for batch in tqdm(tr_iter):
            optim.zero_grad()
            x, y = batch
            
            # process batch
            x, y, last_targets = seq_batch(opt, x, y)

            model_output = model(x, y)
            # print(f"shape after model: model_output -> {model_output.shape}")
            last_model = model_output[:, -1, :]
            # print(f"shape after model: last_model -> {last_model.shape}")
            loss = loss_fn(last_model, last_targets)
            loss.backward()
            optim.step()
            train_loss.append(loss.item())
            train_acc.append(get_acc(last_model, last_targets))
            
        avg_loss = np.mean(train_loss[-opt["iterations"]:])
        avg_acc = np.mean(train_acc[-opt["iterations"]:])
        print('Avg Train Loss: {}, Avg Train Acc: {}'.format(avg_loss, avg_acc))
        if val_dataloader is None:
            continue
        val_iter = iter(val_dataloader)
        model.eval()
        for batch in val_iter:
            x, y = batch
            x, y, last_targets = batch_for_few_shot(opt, x, y)
            model_output = model(x, y)
            last_model = model_output[:, -1, :]
            loss = loss_fn(last_model, last_targets)
            val_loss.append(loss.item())
            val_acc.append(get_acc(last_model, last_targets))
        avg_loss = np.mean(val_loss[-opt["iterations"]:])
        avg_acc = np.mean(val_acc[-opt["iterations"]:])
        postfix = ' (Best)' if avg_acc >= best_acc else ' (Best: {})'.format(
            best_acc)
        print('Avg Val Loss: {}, Avg Val Acc: {}{}'.format(
            avg_loss, avg_acc, postfix))
        if avg_acc >= best_acc:
            torch.save(model.state_dict(), best_model_path)
            best_acc = avg_acc
            best_state = model.state_dict()
        #for name in ['train_loss', 'train_acc', 'val_loss', 'val_acc']:
        #    save_list_to_file(os.path.join(opt.exp, name + '.txt'), locals()[name])

    torch.save(model.state_dict(), last_model_path)

    return best_state, best_acc, train_loss, train_acc, val_loss, val_acc

## Trials

In [263]:
tr_iter = iter(train_loader)

In [294]:
for sample in tr_iter:
    x, y = sample
    print(f"x shape: {x.shape}, y shape: {y.shape}")
    break

x shape: torch.Size([5, 7, 2866]), y shape: torch.Size([5, 7])


x is of shape `n_classes x (n_support + n_query) x input_dim`, y `n_classes x (n_support + n_query)`

In [299]:
n_classes = 5
n_samples = 3
lr = 1e-4
# backbone = ResNet10()
backbone = FCNet(x_dim=x.shape[2])


model = SnailFewShot(n_classes, n_samples, backbone)
optim = torch.optim.Adam(params=model.parameters(), lr=lr)

final num_channels: 1509


In [300]:
options = {
    'num_cls': n_classes,
    'num_samples': n_samples,
    'batch_size': 32,
    'cuda': False,
    'iterations': 10000,
    'epochs': 10
}

res = train(opt=options,
            tr_iter=tr_iter,
            val_dataloader=val_loader,
            model=model,
            optim=optim)

=== Epoch: 0 ===
Putting model into train mode...
Model is ready to train!


 73%|███████▎  | 73/100 [00:13<00:05,  5.33it/s]

Avg Train Loss: 1.5901453609335912, Avg Train Acc: 0.26506849551854067





support_set shape: torch.Size([15, 2866])
seq shape: torch.Size([16, 2866])
seq shape: torch.Size([16, 2866])
seq shape: torch.Size([16, 2866])
seq shape: torch.Size([16, 2866])
seq shape: torch.Size([16, 2866])
seq shape: torch.Size([16, 2866])
seq shape: torch.Size([16, 2866])
seq shape: torch.Size([16, 2866])
seq shape: torch.Size([16, 2866])
seq shape: torch.Size([16, 2866])
seq shape: torch.Size([16, 2866])
seq shape: torch.Size([16, 2866])
seq shape: torch.Size([16, 2866])
seq shape: torch.Size([16, 2866])
seq shape: torch.Size([16, 2866])
seq shape: torch.Size([16, 2866])
seq shape: torch.Size([16, 2866])
seq shape: torch.Size([16, 2866])
seq shape: torch.Size([16, 2866])
seq shape: torch.Size([16, 2866])
sequences shape: torch.Size([20, 16, 2866])


ValueError: too many values to unpack (expected 3)

In [None]:

model = init_model(options)
optim = torch.optim.Adam(params=model.parameters(), lr=options.lr)
res = train(opt=options,
            tr_dataloader=tr_dataloader,
            val_dataloader=val_dataloader,
            model=model,
            optim=optim)