In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import data_utils
import train_utils
import models

%reload_ext autoreload
%autoreload 2

In [2]:
dev = data_utils.COAD_dataset(data_utils.COAD_DEV)
dev_loader = torch.utils.data.DataLoader(dev, batch_size=1, shuffle=True , pin_memory=True)

In [3]:
# generator
class Generator(nn.Module):
    def __init__(self, n_conv_layers, kernel_size, n_conv_filters, hidden_size, n_rnn_layers, dropout=0.5):
        super(Generator, self).__init__()
        self.n_conv_layers = n_conv_layers
        self.kernel_size = kernel_size
        self.n_conv_filters = n_conv_filters
        self.hidden_size = hidden_size
        self.n_rnn_layers = n_rnn_layers
        self.conv_layers = []
        self.m = nn.MaxPool2d(2, stride=2)
        self.relu = nn.ReLU()
         
        in_channels = 3        
        for layer in range(self.n_conv_layers):
            self.conv_layers.append(nn.Conv2d(in_channels, self.n_conv_filters[layer], self.kernel_size[layer]))
            self.conv_layers.append(self.relu)
            self.conv_layers.append(self.m)
            in_channels = self.n_conv_filters[layer]
        self.conv = nn.Sequential(*self.conv_layers)
        in_channels = in_channels * 25

        self.lstm = nn.LSTM(in_channels, self.hidden_size, self.n_rnn_layers, batch_first=True, 
                            dropout=dropout, bidirectional=True) 
        in_channels = hidden_size * 2
        self.classification_layer = nn.Linear(in_channels, 2)
        
    def forward(self, x):
        embed = self.conv(x)
        embed = embed.view(1,x.shape[0],-1)
        self.lstm.flatten_parameters()
        output, hidden = self.lstm(embed)
        y = self.classification_layer(output)
        return y
    
    def zero_grad(self):
        """Sets gradients of all model parameters to zero."""
        for p in self.parameters():
            if p.grad is not None:
                p.grad.data.zero_()

In [28]:
n_conv_layers = 2
kernel_size = [4,3]
n_conv_filters = [36,48]
hidden_size = 512
n_rnn_layers = 2
dropout=0.5
gen = Generator(n_conv_layers, kernel_size, n_conv_filters, hidden_size, n_rnn_layers, dropout=dropout)
gen.cuda()

Generator(
  (m): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (relu): ReLU()
  (conv): Sequential(
    (0): Conv2d(3, 36, kernel_size=(4, 4), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(36, 48, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (lstm): LSTM(1200, 512, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)
  (classification_layer): Linear(in_features=1024, out_features=2, bias=True)
)

In [29]:
n_conv_layers = 2
n_fc_layers = 2
kernel_size = [4,3]
n_conv_filters = [36,48]
hidden_size = [512,512]
dropout=0.5
enc = models.ConvNet(n_conv_layers, n_fc_layers, kernel_size, n_conv_filters, hidden_size, dropout=dropout)
enc.cuda()

ConvNet(
  (m): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (n): Dropout(p=0.5)
  (relu): ReLU()
  (conv): Sequential(
    (0): Conv2d(3, 36, kernel_size=(4, 4), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(36, 48, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=1200, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.5)
  )
  (classification_layer): Linear(in_features=512, out_features=2, bias=True)
)

In [6]:
def pool_fn(x):
    #v,a = torch.max(x,0)
    v = torch.mean(x,0)
    return v

In [7]:
lamb1 = 0.01
lamb2 = 0.01
xent = nn.CrossEntropyLoss()

In [8]:
lsm = nn.LogSoftmax(dim=2)

In [9]:
num_samples = 10

In [11]:
learning_rate = 0.1
optimizer = torch.optim.Adam(enc.parameters(), lr = learning_rate)

In [32]:
def sampler(slide, gen, num_samples):
    zis = []
    grads = []
    all_grads = []
    for p in gen.parameters():
        start = [num_samples]
        start.extend(list(p.shape))
        all_grads.append(torch.zeros(start, device='cuda'))
        grads.append(torch.zeros(p.shape, device='cuda'))
        
    for sample in range(num_samples):
        preds = gen(slide)
        logits = lsm(preds).squeeze(0)
        b = torch.distributions.bernoulli.Bernoulli(logits=logits[:,1])
        zi = b.sample() #zis = b.sample(torch.Size([batch_size]))
        zis.append(zi)

        logprobs = b.log_prob(zi).sum()
        logprobs.backward()

        for idx,p in enumerate(gen.parameters()):
            all_grads[idx][sample] = p.grad

        gen.zero_grad()
        
    return zis, grads, all_grads

In [33]:
def rationales_training(e, train_loader, gen, enc, pool_fn, num_samples, lamb1, lamb2, xent,
                        learning_rate, optimizer):
    gen.train()
    enc.train()
    
    total_loss = 0
    for slide,label in train_loader:
        slide,label = slide.squeeze(0).cuda(),label.cuda()
        zis, grads, all_grads = sampler(slide, gen, num_samples)
        zis = torch.stack(zis)
        
        rationales = [slide[zi==1,:,:,:] for zi in zis]
        sampled_rationales = torch.cat(rationales,dim=0)
        outputs = enc(sampled_rationales)
        
        lens = zis.sum(dim=1)
        indexs = torch.cat([torch.zeros(1),torch.cumsum(lens,0).cpu()]).int()
        outputs = [outputs[indexs[n]:indexs[n+1]] for n,ix in enumerate(indexs[:-1])]
        
        pool = torch.stack([pool_fn(o).unsqueeze(0) for o in outputs])
        y_hat = enc.classification_layer(pool.squeeze(1))
        
        znorm = torch.norm(zis.float(), p=1, dim=1)
        zdist = torch.sum(torch.abs(zis[:,:-1] - zis[:,1:]), dim=1)
        omega = (lamb1 * znorm) + (lamb2 * zdist)
        cost = xent(y_hat, label.repeat(num_samples)) + omega
        
        for sample in range(num_samples):
            for idx,p in enumerate(gen.parameters()):
                grads[idx] += cost[sample] * all_grads[idx][sample] 
        
        for idx,p in enumerate(gen.parameters()):
            p.data = p.data - learning_rate * (grads[idx] / float(num_samples))
            
        loss = cost.sum()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        total_loss += loss.detach().cpu().numpy()

    print('Epoch: {0}, Train Loss: {1:0.4f}'.format(e, total_loss))

In [34]:
e = 0
rationales_training(e, dev_loader, gen, enc, pool_fn, num_samples, lamb1, lamb2, xent, learning_rate, optimizer)

Epoch: 0, Train Loss: 318.7721
