In [None]:
#Pytorch
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

#Torchvision
import torchvision
from torchvision import datasets, models, transforms, utils
  
#Pytorch
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

#Torchvision
import torchvision
from torchvision import datasets, models, transforms, utils
from torch.utils.data import Dataset, DataLoader

#Image Processing
import matplotlib.pyplot as plt
from skimage import io, transform, color
import PIL
from PIL import Image
import augmentations
from augmentations import *

#Others
import sklearn.metrics
from sklearn.metrics import *
import numpy as np
import pandas as pd
import cv2
import time
import os
import copy
from model_summary import *
import pretrainedmodels
import tqdm
from tqdm import tqdm_notebook as tqdm
import warnings
warnings.filterwarnings("ignore")



## Dataloader

class dataset(Dataset):

    def __init__(self, csv_file, root_dir, transform=None):

        self.data_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        self.mask_dir = self.root_dir.replace('images','masks')
        
    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,self.data_frame.iloc[idx]['name'])
        image = Image.open(img_name)
        
        mask_name = os.path.join(self.mask_dir,self.data_frame.iloc[idx]['name'])
        mask = io.imread(mask_name)
        mask = np.array([mask,mask,mask]).transpose((1,2,0))
        mask = Image.fromarray(mask)

        label = self.data_frame.iloc[idx]['category']       

        if self.transform:
            image,mask = self.transform(image,mask)
        
        mask_final = mask[0,:,:]
        mask_final[mask_final<0.5] = 0
        mask_final[mask_final>0.5] = 1
        
        return {'image':image, 'category':label, 'mask':mask_final, 'name':self.data_frame.iloc[idx]['name']}
    

def get_dataloader(data_dir, train_csv_path, image_size, img_mean, img_std, batch_size=1):

    data_transforms = {
        'train': Compose([
            Resize(image_size),
            RandomHorizontallyFlip(0.5),
            RandomVerticallyFlip(0.5),
            RandomTranslate((0.2,0.2)),
            RandomRotate(15),
            ToTensor(),
            Normalize(img_mean,img_std)
        ]),
        'valid': Compose([
            Resize(image_size),
            ToTensor(),
            Normalize(img_mean,img_std)
        ]),
        'test': Compose([
            Resize(image_size),
            ToTensor(),
            Normalize(img_mean,img_std)        
        ])
    }

    image_datasets = {}
    dataloaders = {}
    dataset_sizes = {}

    for x in ['train', 'valid', 'test']:
        if x == 'test':
            bs = 1
            sh = False
        elif x == 'valid':
            bs = batch_size
            sh = False
        else:
            bs = 1
            sh = False
        image_datasets[x] = dataset(train_csv_path.replace('train',x),root_dir=data_dir,transform=data_transforms[x])
        dataloaders[x] = torch.utils.data.DataLoader(image_datasets[x], batch_size=bs,shuffle=sh, num_workers=8)    
        dataset_sizes[x] = len(image_datasets[x])

    device = torch.device("cuda:0")

    return dataloaders,dataset_sizes,image_datasets,device

## Selector network (VGG-UNet)

class SaveFeatures():
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output
    def remove(self): self.hook.remove()

class UnetBlock(nn.Module):
    def __init__(self, up_in, x_in, n_out):
        super().__init__()
        up_out = x_out = n_out//2
        self.x_conv  = nn.Conv2d(x_in,  x_out,  1)
        self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2)
        self.bn = nn.BatchNorm2d(n_out)
        
    def forward(self, up_p, x_p):
        up_p = self.tr_conv(up_p)
        x_p = self.x_conv(x_p)
        cat_p = torch.cat([up_p,x_p], dim=1)
        return F.dropout2d(self.bn(F.relu(cat_p)),p=0.5) #Using dropout after non-linearity and before the 

class Unet34(nn.Module):
    def __init__(self, rn):
        super().__init__()
        self.rn = rn
        self.sfs = [SaveFeatures(rn[0][i]) for i in [12,22,32,42]]
        self.up1 = UnetBlock(512,512,256)
        self.up2 = UnetBlock(256,512,256)
        self.up3 = UnetBlock(256,256,256)
        self.up4 = UnetBlock(256,128,256)
        self.up5 = nn.ConvTranspose2d(256,16, 2, stride=2)
        
    def forward(self,x):
        x = F.relu(self.rn(x))
        x1 = self.up1(x, self.sfs[3].features)
        x2 = self.up2(x1, self.sfs[2].features)
        x3 = self.up3(x2, self.sfs[1].features)
        x4 = self.up4(x3, self.sfs[0].features)
        x5 = self.up5(x4)
        return x5
    
    def close(self):
        for sf in self.sfs: sf.remove()
            
