In [20]:
#Pytorch
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

#Torchvision
import torchvision
from torchvision import datasets, models, transforms, utils
from torch.utils.data import Dataset, DataLoader

#Image Processing
import matplotlib.pyplot as plt
from skimage import io, transform, color
import PIL
from PIL import Image
import augmentations
from augmentations import *

#Others
import sklearn.metrics
from sklearn.metrics import *
import numpy as np
import pandas as pd
import cv2
import time
import os
import copy
from model_summary import *
import pretrainedmodels
import tqdm
from tqdm import tqdm_notebook as tqdm
import warnings
warnings.filterwarnings("ignore")

## Dataloader

class dataset(Dataset):

    def __init__(self, csv_file, root_dir, transform=None):

        self.data_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        self.mask_dir = '../Data/mask_orient_cropped/'
        self.dist_mask_dir = '../Data/distance_mask_orient_cropped/'
        
    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):

        #get the image
        img_name = os.path.join(self.root_dir,self.data_frame.iloc[idx]['name'])
        image = Image.open(img_name)
        
        #get the mask
        mask_name = os.path.join(self.mask_dir,self.data_frame.iloc[idx]['name'].replace('.j','_mask.j'))
        mask = io.imread(mask_name)
        
        #get the distance mask
        dist_mask_name = os.path.join(self.dist_mask_dir,self.data_frame.iloc[idx]['name'].replace('.j','_mask.j'))
        dist_mask = io.imread(dist_mask_name)

        #add the distance mask as the final channel of the 
        mask = np.array([mask,mask,dist_mask]).transpose((1,2,0))
        mask = Image.fromarray(mask)     

        if self.transform:
            image,mask = self.transform(image,mask)
        
        mask_final = mask[0,:,:]
        mask_final[mask_final<0.5] = 0
        mask_final[mask_final>0.5] = 1
        
        dist_mask = mask[2,:,:]
        
        return {'image':image, 'mask':mask_final, 'dist_mask':dist_mask, 'name':self.data_frame.iloc[idx]['name']}
    

def get_dataloader(data_dir, train_csv_path, img_mean, img_std, batch_size=1):

    data_transforms = {
        'train': Compose([
            RandomHorizontallyFlip(0.5),
            RandomVerticallyFlip(0.5),
            RandomTranslate((0.1,0.1)),
            RandomRotate(90),
            ToTensor(),
            Normalize(img_mean,img_std)
        ]),
        'valid': Compose([
            ToTensor(),
            Normalize(img_mean,img_std)
        ]),
        'test': Compose([
            ToTensor(),
            Normalize(img_mean,img_std)        
        ])
    }

    image_datasets = {}
    dataloaders = {}
    dataset_sizes = {}

    for x in ['train', 'valid', 'test']:
        if x == 'train':
            bs = batch_size
            sh = True
        elif x == 'valid':
            bs = batch_size
            sh = False
        else:
            bs = 1
            sh = False
        image_datasets[x] = dataset(train_csv_path.replace('train',x),root_dir=data_dir,transform=data_transforms[x])
        dataloaders[x] = torch.utils.data.DataLoader(image_datasets[x], batch_size=bs,shuffle=sh, num_workers=8)    
        dataset_sizes[x] = len(image_datasets[x])

    device = torch.device("cuda:0")

    return dataloaders,dataset_sizes,image_datasets,device

# a,_,_,_ = get_dataloader('../Data/CBIS-DDSM_classification_orient_cropped/','../CSV/gain_train.csv',[0,0,0],[1,1,1],1)

class SaveFeatures():
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output
    def remove(self): self.hook.remove()

class UnetBlock(nn.Module):
    def __init__(self, up_in, x_in, n_out):
        super().__init__()
        up_out = x_out = n_out//2
        self.x_conv  = nn.Conv2d(x_in,  x_out,  1)
        self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2)
        self.bn = nn.BatchNorm2d(n_out)
        
    def forward(self, up_p, x_p):
        up_p = self.tr_conv(up_p)
        x_p = self.x_conv(x_p)
        cat_p = torch.cat([up_p,x_p], dim=1)
        return self.bn(F.relu(cat_p))

