In [1]:
#Pytorch
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

#Torchvision
import torchvision
from torchvision import datasets, models, transforms, utils
  
#Pytorch
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

#Torchvision
import torchvision
from torchvision import datasets, models, transforms, utils
from torch.utils.data import Dataset, DataLoader

#Image Processing
import matplotlib.pyplot as plt
from skimage import io, transform, color
import PIL
from PIL import Image

#Others
import sklearn.metrics
from sklearn.metrics import *
import numpy as np
import pandas as pd
import cv2
import time
import os
import copy
from model_summary import *
import pretrainedmodels
import tqdm
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

In [2]:
class dataset(Dataset):

    def __init__(self, csv_file, root_dir, transform=None):

        self.data_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        self.mask_dir = self.root_dir.replace('CBIS-DDSM_classification','masks')

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,self.data_frame.iloc[idx]['name'])
        image = Image.open(img_name)

        label = self.data_frame.iloc[idx]['category']

        mask_name = os.path.join(self.mask_dir,self.data_frame.iloc[idx]['name'].replace('.j','_mask.j'))
        mask = io.imread(mask_name)
        mask = np.array([mask,mask,mask]).transpose((1,2,0))
        mask = Image.fromarray(mask)

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask) 
      
        return {'image':image,'category':label,'mask':mask, 'name':img_name}
    

def get_dataloader(data_dir, train_csv_path, image_size, img_mean, img_std, batch_size=1):

    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize(image_size),#row to column ratio should be 1.69
            #transforms.RandomHorizontalFlip(0.5),
            #transforms.CenterCrop((image_size[1],image_size[1])),
            transforms.RandomVerticalFlip(0.5),
            transforms.RandomRotation(15),
            transforms.RandomAffine(translate=(0,0.2),degrees=15,shear=15),
            transforms.ToTensor(),
            #transforms.Normalize([0.223, 0.231, 0.243], [0.266, 0.270, 0.274])
            transforms.Normalize(img_mean,img_std)
        ]),
        'valid': transforms.Compose([
            transforms.Resize(image_size),
            #transforms.CenterCrop((image_size[1],image_size[1])),
            transforms.ToTensor(),
            #transforms.Normalize([0.223, 0.231, 0.243], [0.266, 0.270, 0.274])
            transforms.Normalize(img_mean,img_std)
        ]),
        'test': transforms.Compose([
            transforms.Resize(image_size),
            #transforms.CenterCrop((image_size[1],image_size[1])),
            transforms.ToTensor(),
            #transforms.Normalize([0.223, 0.231, 0.243], [0.266, 0.270, 0.274])
            transforms.Normalize(img_mean,img_std)
        ])
    }

    image_datasets = {}
    dataloaders = {}
    dataset_sizes = {}

    for x in ['train', 'valid', 'test']:
        image_datasets[x] = dataset(train_csv_path.replace('train',x),root_dir=data_dir,transform=data_transforms[x])
        dataloaders[x] = torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,shuffle=True, num_workers=4)    
        dataset_sizes[x] = len(image_datasets[x])

    device = torch.device("cuda:0")

    return dataloaders,dataset_sizes,image_datasets,device

In [None]:
def denorm_img(img_ten,img_mean,img_std):

    bz,nc,h,w = img_ten.shape
    output = []
    img_num = img_ten.numpy()
    
    for i in range(bz):
        
        img = img_ten[i].numpy().squeeze()
        
        img[0,:,:] = img[0,:,:]*img_std[0]
        img[1,:,:] = img[1,:,:]*img_std[1]
        img[2,:,:] = img[2,:,:]*img_std[2]

        img[0,:,:] = img[0,:,:] + img_mean[0]
        img[1,:,:] = img[1,:,:] + img_mean[1]
        img[2,:,:] = img[2,:,:] + img_mean[2]
        
        img = img.mean(axis=0)
        img[img>=0.5*img.max()] = 1
        img[img<0.5*img.max()] = 0
        
        output.append(img)
    
    output = np.array(output)
    return output