class mdl(nn.Module):
    def __init__(self,base_model):
        super().__init__()
        self.base = base_model 
        self.gap = nn.AdaptiveAvgPool2d((1,1))
        self.fc1 = nn.Linear(16,2)

    def forward(self, x):
        x = self.base(x)
        x = self.gap(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

# def build_model():
#     v = models.vgg16_bn(pretrained=False)
#     v1 = nn.Sequential(*list(v.children())[:-1])
#     m = Unet34(v1)
#     model_final = mdl(m)
#     return model_final

def build_model():

    class mdl(nn.Module):
        def __init__(self,base_model):
            super().__init__()
            self.base = base_model 
            self.gap = nn.AdaptiveAvgPool2d((1,1))
            self.fc1 = nn.Linear(512,2)

        def forward(self, x):
            x_base = self.base(x)
            x = self.gap(x_base)
            x = x.view(x.size(0), -1)
            x = self.fc1(x)
            return x,x_base 

    v = models.vgg16_bn(pretrained=True)
    v1 = nn.Sequential(*list(v.children())[:-1])
    model = mdl(v1[-1][:-1])
    
    return model

# a = build_selector()
# summary(a.cuda(),(3,224,224))

    
def get_IoU(pred, targs, device):

    targs = torch.Tensor(targs).to(device)
    return (pred*targs).sum() / ((pred+targs).sum() - (pred*targs).sum())


class grad_cam():
    def __init__(self):
        
        #Initialization
        self.data_dir = '../Data/oxford_pets/sparse_images/'
        self.train_csv = '../CSV/oxford_pet_train.csv'
        self.num_epochs = 25
        self.input_shape = (224,224)
        self.batch_size = 1
        self.img_mean = [0.485, 0.456, 0.406]#[0,0,0]
        self.img_std = [0.229, 0.224, 0.225]#[1,1,1]
        
        self.exp_name = 'Weights/high_res_grad_cam_vgg_16_oxford'
        
        #Define the three models
        self.model = build_model()
        
        #Put them on the GPU
        self.model = self.model.cuda()
        
        #Get the dataloaders
        self.dataloaders,self.dataset_sizes,self.dataset,self.device = get_dataloader(self.data_dir,self.train_csv,\
                                                        self.input_shape,self.img_mean,self.img_std,self.batch_size)
        
        #Get the optimizer
        self.optimizer = optim.Adam(self.model.parameters(),lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
        #self.optimizer = optim.SGD(self.model.parameters(), lr = 0.01, momentum=0.9)
        
        
        self.loss_fn = nn.CrossEntropyLoss()
        
    def train(self):
        
        since = time.time()
        best_epoch_acc = 0.0
        best_epoch_f1 = 0.0
        
        
        for epoch in range(self.num_epochs):
            print('Epoch {}/{}'.format(epoch, self.num_epochs - 1),flush=True)
            print('-' * 10,flush=True)
            # Each epoch has a training and validation phase
            for phase in ['train', 'valid']:
                
                labels_list = []
                preds_list = []
                
                if phase == 'train':
                    
                    #Set the models to training mode
                    self.model.train()
                
                else:
                    #Set the models to evaluation mode
                    self.model.eval()
                    
                #Keep a track of all the three loss
                running_loss = 0.0
                
                #Metrics : predictor auc and selector iou
                running_acc = 0
                running_f1 = 0
                
                #tqdm bar
                pbar = tqdm(total=self.dataset_sizes[phase])

                # Iterate over data.
                for sampled_batch in self.dataloaders[phase]:

                    inputs = sampled_batch['image']
                    labels = sampled_batch['category']

                    #Input needs to be float and labels long
                    inputs = inputs.float().to(self.device)
                    labels = labels.long().to(self.device)

                    # zero the parameter gradients
                    self.optimizer.zero_grad()
                
                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                                            
                        outputs,_ = self.model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = self.loss_fn(outputs, labels)

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            
                            loss.backward()
                            self.optimizer.step()
                                    
                    # statistics
                    running_loss += loss.item() * inputs.size(0)
                    running_acc += torch.sum(preds == labels.data)
                    #running_f1 += f1_score(labels.data,preds)*inputs.size(0)
                    #import pdb;pdb.set_trace()
                    labels_list += [labels.data.view(-1)]
                    preds_list += [preds.view(-1)]

                    pbar.update(inputs.shape[0])
                pbar.close()

                labels_list = torch.cat(labels_list)
                preds_list = torch.cat(preds_list)
                
                epoch_loss = running_loss / self.dataset_sizes[phase]
                epoch_acc = running_acc.double() / self.dataset_sizes[phase]
                epoch_f1 = f1_score(labels_list,preds_list)

                print('{} Sel_Loss: {:.4f} Acc: {:.4f} F1: {:.4f}'.format(
                    phase, epoch_loss, epoch_acc,  epoch_f1))

                # deep copy the model
                if phase == 'valid' and epoch_f1 > best_epoch_f1:
                    best_epoch_f1 = epoch_f1
                    torch.save(self.model.state_dict(),self.exp_name+'.pt')
                    
        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print('Best F1: {:4f}'.format(best_epoch_f1))

        torch.save(self.model.state_dict(),self.exp_name+'_final.pt')
        
        print('Training completed finally !!!!!')

        
    def test_model_auc(self):
                
        self.model.load_state_dict(torch.load(self.exp_name+'.pt'))
        self.model.eval()
        
        acc = 0
        total = 0
        mode = 'test'

        predictions = []
        ground_truth = []
        
        with torch.no_grad():
            
            pbar = tqdm(total=self.dataset_sizes[mode])
            for data in self.dataloaders[mode]:

                inputs = data['image']
                labels = data['category']
                
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                
                output = self.model(inputs)
                _,out = torch.max(output,1)
                
                predictions.append(output.cpu().numpy())
                ground_truth.append(labels.cpu().numpy())
                
                total += labels.size(0)
                acc += torch.sum(out==labels.data)
                pbar.update(inputs.shape[0])
            pbar.close()
                
        pred = predictions[0]
        for i in range(len(predictions)-1):
            pred = np.concatenate((pred,predictions[i+1]),axis=0)
            
        gt = ground_truth[0]
        for i in range(len(ground_truth)-1):
            gt = np.concatenate((gt,ground_truth[i+1]),axis=0)
            
        #import pdb;pdb.set_trace()
        auc = roc_auc_score(gt,pred[:,1],average='weighted')
        
        print("AUC:", auc)
        print("ACC:", acc.double()/total)
        

    def get_cam(self):
                
        self.model.load_state_dict(torch.load(self.exp_name+'.pt'))
        self.model.eval()
        
        acc = 0
        total = 0
        mode = 'test'

        cm = []
        m = []
        bm = []
        
        params = list(self.model.parameters())                        
        weight_softmax = torch.squeeze(params[-2].data)
        
        iou = 0
        
        with torch.no_grad():
            
            pbar = tqdm(total=self.dataset_sizes[mode])
            for data in self.dataloaders[mode]:

                inputs = data['image']
                labels = data['category']
                
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                
                input_write = inputs.cpu().numpy().squeeze().transpose((1,2,0))
                
                output = self.model(inputs)
                _,out = torch.max(output,1)      

                #Get the CAM which will the prob map
                cam = torch.matmul(weight_softmax[out[0]],feat[0].reshape(feat[0].shape[0],feat[0].shape[1]*feat[0].shape[2]))
                cam = F.relu(cam.reshape(feat[0].shape[1], feat[0].shape[2]))
                cam_img = F.interpolate(cam.unsqueeze(dim=0).unsqueeze(dim=0),(self.input_shape[0],self.input_shape[1]),mode='bilinear')             
                cam_img = cam_img - cam_img.min()
                cam_img = cam_img/cam_img.max()
                
                base_path = '../Experiments/Oxford_pets/'
                name = data['name'][0]
                
                heatmap = cam_img.cpu().numpy().squeeze()
                heatmap[heatmap>heatmap.mean()+heatmap.std()] = 1
                heatmap[heatmap<=heatmap.mean()+heatmap.std()] = 0
                heatmap = np.expand_dims(heatmap,axis=3)
                
                #heatmap = cv2.applyColorMap(np.uint8(255*cam_img.cpu().numpy().squeeze()), cv2.COLORMAP_JET)
                #heatmap = np.float32(heatmap) / 255
                #cam_f = heatmap + np.float32(inputs.cpu().numpy().squeeze().transpose((1,2,0)))
                
                cam_f = heatmap*(input_write)
                #import pdb;pdb.set_trace()
                cam_f = cam_f / np.max(cam_f)
                
                
                pr = name.replace('.j','_cam.j')
                cv2.imwrite(base_path+pr,cam_f*255)
                
                #cv2.imwrite(base_path+name,input_write*255)
                      
                pbar.update(inputs.shape[0])
                
            pbar.close()                
        
    def return_model(self):
        self.model.load_state_dict(torch.load(self.exp_name+'_sel.pt'))
        self.model.eval()
        mode = 'test'
        return self.model,self.dataloaders[mode]

In [17]:
#Pytorch
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

#Torchvision
import torchvision
from torchvision import datasets, models, transforms, utils
  
#Pytorch
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

#Torchvision
import torchvision
from torchvision import datasets, models, transforms, utils
from torch.utils.data import Dataset, DataLoader

#Image Processing
import matplotlib.pyplot as plt
from skimage import io, transform, color
import PIL
from PIL import Image

#Others
import sklearn.metrics
from sklearn.metrics import *
import numpy as np
import pandas as pd
import cv2
import time
import os
import copy
from model_summary import *
import pretrainedmodels
import tqdm
from tqdm import tqdm_notebook as tqdm
import warnings
warnings.filterwarnings("ignore")



class dataset(Dataset):

    def __init__(self, csv_file, root_dir, transform=None):

        self.data_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        self.mask_dir = self.root_dir.replace('images','masks')
        
    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,self.data_frame.iloc[idx]['name'])
        image = Image.open(img_name)
        
        mask_name = os.path.join(self.mask_dir,self.data_frame.iloc[idx]['name'])
        mask = io.imread(mask_name)
        mask = np.array([mask,mask,mask]).transpose((1,2,0))
        mask = Image.fromarray(mask)

        label = self.data_frame.iloc[idx]['category']       

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
    
        return {'image':image, 'category':label, 'mask':mask, 'name':self.data_frame.iloc[idx]['name']}
    

def get_dataloader(data_dir, train_csv_path, image_size, img_mean, img_std, batch_size=1):

    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize(image_size),
            transforms.RandomHorizontalFlip(0.5),
            transforms.RandomAffine(translate=(0,0.2),degrees=15,shear=15),
            transforms.ToTensor(),
            transforms.Normalize(img_mean,img_std)
        ]),
        'valid': transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(img_mean,img_std)
        ]),
        'test': transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(img_mean,img_std)
        ])
    }

    image_datasets = {}
    dataloaders = {}
    dataset_sizes = {}

    for x in ['train', 'valid', 'test']:
        if x == 'train':
            bs = batch_size
            sh = True
        elif x == 'valid':
            bs = batch_size
            sh = False
        else:
            bs = 1
            sh = False
        image_datasets[x] = dataset(train_csv_path.replace('train',x),root_dir=data_dir,transform=data_transforms[x])
        dataloaders[x] = torch.utils.data.DataLoader(image_datasets[x], batch_size=bs,shuffle=sh, num_workers=8)    
        dataset_sizes[x] = len(image_datasets[x])

    device = torch.device("cuda:0")

    return dataloaders,dataset_sizes,image_datasets,device

