In [None]:
#Pytorch
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

#Torchvision
import torchvision
from torchvision import datasets, models, transforms, utils
from torch.utils.data import Dataset, DataLoader

#Image Processing
import matplotlib.pyplot as plt
from skimage import io, transform, color
import PIL
from PIL import Image
import augmentations
from augmentations import *

#Others
import sklearn.metrics
from sklearn.metrics import *
import numpy as np
import pandas as pd
import cv2
import time
import os
import copy
from model_summary import *
import pretrainedmodels
import tqdm
from tqdm import tqdm_notebook as tqdm
import warnings
warnings.filterwarnings("ignore")

import dataloaders
from dataloaders import *

'''Dataloader'''
class dataset(Dataset):

    def __init__(self, csv_file, root_dir, transform=None):

        self.data_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        self.mask_dir = '../Data/mask_orient_cropped/'
        
    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,self.data_frame.iloc[idx]['name'])
        image = Image.open(img_name)
        
        mask_name = os.path.join(self.mask_dir,self.data_frame.iloc[idx]['name'].replace('.j','_mask.j'))
        mask = io.imread(mask_name)
        mask = np.array([mask,mask,mask]).transpose((1,2,0))
        mask = Image.fromarray(mask)

        label = self.data_frame.iloc[idx]['category']       

        if self.transform:
            image,mask = self.transform(image,mask)
        
        mask_final = mask[0,:,:]
        mask_final[mask_final<0.5] = 0
        mask_final[mask_final>0.5] = 1
        
        return {'image':image, 'category':label, 'mask':mask_final, 'name':self.data_frame.iloc[idx]['name']}
    
def get_dataloader(data_dir, train_csv_path, image_size, img_mean, img_std, batch_size=1):

    data_transforms = {
        'train': Compose([
            RandomHorizontallyFlip(0.5),
            #RandomVerticallyFlip(0.5),
            RandomTranslate((0.2,0.2)),
            RandomRotate(15),
            ToTensor(),
            Normalize(img_mean,img_std)
        ]),
        'valid': Compose([
            ToTensor(),
            Normalize(img_mean,img_std)
        ]),
        'test': Compose([
            ToTensor(),
            Normalize(img_mean,img_std)        
        ])
    }

    image_datasets = {}
    dataloaders = {}
    dataset_sizes = {}

    for x in ['train', 'valid', 'test']:
        if x == 'train':
            bs = batch_size
            sh = True
        elif x == 'valid':
            bs = batch_size
            sh = False
        else:
            bs = 1
            sh = False
        image_datasets[x] = dataset(train_csv_path.replace('train',x),root_dir=data_dir,transform=data_transforms[x])
        dataloaders[x] = torch.utils.data.DataLoader(image_datasets[x], batch_size=bs,shuffle=sh, num_workers=8)    
        dataset_sizes[x] = len(image_datasets[x])

    device = torch.device("cuda:0")

    return dataloaders,dataset_sizes,image_datasets,device

#Selector network

#The selector output is 5x5x5. Each patch will have 5 anchor boxes

def build_selector():

    class mdl(nn.Module):
        def __init__(self,base_model):
            super().__init__()
            self.base = base_model 
            self.l1 = nn.Conv2d(512,5,3,stride=2)
        def forward(self, x):
            x = self.base(x)
            x = self.l1(x)   
            return x

    v = models.vgg16_bn(pretrained=True)
    v1 = nn.Sequential(*list(v.children())[:-1])

    #r = models.resnet101(pretrained=True)
    #r1 = nn.Sequential(*list(r.children())[:-2])
    
    model = mdl(v1[-1])
        
    return model

## Predictor-Discriminator-Baseline
def build_baseline_predictor():

    class mdl(nn.Module):
        def __init__(self,base_model):
            super().__init__()
            self.base = base_model 
            self.gap = nn.AdaptiveAvgPool2d((1,1))
            self.fc1 = nn.Linear(512,2)

        def forward(self, x):
            x_base = self.base(x)
            x = self.gap(x_base)
            x = x.view(x.size(0), -1)
            x = self.fc1(x)
            return x,x_base 

    v = models.vgg16_bn(pretrained=True)
    v1 = nn.Sequential(*list(v.children())[:-1])

    #r = models.resnet101(pretrained=True)
    #r1 = nn.Sequential(*list(r.children())[:-2])
    
    model = mdl(v1[-1])
    model.load_state_dict(torch.load('classification_bc_vgg_16_balanced_mass_sel.pt'))
        
    return model

