In [1]:
import os
import sys
import math
import random
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
import pycocotools
from torch.utils.tensorboard import SummaryWriter

In [2]:
# if torch.cuda.is_available():  
#     device = "cuda:0" 
# else:  
#     device = "cpu" 
    
# torch.cuda.current_device()

In [3]:
# Code from Maastricht to access GPU of choice

import torch
import os
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='0'


In [4]:
torch.cuda.set_device(0)
device = torch.device(f'cuda:{0}' if torch.cuda.is_available() else 'cpu')


In [5]:
torch.cuda.current_device(), torch.cuda.get_device_name()

## Import data

In [6]:
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode


# Class for a customized dataset
# In this case preprocessed CEM images combined in a 3-channel RGB .jpg format
# and the corresponding mask of present lesions in a 1-channel .png format
class CustomImageDataset(Dataset):
    def __init__(self, root, annotations_file, img_dir, mask_dir, train=False, transform=None, target_transform=None):
        # Read the .csv file with all the information
        self.img_labels = pd.read_csv(os.path.join(root, annotations_file))
        # Define the directories of the images and masks
        self.img_dir = os.path.join(root, img_dir)
        self.mask_dir = os.path.join(root, mask_dir)
        # Define whethe transformations are included
        self.train = train
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        # Return the number of cases in the dataset
        # In this set, CC and MLO of the same breast are considered different cases
        
        self.setlen = len(self.img_labels)
        self.auglen = len(self.img_labels)*4
        return len(self.img_labels)*4

    def __getitem__(self, augidx):
        idx = augidx % self.setlen
        print('patient', idx, augidx)
        
        # Read the image and the mask for a case from the directories
        img_path = self.img_labels.iloc[idx, 0]
        mask_path = self.img_labels.iloc[idx,6]
        image = read_image(img_path).float()
        mask = read_image(mask_path)
        
        xmin = self.img_labels.iloc[idx, 1]
        xmax = self.img_labels.iloc[idx, 2]
        ymin = self.img_labels.iloc[idx, 3]
        ymax = self.img_labels.iloc[idx, 4]
        box_in_patch = True
#         print('Init', image.shape, mask.shape, np.unique(mask), xmin,ymin,xmax,ymax)
        
        # use quadrant for data augmentation and to limit resizing
        if augidx <= self.setlen :
            # upper left
#             print('upper left', int(image.shape[1]/2), int(image.shape[2]/2))
            
            if (xmin > int(image.shape[1]/2)) or (ymin > int(image.shape[2]/2)) :
                box_in_patch = False
            else : 
                xmax = min(xmax, int(image.shape[1]/2))
                ymax = min(ymax, int(image.shape[2]/2))   
            
            mask = mask[:, :int(image.shape[1]/2), :int(image.shape[2]/2)]
            image = image[:, :int(image.shape[1]/2), :int(image.shape[2]/2)]
                
        elif self.setlen < augidx <= self.setlen*2 :
            # bottom left
#             print('bottom left', int(image.shape[1]/2), int(image.shape[2]/2))
            
            if (xmax < int(image.shape[1]/2)) or (ymin > int(image.shape[2]/2)) :
                box_in_patch = False
            else : 
                xmin = max(xmin, int(image.shape[1]/2))
                ymax = min(ymax, int(image.shape[2]/2))
                
                xmin -= int(image.shape[1]/2)
                xmax -= int(image.shape[1]/2)
            
            mask = mask[:, int(image.shape[1]/2):, :int(image.shape[2]/2)]
            image = image[:, int(image.shape[1]/2):, :int(image.shape[2]/2)]
                
        elif self.setlen*2 < augidx <= self.setlen*3 :
            # upper righgt
#             print('upper right', int(image.shape[1]/2), int(image.shape[2]/2))
            
            if (xmin > int(image.shape[1]/2)) or (ymax < int(image.shape[2]/2)) :
                box_in_patch = False
            else : 
                xmax = min(xmax, int(image.shape[1]/2))
                ymin = max(ymin, int(image.shape[2]/2))
                
                ymin -= int(image.shape[2]/2)
                ymax -= int(image.shape[2]/2)
            
            mask = mask[:, :int(image.shape[1]/2), int(image.shape[2]/2):]
            image = image[:, :int(image.shape[1]/2), int(image.shape[2]/2):]
                
        elif self.setlen*3 < augidx <= self.setlen*4 :
            # bottom right
