In [149]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import data_utils
import train_utils

%reload_ext autoreload
%autoreload 2

In [128]:
class ConvNet(nn.Module):
    def __init__(self, n_conv_layers, n_fc_layers, kernel_size, n_conv_filters, hidden_size, dropout=0.5):
        super(ConvNet, self).__init__()
        self.n_conv_layers = n_conv_layers
        self.n_fc_layers = n_fc_layers
        self.kernel_size = kernel_size
        self.n_conv_filters = n_conv_filters
        self.hidden_size = hidden_size
        self.conv_layers = []
        self.fc_layers = []
        self.m = nn.MaxPool2d(2, stride=2)
        self.n = nn.Dropout(dropout)
        self.relu = nn.ReLU()
        
        in_channels = 3        
        for layer in range(self.n_conv_layers):
            self.conv_layers.append(nn.Conv2d(in_channels, self.n_conv_filters[layer], self.kernel_size[layer]))
            self.conv_layers.append(self.relu)
            self.conv_layers.append(self.m)
            in_channels = self.n_conv_filters[layer]
        in_channels = in_channels * 25
        for layer in range(self.n_fc_layers):
            self.fc_layers.append(nn.Linear(in_channels, self.hidden_size[layer]))
            self.fc_layers.append(self.relu)
            self.fc_layers.append(self.n)
            in_channels = self.hidden_size[layer]
        self.conv = nn.Sequential(*self.conv_layers)
        self.fc = nn.Sequential(*self.fc_layers)
        self.classification_layer = nn.Linear(in_channels, 2)
        
    def forward(self, x):
        embed = self.conv(x)
        embed = embed.view(x.shape[0],-1)
        y = self.fc(embed)
        return y

In [22]:
dev = data_utils.COAD_dataset(data_utils.COAD_DEV)
dev_loader = torch.utils.data.DataLoader(dev, batch_size=1, shuffle=True , pin_memory=True)

In [129]:
n_conv_layers = 2
n_fc_layers = 2
kernel_size = [4,3]
n_conv_filters = [36,48]
hidden_size = [512,512]
dropout=0
net = ConvNet(n_conv_layers, n_fc_layers, kernel_size, n_conv_filters, hidden_size, dropout=dropout)
net.cuda()

ConvNet(
  (m): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (n): Dropout(p=0)
  (relu): ReLU()
  (conv): Sequential(
    (0): Conv2d(3, 36, kernel_size=(4, 4), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(36, 48, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=1200, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): ReLU()
    (5): Dropout(p=0)
  )
  (classification_layer): Linear(in_features=512, out_features=2, bias=True)
)

In [84]:
criterion = nn.CrossEntropyLoss()

In [130]:
lr = 0.0001
weight_decay = 0
optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)

In [136]:
epochs = 1000
net.train()

for e in range(epochs):
    total_loss = 0
    for slide,label in dev_loader:
        slide.squeeze_()
        slide, label = slide.cuda(), label.cuda()
        output = net(slide)
        pool = torch.mean(output, 0).unsqueeze(0)
        output = net.classification_layer(pool)
        loss = criterion(output, label)#.unsqueeze(0).float())
        loss.backward()
        total_loss += loss.detach().cpu().numpy()
        optimizer.step()
        optimizer.zero_grad()
    print('Epoch: {0}, Train NLL: {1:0.4f}'.format(e, total_loss))

Epoch: 0, Train NLL: 0.0000
Epoch: 1, Train NLL: 0.0000
Epoch: 2, Train NLL: 0.0000
Epoch: 3, Train NLL: 0.0000
Epoch: 4, Train NLL: 0.0000
Epoch: 5, Train NLL: 0.0000
Epoch: 6, Train NLL: 0.0000
Epoch: 7, Train NLL: 0.0000
Epoch: 8, Train NLL: 0.0000
Epoch: 9, Train NLL: 0.0000
Epoch: 10, Train NLL: 0.0000
Epoch: 11, Train NLL: 0.0000
Epoch: 12, Train NLL: 0.0000
Epoch: 13, Train NLL: 0.0000
Epoch: 14, Train NLL: 0.0000
Epoch: 15, Train NLL: 0.0000
Epoch: 16, Train NLL: 0.0000
Epoch: 17, Train NLL: 0.0000
Epoch: 18, Train NLL: 0.0000
Epoch: 19, Train NLL: 0.0000
Epoch: 20, Train NLL: 0.0000
Epoch: 21, Train NLL: 0.0000
Epoch: 22, Train NLL: 0.0000
Epoch: 23, Train NLL: 0.0000
Epoch: 24, Train NLL: 0.0000
Epoch: 25, Train NLL: 0.0000
Epoch: 26, Train NLL: 0.0000
Epoch: 27, Train NLL: 0.0000
Epoch: 28, Train NLL: 0.0000
Epoch: 29, Train NLL: 0.0000
Epoch: 30, Train NLL: 0.0000
Epoch: 31, Train NLL: 0.0000
Epoch: 32, Train NLL: 0.0000
Epoch: 33, Train NLL: 0.0000
Epoch: 34, Train NLL: 0.

