In [1]:
import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_
from torch.nn.utils import weight_norm
import numpy as np
import torchvision.utils as vutils
from torchvision.utils import save_image
import random
import os
import shutil
import pdb
from logger import Logger
from PIL import Image

In [None]:
# Initialization
num_channels = 3
num_classes = 2
num_epochs = 300
image_size = 32
batch_size = 64
epsilon = 1e-8 # used to avoid NAN loss
logger = Logger('./logs')

# Initialize parameters
lr = 1e-5
b1 = 0.5 # adam: decay of first order momentum of gradient
b2 = 0.999 # adam: decay of first order momentum of gradient

model_path ='./baseline.tar'
image_dir = 'baseline_images'

os.makedirs(image_dir, exist_ok=True)

In [None]:
# Create Dataset
class TCGADataset(Dataset):
    def __init__(self, image_size, split):
        self.split = split
        self.tcga_dataset = self._create_dataset(image_size, split)
        self.patches, self.labels = self.tcga_dataset
        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize((.5, .5, .5), (.5, .5, .5))
        ])
        
    def _create_dataset(self, image_size, split):
        data_dir = '../dataset/patch_data'
        if self.split == 'train':
            data_dir = os.path.join(data_dir, 'train')
        else:
            data_dir = os.path.join(data_dir, 'dev')
            
        all_files = ['5.npz', '6.npz', '7.npz', '8.npz', '9.npz', '10.npz'] #os.listdir(data_dir)
        images = []
        labels = []
        
        # Iterate over all files
        for file in all_files:
            if '.npz' not in file:
                continue
            file_path = os.path.join(data_dir, file)
            data = np.load(file_path)
            X = data['arr_0']
            y = data['arr_1']
            images.append(X)
            labels.append(y)
            
        images = np.concatenate(images)
        labels = np.concatenate(labels)
        labels = np.asarray([1 if x in [330.0,331.0] else 0 for x in labels])
        
#         # Balance dataset
#         cancer = np.count_nonzero(labels)
#         noncancer = (labels.shape[0]-cancer)
#         minimum = min(cancer,noncancer)
#         sample_idxs_cancer = random.sample(list(np.where(labels == 1)[0]), minimum)
#         sample_idxs_nocancer = random.sample(list(np.where(labels == 0)[0]), minimum)
        
#         new_idxs = []
#         new_idxs.extend(sample_idxs_cancer)
#         new_idxs.extend(sample_idxs_nocancer)
#         random.shuffle(new_idxs)
#         images = images[new_idxs]
#         labels = labels[new_idxs]
        
        return images, labels
    

    def __getitem__(self, idx):
        data, label = self.patches[idx], self.labels[idx]
        return self.transform(Image.fromarray(data)), label

    def __len__(self):
        return len(self.labels)

In [None]:
# Get dataloaders
def get_loader(image_size, batch_size):
    num_workers = 2
    tcga_train = TCGADataset(image_size=image_size, split='train')
#     tcga_test = TCGADataset(image_size=image_size, split='test')

    train_loader = DataLoader(
        dataset=tcga_train,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers
    )

#     test_loader = DataLoader(
#         dataset=tcga_test,
#         batch_size=batch_size,
#         shuffle=True
#         #num_workers=num_workers
#     )

    return train_loader#, test_loader

In [6]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.zero_()
        
        
def initializer(m):
    # Run xavier on all weights and zero all biases
    if hasattr(m, 'weight'):
        if m.weight.ndimension() > 1:
            xavier_uniform_(m.weight.data)

    if hasattr(m, 'bias') and m.bias is not None:
        m.bias.data.zero_() 

In [7]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
          
        dropout_rate = 0.5
        filter1 = 96
        filter2 = 192
        
        # Conv operations
        # CNNBlock 1
        self.wn_conv1 = nn.Sequential(
            weight_norm(nn.Conv2d(in_channels=num_channels, out_channels=filter1, kernel_size=3, stride=1, padding=1), name='weight'),
            nn.LeakyReLU(0.2),
            weight_norm(nn.Conv2d(in_channels=filter1, out_channels=filter1, kernel_size=3, stride=1, padding=1), name='weight'),
            nn.LeakyReLU(0.2),
            weight_norm(nn.Conv2d(in_channels=filter1, out_channels=filter1, kernel_size=3, stride=2, padding=1), name='weight'),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(dropout_rate)
        )
        
        # CNNBlock 2
        self.wn_conv2 = nn.Sequential(
            weight_norm(nn.Conv2d(in_channels=filter1, out_channels=filter2, kernel_size=3, stride=1, padding=1), name='weight'),
            nn.LeakyReLU(0.2),
            weight_norm(nn.Conv2d(in_channels=filter2, out_channels=filter2, kernel_size=3, stride=1, padding=1), name='weight'),
            nn.LeakyReLU(0.2),
            weight_norm(nn.Conv2d(in_channels=filter2, out_channels=filter2, kernel_size=3, stride=2, padding=1), name='weight'),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(dropout_rate)
        )
        
        # CNNBlock 3
        self.wn_conv3 = nn.Sequential(
            weight_norm(nn.Conv2d(in_channels=filter2, out_channels=filter2, kernel_size=3, stride=1, padding=0), name='weight'),
            nn.LeakyReLU(0.2),
            weight_norm(nn.Conv2d(in_channels=filter2, out_channels=filter2, kernel_size=1, stride=1, padding=0), name='weight'),
            nn.LeakyReLU(0.2),
            weight_norm(nn.Conv2d(in_channels=filter2, out_channels=filter2, kernel_size=1, stride=1, padding=0), name='weight'),
            nn.LeakyReLU(0.2)
        )
                
        # Linear 
        self.wn_linear = weight_norm(nn.Linear(in_features=filter2, out_features=(num_classes + 1)), name='weight')
        self.apply(initializer)
        
    def forward(self, x):
        # Convolutional Operations
        pdb.set_trace()
        x = self.wn_conv1(x)
        x = self.wn_conv2(x)
        x = self.wn_conv3(x)
        
        # Linear
        x = x.mean(dim=3).mean(dim=2)
        x = self.wn_linear(flatten)
        x = F.sigmoid(x)
        
        return x

