In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision
from torchvision import models
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
from tqdm import tqdm
import shutil
import glob
import cv2 as cv
import matplotlib.pyplot as plt

In [2]:
torch.cuda.get_device_name(0)

'Tesla P100-PCIE-16GB'

In [3]:
def getListOfFiles(dirName):
    # create a list of file and sub directories 
    # names in the given directory 
    listOfFile = os.listdir(dirName)
    allFiles = list()
    # Iterate over all the entries
    for entry in listOfFile:
        # Create full path
        fullPath = os.path.join(dirName, entry)
        # If entry is a directory then get the list of files in this directory 
        if os.path.isdir(fullPath):
            allFiles = allFiles + getListOfFiles(fullPath)
        else:
            allFiles.append(fullPath)

    return allFiles

In [4]:
files=getListOfFiles('../input/breakhis/BreaKHis_v1/BreaKHis_v1/histology_slides/breast')
imgs=[]
for f in files:
    if f.endswith('.png'):
        imgs.append(f)

In [5]:
# Create custom dataset class
# Note : Using torch.Tensor() creates memory overhead, as it has grad_required = True
# Hence use torch.new_tensor() which has required_grad set to false
class Breakhis(Dataset):
    def __init__(self,img_list,transform=None):
        self.img_list=img_list
        self.transform=transform
    def __len__(self):
        return len(self.img_list)
    def __getitem__(self,index):
        img_path=self.img_list[index]
        image=Image.open(img_path)
        temp=img_path.split('/')
        if temp[7]=='benign':
            ylabel=torch.tensor(0,dtype=torch.float32,requires_grad=False)
        elif temp[7]=='malignant':
            ylabel=torch.tensor(1,dtype=torch.float32,requires_grad=False)
        else:
            print(temp,temp[7])
        if self.transform:
            image=self.transform(image)
        return [image,ylabel]

In [6]:
# Check accuracy function
def check_accuracy(output,labels):
    _,predpos=output.max(1)
    num_samples=len(labels)
    num_correct=(predpos==labels).sum()
    return (num_correct/num_samples)*100

In [7]:
# Function to calc mean and std across dataset
def mean_std(loader,device):
    # V(X) = E(X**2)-E(X)**2
    channels_sum,channels_squared_sum,num_batches = 0, 0, 0
    for data,_ in loader:
        data.to(device)
        channels_sum+=torch.mean(data,[0,2,3])
        channels_squared_sum+=torch.mean(data**2,[0,2,3])
        num_batches+=1
    mean=channels_sum/num_batches
    std=(channels_squared_sum/num_batches-mean**2)**0.5
    return mean,std

In [8]:
def save_checkpoint(state,filename='clahe.pth.tar'):
    print('Saving weights-->')
    torch.save(state,filename)

In [9]:
def load_checkpoint(filename):
    print('Loading weights-->')
    model.load_state_dict(checkpoint['state_dict'])
    optim.load_state_dict(checkpoint['optimizer'])

In [10]:
# Parameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32
shuffle_dataset = True
random_seed= 42
num_workers=2
learning_rate=0.001
print(device)
num_epochs=25
load_model=False

cuda


In [11]:
# Create resnet model, with respecitve transform
model = models.resnet50(pretrained=False)
model.fc=nn.Sequential(nn.Linear(2048,1024),
                      nn.LeakyReLU(),
                      nn.Linear(1024,512),
                      nn.LeakyReLU(),
                      nn.Linear(512,2))
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [12]:
# Loss and Optimizer
criterion=nn.CrossEntropyLoss()
optim=torch.optim.Adam(model.parameters(),lr=learning_rate)

In [14]:
if load_model:
    load_checkpoint(torch.load('weights.pth.tar'))

In [80]:
# Preprocessing functions

# Contrast Limited Adaptive Histogram Equilization
def clahe(img): 
    # Takes a PIL instance input, and returns nd.array
    # Here we recieve PIL instance as input shape [H,W,C]
    img=np.array(img,dtype=np.uint8)

    lab = cv.cvtColor(img, cv.COLOR_RGB2LAB)

    lab_planes = cv.split(lab)

    clahe = cv.createCLAHE(clipLimit=2.0,tileGridSize=(8,8))

    lab_planes[0] = clahe.apply(lab_planes[0])

    lab = cv.merge(lab_planes)

    rgb = cv.cvtColor(lab, cv.COLOR_LAB2RGB)

    return rgb