KeyboardInterrupt: 

In [148]:
net.eval()

for e in range(epochs):
    total_loss = 0
    labels = []
    preds = []
    for slide,label in dev_loader:
        slide.squeeze_()
        slide, label = slide.cuda(), label.cuda()
        output = net(slide)
        pool = torch.mean(output, 0).unsqueeze(0)
        output = net.classification_layer(pool)
        loss = criterion(output, label)#.unsqueeze(0).float())
        
        total_loss += loss.detach().cpu().numpy()
        labels.extend(label.cpu().numpy())
        preds.append(torch.argmax(output).detach().cpu().numpy())
    
    acc = np.mean(labels == preds)
    print('Epoch: {0}, Val NLL: {1:0.4f}, Val Acc: {2:0.4f}'.format(e, total_loss, acc))

Epoch: 0, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 1, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 2, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 3, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 4, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 5, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 6, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 7, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 8, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 9, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 10, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 11, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 12, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 13, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 14, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 15, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 16, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 17, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 18, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 19, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 20, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 21, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 22, Val NLL: 0.0000, Val Acc: 1.000

Epoch: 191, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 192, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 193, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 194, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 195, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 196, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 197, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 198, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 199, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 200, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 201, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 202, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 203, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 204, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 205, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 206, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 207, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 208, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 209, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 210, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 211, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 212, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 213

Epoch: 379, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 380, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 381, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 382, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 383, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 384, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 385, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 386, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 387, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 388, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 389, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 390, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 391, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 392, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 393, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 394, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 395, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 396, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 397, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 398, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 399, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 400, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 401

Epoch: 563, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 564, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 565, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 566, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 567, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 568, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 569, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 570, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 571, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 572, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 573, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 574, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 575, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 576, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 577, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 578, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 579, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 580, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 581, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 582, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 583, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 584, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 585

Epoch: 747, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 748, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 749, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 750, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 751, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 752, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 753, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 754, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 755, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 756, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 757, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 758, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 759, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 760, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 761, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 762, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 763, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 764, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 765, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 766, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 767, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 768, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 769

Epoch: 931, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 932, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 933, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 934, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 935, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 936, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 937, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 938, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 939, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 940, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 941, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 942, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 943, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 944, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 945, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 946, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 947, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 948, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 949, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 950, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 951, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 952, Val NLL: 0.0000, Val Acc: 1.0000
Epoch: 953

In [151]:
train = data_utils.COAD_dataset(data_utils.COAD_TRAIN)
train_loader = torch.utils.data.DataLoader(train, batch_size=1, shuffle=True , pin_memory=True)

valid = data_utils.COAD_dataset(data_utils.COAD_VALID)
valid_loader = torch.utils.data.DataLoader(valid, batch_size=1, shuffle=True , pin_memory=True)

In [152]:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=4)

In [156]:
def pool_fn(x):
    return torch.mean(x,0)


best_loss = 1e8
for e in range(epochs):
    train_utils.training_loop(e, train_loader, net, criterion, optimizer,pool_fn)
    loss = train_utils.validation_loop(e, valid_loader, net, criterion,pool_fn)
    scheduler.step(loss)
    print('LR = {}'.format(optimizer.state_dict()['param_groups'][0]['lr']))
    if loss < best_loss:
        torch.save(net.state_dict(),'best_convnet.pt')
        best_loss = loss
        print('WROTE MODEL')

Epoch: 0, Train NLL: 21.5242
Epoch: 0, Val NLL: 2.9926, Val Acc: 1.0000
LR = 1.0000000000000004e-08
WROTE MODEL
Epoch: 1, Train NLL: 21.5242
Epoch: 1, Val NLL: 2.9925, Val Acc: 1.0000
LR = 1.0000000000000004e-08
WROTE MODEL
Epoch: 2, Train NLL: 21.5239
Epoch: 2, Val NLL: 2.9926, Val Acc: 1.0000
LR = 1.0000000000000004e-08
Epoch: 3, Train NLL: 21.5238
Epoch: 3, Val NLL: 2.9927, Val Acc: 1.0000
LR = 1.0000000000000004e-08
Epoch: 4, Train NLL: 21.5240
Epoch: 4, Val NLL: 2.9926, Val Acc: 1.0000
LR = 1.0000000000000004e-08
Epoch: 5, Train NLL: 21.5233
Epoch: 5, Val NLL: 2.9925, Val Acc: 1.0000
LR = 1.0000000000000004e-08
WROTE MODEL
Epoch: 6, Train NLL: 21.5233
Epoch: 6, Val NLL: 2.9926, Val Acc: 1.0000
LR = 1.0000000000000004e-08
Epoch: 7, Train NLL: 21.5231
Epoch: 7, Val NLL: 2.9924, Val Acc: 1.0000
LR = 1.0000000000000004e-08
WROTE MODEL
Epoch: 8, Train NLL: 21.5228
Epoch: 8, Val NLL: 2.9925, Val Acc: 1.0000
LR = 1.0000000000000004e-08
Epoch: 9, Train NLL: 21.5229
Epoch: 9, Val NLL: 2.99

KeyboardInterrupt: 