class Unet34(nn.Module):
    def __init__(self, rn):
        super().__init__()
        self.rn = rn
        self.sfs = [SaveFeatures(rn[0][i]) for i in [12,22,32,42]]
        self.up1 = UnetBlock(512,512,256)
        self.up2 = UnetBlock(256,512,256)
        self.up3 = UnetBlock(256,256,256)
        self.up4 = UnetBlock(256,128,256)
        self.up5 = nn.ConvTranspose2d(256, 1, 2, stride=2)
        
    def forward(self,x):
        x = F.relu(self.rn(x))
        x = self.up1(x, self.sfs[3].features)
        x = self.up2(x, self.sfs[2].features)
        x = self.up3(x, self.sfs[1].features)
        x = self.up4(x, self.sfs[0].features)
        x = self.up5(x)
        return x
    
    def close(self):
        for sf in self.sfs: sf.remove()

def build_model():
    v = models.vgg16_bn(pretrained=True)
    v1 = nn.Sequential(*list(v.children())[:-1])
    m = Unet34(v1)
    return m

# a = build_model()

# summary(a.cuda(),(3,256,256))

def get_IoU(pred, targs):
    return 2*(pred*targs).sum() / ((pred+targs).sum())# - (pred*targs).sum())

def new_metric(preds,targs,device,patch_size=64):
    
    mask = targs
    b,c,h,w = preds.size()
    