def get_IoU(pred, targs, device):

    pred[pred>pred.mean()+2*pred.std()] = 1
    pred[pred<pred.mean()+2*pred.std()] = 0

    targs = torch.Tensor(targs).to(device)
    
    #targs = torch.Tensor((targs>0)).to(device)#.float()
    #pred = (pred>0)#.float()
    #return (pred*targs).sum() / ((pred+targs).sum() - (pred*targs).sum())
    return (pred*targs).sum()/targs.sum(),pred.sum()/(pred.shape[-1]*pred.shape[2])

#ir2 = pretrainedmodels.__dict__['inceptionresnetv2'](num_classes=1000, pretrained='imagenet')
#ir1 = nn.Sequential(*list(ir2.children())[:-1])
#summary(ir1.cuda(),(3,540,320))

#vggnet = models.vgg11_bn(pretrained=True)
#vgg_conv = nn.Sequential(*list(vggnet.children())[0][:-1])

def build_model():
    class vgg_gain(nn.Module):
        def __init__(self,vgg_base):
            super().__init__()
            self.vgg_base = vgg_base
            self.gap = nn.AdaptiveAvgPool2d((1,1))#nn.AvgPool2d((14,14),stride=1)
            self.fc = nn.Linear(512,2)


        def forward(self, x):
            x_base = self.vgg_base(x)
            x = self.gap(x_base)
            x = x.view(x.size(0), -1)
            x = self.fc(x)
            return x,x_base
        
        vggnet = models.vgg11_bn(pretrained=True)
        vgg_conv = nn.Sequential(*list(vggnet.children())[0][:-1])
        v1 = vgg_gain(vgg_conv)

    return v1


def returnCAM(feature_conv, weight_softmax, class_idx, output_shape):
    # generate the class activation maps upsample to 256x256
    size_upsample = output_shape
    bz, nc, h, w = feature_conv.shape
    output_cam = []
    for i in range(bz):
        #import pdb;pdb.set_trace()
        idx = class_idx[0][i]
        cam = weight_softmax[idx].dot(feature_conv[i].reshape((nc, h*w)))
        cam = cam.reshape(h, w)
        #cam = cam - np.min(cam)
        #cam_img = cam / np.max(cam)
        #print('cam img shape',cam_img.shape)
        cam_img = cv2.resize(cam,(size_upsample[0],size_upsample[1]))
        cam_img[cam_img<0] = 0
        output_cam.append(cam_img)
    output_cam = np.array(output_cam)
    
    final_output_cam = np.zeros((bz,3,size_upsample[1],size_upsample[0]))
    final_output_cam[:,0,:,:] = output_cam
    final_output_cam[:,1,:,:] = output_cam
    final_output_cam[:,2,:,:] = output_cam
    
    return final_output_cam