# Medium filter to remove the noisiness of an image 
def median_filter(img):
    img=np.pad(img,[(1,1),(1,1),(0,0)],'constant',constant_values=(0))
    print(img.shape)
    median = cv.medianBlur(np.float32(img),5)
    return median

In [84]:
# Create custom transform using preprocessing function
class custom_transform():
    def __call__(self,img):
        return clahe(img)
    
    def __repr__(self):
        print('CLAHE()')

In [85]:
# Create dummy custom transform without normalization to check if CLAHE and median filter is working is working
dummy_transform =transforms.Compose([
    transforms.Resize((512,512)),
    custom_transform()
])

In [86]:
# Test clahe 
# Refresh ./kaggle/working dir after execution
img=Image.open('../input/breakhis/BreaKHis_v1/BreaKHis_v1/histology_slides/breast/malignant/SOB/ductal_carcinoma/SOB_M_DC_14-13993/400X/SOB_M_DC-14-13993-400-001.png')
img.save('before_transform.png')
result=dummy_transform(img)
plt.imsave('after_transform.png',result)

In [20]:
# Mean and STD for normalize over ALL images(test and train)
dummyset= Breakhis(imgs,transforms.Compose([
    transforms.Resize((521,521)),
    transforms.ToTensor()
]))
dummy_loader = DataLoader(dummyset, batch_size=batch_size,num_workers=num_workers, shuffle=True)
mean,std=mean_std(dummy_loader)

In [21]:
mean.tolist(), std.tolist()

([0.7869861721992493, 0.6267361044883728, 0.7644734382629395],
 [0.12267550081014633, 0.17426468431949615, 0.10706450045108795])

In [23]:
# Include normalization and clahe in final transform
transform =transforms.Compose([
    transforms.Resize((512,512)),
    custom_transform(),
    transforms.ToTensor(),
    transforms.Normalize(mean.tolist(),std.tolist())
])

In [24]:
dataset_normalized= Breakhis(imgs,transform)

In [25]:
# Random split into train test and validation
dataset_size=len(dataset_normalized)
print('Total images : ',dataset_size)
train_set,valid_set,test_set=random_split(dataset_normalized,[5539,1580,790])

Total images :  7909


In [26]:
print('Train, Validation, Test : ',len(train_set),len(valid_set),len(test_set))

Train, Validation, Test :  5539 1580 790


In [27]:
# Create train and validation loader
train_loader = DataLoader(train_set, batch_size=batch_size,num_workers=num_workers, shuffle=True)
validation_loader = DataLoader(valid_set, batch_size=batch_size,num_workers=num_workers,shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size,num_workers=num_workers,shuffle=True)

In [28]:
# Put model on cuda 
model.to(device)
# Put the model on train mode
model.train()
print()




In [29]:
i,y=next(iter(train_loader))
i=i.to(device)
y=y.to(device)
y_pred=model(i)
print(y_pred.shape)

torch.Size([32, 2])


In [30]:
# Training loop for the model
min_loss=None
for epoch in range(num_epochs):
    losses=[]
    accuracies=[]
    loop= tqdm(enumerate(train_loader),total=len(train_loader),leave=False)
    for batch_idx, (data,labels) in loop:
        # Put data on cuda
        data=data.to(device)
        labels=labels.to(device).long()
        
        # Forward pass
        output=model(data)
        
        # Find out loss
        loss=criterion(output,labels)
        accuracy=check_accuracy(output,labels)
        losses.append(loss.detach().item())
        accuracies.append(accuracy.detach().item())
        
        optim.zero_grad()
        
        # Back prop
        loss.backward()
        
        # Step
        optim.step()
        
        # Update TQDM progress bar
        loop.set_description(f"Epoch [{epoch}/{num_epochs}] ")
        loop.set_postfix(loss=loss.detach().item(),accuracy=accuracy.detach().item())
        
    moving_loss=sum(losses)/len(losses)
    moving_accuracy=sum(accuracies)/len(accuracies)
    checkpoint={'state_dict': model.state_dict(),'optimizer': optim.state_dict()}
    # Save check point
    if min_loss==None:
        min_loss=moving_loss
        save_checkpoint(checkpoint)
    elif moving_loss<min_loss:
        min_loss=moving_loss
        save_checkpoint(checkpoint)
    print('Epoch {0} : Loss = {1} , Accuracy={2}'.format(epoch,moving_loss,moving_accuracy))

                                                                                           