# def build_model():

#     class mdl(nn.Module):
#         def __init__(self,base_model):
#             super().__init__()
#             self.base = base_model 
#             self.gap = nn.AdaptiveAvgPool2d((1,1))
#             self.fc1 = nn.Linear(512,2)

#         def forward(self, x):
#             x_base = self.base(x)
#             x = self.gap(x_base)
#             x = x.view(x.size(0), -1)
#             x = self.fc1(x)
#             return x,x_base 

#     v = models.vgg16_bn(pretrained=True)
#     v1 = nn.Sequential(*list(v.children())[:-1])
#     model = mdl(v1[-1][:-1])
    
#     return model
def build_base():
    class unetConv2(nn.Module):
        def __init__(self, in_size, out_size, is_batchnorm):
            super(unetConv2, self).__init__()

            if is_batchnorm:
                self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 1),
                                           nn.BatchNorm2d(out_size),
                                           nn.ReLU(),)
                self.conv2 = nn.Sequential(nn.Conv2d(out_size, out_size, 3, 1, 1),
                                           nn.BatchNorm2d(out_size),
                                           nn.ReLU(),)
            else:
                self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 1),
                                           nn.ReLU(),)
                self.conv2 = nn.Sequential(nn.Conv2d(out_size, out_size, 3, 1, 1),
                                           nn.ReLU(),)
        def forward(self, inputs):
            outputs = self.conv1(inputs)
            outputs = self.conv2(outputs)
            return outputs

    class unetUp(nn.Module):
        def __init__(self, in_size, out_size, is_deconv):
            super(unetUp, self).__init__()
            self.conv = unetConv2(in_size, out_size, False)
            if is_deconv:
                self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
            else:
                self.up = nn.UpsamplingBilinear2d(scale_factor=2)

        def forward(self, inputs1, inputs2):
            outputs2 = self.up(inputs2)
            offset = outputs2.size()[2] - inputs1.size()[2]
            padding = 2 * [offset // 2, offset // 2]
            outputs1 = F.pad(inputs1, padding)
            return self.conv(torch.cat([outputs1, outputs2], 1))

    class unet(nn.Module):

        def __init__(self, feature_scale=4, n_classes=1, is_deconv=True, in_channels=3, is_batchnorm=True):
            super(unet, self).__init__()
            self.is_deconv = is_deconv
            self.in_channels = in_channels
            self.is_batchnorm = is_batchnorm
            self.feature_scale = feature_scale

            filters = [32, 64, 128, 256, 512]
            filters = [int(x / self.feature_scale) for x in filters]

            #downsampling
            self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
            self.maxpool1 = nn.MaxPool2d(kernel_size=2)

            self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
            self.maxpool2 = nn.MaxPool2d(kernel_size=2)

            self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
            self.maxpool3 = nn.MaxPool2d(kernel_size=2)

            self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
            self.maxpool4 = nn.MaxPool2d(kernel_size=2)
                    
            self.center = unetConv2(filters[3], filters[4], self.is_batchnorm)

            # upsampling
            self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv)
            self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv)
            self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv)
            self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv)

            # final conv (without any concat)
            self.final = nn.Conv2d(filters[0], n_classes, 1)

        def forward(self, inputs):
            conv1 = self.conv1(inputs)
            maxpool1 = self.maxpool1(conv1)

            conv2 = self.conv2(maxpool1)
            maxpool2 = self.maxpool2(conv2)

            conv3 = self.conv3(maxpool2)
            maxpool3 = self.maxpool3(conv3)

            conv4 = self.conv4(maxpool3)
            maxpool4 = self.maxpool4(conv4)

            center = self.center(maxpool4)
            up4 = self.up_concat4(conv4, center)
            up3 = self.up_concat3(conv3, up4)
            up2 = self.up_concat2(conv2, up3)
            up1 = self.up_concat1(conv1, up2)

            return up1
        
    model = unet()
    return model

            