#             print('bottom right', int(image.shape[1]/2), int(image.shape[2]/2))
            
            if (xmax < int(image.shape[1]/2)) or (ymax < int(image.shape[2]/2)) :
                box_in_patch = False
            else : 
                xmin = max(xmin, int(image.shape[1]/2))
                ymin = max(ymin, int(image.shape[2]/2))
                
                xmin -= int(image.shape[1]/2)
                xmax -= int(image.shape[1]/2)
                ymin -= int(image.shape[2]/2)
                ymax -= int(image.shape[2]/2)
            
            mask = mask[:, int(image.shape[1]/2):, int(image.shape[2]/2):]
            image = image[:, int(image.shape[1]/2):, int(image.shape[2]/2):]            
                
        print(box_in_patch)
#         print('Patch', image.shape, mask.shape, np.unique(mask), xmin,ymin,xmax,ymax)
        
        # Apply transformations if defined
        flipint = random.random()
        if self.train and flipint > 0.5 :        
            image = T.RandomHorizontalFlip(p=1.0)(image)
            mask = T.RandomHorizontalFlip(p=1.0)(mask)
            
            ymax_temp = image.shape[2] - ymin
            ymin = image.shape[2] - ymax
            ymax = ymax_temp
#             xmax_temp = image.shape[1] - xmin
#             xmin = image.shape[1] - xmax
#             xmax = xmax_temp

#         print('Flip', image.shape, mask.shape, np.unique(mask), xmin,ymin,xmax,ymax)
        # Resize so all images and masks have the same size
#         image = T.Resize([800,800])(image)
#         mask = T.Resize([800,800])(mask)    
#         resize_scale_x = 800/image.size()[1]
#         resize_scale_y = 800/image.size()[2]

        # Resize if necessary
        # First the smallest dimension is reduced to 400 if it is larger
        # Then the largest dimension is reduced to 650 if it is still larger
        resize_scale = 1.0
        min_size_idx = np.argmin([image.size()[1], image.size()[2]])
        
        if min_size_idx == 0 and image.size()[1] > 400 :
            resize_scale *= 400/image.size()[1]
            image = T.Resize([400, int(400*image.size()[2]/image.size()[1])])(image)
            mask = T.Resize([400, int(400*image.size()[2]/image.size()[1])], interpolation=InterpolationMode.NEAREST)(mask)
        if min_size_idx == 0 and image.size()[2] > 650 :
            resize_scale *= 650/image.size()[2]
            image = T.Resize([int(650*image.size()[1]/image.size()[2]), 650])(image)
            mask = T.Resize([int(650*image.size()[1]/image.size()[2]), 650], interpolation=InterpolationMode.NEAREST)(mask)
        
        if min_size_idx == 1 and image.size()[2] > 400 :
            resize_scale *= 400/image.size()[2]
            image = T.Resize([int(400*image.size()[1]/image.size()[2]), 400])(image)
            mask = T.Resize([int(400*image.size()[1]/image.size()[2]), 400], interpolation=InterpolationMode.NEAREST)(mask)
        if min_size_idx == 1 and image.size()[1] > 650 :
            resize_scale *= 650/image.size()[1]
            image = T.Resize([650, int(650*image.size()[2]/image.size()[1])])(image)
            mask = T.Resize([650, int(650*image.size()[2]/image.size()[1])], interpolation=InterpolationMode.NEAREST)(mask)
            
#         print('Resize', image.shape, mask.shape, np.unique(mask), xmin,ymin,xmax,ymax)
        # Normalize image with mean and standard deviation per channel
        mean = torch.mean(image, dim=(1,2))
        stdev = torch.std(image, dim=(1,2))
        image = T.Normalize(mean, stdev)(image)
        
        # Rescale to [0,1] range per channel
        for dim in range(3) :
            image[dim] -= torch.min(image[dim])
            image[dim] /= torch.max(image[dim])  
        
        # Read the location of the lesion bounding box from the .csv file
        if box_in_patch :            
            # Create separate channel in mask for each lesion