#     conv_filter = torch.ones((1, 1, 5, 5))/25
#     conv_filter = conv_filter.to(device)
#     filtered_preds = F.conv2d(preds, conv_filter, padding=2)
#     max_loc = torch.argmax(filtered_preds)
#     h_loc = int(max_loc%h)
#     w_loc = int(max_loc/h)

    temp = preds
    temp = temp.detach().cpu().numpy()
    temp = temp.squeeze()
    temp[temp>0] = 1
    temp[temp<0] = 0
    #import pdb;pdb.set_trace()
    hs,ws = np.where(temp==1)
    try:
        h_loc = int(hs.mean())
        w_loc = int(ws.mean())
    except:
        return 0
    
    h_up = max(0,h_loc-patch_size//2)
    h_bottom = min(h,h_loc+patch_size//2)
    
    w_left = max(0,w_loc-patch_size//2)
    w_right = min(w,w_loc+patch_size//2)
    
    if h_up == 0:
        h_bottom = min(patch_size,h)
    elif h_bottom == h:
        h_up = max(0,-patch_size+h)
        
    if w_left == 0:
        w_right = min(patch_size,w)
    elif w_right == w:
        w_left = max(0,-patch_size+w)
    
    mask_loc = mask[0,h_up:h_bottom,w_left:w_right]
    return 1.0*mask_loc.sum()/(mask.sum()+1)
    
    
class SoftDiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(SoftDiceLoss, self).__init__()

    def forward(self, logits, targets):
        smooth = 1
        num = targets.size(0)
        probs = F.sigmoid(1*(logits))
        m1 = probs.view(num, -1)
        m2 = targets.view(num, -1)
        intersection = (m1 * m2)

        score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
        score = 1 - score.sum() / num
        return score
    
class DistanceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DistanceLoss, self).__init__()

    def forward(self, logits, targets, mask):
    
        num = targets.size(0)
        probs = F.sigmoid(logits)
        pr = probs.view(num, -1)
        di = targets.view(num, -1)
        ma = mask.view(num,-1)
        
#         intersection = (m1 * m2)
#         score = intersection.sum()/m2.sum()

        #loss = di*((1-pr)*ma + pr*(1-ma))
        loss = -(ma*torch.log(pr+1e-6)*di + (1-ma)*torch.log(1-pr+1e-6)*di)
        
        return torch.mean(loss)
    
class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        
    def forward(self, input, target):
        input = input.view(-1)
        target = target.view(-1)
        target_a = target*self.alpha
        target_1_a = (1-target)*(1-self.alpha)
        
        pt = F.sigmoid(input)
    
        loss = target_a*((1-pt)**self.gamma)*torch.log(pt+1e-8) + target_1_a*((pt)**self.gamma)*torch.log(1-pt+1e-8) 
        loss = -loss.mean()
        
        return loss   
    
class VGG_unet_bc():
    def __init__(self):
        
        #Initialization
        self.data_dir = '../Data/CBIS-DDSM_classification_orient_cropped/'
        self.train_csv = '../CSV/mass_segmentation_train.csv'
        self.num_epochs = 30
        self.batch_size = 1
        self.img_mean = [0.253, 0.238, 0.234]
        self.img_std = [0.272, 0.268, 0.262]
        self.exp_name = 'Weights/VGG_unet_cbis_ddsm_diceloss_distloss'
        self.mean_pixel = 0
        
        #Define the three models
        self.model = build_model()
        
        #Put them on the GPU
        self.model = self.model.cuda()
        #self.model.load_state_dict(torch.load(self.exp_name+'.pt'))
        
        #Get the dataloaders
        self.dataloaders,self.dataset_sizes,self.dataset,self.device = get_dataloader(self.data_dir,self.train_csv,\
                                                        self.img_mean,self.img_std,self.batch_size)
        

        self.optimizer = optim.Adam(self.model.parameters(),lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
        
        #self.lr_scheduler = CosineAnnealingLR(self.optimizer,T_max=10,eta_min=1e-4)
    
        
    def train(self):
        
        since = time.time()
        best_epoch_acc = 0.0
        best_epoch_iou = 0.0
        best_epoch_metric = 0.0
        
        for epoch in range(self.num_epochs):
            print('Epoch {}/{}'.format(epoch, self.num_epochs - 1),flush=True)
            print('-' * 10,flush=True)

            # Each epoch has a training and validation phase
            for phase in ['train', 'valid']:
                if phase == 'train':
                    
                    #Set the models to training mode
                    #self.lr_scheduler.step()
                    self.model.train()
                
                else:
                    #Set the models to evaluation mode
                    self.model.eval()
                    
                #Keep a track of all the three loss
                running_loss = 0.0
                running_dice_loss = 0.0
                running_dist_loss = 0.0
                running_mean_loss = 0.0
                
                #Metrics : predictor auc and selector iou
                running_iou = 0
                running_new_metric = 0
                
                #tqdm bar
                pbar = tqdm(total=self.dataset_sizes[phase])

                # Iterate over data.
                for sampled_batch in self.dataloaders[phase]:

                    inputs = sampled_batch['image']
                    mask = sampled_batch['mask']
                    dist_mask = sampled_batch['dist_mask']
                   
                    #Input needs to be float and labels long
                    inputs = inputs.float().to(self.device)
                    mask = mask.to(self.device)
                    dist_mask = dist_mask.to(self.device)
                    
                    # zero the parameter gradients
                    self.optimizer.zero_grad()
                
                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        
                        outputs = self.model(inputs)
                        
                        out4loss = outputs.view(-1)
                        mask4loss = mask.view(-1)
                        dm4loss = dist_mask.view(-1)
                        
                        z_c = float((mask==0).sum())
                        o_c = float((mask==1).sum())
                        alpha = z_c/(z_c+o_c)
                        
                        ce_loss = nn.BCEWithLogitsLoss()(out4loss,mask4loss)
                        #ce_loss = FocalLoss(2,alpha)(outputs,mask)
                        dice_loss = SoftDiceLoss()(outputs,mask)
                        dist_loss = DistanceLoss()(outputs,dist_mask,mask)
                        mean_loss = torch.abs(torch.mean(outputs) - self.mean_pixel)
                        #print(mean_loss,torch.mean(outputs))
                        
                        loss = ce_loss + dist_loss
                        #print(dice_loss,dist_loss)
                        
                        #print((F.sigmoid(out4loss)*dm4loss).sum(),(outputs>0).sum(),maskdm4loss.sum())
                        #import pdb;pdb.set_trace()
                        
                        # backward + optimize only if in training phase
                        if phase == 'train':
                            
                            loss.backward()
                            self.optimizer.step()
                                    
                    out4iou = out4loss
                    out4iou[out4loss>0] = 1
                    out4iou[out4loss<0] = 0
                    # statistics
                    running_loss += loss.item() * inputs.size(0)
                    running_dice_loss += dice_loss.item() * inputs.size(0)
                    running_dist_loss += dist_loss.item() * inputs.size(0)
                    running_mean_loss += mean_loss.item() * inputs.size(0)
                    running_iou += get_IoU(out4loss,mask4loss) * inputs.size(0)
                    running_new_metric += new_metric(outputs,mask,self.device,64) * inputs.size(0)
                    #print(running_new_metric)
                    #print(running_iou)
                    
                    pbar.update(inputs.shape[0])
                pbar.close()

                epoch_loss = running_loss / self.dataset_sizes[phase]
                epoch_dice_loss = running_dice_loss / self.dataset_sizes[phase]
                epoch_dist_loss = running_dist_loss / self.dataset_sizes[phase]
                epoch_mean_loss = running_mean_loss / self.dataset_sizes[phase]
                epoch_iou = running_iou / self.dataset_sizes[phase]
                epoch_new_metric = running_new_metric / self.dataset_sizes[phase]
                
                print('{} Loss: {:.4f} Dice_Loss: {:.4f} Dist_Loss: {:.4f} Mean_Loss: {:.4f} IoU: {:.4f} New Metric: {:.4f}'.format(
                    phase, epoch_loss, epoch_dice_loss, epoch_dist_loss, epoch_mean_loss, epoch_iou, epoch_new_metric))

                # deep copy the model
                if phase == 'valid' and epoch_new_metric > best_epoch_metric:
                    best_epoch_metric = epoch_new_metric
                    torch.save(self.model.state_dict(),self.exp_name+'.pt')
                    
        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print('Best new metric: {:4f}'.format(best_epoch_metric))

        torch.save(self.model.state_dict(),self.exp_name+'_final.pt')
        
        print('Training completed finally !!!!!')
        
        
    def test_model(self):
                
        self.model.load_state_dict(torch.load(self.exp_name+'.pt'))
        self.model.eval()
        
        mIoU = 0
        running_iou = 0
        running_new_metric = 0
        
        total = 0
        mode = 'valid'

        image_list = []
        mask_list = []
        pred_list = []
        i=0
        
        with torch.no_grad():
            
            for data in self.dataloaders[mode]:

                images = data['image'].to(self.device)
                mask = data['mask'].to(self.device)
                dist_mask = data['dist_mask']
                
                preds = self.model(images)
                
                out4metric = preds.view(-1)
                mask4metric = mask.view(-1)
                
#                 prediction = preds[:]
                
#                 #import pdb;pdb.set_trace()
#                 sorted_pixels = sorted(prediction.cpu().numpy().reshape(-1))
                
#                 thresh = float(sorted_pixels[int(0.99*len(sorted_pixels))])
                
#                 i+=1
#                 prediction[prediction>thresh] = 255
#                 prediction[prediction<thresh] = 0
#                 prediction = prediction/255

                prediction = preds
                prediction[prediction>0] = 1
                prediction[prediction<0] = 0

                running_iou += get_IoU(prediction.view(-1),mask4metric) * images.size(0)
                running_new_metric += new_metric(prediction,mask,self.device,64)
                
                base_path = '../Experiments/cbis_ddsm_sup/'
                name = data['name'][0]

                pr = name.replace('.j','_supervised_out_mask_2.j')
                
                preds = preds - preds.min()
                preds = preds/preds.max()
                
                cv2.imwrite(base_path+pr,prediction.cpu().numpy().squeeze()*255)
                #cv2.imwrite(base_path+pr.replace('.j','_probs.j'),preds.cpu().numpy().squeeze()*255)
                
                
        mIoU = running_iou/self.dataset_sizes[mode]
        new_met = running_new_metric/self.dataset_sizes[mode]
        print('mean IoU:',mIoU)
        print('new metric',new_met)

#                 image_list.append(images.squeeze().cpu().numpy().transpose((1,2,0)))
#                 mask_list.append(mask.transpose((1,2,0)).squeeze())
#                 pred_list.append(preds.squeeze().cpu().numpy())

#         return image_list,mask_list,pred_list
    

In [21]:
u = VGG_unet_bc()

In [22]:
u.train()

Epoch 0/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.4906 Dice_Loss: 0.9722 Dist_Loss: 0.1275 Mean_Loss: 1.1094 IoU: 0.2037 New Metric: 0.4298


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0996 Dice_Loss: 0.9528 Dist_Loss: 0.0283 Mean_Loss: 2.8049 IoU: 0.2065 New Metric: 0.3452
Epoch 1/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0552 Dice_Loss: 0.8986 Dist_Loss: 0.0118 Mean_Loss: 3.6419 IoU: nan New Metric: 0.5292


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0348 Dice_Loss: 0.8998 Dist_Loss: 0.0062 Mean_Loss: 4.5155 IoU: 0.1533 New Metric: 0.2538
Epoch 2/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0278 Dice_Loss: 0.8178 Dist_Loss: 0.0042 Mean_Loss: 4.9440 IoU: nan New Metric: 0.5452


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0254 Dice_Loss: 0.8132 Dist_Loss: 0.0033 Mean_Loss: 5.5387 IoU: 0.2428 New Metric: 0.3876
Epoch 3/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0223 Dice_Loss: 0.7549 Dist_Loss: 0.0027 Mean_Loss: 5.7675 IoU: 0.3359 New Metric: 0.5686


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0224 Dice_Loss: 0.7751 Dist_Loss: 0.0022 Mean_Loss: 6.3067 IoU: 0.2569 New Metric: 0.4237
Epoch 4/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0205 Dice_Loss: 0.7257 Dist_Loss: 0.0023 Mean_Loss: 6.3957 IoU: nan New Metric: 0.5747


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0229 Dice_Loss: 0.7715 Dist_Loss: 0.0019 Mean_Loss: 6.8843 IoU: 0.2400 New Metric: 0.3851
Epoch 5/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0197 Dice_Loss: 0.7025 Dist_Loss: 0.0020 Mean_Loss: 6.8925 IoU: nan New Metric: 0.5895


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0249 Dice_Loss: 0.7748 Dist_Loss: 0.0020 Mean_Loss: 7.3350 IoU: 0.2161 New Metric: 0.3356
Epoch 6/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0193 Dice_Loss: 0.6970 Dist_Loss: 0.0019 Mean_Loss: 7.2976 IoU: nan New Metric: 0.5809


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0234 Dice_Loss: 0.7859 Dist_Loss: 0.0016 Mean_Loss: 7.8259 IoU: 0.1852 New Metric: 0.3008
Epoch 7/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0188 Dice_Loss: 0.6819 Dist_Loss: 0.0019 Mean_Loss: 7.6185 IoU: nan New Metric: 0.6121


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0233 Dice_Loss: 0.7511 Dist_Loss: 0.0016 Mean_Loss: 8.2533 IoU: 0.2299 New Metric: 0.3739
Epoch 8/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0184 Dice_Loss: 0.6760 Dist_Loss: 0.0018 Mean_Loss: 7.8734 IoU: nan New Metric: 0.6099


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0237 Dice_Loss: 0.7246 Dist_Loss: 0.0018 Mean_Loss: 8.1604 IoU: 0.2828 New Metric: 0.4362
Epoch 9/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0183 Dice_Loss: 0.6690 Dist_Loss: 0.0018 Mean_Loss: 8.0675 IoU: nan New Metric: 0.6149


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0240 Dice_Loss: 0.7547 Dist_Loss: 0.0016 Mean_Loss: 8.4868 IoU: 0.2251 New Metric: 0.3266
Epoch 10/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0178 Dice_Loss: 0.6551 Dist_Loss: 0.0017 Mean_Loss: 8.2303 IoU: nan New Metric: 0.6346


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0216 Dice_Loss: 0.7219 Dist_Loss: 0.0016 Mean_Loss: 8.5272 IoU: 0.2774 New Metric: 0.4192
Epoch 11/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0175 Dice_Loss: 0.6488 Dist_Loss: 0.0017 Mean_Loss: 8.3517 IoU: nan New Metric: 0.6490


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0238 Dice_Loss: 0.6961 Dist_Loss: 0.0019 Mean_Loss: 8.6771 IoU: 0.3167 New Metric: 0.4543
Epoch 12/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0176 Dice_Loss: 0.6503 Dist_Loss: 0.0017 Mean_Loss: 8.4707 IoU: nan New Metric: 0.6302


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0233 Dice_Loss: 0.7222 Dist_Loss: 0.0022 Mean_Loss: 8.8466 IoU: 0.3004 New Metric: 0.4369
Epoch 13/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0172 Dice_Loss: 0.6478 Dist_Loss: 0.0017 Mean_Loss: 8.5678 IoU: nan New Metric: 0.6396


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0237 Dice_Loss: 0.7369 Dist_Loss: 0.0017 Mean_Loss: 9.0713 IoU: 0.2569 New Metric: 0.3564
Epoch 14/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0173 Dice_Loss: 0.6471 Dist_Loss: 0.0017 Mean_Loss: 8.6512 IoU: nan New Metric: 0.6417


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0231 Dice_Loss: 0.7023 Dist_Loss: 0.0017 Mean_Loss: 9.0783 IoU: 0.3038 New Metric: 0.4105
Epoch 15/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0170 Dice_Loss: 0.6415 Dist_Loss: 0.0016 Mean_Loss: 8.7263 IoU: nan New Metric: 0.6478


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0229 Dice_Loss: 0.7335 Dist_Loss: 0.0015 Mean_Loss: 9.0591 IoU: 0.2513 New Metric: 0.3425
Epoch 16/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0167 Dice_Loss: 0.6304 Dist_Loss: 0.0016 Mean_Loss: 8.7886 IoU: nan New Metric: 0.6633


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0249 Dice_Loss: 0.7303 Dist_Loss: 0.0016 Mean_Loss: 9.3761 IoU: 0.2663 New Metric: 0.3693
Epoch 17/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0163 Dice_Loss: 0.6244 Dist_Loss: 0.0015 Mean_Loss: 8.8599 IoU: nan New Metric: 0.6726


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0269 Dice_Loss: 0.7228 Dist_Loss: 0.0017 Mean_Loss: 9.2665 IoU: 0.2738 New Metric: 0.3715
Epoch 18/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0165 Dice_Loss: 0.6283 Dist_Loss: 0.0016 Mean_Loss: 8.9392 IoU: nan New Metric: 0.6715


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0248 Dice_Loss: 0.7178 Dist_Loss: 0.0018 Mean_Loss: 9.4999 IoU: 0.2723 New Metric: 0.3976
Epoch 19/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0164 Dice_Loss: 0.6236 Dist_Loss: 0.0016 Mean_Loss: 9.0027 IoU: nan New Metric: 0.6716


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0236 Dice_Loss: 0.6975 Dist_Loss: 0.0016 Mean_Loss: 9.4731 IoU: 0.2949 New Metric: 0.4092
Epoch 20/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0158 Dice_Loss: 0.6156 Dist_Loss: 0.0015 Mean_Loss: 9.0261 IoU: nan New Metric: 0.6772


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0257 Dice_Loss: 0.7180 Dist_Loss: 0.0017 Mean_Loss: 9.7466 IoU: 0.2790 New Metric: 0.3805
Epoch 21/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0160 Dice_Loss: 0.6047 Dist_Loss: 0.0015 Mean_Loss: 9.0924 IoU: nan New Metric: 0.7055


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0248 Dice_Loss: 0.7356 Dist_Loss: 0.0016 Mean_Loss: 9.8086 IoU: 0.2534 New Metric: 0.3633
Epoch 22/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0159 Dice_Loss: 0.6103 Dist_Loss: 0.0015 Mean_Loss: 9.1316 IoU: nan New Metric: 0.6916


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0246 Dice_Loss: 0.7269 Dist_Loss: 0.0019 Mean_Loss: 9.5548 IoU: 0.2815 New Metric: 0.3993
Epoch 23/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0155 Dice_Loss: 0.6029 Dist_Loss: 0.0014 Mean_Loss: 9.1916 IoU: nan New Metric: 0.6921


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0232 Dice_Loss: 0.7387 Dist_Loss: 0.0016 Mean_Loss: 9.4522 IoU: 0.2537 New Metric: 0.3684
Epoch 24/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0154 Dice_Loss: 0.6037 Dist_Loss: 0.0014 Mean_Loss: 9.2310 IoU: 0.4531 New Metric: 0.6930


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0254 Dice_Loss: 0.7607 Dist_Loss: 0.0015 Mean_Loss: 9.7552 IoU: 0.2091 New Metric: 0.2985
Epoch 25/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0151 Dice_Loss: 0.5935 Dist_Loss: 0.0014 Mean_Loss: 9.2895 IoU: nan New Metric: 0.6931


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0239 Dice_Loss: 0.7313 Dist_Loss: 0.0016 Mean_Loss: 9.8797 IoU: 0.2530 New Metric: 0.3503
Epoch 26/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0151 Dice_Loss: 0.5926 Dist_Loss: 0.0014 Mean_Loss: 9.3446 IoU: nan New Metric: 0.7188


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0234 Dice_Loss: 0.6926 Dist_Loss: 0.0019 Mean_Loss: 9.5985 IoU: 0.3038 New Metric: 0.4147
Epoch 27/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0151 Dice_Loss: 0.5967 Dist_Loss: 0.0014 Mean_Loss: 9.3811 IoU: nan New Metric: 0.7020


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0242 Dice_Loss: 0.7126 Dist_Loss: 0.0015 Mean_Loss: 9.7101 IoU: 0.2711 New Metric: 0.3822
Epoch 28/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0149 Dice_Loss: 0.5900 Dist_Loss: 0.0014 Mean_Loss: 9.4350 IoU: nan New Metric: 0.7088


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0252 Dice_Loss: 0.7350 Dist_Loss: 0.0017 Mean_Loss: 9.9727 IoU: 0.2630 New Metric: 0.3669
Epoch 29/29
----------


HBox(children=(IntProgress(value=0, max=924), HTML(value='')))


train Loss: 0.0147 Dice_Loss: 0.5842 Dist_Loss: 0.0013 Mean_Loss: 9.4740 IoU: nan New Metric: 0.7168


HBox(children=(IntProgress(value=0, max=307), HTML(value='')))


valid Loss: 0.0232 Dice_Loss: 0.6812 Dist_Loss: 0.0017 Mean_Loss: 9.8432 IoU: 0.3154 New Metric: 0.4401
Training complete in 16m 35s
Best new metric: 0.454344
Training completed finally !!!!!


In [23]:
u.test_model()

mean IoU: tensor(0.3167, device='cuda:0')
new metric tensor(0.4543, device='cuda:0')


In [None]:
plt.imshow(cv2.imread('../Experiments/cbis_ddsm/Mass-Test_P_00056_LEFT_MLO_supervised_out.jpg'))

In [None]:
plt.imshow(cv2.imread('../Experiments/cbis_ddsm/Mass-Test_P_00056_LEFT_MLO_supervised_new_mask.jpg'))

In [None]:
u.train()

In [None]:
%debug

In [None]:
u.test_model()

In [None]:
'../Data/masks_orient/Mass-Training_P_00913'

In [None]:
%debug

In [None]:
a,b,c = u.test_model()

In [None]:
len(a)

In [None]:
plt.imshow(a[1])

In [None]:
b[0].sum()

In [None]:
plt.imshow(b[1])

In [None]:
d = c[1]
d[d>0] = 255
d[d<=0] = 0

In [None]:
plt.imshow(d)

In [None]:
get_IoU(d/255,b[1])

In [None]:
plt.imshow(a[1])

In [None]:
plt.imshow(b[1])

In [None]:
plt.imshow(c[1])

In [None]:
u.train()

In [None]:
%debug

In [None]:
_,_,_ = dci.get_cam()

In [None]:
dci.train()

In [None]:
!nvidia-smi

In [None]:
# dci.test_model_acc()

In [None]:
md,dl = dci.return_model()

In [None]:
a = iter(dl).next()

m = denorm_img(a['mask'],[0.223, 0.231, 0.243],[0.266, 0.270, 0.274]).squeeze()
bm = denorm_img(a['bmask'],[0.223, 0.231, 0.243],[0.266, 0.270, 0.274]).squeeze()

In [None]:
device = torch.device("cuda:0")

p = F.sigmoid(md(a['image'].to(device))).detach().cpu().numpy().squeeze()

In [None]:
m.shape

In [None]:
p.shape

In [None]:
bm.shape

In [None]:
plt.imshow(bm)

In [None]:
plt.imshow(m)

In [None]:
p_m = p*bm
p_m[p_m > p_m.mean() + p_m.std()] = 1
p_m[p_m < p_m.mean() + p_m.std()] = 0

In [None]:
plt.imshow(p_m)

In [None]:
p_m_t = torch.Tensor(p_m)
print(nn.L1Loss()(torch.Tensor(p_m),torch.zeros(p_m_t.shape)))

In [None]:
p*bm

In [None]:
plt.imshow(p)