class mdl(nn.Module):
    def __init__(self,base_model):
        super().__init__()
        self.base = base_model 
        self.gap = nn.AdaptiveAvgPool2d((1,1))
        self.fc1 = nn.Linear(8,2)

    def forward(self, x):
        x_base = self.base(x)
        x = self.gap(x_base)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x,x_base

def build_model():
    m = build_base()
    model_final = mdl(m)
    return model_final
    
def get_IoU(pred, targs, device):

    targs = torch.Tensor(targs).to(device)
    return (pred*targs).sum() / ((pred+targs).sum() - (pred*targs).sum())

def denorm_img(img_ten,img_mean,img_std):

    bz,nc,h,w = img_ten.shape
    output = []
    img_num = img_ten.numpy()
    
    for i in range(bz):
        
        img = img_ten[i].numpy().squeeze()
        
        img[0,:,:] = img[0,:,:]*img_std[0]
        img[1,:,:] = img[1,:,:]*img_std[1]
        img[2,:,:] = img[2,:,:]*img_std[2]

        img[0,:,:] = img[0,:,:] + img_mean[0]
        img[1,:,:] = img[1,:,:] + img_mean[1]
        img[2,:,:] = img[2,:,:] + img_mean[2]
        
        img = img.mean(axis=0)
        img[img>=0.2*img.max()] = 1
        img[img<0.2*img.max()] = 0
        
        output.append(img)
    
    output = np.array(output)
    return output
    