#             mask_out = np.zeros((torch.max(mask).item(), mask.shape[-2], mask.shape[-1]))
            mask_out = np.zeros((1, mask.shape[-2], mask.shape[-1]))
#             for lesion_idx in range(torch.max(mask).item()) :
            for lesion_idx in range(1) :
    #             mask_out[lesion_idx][mask[0]==lesion_idx+1] = 1  
                mask_out[lesion_idx][mask[0]>0] = 1 # alternative for wrong masks with only one lesion
            
            boxes = [[ymin*resize_scale, xmin*resize_scale, ymax*resize_scale, xmax*resize_scale]]
            area = (boxes[0][3] - boxes[0][1]) * (boxes[0][2] - boxes[0][0])
        else :
            mask_out = np.zeros_like(mask)
            
            boxes = [[]]
            area = 0
        # Read the label of the lesion from the .csv file
        labels = self.img_labels.iloc[idx, 5]  + 1   # 0 represents background class, thus 1=benign, 2=malignant
#         print('Labels', image.shape, mask_out.shape, np.unique(mask_out), xmin*resize_scale,ymin*resize_scale,xmax*resize_scale,ymax*resize_scale)
        
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        area = torch.as_tensor(area, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        labels = torch.tensor([labels])
        image_id = torch.tensor([idx])
       
        mask_out = torch.from_numpy(mask_out)
        mask_out = mask_out.to(torch.uint8)
#         print('Out', image.shape, mask_out.shape, np.unique(mask_out), xmin*resize_scale,ymin*resize_scale,xmax*resize_scale,ymax*resize_scale)
        
        iscrowd = torch.zeros((2,), dtype=torch.int64)  
            
        target = {}
        target["image_id"] = image_id
        target["masks"] = mask_out
        target["boxes"] = boxes
        target["area"] = area
        target["labels"] = labels
        target["iscrowd"] = iscrowd
            
        return image, target


In [7]:
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode

from monai.transforms import CropForeground, RandSpatialCropSamplesd, Flipd, RandCropByPosNegLabeld, Resized, FillHoles, NormalizeIntensityd
from monai.utils.enums import InterpolateMode

# Define the transformations of Monai to be used later on
resize_fcn_large = Resized(["image","mask"],
                     spatial_size=(1300,800),
                     mode=[InterpolateMode.BILINEAR,InterpolateMode.NEAREST])
resize_fcn_small = Resized(["image","mask"],
                     spatial_size=(650,400),
                     mode=[InterpolateMode.BILINEAR,InterpolateMode.NEAREST])
# crops_fcn = RandSpatialCropSamplesd(["image","mask"],
#                                     num_samples=10,
#                                     roi_size=(650,400),
#                                     random_size=False)
crop_fcn_small = RandCropByPosNegLabeld(["image", "mask"],
                                  "mask",
                                  (650,400),
                                  pos=0.1, neg=1.0,
                                  num_samples=10)
flip_fcn = Flipd(["image","mask"],
                 spatial_axis=1)
fill_fcn = FillHoles()
norm_fcn = NormalizeIntensityd(["image"], channel_wise=True, nonzero=False)

# Class for a customized dataset
# In this case preprocessed CEM images combined in a 3-channel RGB .jpg format
# and the corresponding mask of present lesions in a 1-channel .png format
class CustomImageDatasetMonai(Dataset):
    def __init__(self, root, annotations_file, img_dir, mask_dir, train=False, transform=None, target_transform=None):
        # Read the .csv file with all the information
        self.img_labels = pd.read_csv(os.path.join(root, annotations_file))
        # Define the directories of the images and masks
        self.img_dir = os.path.join(root, img_dir)
        self.mask_dir = os.path.join(root, mask_dir)
        # Define whether transformations are included
        self.train = train
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        # Return the number of cases in the dataset
        # In this set, CC and MLO of the same breast are considered different cases
        
        return len(self.img_labels)

    def __getitem__(self, idx):
        
        # Read the image and the mask for a case from the directories
        img_path = self.img_labels.iloc[idx, 0]
        mask_path = self.img_labels.iloc[idx,6]
        image = read_image(img_path).float()
        mask = read_image(mask_path)
        
        print('patient', idx, img_path)
        
        init_dict = {}
        init_dict["image"] = image
        init_dict["mask"] = mask
        print('Init', init_dict["mask"].shape, np.unique(init_dict["mask"],return_counts=True))
        
#         # Resize the image to 1300 by 800
#         init_dict = resize_fcn(init_dict)
        
        # Create 10 samples of croppped image and mask
#         crop_dict = crops_fcn(init_dict)
#         print(crop_dict[0]["image"].shape)
        # Find the first sample with a nonzero mask
    
        crop_fcn_large = RandCropByPosNegLabeld(["image", "mask"],
                                                "mask",
                                                (int(mask.shape[1]/2),int(mask.shape[2]/2)),
                                                pos=0.1, neg=1.0,
                                                num_samples=10)
    
        # Generate 10 croppped samples until a nonzero mask is obtained
        # Then choose the sample with the largest nonzero region
        max_pos = 0
        max_pos_idx = 0
        while max_pos == 0 :
            crop_dict = crop_fcn_large(init_dict)
            for crop_idx in range(10) :
#                 print(crop_idx, np.count_nonzero(crop_dict[crop_idx]["mask"]), crop_dict[crop_idx]["mask"].shape)
                if np.count_nonzero(crop_dict[crop_idx]["mask"]) > max(max_pos, mask.nelement()/40000) :
                    max_pos = np.count_nonzero(crop_dict[crop_idx]["mask"])
                    max_pos_idx = crop_idx
            
        print('Crop', crop_dict[max_pos_idx]["mask"].shape, np.unique(crop_dict[max_pos_idx]["mask"],return_counts=True))
        
        resize_dict = resize_fcn_small(crop_dict[max_pos_idx])
        print('Resize', resize_dict["mask"].shape, np.unique(resize_dict["mask"],return_counts=True))
        
        # Data augmentation by horizontal flipping
        flipint = random.random()
        if self.train and flipint > 0.5 : 
            resize_dict = flip_fcn(resize_dict)
        else :
            resize_dict = resize_dict
        print('Flip', resize_dict["mask"].shape, np.unique(resize_dict["mask"],return_counts=True))
            
        # Normalize image as values should be in [0,1]
        min_vals = [torch.min(resize_dict["image"][0]).item(), torch.min(resize_dict["image"][1]).item(), torch.min(resize_dict["image"][2]).item()]
        max_vals = [torch.max(resize_dict["image"][0]).item(), torch.max(resize_dict["image"][1]).item(), torch.max(resize_dict["image"][2]).item()]
        norm_minmax_fcn = NormalizeIntensityd(["image"],
                                              min_vals, [a_i-b_i for a_i,b_i in zip(max_vals,min_vals)],
                                              channel_wise=True, nonzero=False)
        
#         resize_dict = norm_fcn(resize_dict)
        resize_dict = norm_minmax_fcn(resize_dict)
#         print('Resize', resize_dict["mask"].shape, np.unique(resize_dict["mask"],return_counts=True), np.unique(resize_dict["image"]))
                        
        mask_out = np.zeros((1, resize_dict["mask"].shape[-2], resize_dict["mask"].shape[-1]))
        mask_out[0][resize_dict["mask"][0]>0] = 1 # alternative for wrong masks with only one lesion
        print('Mask out', mask_out.shape, np.unique(mask_out,return_counts=True))
            
        # Find the new bounding box on the transformed image
        bbox = CropForeground().compute_bounding_box(resize_dict["mask"])
        print(bbox)
        bbox_vals = [[bbox[0][1], bbox[0][0], bbox[1][1], bbox[1][0]]]
        area_vals = [(bbox_vals[0][3] - bbox_vals[0][1]) * (bbox_vals[0][2] - bbox_vals[0][0])]      
        labels = self.img_labels.iloc[idx, 5]  + 1   # 0 represents background class, thus 1=benign, 2=malignant
        iscrowd = torch.zeros((2,), dtype=torch.int64)  
        print(area_vals)
            
        # Create the target dictionary
        target = {}
        target["image_id"] = torch.tensor([idx])
        target["masks"] = torch.from_numpy(mask_out).to(torch.uint8)
        target["boxes"] = torch.as_tensor(bbox_vals, dtype=torch.float32)
        target["area"] = torch.as_tensor(area_vals, dtype=torch.float32)
        target["labels"] = torch.tensor([torch.as_tensor(labels, dtype=torch.int64)])
        target["iscrowd"] = iscrowd
        
        return resize_dict["image"], target

In [8]:
def get_transform(train=False) :
    if train :
        transforms = torch.nn.Sequential(T.RandomHorizontalFlip(p=0.5))
        
        return transforms
        
    else : 
        return None
    

In [9]:

traindatadir = r'B:\Astrid\Preprocessed\220615_preprocessed\realAll'
trainval_data = CustomImageDatasetMonai(traindatadir, 'annotations_train_real.csv', 'colored_to_jpg', 'mask_to_png', train=True)

testdatadir = r'B:\Astrid\Preprocessed\220615_preprocessed\realAll'
test_data = CustomImageDatasetMonai(testdatadir, 'annotations_test_real.csv', 'colored_to_jpg', 'mask_to_png')

savedir = 'B:\\Astrid\\Preprocessed\\ModelsAll'

In [10]:
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import utils
from torch.utils.data import random_split

# Read training dataset and split in training and validation data
# Use same random split every time
# train_data_nonflip, val_data = random_split(trainval_data_nonflip, [int(np.floor(len(trainval_data_nonflip)*0.8)), int(np.ceil(len(trainval_data_nonflip)*0.2))], generator=torch.Generator().manual_seed(0))
# train_data_flip, _ = random_split(trainval_data_flip, [int(np.floor(len(trainval_data_flip)*0.8)), int(np.ceil(len(trainval_data_flip)*0.2))], generator=torch.Generator().manual_seed(0))
# train_data = torch.utils.data.ConcatDataset([train_data_nonflip, train_data_flip])

train_data, val_data = random_split(trainval_data, [int(np.floor(len(trainval_data)*0.8)), int(np.ceil(len(trainval_data)*0.2))], generator=torch.Generator().manual_seed(0))
# train_data = trainval_data
# val_data = test_data

train_dataloader = DataLoader(train_data, batch_size=2, shuffle=True, collate_fn=utils.collate_fn)

val_dataloader = DataLoader(val_data, batch_size=2, shuffle=True, collate_fn=utils.collate_fn)

# Read test dataset
test_dataloader = DataLoader(test_data, batch_size=2, shuffle=True, collate_fn=utils.collate_fn)


In [11]:
len(train_data), len(val_data), len(test_data)

In [12]:
train_iter = iter(train_dataloader)

In [13]:
# For Training
train_images, train_targets = next(train_iter)
train_image_list = list(image for image in train_images)
train_target_list = [{k: v.to('cpu') for k, v in t.items()} for t in train_targets]

# For Validation
val_images, val_targets = next(iter(val_dataloader))
val_image_list = list(image for image in val_images)
val_target_list = [{k: v for k, v in t.items()} for t in val_targets]

# For Testing
test_images, test_targets = next(iter(test_dataloader))
test_image_list = list(image for image in test_images)
test_target_list = [{k: v for k, v in t.items()} for t in test_targets]

In [22]:
(top_left, top_right, bottom_left, bottom_right, center) = T.FiveCrop(size=(int(train_images[0].shape[1]/2), int(train_images[0].shape[2]/2)))(train_images[0])

In [71]:
fig, ax = plt.subplots(3,3)

ax[0][0].imshow(top_left[0])
ax[0][2].imshow(top_right[0])
ax[1][1].imshow(center[0])
ax[2][0].imshow(bottom_left[0])
ax[2][2].imshow(bottom_right[0])


In [123]:
t = torch.empty(3, 4, 5)
t.nelement()
# torch.Size([3, 4, 5])
# t.size(dim=1)

In [121]:
foobar.count(0)

In [16]:
val_target_list

In [17]:
test_target_list

In [74]:
train_targets[0]['image_id'], train_targets[1]['image_id'], train_targets[2]['image_id'], train_targets[3]['image_id']

In [75]:
for idx in range(4) :
    print(train_targets[idx]['boxes'])

In [76]:
train_targets[0]

In [187]:
train_images[0].shape

In [243]:
torch.min(train_images[0][0]).item(), torch.min(train_images[0][1]), torch.min(train_images[0][2])

In [68]:
pat_idx = 0
print(train_targets[pat_idx]['image_id'], train_targets[pat_idx]['boxes'])

fig, ax = plt.subplots(1,4, figsize=(10,15))

ax[0].imshow(train_images[pat_idx][0])
ax[1].imshow(train_targets[pat_idx]['masks'][0])

enlarged_box = [int(train_targets[pat_idx]['boxes'][0][1])-50,
                int(train_targets[pat_idx]['boxes'][0][3])+50,
                int(train_targets[pat_idx]['boxes'][0][0])-50,                
                int(train_targets[pat_idx]['boxes'][0][2])+50]
print(enlarged_box)
ax[2].imshow(train_images[pat_idx][0, enlarged_box[0]:enlarged_box[1], enlarged_box[2]:enlarged_box[3]])
ax[3].imshow(train_targets[pat_idx]['masks'][0, enlarged_box[0]:enlarged_box[1], enlarged_box[2]:enlarged_box[3]])

In [51]:
pat_idx = 2
print(val_targets[pat_idx]['image_id'], val_targets[pat_idx]['boxes'])

fig, ax = plt.subplots(1,4, figsize=(10,15))

ax[0].imshow(val_images[pat_idx][0])
ax[1].imshow(val_targets[pat_idx]['masks'][0])

enlarged_box = [int(val_targets[pat_idx]['boxes'][0][1])-50,
                int(val_targets[pat_idx]['boxes'][0][3])+50,
                int(val_targets[pat_idx]['boxes'][0][0])-50,                
                int(val_targets[pat_idx]['boxes'][0][2])+50]
print(enlarged_box)
ax[2].imshow(val_images[pat_idx][0, enlarged_box[0]:enlarged_box[1], enlarged_box[2]:enlarged_box[3]])
ax[3].imshow(val_targets[pat_idx]['masks'][0, enlarged_box[0]:enlarged_box[1], enlarged_box[2]:enlarged_box[3]])

## Define model

In [14]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

num_classes = 3

roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
                                                output_size=7,
                                                sampling_ratio=2)