def get_sample(target):

    prob_vector = F.softmax(target.view(-1))
    probs = prob_vector.data.cpu().numpy()
    
    probs = probs.astype('float64')
    probs = probs/probs.sum()
    
    try:
        prob_sample = np.random.multinomial(1,probs,1)
    except:import pdb;pdb.set_trace()
    return prob_sample.reshape(target.shape)

def get_anchor_box(m_r,m_c,img_shape,patch_shape):
    
    r1 = max(0,m_r - patch_shape[0]//2)
    r2 = min(img_shape[0],m_r + patch_shape[0]//2)
    
    c1 = max(0,m_c - patch_shape[1]//2)
    c2 = min(img_shape[1],m_c + patch_shape[1]//2)
    
    if r1 == 0:
        r2 = patch_shape[0]
        
    if r2 == img_shape[0]:
        r1 = img_shape[0]-patch_shape[0]
        
    if c1 == 0:
        c2 = patch_shape[1]
        
    if c2 == img_shape[1]:
        c1 = img_shape[1]-patch_shape[1]
        
    if r2-r1 != patch_shape[0] and c2-c1 != patch_shape[1]:
        import pdb;pdb.set_trace()
    
    return r1,r2,c1,c2

def get_patch_center(inp,img_shape):
    
    grid = inp
    _,c,h,w = inp.shape
    _,c_l,h_l,w_l = np.where(grid==1)
    
    c_l = c_l[0]
    h_l = h_l[0]
    w_l = w_l[0]
    
    patch_h = img_shape[0]//h
    patch_w = img_shape[1]//w
    
    patch_h_4 = patch_h//4
    patch_w_4 = patch_w//4
    
    r1 = h_l*patch_h
    r2 = (h_l+1)*patch_h
    
    c1 = c_l*patch_w
    c2 = (c_l+1)*patch_w
    
    m_r = (r1+r2)//2
    m_c = (c1+c2)//2
    
    if c_l == 1:
        m_r = m_r - patch_h_4
        m_c = m_c - patch_w_4
        
    elif c_l == 2:
        m_r = m_r - patch_h_4
        m_c = m_c + patch_w_4

    elif c_l == 3:
        m_r = m_r + patch_h_4
        m_c = m_c + patch_w_4

    elif c_l == 4:
        m_r = m_r + patch_h_4
        m_c = m_c - patch_w_4

    m_r = int(m_r)
    m_c = int(m_c)
    
    return m_r,m_c

def intersection_metric(anchor,mask):
    r1,r2,c1,c2 = anchor
    return mask[0,r1:r2,c1:c2].sum()/(mask.sum()+1)

## DC-INVASE class
class dc_invase():
    def __init__(self):
        
        #Initialization
        self.data_dir =  '../Data/CBIS-DDSM_classification_orient_cropped/'
        self.train_csv = '../CSV/mass_weak_train.csv'
        self.num_epochs = 100
        self.input_shape = (288,256)
        self.patch_shape = (256,256)
        self.batch_size = 1
        self.img_mean = [0.253, 0.238, 0.234]
        self.img_std = [0.272, 0.268, 0.262]
        
        self.exp_name = './Weights/weak_sup_bc'
        
        #Define the four models
        self.selector = build_selector()
        self.baseline = build_baseline_predictor()
        self.predictor = build_baseline_predictor()
        
        #Put them on the GPU
        self.selector = self.selector.cuda()
        self.baseline = self.baseline.cuda()
        self.predictor = self.predictor.cuda()
        
        #Get the dataloaders
        self.dataloaders,self.dataset_sizes,self.dataset,self.device = get_dataloader(self.data_dir,self.train_csv,\
                                                        self.input_shape,self.img_mean,self.img_std,self.batch_size)
        
        #Define optimizers one for each model
        self.optimizer_sel = optim.Adam(self.selector.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-6, amsgrad=False)
        self.optimizer_pred = optim.Adam(self.predictor.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-6, amsgrad=False)

        
    def train(self):
        
        since = time.time()
        best_iou = 0
        best_acc = 0

        for epoch in range(self.num_epochs):
            print('Epoch {}/{}'.format(epoch, self.num_epochs - 1),flush=True)
            print('-' * 10,flush=True)

            # Each epoch has a training and validation phase
            for phase in ['train', 'valid']:
                if phase == 'train':                
                    #Set the models to training mode
                    self.selector.train()
                    self.predictor.train()
                    self.baseline.eval()
        
                else:
                    #Set the models to evaluation mode
                    self.selector.eval()
                    self.baseline.eval()
                    self.predictor.eval()
                    
                #Keep a track of all the three loss
                running_sel_loss = 0.0
                running_pred_loss = 0.0
              
                #Metrics : accuracy
                running_pred_acc = 0
                running_base_acc = 0
                running_int = 0
                
                '''aucroc'''
                y_true = []
                y_pred = []
                
                #tqdm bar
                pbar = tqdm(total=self.dataset_sizes[phase])

                # Iterate over data.
                for sampled_batch in self.dataloaders[phase]:

                    inputs = sampled_batch['image']
                    labels = sampled_batch['category']
                    mask = sampled_batch['mask']
                    
                    #Input needs to be float and labels long
                    inputs = inputs.float().to(self.device)
                    labels = labels.long().to(self.device)
                    mask = mask.to(self.device)
                    
                    # zero the parameter gradients
                    self.optimizer_sel.zero_grad()
                    self.optimizer_pred.zero_grad()
                    
                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        
                        #import pdb;pdb.set_trace()
                        
                        #Generate predictor output probabilities
                        base_out,_ = self.baseline(inputs)
                        base_prob = F.softmax(base_out)
                        _, base_preds = torch.max(base_out, 1)
                        
                        #=>Baseline Cross entropy
                        base_ce_loss = F.cross_entropy(base_out,labels)
                                          
                        #Generate selection probabilites using selector function.
                        sel_prob = self.selector(inputs)
                                                
                        probs_sample = get_sample(sel_prob)

                        m_r,m_c = get_patch_center(probs_sample,self.input_shape)

                        r1,r2,c1,c2 = get_anchor_box(m_r,m_c,self.input_shape,self.patch_shape)
                        
                        #print(r1,r2,c1,c2)
                        
                        int_met = intersection_metric([r1,r2,c1,c2],mask)
                        
                        patch = inputs[0,:,r1:r2,c1:c2].unsqueeze(dim=0)

                        #Generate predictor output probabilities using the baseline cnn
                        pred_out,_ = self.predictor(patch)
                        pred_prob = F.softmax(pred_out)
                        _, pred_preds = torch.max(pred_out, 1)
                                                
                        '''aucroc'''
                        y_true.append(labels.data)
                        y_pred.append(pred_prob.data[0][1])

                        #Predictor Cross entropy
                        pred_ce_loss = F.cross_entropy(pred_out,labels)

                        with torch.no_grad():

                            k_l = pred_ce_loss - base_ce_loss                  

                        probs_sample = torch.Tensor(probs_sample).to(self.device)
                        
                        probs_sample = probs_sample.view(-1)
                        sel_prob = F.softmax(sel_prob.view(-1))
                        
                        distribution_loss = torch.mean(probs_sample*torch.log(sel_prob + 1e-8) + (1-probs_sample)*torch.log(1 - sel_prob + 1e-8))
                                                
                        sel_loss = distribution_loss*(k_l) + c2/self.input_shape[1] + abs((r1+r2)/2-self.input_shape[0]//2)/self.input_shape[0]
            
                        #print(distribution_loss*kl_diff,self.beta*l1_loss)
                                                    
                        # backward + optimize only if in training phase
                        if phase == 'train':
                            
                            pred_ce_loss.backward(retain_graph=True)
                            self.optimizer_sel.zero_grad()
                            self.optimizer_pred.step()
                            
                            #Update sel
                            sel_loss.backward()
                            self.optimizer_pred.zero_grad()
                            self.optimizer_sel.step()
                                    
                    # statistics
                    running_sel_loss += sel_loss.item() * inputs.size(0)
                    running_pred_loss += pred_ce_loss.item() * inputs.size(0)
                
                    running_pred_acc += torch.sum(pred_preds == labels.data)
                    running_base_acc += torch.sum(base_preds == labels.data)
                    running_int += int_met * inputs.size(0)
                    
                    pbar.update(inputs.shape[0])
                pbar.close()

                epoch_sel_loss = running_sel_loss / self.dataset_sizes[phase]
                epoch_pred_loss = running_pred_loss / self.dataset_sizes[phase]
                
                epoch_base_acc = running_base_acc.double()/ self.dataset_sizes[phase]
                epoch_pred_acc = running_pred_acc.double() / self.dataset_sizes[phase]
                epoch_int = running_int / self.dataset_sizes[phase]
                
                epoch_auc_roc = sklearn.metrics.roc_auc_score(y_true,y_pred,average='weighted')
                
                print('{} Sel_Loss: {:.4f} Pred_Loss: {:.4f} BAC: {:.4f} PAC: {:.4f} Int: {:.4f} Auc: {:4f}'.format(
                    phase, epoch_sel_loss, epoch_pred_loss, epoch_base_acc, epoch_pred_acc, epoch_int, epoch_auc_roc))

                # deep copy the model
                if phase == 'valid' and epoch_pred_acc > best_acc:
                    
                    best_acc = epoch_pred_acc
                    torch.save(self.selector.state_dict(),self.exp_name+'_sel.pt')
                    torch.save(self.baseline.state_dict(),self.exp_name+'_base.pt')
                    #import pdb;pdb.set_trace()


        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print('Best iou: {:4f}'.format(best_iou))

        torch.save(self.baseline.state_dict(),self.exp_name+'_base_final.pt')
        torch.save(self.selector.state_dict(),self.exp_name+'_sel_final.pt')

        print('Training completed finally !!!!!')
        
    def get_cam(self):
                
        self.selector.load_state_dict(torch.load(self.exp_name+'_sel_final.pt'))
        self.selector.eval()
        
        acc = 0
        total = 0
        mode = 'test'

        cm = []
        m = []
        bm = []
        
        params = list(self.selector.parameters())                        
        weight_softmax = torch.squeeze(params[-2].data)
        
        iou = 0
        
        with torch.no_grad():
            
            pbar = tqdm(total=self.dataset_sizes[mode])
            for data in self.dataloaders[mode]:

                inputs = data['image']
                labels = data['category']

                inputs = inputs.to(self.device)
                labels = labels.to(self.device) 
                
                sel_prob = self.selector(inputs)
                sel_prob = sel_prob - sel_prob.min()
                sel_prob = sel_prob/sel_prob.max()

                #Threshold using 0.5
                #bin_samples = test_samples(sel_prob.data)
                
                #Sample using the distribution induced
                bin_samples = sampler(sel_prob.data.cpu().numpy())
                bin_samples = torch.Tensor(bin_samples).to(self.device)
                bin_mask = self.prob_mask(bin_samples).to(self.device) 

                base_path = '../Experiments/Oxford_pets/'
                name = data['name'][0]

                #heatmap = cv2.applyColorMap(np.uint8(255*bin_mask.cpu().numpy().squeeze()), cv2.COLORMAP_JET)
                heatmap = bin_mask.cpu().numpy().squeeze()
                heatmap = np.expand_dims(heatmap,axis=2)
                #heatmap = np.float32(heatmap) / 255
                cam_f = heatmap*np.float32(inputs.cpu().numpy().squeeze().transpose((1,2,0)))
                cam_f = cam_f / np.max(cam_f)
                #cam_f = heatmap
                pr = name.replace('.j','_bin_8x8_samp_share_1_final.j')
                cv2.imwrite(base_path+pr,cam_f*255)

                
                pbar.update(inputs.shape[0])
                
            pbar.close()
        

    def return_model(self):
        self.selector.load_state_dict(torch.load(self.exp_name+'_sel.pt'))
        self.selector.eval()
        return self.selector,self.dataloaders['valid']

In [None]:
dc = dc_invase()

In [None]:
dc.train()

In [None]:
%d

In [None]:
%d

In [None]:
plt.imshow(cv2.imread('../Experiments/Oxford_pets/cat_Abyssinian_105_bin_8x8_samp_share_1_final.jpg'))

In [None]:
dc.train()

In [None]:
!nvidia-smi