In [8]:
# Initialize loss and model
criterion = nn.BCELoss()
model = Model()

# Data Loader
train_loader = get_loader(image_size, batch_size)

# Initialize weights
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(b1, b2))

if torch.cuda.is_available():
    criterion.cuda()
    model = nn.DataParallel(model)
    model.cuda()

In [12]:
# Training Function
def train(epoch, num_epochs, optimizer, criterion, dataloader, model):
    model.train()
    
    total_loss = 0
    total_acc = 0
    loader_len = len(dataloader)

    for i, data in enumerate(dataloader):
        
        optimizer.zero_grad()
        img, label = data
        if torch.cuda.is_available():
            img = img.cuda()
            label = label.cuda()
        b_size = img.size(0)
    
        # Loss computation
        probs = model(img)
        loss = criterion(probs, label)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        
        # Train Accuracy Computation
        _, predicted = torch.max(probs, dim=1)
        correct = torch.sum(torch.eq(predicted, label))
        batch_accuracy = correct.item()/float(b_size)
        total_acc += batch_accuracy
        
        # Print stats
        if i%b_size == b_size-1:
            print("Train [Epoch %d/%d] [Batch %d/%d] [loss: %f, acc: %d%%]" % (epoch, num_epochs, i, 
                                       loader_len, loss.item(), 100 * batch_accuracy))
            
    total_loss = total_loss/float(i+1)
    total_acc = total_acc/float(i+1)
    return total_loss, total_acc

In [13]:
# Testing Function
def test(epoch, num_epochs, criterion, dataloader, model):
    model.eval()
    
    total_loss = 0
    total_acc = 0
    loader_len = len(dataloader)

    for i, data in enumerate(dataloader):
        
        img, label = data
        if torch.cuda.is_available():
            img = img.cuda()
            label = label.cuda()
        b_size = img.size(0)
    
        # Loss computation
        probs = model(img)
        loss = criterion(probs, label)
        total_loss += loss.item()
        
        # Test Accuracy Computation
        _, predicted = torch.max(probs, dim=1)
        correct = torch.sum(torch.eq(predicted, label))
        batch_accuracy = correct.item()/float(b_size)
        total_acc += batch_accuracy
        
        # Print stats
        if i%b_size == b_size-1:
            print("Test [Epoch %d/%d] [Batch %d/%d] [loss: %f, acc: %d%%]" % (epoch, num_epochs, i, 
                                       loader_len, loss.item(), 100 * batch_accuracy))
            
    total_loss = total_loss/float(i+1)
    total_acc = total_acc/float(i+1)
    return total_loss, total_acc

In [None]:
def save_checkpoint(state):
    torch.save(state, 'baseline_chkpt.tar')

In [14]:
'''
Call Train and Test and save best model
'''

for epoch in range(num_epochs):
    train_loss, train_acc = train(epoch, num_epochs, optimizer, criterion, train_loader, model)
    print('-----------')
    valid_loss, valid_acc = test(epoch, num_epochs, criterion, valid_loader, model)
    
    # Save best model
    save_checkpoint({
    'epoch': epoch + 1,
    'state_dict': model.state_dict(),
    'optimizer' : optimizer.state_dict()
    }
    
    # Display Progress
    print ("[Epoch %d/%d], Train Loss = %f, Validation Loss = %f, Test Loss = %f, Test Accuracy = %f" % (epoch, num_epochs, train_loss, 
                                                                                    valid_loss, test_loss, test_accuracy))
    

  "Please ensure they have the same size.".format(target.size(), input.size()))
  "Please ensure they have the same size.".format(target.size(), input.size()))
  "Please ensure they have the same size.".format(target.size(), input.size()))


[Epoch 0/10], Train Loss = 725.466133, Validation Loss = 49.491781, Test Loss = 49.491781, Test Accuracy = 50.000000


KeyboardInterrupt: 