mask_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
                                                     output_size=14,
                                                     sampling_ratio=2)

# anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
#                                    aspect_ratios=((0.5, 1.0, 2.0),))

# load an instance segmentation model pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True,
                                                           box_roi_pool=roi_pooler,
                                                           mask_roi_pool=mask_roi_pooler,
                                                           min_size=400,
                                                           max_size=650,
                                                           box_fg_iou_thresh=0.5,
                                                           box_bg_iou_thresh=0.5)

# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)


# now get the number of input features for the mask classifier
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
# and replace the mask predictor with a new one
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                   hidden_layer,
                                                   num_classes)


In [15]:
model.to(device)

In [16]:
def train_one_epoch_kulum(model, optimizer, train_data_loader, val_data_loader, writer, device, epoch, print_freq, scaler=None):
    
    # Initialise training
    model.train()
    
    # Set up logger to save metrics and losses
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
    header = f"Epoch: [{epoch}]"

    # Define scheduler for learning rate in optimizer
    lr_scheduler = None
    if epoch == 0:
        warmup_factor = 1.0 / 1000
        warmup_iters = min(1000, len(train_data_loader) - 1)

        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
        )

    # Read the images and targets from the training data loader
    for images, targets in metric_logger.log_every(train_data_loader, print_freq, header):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        # Compute the losses of the model on these training images and targets
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

        # Reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        loss_value = losses_reduced.item()

        if not math.isfinite(loss_value):
            print(f"Loss is {loss_value}, stopping training")
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        if scaler is not None:
            scaler.scale(losses).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            losses.backward()
