In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import data_utils
import train_utils
import models

%reload_ext autoreload
%autoreload 2

In [2]:
dev = data_utils.COAD_dataset(data_utils.COAD_DEV)
dev_loader = torch.utils.data.DataLoader(dev, batch_size=1, shuffle=True , pin_memory=True)

In [18]:
# generator
class Generator(nn.Module):
    def __init__(self, n_conv_layers, kernel_size, n_conv_filters, hidden_size, n_rnn_layers, dropout=0.5):
        super(Generator, self).__init__()
        self.n_conv_layers = n_conv_layers
        self.kernel_size = kernel_size
        self.n_conv_filters = n_conv_filters
        self.hidden_size = hidden_size
        self.n_rnn_layers = n_rnn_layers
        self.conv_layers = []
        self.m = nn.MaxPool2d(2, stride=2)
        self.relu = nn.ReLU()
         
        in_channels = 3        
        for layer in range(self.n_conv_layers):
            self.conv_layers.append(nn.Conv2d(in_channels, self.n_conv_filters[layer], self.kernel_size[layer]))
            self.conv_layers.append(self.relu)
            self.conv_layers.append(self.m)
            in_channels = self.n_conv_filters[layer]
        self.conv = nn.Sequential(*self.conv_layers)
        in_channels = in_channels * 25

        self.lstm = nn.LSTM(in_channels, self.hidden_size, self.n_rnn_layers, batch_first=True, 
                            dropout=dropout, bidirectional=True) 
        in_channels = hidden_size * 2
        self.classification_layer = nn.Linear(in_channels, 2)
        
    def forward(self, x):
        embed = self.conv(x)
        embed = embed.view(1,x.shape[0],-1)
        output, hidden = self.lstm(embed)
        y = self.classification_layer(output)
        return y

In [4]:
# encoder

In [None]:
# loss fn

In [6]:
for slide,label in dev_loader:
    break

In [8]:
slide.squeeze_(0)

tensor([[[[0.5216, 0.3098, 0.5961,  ..., 0.5294, 0.3098, 0.6157],
          [0.5647, 0.3451, 0.6471,  ..., 0.6706, 0.4392, 0.6902],
          [0.6510, 0.4157, 0.6706,  ..., 0.6510, 0.4078, 0.7137],
          ...,
          [0.6118, 0.3843, 0.7255,  ..., 0.5765, 0.3490, 0.6667],
          [0.5412, 0.3098, 0.6353,  ..., 0.6118, 0.3882, 0.6588],
          [0.6863, 0.4588, 0.7176,  ..., 0.8706, 0.6078, 0.8353]],

         [[0.5843, 0.3608, 0.7059,  ..., 0.7451, 0.5098, 0.8275],
          [0.6235, 0.3961, 0.7098,  ..., 0.6745, 0.4588, 0.7216],
          [0.6431, 0.4235, 0.6941,  ..., 0.8118, 0.5490, 0.7804],
          ...,
          [0.8314, 0.6392, 0.8510,  ..., 0.8196, 0.5804, 0.8353],
          [0.7961, 0.5608, 0.8157,  ..., 0.8196, 0.5725, 0.8196],
          [0.7843, 0.5412, 0.7882,  ..., 0.7490, 0.5098, 0.7373]],

         [[0.5804, 0.3765, 0.6392,  ..., 0.8000, 0.5569, 0.8196],
          [0.7843, 0.5451, 0.8078,  ..., 0.9725, 0.7333, 0.9294],
          [0.9490, 0.7059, 0.9059,  ..., 0

In [9]:
slide.shape

torch.Size([128, 3, 27, 27])

In [19]:
n_conv_layers = 2
kernel_size = [4,3]
n_conv_filters = [36,48]
hidden_size = 512
n_rnn_layers = 2
dropout=0.5
gen = Generator(n_conv_layers, kernel_size, n_conv_filters, hidden_size, n_rnn_layers, dropout=dropout)
gen.cuda()

Generator(
  (m): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (relu): ReLU()
  (conv): Sequential(
    (0): Conv2d(3, 36, kernel_size=(4, 4), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(36, 48, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (lstm): LSTM(1200, 512, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)
  (classification_layer): Linear(in_features=1024, out_features=2, bias=True)
)

In [20]:
slide = slide.cuda()
preds = gen(slide)
preds.shape

torch.Size([1, 128, 2])

In [21]:
slide.shape

torch.Size([128, 3, 27, 27])

In [25]:
selector = torch.argmax(preds, dim=2)
selector.shape

torch.Size([1, 128])

In [30]:
rationale = slide[selector.squeeze_(0)==1,:,:,:]

In [29]:
n_conv_layers = 2
n_fc_layers = 2
kernel_size = [4,3]
n_conv_filters = [36,48]
hidden_size = [512,512]
dropout=0.5
enc = models.ConvNet(n_conv_layers, n_fc_layers, kernel_size, n_conv_filters, hidden_size, dropout=dropout)
enc.cuda()

ConvNet(
  (m): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (n): Dropout(p=0.5)
  (relu): ReLU()
  (conv): Sequential(
    (0): Conv2d(3, 36, kernel_size=(4, 4), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(36, 48, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=1200, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.5)
  )
  (classification_layer): Linear(in_features=512, out_features=2, bias=True)
)

In [31]:
output = enc(rationale)
output.shape

torch.Size([48, 512])