Saving weights-->
Epoch 0 : Loss = 0.4553904429763213 , Accuracy=80.98659006754558


                                                                                           

Saving weights-->
Epoch 1 : Loss = 0.36224665330059225 , Accuracy=85.35081420547661


                                                                                           

Epoch 2 : Loss = 0.37326245455221196 , Accuracy=85.52442528735632


                                                                                            

Saving weights-->
Epoch 3 : Loss = 0.3320723494150858 , Accuracy=86.9492337457065


                                                                                           

Saving weights-->
Epoch 4 : Loss = 0.3131623829490152 , Accuracy=87.55986592961453


                                                                                           

Saving weights-->
Epoch 5 : Loss = 0.30722431380344534 , Accuracy=87.80531609195403


                                                                                           

Saving weights-->
Epoch 6 : Loss = 0.29459237024701873 , Accuracy=88.17049811352258


                                                                                           

Saving weights-->
Epoch 7 : Loss = 0.2803262002423577 , Accuracy=88.26628353951992


                                                                                           

Epoch 8 : Loss = 0.2961102917833232 , Accuracy=88.30818965517241


                                                                                            

Saving weights-->
Epoch 9 : Loss = 0.2573140935822465 , Accuracy=89.8647030797498


                                                                                             

Epoch 10 : Loss = 0.2727682651502305 , Accuracy=89.01460730892488


                                                                                            

Epoch 11 : Loss = 0.26448611596106797 , Accuracy=89.79885057471265


                                                                                            

Saving weights-->
Epoch 12 : Loss = 0.22919685541298881 , Accuracy=91.05603448275862


                                                                                            

Saving weights-->
Epoch 13 : Loss = 0.22663577900792675 , Accuracy=90.81058432041914


                                                                                             

Saving weights-->
Epoch 14 : Loss = 0.21964550688434606 , Accuracy=91.45114942528735


                                                                                             

Saving weights-->
Epoch 15 : Loss = 0.2042216917087763 , Accuracy=91.70258620689656


                                                                                             

Epoch 16 : Loss = 0.21161032688600578 , Accuracy=91.60081420547661


                                                                                             

Saving weights-->
Epoch 17 : Loss = 0.2018018094929813 , Accuracy=92.1934866192697


                                                                                             

Saving weights-->
Epoch 18 : Loss = 0.18494745506637397 , Accuracy=92.99568965517241


                                                                                             

Saving weights-->
Epoch 19 : Loss = 0.1847092045107107 , Accuracy=92.83405172413794


                                                                                             

Saving weights-->
Epoch 20 : Loss = 0.17446161265988117 , Accuracy=92.8220785732927


                                                                                             

Epoch 21 : Loss = 0.21377393673976947 , Accuracy=91.52298850574712


                                                                                             

Epoch 22 : Loss = 0.1847167047725498 , Accuracy=93.05555558478696


                                                                                             

Epoch 23 : Loss = 0.2103046165115532 , Accuracy=92.30723181538198


                                                                                             

Epoch 24 : Loss = 0.23813875747480612 , Accuracy=91.83429121697085




In [31]:
# Validation accuracy
correct=0
samples=0
for data,labels in validation_loader:
    data=data.to(device)
    labels=labels.to(device)
    # Forward pass
    y_pred=model(data)
    # Accuracy over entire dataset
    _,predpos=y_pred.max(1)
    samples+=len(labels)
    correct+=(predpos==labels).sum().detach().item()
print('Validation accuracy : ',(correct/samples)*100)

Validation accuracy :  91.01265822784809


In [32]:
# Test accuracy
correct=0
samples=0
for data,labels in test_loader:
    data=data.to(device)
    labels=labels.to(device)
    # Forward pass
    y_pred=model(data)
    # Accuracy over entire dataset
    _,predpos=y_pred.max(1)
    samples+=len(labels)
    correct+=(predpos==labels).sum().detach().item()
print('Test accuracy : ',(correct/samples)*100) 

Test accuracy :  91.89873417721519