class gain():
    
    def __init__(self):
        
        self.data_dir = '../Data/CBIS-DDSM_classification_1/'
        self.train_csv = '../CSV/gain_train.csv'
        #image_size = (640,384)
        self.image_size = (320,192)
        self.num_classes = 2
        self.num_epochs = 50
        self.batch_size = 1
        self.sigma = 0
        self.w = 1
        self.alpha = 1
        self.img_mean = [0.223, 0.231, 0.243]
        self.img_std = [0.266, 0.270, 0.274]

        self.dataloaders,self.dataset_sizes,self.dataset,device = get_dataloader(self.data_dir,self.train_csv,\
                                                                    self.image_size,self.img_mean,self.img_std,self.batch_size)

        self.model = build_model()
        self.ce = nn.CrossEntropyLoss()
        
        
        self.optimizer_ft = optim.Adam(self.model.parameters(), lr=0.01, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
    
    def train():

        best_model_wts = copy.deepcopy(model.state_dict())
        best_acc = 0.0

        for epoch in range(num_epochs):
            print('Epoch {}/{}'.format(epoch, num_epochs - 1),flush=True)
            print('-' * 10,flush=True)

            #import pdb;pdb.set_trace()

            # Each epoch has a training and validation phase
            for phase in ['train', 'valid']:
                if phase == 'train':
                    scheduler.step()
                    model.train()  # Set model to training mode
                else:
                    model.eval()   # Set model to evaluate mode

                running_loss = 0.0
                running_corrects = 0
                running_dice = 0

                #tqdm bar
                pbar = tqdm(total=dataset_sizes[phase])            
                # Iterate over data.
                for sampled_batch in dataloaders[phase]:

                    inputs = sampled_batch['image']
                    labels = sampled_batch['category']
                    mask = sampled_batch['mask']

                    inputs = inputs.float().to(device)
                    labels = labels.long().to(device)
                    #print('labels shape',labels.shape)
                    #print(mask.shape)
                    mask = denorm_img(mask,img_mean,img_std).squeeze()
                    #print('mask shape',mask.shape)
                    mask[mask>0.1] = 1
                    mask[mask<0.1] = 0

                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):

                        #CAM computation need to take place in eval mode
    #                     if phase == 'train':
    #                         print('yaay')
    #                         model.eval()

                        #Save features for the forward pass
                        sfs = SaveFeatures(model.vgg_base[27])
                        outputs = torch.exp(model(inputs))
                        #print('outputs shape',outputs.shape)
                        sfs.remove()

                        #print(sfs.features.requires_grad)
                        #Get the features obtained after forward pass
                        features = sfs.features.detach().cpu().numpy()
                        #print('Features shape',features.shape)

                        #This will get the prediction for the sample
                        _, preds = torch.max(outputs, 1)

                        #Get the weights of the model
                        params = list(model.parameters())
                        weight_softmax = np.squeeze(params[-2].data.cpu().numpy())

                        #Get the CAM
                        cam_orig = np.array(returnCAM(features,weight_softmax,[preds],(inputs.size(-1),inputs.size(-2))))
                        #print('cam orig shape',cam_orig.shape)

                        #Convert cam to tensor
                        cam = torch.from_numpy(cam_orig).float().to(device)
                        #print('cam shape',cam.shape)
                        #import pdb
                        #pdb.set_trace()

                        #T(A) as defined in the paper
                        t_cam = F.sigmoid(w*(cam - sigma))
                        #print('t cam shape',t_cam.shape)
                        #print('inputs shape',inputs.shape)

                        #Mining input
                        mining_input = inputs - t_cam*inputs
                        #print('mining_input shape',mining_input.shape)

                        #Compute the mining output
                        mining_output = torch.exp(model(mining_input))

    #                     #Convert to training mode
    #                     if phase == 'train':
    #                         model.train()

                        #Compute the stream loss
                        loss_stream = stream_loss(outputs, labels)

                        #print('labels shape',labels.shape)
                        #Compute the mining loss
                        loss_mining = mining_loss(mining_output,labels)

                        #import pdb;pdb.set_trace()

                        #Total loss is the sum of the two loss
                        loss = loss_stream + alpha*loss_mining

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    # statistics
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                    running_dice += dice(cam_orig,mask)

                    pbar.update(inputs.shape[0])
                pbar.close()


                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]
                epoch_dice = running_dice / dataset_sizes[phase]

                print('{} Loss: {:.4f} Acc: {:.4f} Dice: {:.4f}'.format(
                    phase, epoch_loss, epoch_acc, epoch_dice))

                # deep copy the model
                if phase == 'valid' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    torch.save(model.state_dict(),'gain_vgg_'+str(epoch_acc)+'_acc.pt')

            print()

        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print('Best val Acc: {:4f}'.format(best_acc))

        # load best model weights
        model.load_state_dict(best_model_wts)
        #model.save_state_dict('vgg_gain.pt')
        return model