#             torch.nn.utils.clip_grad_norm_(model.parameters(), 50) # added since NaN in loss function
#             torch.nn.utils.clip_grad_value_(model.parameters(), 50) # added since NaN in loss function
            optimizer.step()

        if lr_scheduler is not None:
            lr_scheduler.step()

        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
           
        # Write to .tfevents
#         grid = torchvision.utils.make_grid(images)
#         writer.add_image('images', grid, 0)
#         writer.add_graph(model, images)
        
        writer.add_scalar('Loss/train', metric_logger.loss.value, epoch)
        writer.add_scalar('Loss classifier/train', metric_logger.loss_classifier.value, epoch)
        writer.add_scalar('Loss box reg/train', metric_logger.loss_box_reg.value, epoch)
        writer.add_scalar('Loss mask/train', metric_logger.loss_mask.value, epoch)
        writer.add_scalar('Loss objectness/train', metric_logger.loss_objectness.value, epoch)
        writer.add_scalar('Loss rpn box reg/train', metric_logger.loss_rpn_box_reg.value, epoch)
        
    # Set up logger to save metrics and losses of validation
    metric_logger_val = utils.MetricLogger(delimiter="  ")
    metric_logger_val.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
        
    # Read the images and targets from the validation data loader
    for images_val, targets_val in metric_logger_val.log_every(val_data_loader, print_freq, header):
        images_val = list(image.to(device) for image in images_val)
        targets_val = [{k: v.to(device) for k, v in t.items()} for t in targets_val]
        # Compute the losses of the model on these training images and targets
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            loss_dict_val = model(images_val, targets_val)
            losses_val = sum(loss for loss in loss_dict_val.values())

        # Reduce losses over all GPUs for logging purposes
        loss_dict_reduced_val = utils.reduce_dict(loss_dict_val)
        losses_reduced_val = sum(loss for loss in loss_dict_reduced_val.values())

        loss_value_val = losses_reduced_val.item()
        
        metric_logger_val.update(loss=losses_reduced_val, **loss_dict_reduced_val)
        metric_logger_val.update(lr=optimizer.param_groups[0]["lr"])
        
        writer.add_scalar('Loss/val', metric_logger_val.loss.value, epoch)
        writer.add_scalar('Loss classifier/val', metric_logger_val.loss_classifier.value, epoch)
        writer.add_scalar('Loss box reg/val', metric_logger_val.loss_box_reg.value, epoch)
        writer.add_scalar('Loss mask/val', metric_logger_val.loss_mask.value, epoch)
        writer.add_scalar('Loss objectness/val', metric_logger_val.loss_objectness.value, epoch)
        writer.add_scalar('Loss rpn box reg/val', metric_logger_val.loss_rpn_box_reg.value, epoch)

    return metric_logger, metric_logger_val

