In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import data_utils
import train_utils
import models

%reload_ext autoreload
%autoreload 2

In [2]:
dev = data_utils.COAD_dataset(data_utils.COAD_DEV)
dev_loader = torch.utils.data.DataLoader(dev, batch_size=1, shuffle=True , pin_memory=True)

In [303]:
# generator
class Generator(nn.Module):
    def __init__(self, n_conv_layers, kernel_size, n_conv_filters, hidden_size, n_rnn_layers, dropout=0.5):
        super(Generator, self).__init__()
        self.n_conv_layers = n_conv_layers
        self.kernel_size = kernel_size
        self.n_conv_filters = n_conv_filters
        self.hidden_size = hidden_size
        self.n_rnn_layers = n_rnn_layers
        self.conv_layers = []
        self.m = nn.MaxPool2d(2, stride=2)
        self.relu = nn.ReLU()
         
        in_channels = 3        
        for layer in range(self.n_conv_layers):
            self.conv_layers.append(nn.Conv2d(in_channels, self.n_conv_filters[layer], self.kernel_size[layer]))
            self.conv_layers.append(self.relu)
            self.conv_layers.append(self.m)
            in_channels = self.n_conv_filters[layer]
        self.conv = nn.Sequential(*self.conv_layers)
        in_channels = in_channels * 25

        self.lstm = nn.LSTM(in_channels, self.hidden_size, self.n_rnn_layers, batch_first=True, 
                            dropout=dropout, bidirectional=True) 
        in_channels = hidden_size * 2
        self.classification_layer = nn.Linear(in_channels, 2)
        
    def forward(self, x):
        embed = self.conv(x)
        embed = embed.view(1,x.shape[0],-1)
        self.lstm.flatten_parameters()
        output, hidden = self.lstm(embed)
        y = self.classification_layer(output)
        return y
    
    def zero_grad(self):
        """Sets gradients of all model parameters to zero."""
        for p in self.parameters():
            if p.grad is not None:
                p.grad.data.zero_()

In [6]:
for slide,label in dev_loader:
    break

In [8]:
slide.squeeze_(0)