class grad_cam():
    def __init__(self):
        
        #Initialization
        self.data_dir = '../Data/oxford_pets/sparse_images/'
        self.train_csv = '../CSV/oxford_pet_train.csv'
        self.num_epochs = 25
        self.input_shape = (224,224)
        self.batch_size = 4
        self.img_mean = [0,0,0]
        self.img_std = [1, 1, 1]
        
        self.exp_name = 'Weights/high_res_grad_cam_vgg_16_oxford'
        
        #Define the three models
        self.model = build_model()
        
        #Put them on the GPU
        self.model = self.model.cuda()
        
        #Get the dataloaders
        self.dataloaders,self.dataset_sizes,self.dataset,self.device = get_dataloader(self.data_dir,self.train_csv,\
                                                        self.input_shape,self.img_mean,self.img_std,self.batch_size)
        
        #Get the optimizer
        self.optimizer = optim.Adam(self.model.parameters(),lr=0.005, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
        
        self.loss_fn = nn.CrossEntropyLoss()
        
    def train(self):
        
        since = time.time()
        best_epoch_acc = 0.0
        best_epoch_f1 = 0.0
        
        for epoch in range(self.num_epochs):
            print('Epoch {}/{}'.format(epoch, self.num_epochs - 1),flush=True)
            print('-' * 10,flush=True)

            # Each epoch has a training and validation phase
            for phase in ['train', 'valid']:
                labels_list = []
                preds_list = []
                
                if phase == 'train':
                    
                    #Set the models to training mode
                    self.model.train()
                
                else:
                    #Set the models to evaluation mode
                    self.model.eval()
                    
                #Keep a track of all the three loss
                running_loss = 0.0
                
                #Metrics : predictor auc and selector iou
                running_acc = 0
                running_f1 = 0
                
                #tqdm bar
                pbar = tqdm(total=self.dataset_sizes[phase])

                # Iterate over data.
                for sampled_batch in self.dataloaders[phase]:

                    inputs = sampled_batch['image']
                    labels = sampled_batch['category']

                    #Input needs to be float and labels long
                    inputs = inputs.float().to(self.device)
                    labels = labels.long().to(self.device)

                    # zero the parameter gradients
                    self.optimizer.zero_grad()
                
                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                                            
                        outputs,_ = self.model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = self.loss_fn(outputs, labels)

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            
                            loss.backward()
                            self.optimizer.step()
                                    
                    # statistics
                    running_loss += loss.item() * inputs.size(0)
                    running_acc += torch.sum(preds == labels.data)
                    #running_f1 += f1_score(labels.data,preds)*inputs.size(0)
                    labels_list += [labels.data.view(-1)]
                    preds_list += [preds.view(-1)]

                    pbar.update(inputs.shape[0])
                pbar.close()

                labels_list = torch.cat(labels_list)
                preds_list = torch.cat(preds_list)
                
                epoch_loss = running_loss / self.dataset_sizes[phase]
                epoch_acc = running_acc.double() / self.dataset_sizes[phase]
                epoch_f1 = f1_score(labels_list,preds_list)

                print('{} Sel_Loss: {:.4f} Acc: {:.4f} F1: {:.4f}'.format(
                    phase, epoch_loss, epoch_acc,  epoch_f1))

                # deep copy the model
                if phase == 'valid' and epoch_f1 > best_epoch_f1:
                    best_epoch_f1 = epoch_f1
                    torch.save(self.model.state_dict(),self.exp_name+'.pt')
                    
        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print('Best F1: {:4f}'.format(best_epoch_f1))

        torch.save(self.model.state_dict(),self.exp_name+'_final.pt')
        
        print('Training completed finally !!!!!')

        
    def test_model_auc(self):
                
        self.model.load_state_dict(torch.load(self.exp_name+'.pt'))
        self.model.eval()
        
        acc = 0
        total = 0
        mode = 'test'

        predictions = []
        ground_truth = []
        
        with torch.no_grad():
            
            pbar = tqdm(total=self.dataset_sizes[mode])
            for data in self.dataloaders[mode]:

                inputs = data['image']
                labels = data['category']
                
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                
                output,_ = self.model(inputs)
                _,out = torch.max(output,1)
                
                predictions.append(output.cpu().numpy())
                ground_truth.append(labels.cpu().numpy())
                
                total += labels.size(0)
                acc += torch.sum(out==labels.data)
                pbar.update(inputs.shape[0])
            pbar.close()
                
        pred = predictions[0]
        for i in range(len(predictions)-1):
            pred = np.concatenate((pred,predictions[i+1]),axis=0)
            
        gt = ground_truth[0]
        for i in range(len(ground_truth)-1):
            gt = np.concatenate((gt,ground_truth[i+1]),axis=0)
            
        #import pdb;pdb.set_trace()
        auc = roc_auc_score(gt,pred[:,1],average='weighted')
        
        print("AUC:", auc)
        print("ACC:", acc.double()/total)
        

    def get_cam(self):
                
        self.model.load_state_dict(torch.load(self.exp_name+'.pt'))
        self.model.eval()
        
        acc = 0
        total = 0
        mode = 'test'

        cm = []
        m = []
        bm = []
        
        params = list(self.model.parameters())                        
        weight_softmax = torch.squeeze(params[-2].data)
        
        iou = 0
        
        with torch.no_grad():
            
            pbar = tqdm(total=self.dataset_sizes[mode])
            for data in self.dataloaders[mode]:

                inputs = data['image']
                labels = data['category']
                
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                
                input_write = inputs.cpu().numpy().squeeze().transpose((1,2,0))
                
                output,feat = self.model(inputs)
                _,out = torch.max(output,1)      

                #Get the CAM which will the prob map
                cam = torch.matmul(weight_softmax[out[0]],feat[0].reshape(feat[0].shape[0],feat[0].shape[1]*feat[0].shape[2]))
                cam = F.relu(cam.reshape(feat[0].shape[1], feat[0].shape[2]))
                cam_img = F.interpolate(cam.unsqueeze(dim=0).unsqueeze(dim=0),(self.input_shape[0],self.input_shape[1]),mode='bilinear')             
                cam_img = cam_img - cam_img.min()
                cam_img = cam_img/cam_img.max()
                
                base_path = '../Experiments/Oxford_pets/'
                name = data['name'][0]
                
                heatmap = cam_img.cpu().numpy().squeeze()
                heatmap[heatmap>heatmap.mean()+heatmap.std()] = 1
                heatmap[heatmap<=heatmap.mean()+heatmap.std()] = 0
                heatmap = np.expand_dims(heatmap,axis=3)
                
                #heatmap = cv2.applyColorMap(np.uint8(255*cam_img.cpu().numpy().squeeze()), cv2.COLORMAP_JET)
                #heatmap = np.float32(heatmap) / 255
                #cam_f = heatmap + np.float32(inputs.cpu().numpy().squeeze().transpose((1,2,0)))
                
                cam_f = heatmap*(input_write)
                #import pdb;pdb.set_trace()
                cam_f = cam_f / np.max(cam_f)
                
                
                pr = name.replace('.j','_cam.j')
                cv2.imwrite(base_path+pr,cam_f*255)
                
                #cv2.imwrite(base_path+name,input_write*255)
                      
                pbar.update(inputs.shape[0])
                
            pbar.close()                
        
    def return_model(self):
        self.model.load_state_dict(torch.load(self.exp_name+'_sel.pt'))
        self.model.eval()
        mode = 'test'
        return self.model,self.dataloaders[mode]


In [18]:
gc = grad_cam()

In [19]:
gc.train()

Epoch 0/24
----------


HBox(children=(IntProgress(value=0, max=1999), HTML(value='')))

train Sel_Loss: 0.6883 Acc: 0.5338 F1: 0.4537


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

valid Sel_Loss: 0.8663 Acc: 0.5070 F1: 0.6693
Epoch 1/24
----------


HBox(children=(IntProgress(value=0, max=1999), HTML(value='')))

train Sel_Loss: 0.6668 Acc: 0.5893 F1: 0.4948


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

valid Sel_Loss: 0.6557 Acc: 0.5950 F1: 0.4122
Epoch 2/24
----------


HBox(children=(IntProgress(value=0, max=1999), HTML(value='')))

train Sel_Loss: 0.6640 Acc: 0.5938 F1: 0.4925


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

valid Sel_Loss: 0.6618 Acc: 0.5780 F1: 0.4155
Epoch 3/24
----------


HBox(children=(IntProgress(value=0, max=1999), HTML(value='')))

train Sel_Loss: 0.6650 Acc: 0.5748 F1: 0.4572


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

valid Sel_Loss: 0.6628 Acc: 0.5880 F1: 0.5560
Epoch 4/24
----------


HBox(children=(IntProgress(value=0, max=1999), HTML(value='')))

train Sel_Loss: 0.6593 Acc: 0.5998 F1: 0.4785


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

valid Sel_Loss: 0.6580 Acc: 0.6180 F1: 0.4753
Epoch 5/24
----------


HBox(children=(IntProgress(value=0, max=1999), HTML(value='')))

train Sel_Loss: 0.6606 Acc: 0.5918 F1: 0.4855


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

valid Sel_Loss: 0.6472 Acc: 0.6170 F1: 0.5663
Epoch 6/24
----------


HBox(children=(IntProgress(value=0, max=1999), HTML(value='')))

train Sel_Loss: 0.6574 Acc: 0.6048 F1: 0.5006


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

valid Sel_Loss: 0.6488 Acc: 0.6180 F1: 0.4893
Epoch 7/24
----------


HBox(children=(IntProgress(value=0, max=1999), HTML(value='')))

train Sel_Loss: 0.6547 Acc: 0.6078 F1: 0.5000


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

valid Sel_Loss: 0.6526 Acc: 0.6300 F1: 0.5363
Epoch 8/24
----------


HBox(children=(IntProgress(value=0, max=1999), HTML(value='')))

train Sel_Loss: 0.6571 Acc: 0.6108 F1: 0.5174


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

valid Sel_Loss: 0.6487 Acc: 0.6090 F1: 0.4292
Epoch 9/24
----------


HBox(children=(IntProgress(value=0, max=1999), HTML(value='')))

train Sel_Loss: 0.6574 Acc: 0.5998 F1: 0.4872


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

valid Sel_Loss: 0.6459 Acc: 0.6300 F1: 0.5220
Epoch 10/24
----------


HBox(children=(IntProgress(value=0, max=1999), HTML(value='')))

train Sel_Loss: 0.6593 Acc: 0.6023 F1: 0.4920


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

valid Sel_Loss: 0.6503 Acc: 0.6100 F1: 0.5486
Epoch 11/24
----------


HBox(children=(IntProgress(value=0, max=1999), HTML(value='')))

train Sel_Loss: 0.6554 Acc: 0.6063 F1: 0.4945


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

valid Sel_Loss: 0.6546 Acc: 0.6200 F1: 0.6082
Epoch 12/24
----------


HBox(children=(IntProgress(value=0, max=1999), HTML(value='')))

train Sel_Loss: 0.6520 Acc: 0.6008 F1: 0.5146


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

valid Sel_Loss: 0.6494 Acc: 0.6330 F1: 0.5349
Epoch 13/24
----------


HBox(children=(IntProgress(value=0, max=1999), HTML(value='')))

train Sel_Loss: 0.6561 Acc: 0.6008 F1: 0.4764


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

valid Sel_Loss: 0.6439 Acc: 0.6250 F1: 0.5198
Epoch 14/24
----------


HBox(children=(IntProgress(value=0, max=1999), HTML(value='')))

Process Process-559:
Process Process-560:
Process Process-555:
Process Process-554:
Process Process-556:
Process Process-553:
Traceback (most recent call last):
Traceback (most recent call last):
Process Process-558:
Traceback (most recent call last):
  File "/home/vdslab/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/vdslab/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/vdslab/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Process Process-557:
  File "/home/vdslab/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/vdslab/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/vdslab/anaconda3/lib/python3.6/multiprocess

  File "/home/vdslab/anaconda3/lib/python3.6/selectors.py", line 376, in select
    fd_event_list = self._poll.poll(timeout)
KeyboardInterrupt
KeyboardInterrupt
  File "/home/vdslab/anaconda3/lib/python3.6/multiprocessing/connection.py", line 414, in _poll
    r = wait([self], timeout)
  File "/home/vdslab/anaconda3/lib/python3.6/multiprocessing/connection.py", line 911, in wait
    ready = selector.select(timeout)
  File "/home/vdslab/anaconda3/lib/python3.6/selectors.py", line 376, in select
    fd_event_list = self._poll.poll(timeout)


KeyboardInterrupt: 

In [None]:
gc.get_cam()

In [None]:
a = cv2.imread('../Experiments/Oxford_pets/cat_Abyssinian_105_cam.jpg')

In [None]:
plt.imshow(a)

In [None]:
gc.test_model_auc()

In [None]:
gc.get_cam()

In [None]:
gc.test_model_auc()

In [None]:
#gc.train()

In [None]:
m,b,c = gc.get_cam()

In [None]:
a = cv2.imread('../Experiments/Sanity_Check/dog.10337_cam.jpg')
plt.imshow(a)

In [None]:
os.mkdir('../Experiments/Sanity_Check')