def early_stopping_kulum(prev_loss, curr_loss, num_epochs) :
    
    if curr_loss > prev_loss :
        return num_epochs + 1
    else :
        return 0


In [17]:
from engine import train_one_epoch, evaluate

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005)
# # and a learning rate scheduler
# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
#                                                step_size=3,
#                                                gamma=0.1)

# model_writer = SummaryWriter('\\\\tsclient\\E\\runs\smallexpsynthetic')
model_writer = SummaryWriter('B:\\Astrid\\Preprocessed\\runs\\102_Monai_resnet_benmal_all')

# let's train it for 10 epochs
num_epochs = 100
logs = []
logs_val = []
min_val_loss = 0.0
stop_epoch = 0

for epoch in range(num_epochs):
    print('Epoch', epoch)
    # train for one epoch, printing every 10 iterations
#     sample_image, sample_target = sample(next(iter(train_dataloader)))
    
#     sample_image, sample_target = first(train_dataloader)
#     output = self.model(sample_image.to(self.device))
    
    epoch_loss, epoch_loss_val = train_one_epoch_kulum(model, optimizer, train_dataloader, val_dataloader, model_writer, device, epoch, print_freq=10)
#     epoch_loss = train_one_epoch(model, optimizer, train_dataloader, device, epoch, print_freq=10)
    logs.append(epoch_loss)
    logs_val.append(epoch_loss_val)
    