tensor([[[[0.5216, 0.3098, 0.5961,  ..., 0.5294, 0.3098, 0.6157],
          [0.5647, 0.3451, 0.6471,  ..., 0.6706, 0.4392, 0.6902],
          [0.6510, 0.4157, 0.6706,  ..., 0.6510, 0.4078, 0.7137],
          ...,
          [0.6118, 0.3843, 0.7255,  ..., 0.5765, 0.3490, 0.6667],
          [0.5412, 0.3098, 0.6353,  ..., 0.6118, 0.3882, 0.6588],
          [0.6863, 0.4588, 0.7176,  ..., 0.8706, 0.6078, 0.8353]],

         [[0.5843, 0.3608, 0.7059,  ..., 0.7451, 0.5098, 0.8275],
          [0.6235, 0.3961, 0.7098,  ..., 0.6745, 0.4588, 0.7216],
          [0.6431, 0.4235, 0.6941,  ..., 0.8118, 0.5490, 0.7804],
          ...,
          [0.8314, 0.6392, 0.8510,  ..., 0.8196, 0.5804, 0.8353],
          [0.7961, 0.5608, 0.8157,  ..., 0.8196, 0.5725, 0.8196],
          [0.7843, 0.5412, 0.7882,  ..., 0.7490, 0.5098, 0.7373]],

         [[0.5804, 0.3765, 0.6392,  ..., 0.8000, 0.5569, 0.8196],
          [0.7843, 0.5451, 0.8078,  ..., 0.9725, 0.7333, 0.9294],
          [0.9490, 0.7059, 0.9059,  ..., 0

In [9]:
slide.shape

torch.Size([128, 3, 27, 27])

In [304]:
n_conv_layers = 2
kernel_size = [4,3]
n_conv_filters = [36,48]
hidden_size = 512
n_rnn_layers = 2
dropout=0.5
gen = Generator(n_conv_layers, kernel_size, n_conv_filters, hidden_size, n_rnn_layers, dropout=dropout)
gen.cuda()

Generator(
  (m): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (relu): ReLU()
  (conv): Sequential(
    (0): Conv2d(3, 36, kernel_size=(4, 4), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(36, 48, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (lstm): LSTM(1200, 512, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)
  (classification_layer): Linear(in_features=1024, out_features=2, bias=True)
)

In [305]:
gen.zero_grad()

In [237]:
slide = slide.cuda()
preds = gen(slide)
preds.shape

torch.Size([1, 128, 2])

In [21]:
slide.shape

torch.Size([128, 3, 27, 27])

In [25]:
selector = torch.argmax(preds, dim=2)
selector.shape

torch.Size([1, 128])

In [30]:
rationale = slide[selector.squeeze_(0)==1,:,:,:]

In [29]:
n_conv_layers = 2
n_fc_layers = 2
kernel_size = [4,3]
n_conv_filters = [36,48]
hidden_size = [512,512]
dropout=0.5
enc = models.ConvNet(n_conv_layers, n_fc_layers, kernel_size, n_conv_filters, hidden_size, dropout=dropout)
enc.cuda()

ConvNet(
  (m): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (n): Dropout(p=0.5)
  (relu): ReLU()
  (conv): Sequential(
    (0): Conv2d(3, 36, kernel_size=(4, 4), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(36, 48, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=1200, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.5)
  )
  (classification_layer): Linear(in_features=512, out_features=2, bias=True)
)

In [31]:
output = enc(rationale)
output.shape

torch.Size([48, 512])

In [32]:
def pool_fn(x):
    #v,a = torch.max(x,0)
    v = torch.mean(x,0)
    return v

pool = pool_fn(output).unsqueeze(0)
output = enc.classification_layer(pool)
output.shape

torch.Size([1, 2])

In [35]:
# loss fn
lamb1 = 0.01
lamb2 = 0.01

xent = nn.CrossEntropyLoss()
znorm = torch.norm(selector.float(), p=1)
zdist = torch.sum(torch.abs(selector[:-1] - selector[1:]))
omega = (lamb1 * znorm) + (lamb2 * zdist)

In [38]:
cost = xent(output, label.cuda()) + omega
cost

tensor(1.1499, device='cuda:0', grad_fn=<AddBackward0>)

In [253]:
lsm = nn.LogSoftmax(dim=2)

In [263]:
batch_size = 10

In [339]:
# storage
zis = []
grads = []
all_grads = []
for p in gen.parameters():
    grads.append(torch.zeros(p.shape, device='cuda'))
    start = [batch_size]
    start.extend(list(p.shape))
    all_grads.append(torch.zeros(start, device='cuda'))

In [340]:
# forward pass, sample, backward pass
for sample in range(batch_size):
    preds = gen(slide)
    logits = lsm(preds).squeeze(0)
    b = torch.distributions.bernoulli.Bernoulli(logits=logits[:,1])
    zi = b.sample() #zis = b.sample(torch.Size([batch_size]))
    zis.append(zi)

    logprobs = b.log_prob(zi).sum()
    logprobs.backward()

    idx = 0
    for p in gen.parameters():
        all_grads[idx][sample] = p.grad
        idx += 1

    gen.zero_grad()

In [341]:
zis = torch.stack(zis)
zis.shape

torch.Size([10, 128])

In [342]:
# get rationales
rationales = [slide[zi==1,:,:,:] for zi in zis]

In [343]:
#max_len = zis.sum(dim=1).max().cpu().numpy().astype(int)
#padded_rationales = torch.stack([F.pad(r,(0,0,0,0,0,0,0,max_len - r.shape[0])) for r in rationales])

In [344]:
# concatenate rationales
lens = zis.sum(dim=1)
sampled_rationales = torch.cat(rationales,dim=0)

In [345]:
assert sampled_rationales.shape[0] == lens.sum()

In [346]:
# feed rationales to encoder
outputs = enc(sampled_rationales)

In [347]:
# unwind rationales
indexs = torch.cat([torch.zeros(1),torch.cumsum(lens,0).cpu()]).int()
outputs = [outputs[indexs[n]:indexs[n+1]] for n,ix in enumerate(indexs[:-1])]

In [348]:
# pool
pool = torch.stack([pool_fn(o).unsqueeze(0) for o in outputs])

In [349]:
# get preds
y_hat = enc.classification_layer(pool.squeeze(1))

In [350]:
# calc cost
znorm = torch.norm(zis.float(), p=1, dim=1)
zdist = torch.sum(torch.abs(zis[:,:-1] - zis[:,1:]), dim=1)
omega = (lamb1 * znorm) + (lamb2 * zdist)
cost = xent(y_hat, label.repeat(batch_size).cuda()) + omega

In [351]:
cost.shape

torch.Size([10])

In [353]:
for sample in range(batch_size):
    idx = 0
    for p in gen.parameters():
        grads[idx] += cost[sample] * all_grads[idx][sample]
        idx += 1

In [355]:
for p in gen.parameters():
    print(p)

Parameter containing:
tensor([[[[-1.4194e-01, -5.9658e-02,  5.4201e-02, -8.4589e-02],
          [-6.6064e-02, -8.9865e-02,  9.8630e-02, -5.9396e-02],
          [-6.5374e-02, -4.2674e-03,  8.2360e-02,  1.3028e-01],
          [-6.7901e-02,  1.2268e-01, -8.8947e-02,  7.0122e-02]],

         [[ 6.9880e-02, -1.3379e-01, -1.0102e-02,  9.1821e-02],
          [ 6.5529e-02, -5.4471e-02,  6.7287e-02,  7.2666e-02],
          [-5.4700e-02,  1.1422e-01,  7.5434e-02, -3.3597e-02],
          [ 1.2686e-01, -5.9633e-02,  7.1653e-02,  3.5697e-02]],

         [[ 5.5489e-02, -2.7086e-02,  1.2634e-01,  2.9274e-02],
          [-9.8619e-02,  3.9167e-02,  8.0053e-02, -1.0697e-02],
          [-2.4472e-03, -7.5699e-02,  1.7157e-02,  1.5214e-05],
          [-1.4202e-01, -1.7843e-02, -1.1125e-01, -2.7665e-02]]],


        [[[ 8.4006e-02,  1.1912e-01, -4.5461e-02, -5.5560e-02],
          [ 8.0556e-02, -5.6804e-02,  3.4504e-02,  8.6827e-02],
          [-1.0442e-02, -9.6858e-03, -6.3741e-02, -9.8955e-02],
          

In [358]:
learning_rate = 0.1

idx = 0
for p in gen.parameters():
    p.data = p.data - learning_rate * grads[idx]
    idx += 1

In [359]:
for p in gen.parameters():
    print(p)

Parameter containing:
tensor([[[[-0.1390, -0.0565,  0.0568, -0.0819],
          [-0.0637, -0.0872,  0.1006, -0.0571],
          [-0.0638, -0.0020,  0.0839,  0.1321],
          [-0.0653,  0.1253, -0.0868,  0.0723]],

         [[ 0.0715, -0.1326, -0.0098,  0.0935],
          [ 0.0680, -0.0518,  0.0690,  0.0746],
          [-0.0507,  0.1184,  0.0782, -0.0297],
          [ 0.1292, -0.0578,  0.0723,  0.0378]],

         [[ 0.0592, -0.0238,  0.1287,  0.0327],
          [-0.0977,  0.0402,  0.0812, -0.0083],
          [ 0.0002, -0.0726,  0.0189,  0.0023],
          [-0.1379, -0.0140, -0.1093, -0.0241]]],


        [[[ 0.0924,  0.1287, -0.0379, -0.0464],
          [ 0.0885, -0.0483,  0.0414,  0.0947],
          [-0.0019, -0.0006, -0.0574, -0.0910],
          [-0.0224, -0.1283,  0.0978, -0.0164]],

         [[ 0.0303,  0.0913, -0.0181,  0.1313],
          [-0.0338,  0.0789,  0.0453,  0.1369],
          [-0.0435, -0.0459, -0.0392,  0.0706],
          [-0.1145,  0.1112,  0.1486,  0.1212]],

      

In [362]:
optimizer = torch.optim.Adam(enc.parameters(), lr = learning_rate)

In [363]:
loss = cost.sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()

In [367]:
def sampler(slide, gen, num_samples):
    zis = []
    grads = []
    all_grads = []
    for p in gen.parameters():
        start = [num_samples]
        start.extend(list(p.shape))
        all_grads.append(torch.zeros(start, device='cuda'))
        grads.append(torch.zeros(p.shape, device='cuda'))
        
    for sample in range(num_samples):
        preds = gen(slide)
        logits = lsm(preds).squeeze(0)
        b = torch.distributions.bernoulli.Bernoulli(logits=logits[:,1])
        zi = b.sample() #zis = b.sample(torch.Size([batch_size]))
        zis.append(zi)

        logprobs = b.log_prob(zi).sum()
        logprobs.backward()

        for idx,p in enumerate(gen.parameters()):
            all_grads[idx][sample] = p.grad

        gen.zero_grad()
        
    return zis, grads, all_grads

In [374]:
def rationales_training(e, train_loader, gen, enc, num_samples, learning_rate, optimizer):
    gen.train()
    enc.train()
    
    total_loss = 0
    for slide,label in train_loader:
        slide,label = slide.squeeze(0).cuda(),label.cuda()
        zis, grads, all_grads = sampler(slide, gen, num_samples)
        zis = torch.stack(zis)
        
        rationales = [slide[zi==1,:,:,:] for zi in zis]
        sampled_rationales = torch.cat(rationales,dim=0)
        outputs = enc(sampled_rationales)
        
        lens = zis.sum(dim=1)
        indexs = torch.cat([torch.zeros(1),torch.cumsum(lens,0).cpu()]).int()
        outputs = [outputs[indexs[n]:indexs[n+1]] for n,ix in enumerate(indexs[:-1])]
        
        pool = torch.stack([pool_fn(o).unsqueeze(0) for o in outputs])
        y_hat = enc.classification_layer(pool.squeeze(1))
        
        znorm = torch.norm(zis.float(), p=1, dim=1)
        zdist = torch.sum(torch.abs(zis[:,:-1] - zis[:,1:]), dim=1)
        omega = (lamb1 * znorm) + (lamb2 * zdist)
        cost = xent(y_hat, label.repeat(batch_size).cuda()) + omega
        
        for sample in range(num_samples):
            for idx,p in enumerate(gen.parameters()):
                grads[idx] += cost[sample] * all_grads[idx][sample] 
        
        for idx,p in enumerate(gen.parameters()):
            p.data = p.data - learning_rate * (grads[idx] / float(num_samples))
            
        loss = cost.sum()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        total_loss += loss.detach().cpu().numpy()

    print('Epoch: {0}, Train Loss: {1:0.4f}'.format(e, total_loss))

In [375]:
e = 0
rationales_training(e, dev_loader, gen, enc, batch_size, learning_rate, optimizer)

  self.dropout, self.training, self.bidirectional, self.batch_first)


RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM

In [377]:
slide.dtype

torch.float32