#     # update the learning rate
#     lr_scheduler.step()
    
    # evaluate on the test dataset
#     evaluate(model, val_dataloader, device=device)
    
#     torch.save({
#                 'epoch': epoch,
#                 'model_state_dict': model.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'loss': loss,
#                 'loss_classifier': loss_classifier,
#                 'loss_box_reg': loss_box_reg,
#                 'loss_mask': loss_mask,
#                 'loss_objectness': loss_objectness,
#                 'loss_rpn_box': loss_rpn_box
#                 }, savedir)

#     if epoch > 50 :
#         stop_epoch = early_stopping_kulum(min_val_loss, epoch_loss_val.loss.value, stop_epoch)
#         if stop_epoch > 5 :
#             print(f'Early stopping at epoch {epoch} with loss {epoch_loss_val.loss.value}.')

#             break

#         else :
#             min_val_loss = min(min_val_loss, epoch_loss_val.loss.value)
    
    if epoch % 10 == 0 :
        torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()
                    }, os.path.join(savedir, 'model102_Monai_clipval_resnet_benmal_all_'+str(epoch)+'.pth'))
        
model_writer.close()

In [21]:
torch.save(model.state_dict(), os.path.join(savedir, 'dict100_Monai_resnet_benmal.dict'))
torch.save(model, os.path.join(savedir, 'model100_Monai_resnet_benmal_final.pth'))

In [16]:
evaluate(model, val_dataloader, device=device)

In [29]:
# Read test dataset
testsynth_data = CustomImageDataset('B:\\Astrid\\Preprocessed\\TestCalcClusterSynthetic', 'annotations_test_calccluster_synthetic.csv', 'colored_to_jpg', 'mask_to_png')
testsynth_dataloader = DataLoader(testsynth_data, batch_size=4, shuffle=True, collate_fn=utils.collate_fn)


In [30]:
evaluate(model, testsynth_dataloader, device=device)

In [296]:
train_images_cuda = []
for te in train_images :
    train_images_cuda.append(te.to(device))
    

In [32]:
val_images_cuda = []
for te in val_images :
    val_images_cuda.append(te.to(device))

In [44]:
torch.cuda.empty_cache()

In [22]:
test_images_cuda = []
for te in test_images :
    test_images_cuda.append(te.to(device))

In [297]:
# test_images.to(device)
predictions = model(train_images_cuda)

In [298]:
predictions

In [299]:
train_targets

In [289]:
logs[0].loss.value

losses = []
losses_class = []
losses_box_reg = []
losses_mask = []
losses_objectness = []
losses_rpn_box_reg = []
for ep in range(len(logs)) :
    losses.append(logs[ep].loss.value)
    losses_class.append(logs[ep].loss_classifier.value)
    losses_box_reg.append(logs[ep].loss_box_reg.value)
    losses_mask.append(logs[ep].loss_mask.value)
    losses_objectness.append(logs[ep].loss_objectness.value)
    losses_rpn_box_reg.append(logs[ep].loss_rpn_box_reg.value)

In [290]:
fig, ax = plt.subplots(2,3, figsize=(15,10))

ax[0,0].plot(range(len(logs)),losses)
ax[0,0].set_title('Loss')
ax[0,1].plot(range(len(logs)),losses_class)
ax[0,1].set_title('Loss_classifier')
ax[0,2].plot(range(len(logs)),losses_box_reg)
ax[0,2].set_title('Loss_box_reg')
ax[1,0].plot(range(len(logs)),losses_mask)
ax[1,0].set_title('Loss_mask')
ax[1,1].plot(range(len(logs)),losses_objectness)
ax[1,1].set_title('Loss_objectness')
ax[1,2].plot(range(len(logs)),losses_rpn_box_reg)
ax[1,2].set_title('Loss_rpn_box_reg')

In [52]:
torch.max(predictions[2]['masks'])

In [None]:
# torch.save({
#             'epoch': epoch,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'loss': loss,
#             ...
#             }, PATH)

In [33]:
model = model.cpu()

In [7]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# load a model pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 2  # 1 class (lesion) + background
num_classes = 3  # 2 classes (lesion benign + lesion malignant) + background


# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

In [13]:
predictions

In [9]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor


def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)

    return model

In [10]:
import transforms as T

def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [11]:
import utils as vision_utils

In [26]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.to(device)
output = model(image_list, target_list)   # Returns losses and detections

# For inference
model.eval()
x = [torch.rand(3,300,400), torch.rand(3,500,400)]
predictions = model(x)           # Returns predictions

In [13]:
predictions

In [14]:
output