In [1]:
import pandas as pd
from pathlib import Path
import pathlib
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from pathlib import Path
from shapely.geometry import Polygon
import pdb
import  cv2
import PIL

In [2]:
OPENSLIDE_PATH=Path('openslide-win64-20230414')

OPENSLIDE_FOLDER=os.path.join(Path().resolve(),'openslide-win64-20230414','openslide-win64-20230414','bin')

In [3]:
if hasattr(os,'add_dll_directory'):
    with os.add_dll_directory(OPENSLIDE_FOLDER):
        import openslide
else:
    import openslide
    
    

In [4]:
openslide

<module 'openslide' from 'A:\\mini\\envs\\aipath\\Lib\\site-packages\\openslide\\__init__.py'>

In [5]:

from dataloading.src.annotation_parser import AnnotationParser


In [11]:
from dataloading.src.annotation_parser import AnnotationParser
from dataloading.src.wsi_datasets_tst import WSI_Pyramid

In [5]:
#%%writefile src/annotation_parser.py
from __future__ import annotations
from shapely.geometry import Polygon
import pandas as pd
from pathlib import Path
import pathlib
import os
import numpy as np
from typing import Tuple
import json
import pdb







    
class AnnotationParser():
    """ Parses WSI Files and associated GeoJson annotations to create a composite data frame
         from both.Also returns the annotations as line items for which th eparsing didn't work 
         Requires folder containing WSI files and Annotations"""
    
    def __init__(self,image_path:pathlib.Path,labels_path:pathlib.Path)->None:
        self.image_path=image_path
        self.labels_path=labels_path
        
        
        
    
    
    
    def get_img_df(self)->pd.DataFrame:
        """use openslide to get properties,shape,levels etc.for each image"""
        img_paths=[self.image_path/fn for fn in os.listdir(self.image_path)  if (self.image_path/fn).suffix == '.tiff']
        img_df=pd.DataFrame({'image_path':img_paths})
        img_df=img_df.assign(image_name=[img_path.stem for img_path in img_paths],
                             WSI_size=[openslide.OpenSlide(img_path).dimensions for img_path in img_paths],
                      levels=[openslide.OpenSlide(img_path).level_count for img_path in img_paths],
                    downsample_levels=[{level:downsample for level,downsample in enumerate(openslide.OpenSlide(img_path).level_downsamples)} 
                           for img_path in img_paths])
        return img_df


    
    def get_coordinates_array(self,anno_row:pd.Series)->np.ndarray:
        """parse the annotation json (as a series of rows to get coordinates
           of the annotation as an array of dim n_points X 2 """
        
       
        
        try:
            geom=anno_row['geometry']
            coord_list=geom['coordinates']
            geom_type=geom['type']

             ### last 2 dimensions of every poly are n_points X 2
            # for multipolygon the list of coordinates i nested 1 level deep

             ## to get largest polygon if more than one get marked by mistake in  every annotation
            
            if geom_type=='Polygon':
                largest_poly=max([np.array(poly).reshape((np.array(poly).shape)[-2:]) for poly in coord_list],key=len)

            if geom_type=='MultiPolygon':
                largest_poly= max([max([np.array(poly).reshape((np.array(poly).shape)[-2:]) for poly in coords],key=len) 
                                   for coords in coord_list],key=len)

            return largest_poly

        
        except KeyError:
            return 'coordinates_key_error'
        
       
    
    def get_class_name_and_color(self,anno_row:pd.Series)->Tuple[str,int]:
        """Parse annotation to return class name for each anno and color assigned in QuPath"""
        
     
        
        try:
            classification=anno_row['properties']['classification']
            class_name=classification['name']
            class_color=classification['color']
            return class_name,np.array(class_color)
        
        except KeyError:
            
            
            return 'class_name_error',0
        
            
        
    
    def parse_json_file(self,json_path:pathlib.Path):
        # need to be list of dicts (anno_list)
        with open(json_path) as json_file:
            anno_list = json.load(json_file)
            if not isinstance(anno_list, list):
                anno_list=[anno_list]
            
        
        anno_df=pd.DataFrame(anno_list)
        anno_df=anno_df.assign(image_name=json_path.stem)
        #pdb.set_trace()
        return anno_df

        
    
    def get_anno_df(self,img_df:pd.DataFrame)->Tuple[pd.DataFrame,pd.DataFrame]:
        
        """ get annotation df including annotation for which parsing did not work (errored)"""
        
        jsons=[self.labels_path/fn for fn in os.listdir(self.labels_path)  if (self.labels_path/fn).suffix == '.geojson']
        anno_df=pd.concat([self.parse_json_file(json_path) for json_path in jsons])
        #pdb.set_trace()
        #anno_df=pd.concat([pd.read_json(json,orient='records').assign(image_name=json.stem) for json in jsons],ignore_index=True)
        
        ## enrich with image specific attrs, total size,zoom levels etc.
        anno_df=anno_df.merge(img_df,on='image_name')
        ## add coordinates as np array and shapely polygon with transormed origin 
        anno_df['coordinates']=anno_df.apply(self.get_coordinates_array,1)
        #pdb.set_trace()
        anno_df['class_name'], anno_df['colour_RGB']=zip(*anno_df.apply(self.get_class_name_and_color,1))
        
        ## select the annos with errored coordinates (due to annotation issues)
        errored=np.logical_or(anno_df['coordinates'].isin(['coordinates_key_error']),
                              anno_df['class_name']=='class_name_error')
        
        errored_df=anno_df[errored]
        anno_df=anno_df[~errored]
        
        #pdb.set_trace()
        ## use shapely to compute polygon attrs
        anno_df['polygon']=anno_df.apply(lambda x:Polygon(x['coordinates']),1)
        anno_df['area']=anno_df.apply(lambda x:x['polygon'].area,1)
        anno_df['circumference']=anno_df.apply(lambda x:x['polygon'].length,1)
        anno_df['bounds']=anno_df.apply(lambda x:np.array(x['polygon'].bounds).reshape((2,2)),1)
       
        #pdb.set_trace()
        
        return anno_df,errored_df
    
    
    
    def parse_annotations(self)->Tuple[pd.DataFrame,pd.DataFrame,pd.DataFrame]:
        """ returns anno_df,img_df and errored df in that order """ 
        img_df=self.get_img_df()
        anno_df,errored_df=self.get_anno_df(img_df)
        
        
        return anno_df,img_df,errored_df
        


In [10]:
#%%writefile src/wsi_datasets.py
from __future__ import annotations
from torch.utils.data import Dataset,DataLoader
from shapely.geometry import Polygon,MultiPolygon
import  cv2
import numpy as np
import pandas as pd
from functools import partial
import torch
from pathlib import Path
import os
import pathlib


    


class WSI_Pyramid(Dataset):


    """Pytorch Dataset class representing a multiscale WSI dataset.inputs are img and anno dfs containing info 
         about WSI images and annotations.Pyramidal crops are sampled from the pyramid_top_levels (usually set to the most
         zoomed in level in the tiff dataset although more than one level can be used) ,with a downsample factor of one . A crop is 
         chosen from the top level of size crop_sz X crop_sz and concentric crops of the same size are chosen in the next
         num_pyramid_level levels."""
   
    
    def __init__(self,
                 anno_df:pd.DataFrame,
                 crop_pixel_size:tuple=(512,512),
                 transform=None,
                 class2num={'Background':0,'Tumor':1},
                 
                 ## default set to show all levels, mostly {0:1} top level is picked
                 pyramid_top_level={0:1.0},
                 num_pyramid_levels=4, 
                 num_pyramid_mask_levels=1,
                 filter_flag=False)->None:
       
        
       
        self.anno_df=anno_df
        ## the size in pixel of  each crop -size is kept same at 
        ## various zoom levels for batching
        self.crop_pixel_size=crop_pixel_size
        self.item_transform = transform
        self.class2num= class2num
        ## create on self.device
        ## offsets to add to crop center to get vertices
        self.offsets=torch.tensor(self.crop_pixel_size)//2
        ## get the downsample levels common in the entire dataset
        self.common_downsample_levels=min(self.anno_df['downsample_levels'],key=len)
        ## take the intersect of user provided ds levels and the ones present in the data
        self.pyramid_top_level=pyramid_top_level
        self.pyramid_all_levels=max(self.anno_df['downsample_levels'],key=len)

        
        self.pyramid_top_idx=list(self.pyramid_top_level.keys())[0]
        self.pyramid_top_downsample=list(self.pyramid_top_level.values())[0]
        
        self.num_pyramid_levels=num_pyramid_levels
        self.num_pyramid_mask_levels=num_pyramid_mask_levels
        assert self.num_pyramid_levels>= self.num_pyramid_mask_levels,'num_pyramid_levels used for inputs should be more than num of target maska'
        
        ## the actual levels of the tiff pyramid used as input to the model
        self.pyramid_zoom_levels={idx:self.common_downsample_levels[idx] for idx in range(self.pyramid_top_idx,self.pyramid_top_idx+self.num_pyramid_levels)}
        
        self.filter_flag=filter_flag
        
    def __len__(self):
        return len(self.anno_df)
    
    
    def get_pyramid_crops(self,annotation_row:pd.Series):
        
        """ get a random center crop at any possible zoom level from the periphery
           of an annotation """
       

       
       
        wsi_size,anno_coordinates=(np.array(x) for x in [annotation_row['WSI_size'],
                                                             annotation_row['coordinates']])

        random_crop_index=np.random.randint(0,len(anno_coordinates))
        random_crop_center=torch.tensor(anno_coordinates[random_crop_index])
        # pdb.set_trace()
        offsets_arr=torch.tensor([[-1,1],[1,1],[1,-1],[-1,-1]])*self.offsets   # 4 X 2
        downsample_arr=torch.tensor(list(self.pyramid_zoom_levels.values())).unsqueeze(1).unsqueeze(1)

       
        pyramid_crops=offsets_arr.unsqueeze(0)*downsample_arr+random_crop_center.unsqueeze(0).unsqueeze(0)# Pyramid_Levels X 4 X 2  (one crop for each level)
        pyramid_crops=np.array(pyramid_crops).astype(np.int32)

        pyramid_top_lefts=pyramid_crops.min(axis=1)   # Pyramid_Levels X 2
      
       
        ## get the top left of every sampled level in the pyramidal tiff by subtracting offset from crop center and scaling by downsample factor
        #sampled_top_lefts={idx:random_crop_center-self.offsets*d_factor for idx,d_factor in self.pyramid_zoom_levels.items()}

        return pyramid_crops,pyramid_top_lefts


    def filter_crops_byWSIsize(self,wsi_size:tuple,all_crops:np.ndarray):
        
         
    
        max_bounds=np.max(all_crops,axis=1).values<wsi_size.unsqueeze(0)
        min_bounds=np.min(all_crops,axis=1).values>torch.zeros_like(wsi_size.unsqueeze(0))
        
        ## get all feasible crops/tiles which are wholly within the WSI bounds, associated with that particular annotation
        all_crops=all_crops[np.logical_and(max_bounds.all(axis=1),min_bounds.all(axis=1))]

        ## if there is no possible crop that fits in the WSI image for a particular annotation and zoom level, retrn 
        ## empty tensors
        
         
        return all_crops
    
   
        
        
        
        
   
    def get_mask_per_class(self,class_annotation_data:pd.DataFrame,crop:np.ndarray,
                          downsample_factor:float)->torch.tensor:
       
        """"function to create masks of each class given the annotation data and crop(image) 
            coordinates=(4X2 shape) also the donsample factor of the crop to scale the polygon coords"""
        annotation_class=class_annotation_data['class_name'].iloc[0]
        annotation_num=self.class2num[annotation_class]
        
        
        ## select the top left point of the crop
        ## its the point with the min X and Y corrdinates (top left of image is origin)
        
        top_left=crop.min(axis=0)
        
        ## create a shapely polygon from crop to find intersections between annotations and crop
        
        crop_poly=Polygon(crop)
        
        ## create list of intersecting polygons with crop to fill with clss encoding
        
        intersects=[]
        for poly in  class_annotation_data['polygon']:
            if not crop_poly.intersects(poly):
                continue
            else:
                intersect=crop_poly.intersection(poly)
                
                if isinstance(intersect,MultiPolygon):
                    for inter in intersect.geoms:
                        ext_coords=((np.array(inter.convex_hull.exterior.coords)-top_left)//downsample_factor).astype(np.int32)
                        intersects.append(ext_coords)
                elif isinstance(intersect,Polygon):
                        ext_coords=((np.array(intersect.convex_hull.exterior.coords)-top_left)//downsample_factor).astype(np.int32)
                        intersects.append(ext_coords)
                else:
                        continue
                        
                        
                

                    
        mask=np.zeros(self.crop_pixel_size,dtype=np.uint8)
        
       
        ## fill the intersected polygons within the mask
        cv2.fillPoly(mask,intersects,color=annotation_num)
        
        return torch.tensor(mask,dtype=torch.uint8)
        
        
        
    def read_slide_region(self,slide_obj:openslide.OpenSlide,top_left:np.ndarray,
                         level:int):
        """ returns the pixel RGB from WSI given a location,crop_size and level"""
        #pdb.set_trace()
        return slide_obj.read_region(tuple(top_left.astype(np.int32)),level,self.crop_pixel_size)

    def get_dl(self,batch_size,kind,shuffle=True):
        ## only shuffle the train dl not the validation one
        shuffle=kind=='train'
        return DataLoader(dataset=self,batch_size=batch_size,shuffle=shuffle)

    def get_img_T(self,pyramid_top_lefts:np.ndarray,image_path:pathlib.Path):

        img_T=np.concatenate([np.array(self.read_slide_region(openslide.OpenSlide(image_path),sampled_top_left,zoom_level))[:,:,:-1] 
                  for zoom_level,sampled_top_left in zip(self.pyramid_zoom_levels,pyramid_top_lefts)],axis=2)
        img_T=torch.tensor(img_T).permute(2,0,1)
        return img_T
        
        


    def get_msk_T(self,pyramid_crops:np.ndarray,image_anno_data:pd.DataFrame):
        pyramid_msk=[]
        for i in range(self.num_pyramid_mask_levels):
            get_classwise_masks=partial(self.get_mask_per_class,
                                              crop=pyramid_crops[i],
                                              downsample_factor= list(self.pyramid_zoom_levels.values())[i])
            
            class_wise_masks=image_anno_data.groupby('class_name').apply(get_classwise_masks)
    
    
    
    
            ## stack the masks of various classes
            stacked_masks=torch.stack(class_wise_masks.to_list(),dim=0)
            
            ## create a composite mask with higher class numbers taking precedence in case of ties
            
            composite_mask=stacked_masks.max(dim=0)
            pyramid_msk.append(composite_mask.values)
        
        return torch.stack(pyramid_msk)



    
    def __getitem__(self, index):
        ## select annotation 
        anno_row=self.anno_df.iloc[index]
        
        ## select all annotations in the same image as indexed annotation
        image_name,anno_class=anno_row['image_name'],anno_row['class_name']
        dowsample_levels=anno_row['downsample_levels']
        image_path=anno_row['image_path']
        image_anno_data=self.anno_df[self.anno_df['image_name']==image_name]
        
        ## select pyramidal crops from N_levels zoom levels
        pyramid_crops,pyramid_top_lefts=self.get_pyramid_crops(anno_row)

        ## create a stack of pyramid crops centered at the annotation with as many zoom levels as descibed by pyramid_levels
        img_T=self.get_img_T(pyramid_top_lefts,image_path)
        mask_T=self.get_msk_T(pyramid_crops,image_anno_data)
    
        return img_T,mask_T




class WSI_Inference(WSI_Pyramid):


    """Pytorch Dataset class to perform inference on a WSI.Input is the tissue locations on a WSI which are obtained after
      removal of background. Inference is run on crops of 128X128 extracted from these locations.Inherits from the pyramid
      parent class to make available common convenience functions"""

    
   
    
    def __init__(self,
                 wsi_path:pathlib.Path,
                 wsi_tissue_locs:np.ndarray,
                 **kwargs)->None:
        ## init the pyramidal dataset
        super().__init__(**kwargs)

        self.wsi_path= wsi_path
        self.wsi_tissue_locs=wsi_tissue_locs

    def __len__(self):
        return len( self.wsi_tissue_locs)

    def __getitem__(self,index):
        top_left=torch.tensor(self.wsi_tissue_locs[index])
        crop_center=top_left+self.offsets
        downsample_arr=torch.tensor(list(self.pyramid_zoom_levels.values()))
        pyramid_top_lefts=crop_center-self.offsets.unsqueeze(0)* downsample_arr.unsqueeze(1)
        pyramid_top_lefts=np.array( pyramid_top_lefts).astype(np.int32)
        img_T=self.get_img_T(pyramid_top_lefts,self.wsi_path)
        ## return locations and pyramidal images for inference
        return top_left,img_T

    def get_dl(self,batch_size):
        return DataLoader(dataset=self,batch_size=batch_size,shuffle=False)
        
        

        
        

    
        
       
        
       
      

In [6]:
#%%writefile src/segmodule.py

import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
from torchmetrics import MetricCollection
from torchmetrics.classification import Accuracy,Precision, Recall,JaccardIndex,Dice



# Define your PyTorch Lightning module (inherits from pl.LightningModule)
class SegLightningModule(pl.LightningModule):
    def __init__(self,in_channels=12,num_classes=2, arch_name='UnetPlusPlus'
                 ,encoder_name='resnet34',crossentropy_weights=(3.0,8.0),lr=1e-3,**kwargs):
        super(SegLightningModule, self).__init__()
        # Define your model architecture here
        
        self.train_metrics = MetricCollection(prefix='Train',metrics=[
                            Accuracy(task='binary' ),
                            Precision(task='binary'),
                            Recall(task='binary'),
                           JaccardIndex(task='binary')])
        self.val_metrics = MetricCollection(prefix='Val',metrics=[
                            Accuracy(task='binary' ),
                            Precision(task='binary'),
                            Recall(task='binary'),
                           JaccardIndex(task='binary')])
        
        arch=getattr(smp, arch_name)
        self.segmentation_model=arch(encoder_name=encoder_name,in_channels=in_channels,classes=num_classes)
        self.loss_fn=nn.CrossEntropyLoss(weight=torch.tensor(crossentropy_weights))
        self.lr=lr
        

    def forward(self, x):
        # Define the forward pass of your model
        
        return self.segmentation_model(x)
    
    def flattened_cross_entropy_loss(self,inp,tgt):
        
        tgt=tgt.flatten(start_dim=-2)
        inp=inp.flatten(start_dim=-2)
      
        return self.loss_fn(inp, tgt)
    
    
    def training_step(self, batch, batch_idx):
        # Define the training step
      
        img_b,mask_b=batch
        img_b=img_b.to(torch.float32)
        #taking the mask at the highest zoom  as target
        mask_b=mask_b[:,0,:,:].to(torch.int64) 
        
        y_pred_logits = self(img_b)
        loss = self.flattened_cross_entropy_loss(y_pred_logits, mask_b)
        y_pred=y_pred_logits.argmax(axis=1)
        
        metrics=self.train_metrics(y_pred,mask_b)
        metrics.update({'Train_loss':loss,'Train_pct_foreground':mask_b.float().mean()})
        
        
        self.log_dict(metrics,on_step=False,on_epoch=True,prog_bar=False)  # Log the training loss for TensorBoard
        return loss

    def validation_step(self, batch, batch_idx):
        # Define the training step
      
        img_b,mask_b=batch
        img_b=img_b.to(torch.float32)
        #taking the mask at the highest zoom  as target
        mask_b=mask_b[:,0,:,:].to(torch.int64) 
        
        y_pred_logits = self(img_b)
        loss = self.flattened_cross_entropy_loss(y_pred_logits, mask_b)
        y_pred=y_pred_logits.argmax(axis=1)
        
        metrics=self.val_metrics(y_pred,mask_b)
        metrics.update({'Val_loss':loss,'Val_pct_foreground':mask_b.float().mean()})
        self.log_dict(metrics,on_step=False,on_epoch=True,prog_bar=False)  # Log the training loss for TensorBoard
        return loss
        


    def configure_optimizers(self):
        # Define your optimizer
        optimizer = optim.Adam(self.parameters(),lr=self.lr)
        return optimizer


class InferenceLightningModule(pl.LightningModule):
    def __init__(self,seg_module:SegLightningModule,inference_fpath:pathlib.Path,**kwargs):
       super(SegLightningModule, self).__init__()
        # Define your model architecture here
        
       self.seg_module=seg_module
       self.inference_fpath=inference_fpath

    def forward(self, x):
        # Define the forward pass of your model
        
        return  self.seg_module(x)
    
    

    def get_inference_schema(self,sample_json='inference_schema.geojson'):
        json_path=self.inference_path/sample_json
        with open(json_path) as json_file:
             schema = json.load(json_file)

        return schema
        
    


    def validation_step(self, batch, batch_idx):
        # Define the training step
      
        img_b,top_left_b=batch
        img_b=img_b.to(torch.float32)
       
        y_pred_logits = self(img_b)
        
        y_pred=y_pred_logits.argmax(axis=1)
        pred_coords=torch.cat([torch.argwhere(pred)+top_left for pred,top_left in zip(y_pred,top_left_b) if pred.sum()>0])
     
        
       
      
        


    

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
#%%writefile src/runner.py
from __future__ import annotations
from sklearn.model_selection import train_test_split
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import TQDMProgressBar,EarlyStopping

from pytorch_lightning.loggers import CSVLogger
from tqdm import tqdm






class experiment_runner():
    
    
    def __init__(self,root=Path('training_data')
                 
                 
               ):
            
            self.label_path=root/'labels'
            self.image_path=root/'images'
            
            
            self.inference_path=Path('inference')

            self.inference_image_path=self.inference_path/'images'
            ## path to store predictions (as json files)
            self.inference_label_path=self.inference_path/'labels'
            
            self.parser=AnnotationParser(self.image_path,self.label_path)
            self.anno_df,self.img_df,self.errored=self.parser.parse_annotations()
            ## precompute tissue locations for training images
            
            t_locs=self.img_df['image_path'].apply(self.preprocess,1)
            self.img_df=self.img_df.assign(tissue_location=t_locs)
            self.img_df['inference_len']=self.img_df['tissue_location'].apply(len,1)
            ## device to perform inference on
            self.device='cuda' if torch.cuda.is_available() else 'cpu'
        
            
            
          
            
    def preprocess(self,img_path:pathlib.Path,downsample_factor=512,gray_background=np.array([236, 236, 236]),tol=0)->np.ndarray:
        """ preprocess WSI to remove grey areas returns pixel locations
            from the downsampled (thumbnail) image where the tissue exists
              """
        
        
        slide=openslide.OpenSlide(img_path)
        H,W=slide.dimensions
       
        thumbnail_img=np.array(slide.get_thumbnail((H//downsample_factor,W//downsample_factor)))
        h,w,c=thumbnail_img.shape
        
        rgb_upper_bound=gray_background+tol
        rgb_lower_bound=gray_background-tol
        grey_mask = cv2.inRange( thumbnail_img, rgb_lower_bound,  rgb_upper_bound)
        
        
        num_comps, labelled_mask = cv2.connectedComponents(~grey_mask)

        tissue_comps=[]
        tissue_mask=[]
        
        for i in range(1,num_comps):
            comp_mask=(labelled_mask==i)
            #pdb.set_trace()
            unique_rgb_vals=np.unique(thumbnail_img.reshape((h*w,c))[comp_mask.flatten()],axis=0)
            if len(unique_rgb_vals)>1:
               tissue_comps.append(np.argwhere(comp_mask))
               tissue_mask.append(comp_mask)

     
                    
                    
            
        return np.concatenate(tissue_comps)*downsample_factor
        
        

   
    
    
    
    def get_dls(self,
               batch_size=32,
                num_pyramid_levels=4,
                num_pyramid_mask_levels=1,
                crop_pixel_size=(128,128),
                pyramid_top_level={0: 1.0}
                
                ):

            #filtering out image_name with less than 2 annotations for stratifying using train_test_split
            vc=self.anno_df['image_name'].value_counts()
            filt_df=self.anno_df.merge(vc.reset_index())
            ## select images with more than the median number of annotations
            filt_df=filt_df[filt_df['count']>50]
            #filt_df=filt_df[filt_df['count']==max(vc)]
           
            df_train,df_val=train_test_split(filt_df,test_size=0.2,random_state=42,stratify=filt_df['image_name'])
            ds_train,ds_val=(WSI_Pyramid(anno_df=df,
                                         crop_pixel_size=crop_pixel_size,
                                         pyramid_top_level=pyramid_top_level,
                                         num_pyramid_levels=num_pyramid_levels,
                                         num_pyramid_mask_levels=num_pyramid_mask_levels) for df in (df_train,df_val))
            
           
            dl_train,dl_val=(dset.get_dl(batch_size=batch_size,kind=kind) for dset,kind in zip((ds_train,ds_val),('train','val')))
            
        
            return dl_train,dl_val


    def get_inference_dl(self,
                         wsi_img_name,
                        batch_size=32,
                         
                         **kwargs):
        
        wsi_path=self.inference_image_path/wsi_img_name
        tissue_locs=self.preprocess(wsi_path)

        ## pass extra args on num pyramid levels etc. to the dataset class
        ds_inference=WSI_Inference(wsi_path= wsi_path,wsi_tissue_locs=tissue_locs,**kwargs)
        dl_inference= ds_inference.get_dl( batch_size=batch_size)
        return dl_inference
    
    
    def get_inference_schema(self,sample_json='inference_schema.geojson'):
        json_path=self.inference_path/sample_json
        with open(json_path) as json_file:
             schema = json.load(json_file)

        return schema
        
    
    
    
    
    def run_inference(self,checkpoint_path:pathlib.Path,wsi_img_name:str,batch_size=32,max_iter=10,**kwargs):
        inference_dl=self.get_inference_dl(wsi_img_name=wsi_img_name,batch_size=batch_size,anno_df=self.anno_df,**kwargs)
       
        ## setting pl_module to eval mode
        inference_model=SegLightningModule.load_from_checkpoint(checkpoint_path)
        inference_model.eval()

        inference_filename='.'.join((wsi_img_name,'geojson'))
        inference_filepath=self.inference_label_path/inference_filename
        
        inference_json=[]
        with torch.no_grad():
          for i,(top_left_b,img_b) in enumerate(tqdm(inference_dl)):
               top_left_b,img_b=top_left_b.to(self.device),img_b.to(self.device)
               if i>max_iter:
                   break
               pred_b= inference_model(img_b.to(torch.float32))
               pred_b=torch.argmax(pred_b,dim=1)
               pred_coords=torch.cat([torch.argwhere(pred)+top_left for pred,top_left in zip(pred_b,top_left_b) if pred.sum()>0])
               coords_list=pred_coords.tolist()
               if len(coords_list)>0:
                  inference_schema=self.get_inference_schema()
                  inference_schema['geometry'][ 'coordinates']=coords_list
                  inference_json.append(inference_schema)
                  ## write in append mode
                  with open(inference_filepath, 'w') as json_file:
                     json.dump(inference_json, json_file, indent=4)

        
       
        
      
                   
                   
                   
                
        
    
    
    def show_batch(self,kind='train',
                   mask_color=torch.tensor((0,128,255)),
                   num_pyramid_levels=4,
                   crop_pixel_size=(512,512),
                   pyramid_top_level={0: 1.0},
                   show_batch_size=8,save=False,**kwargs,
                   ):
                   
              

          dl_train,dl_val=self.get_dls(batch_size=show_batch_size,
                                                   num_pyramid_levels=num_pyramid_levels,
                                                   num_pyramid_mask_levels=num_pyramid_levels,
                                                   crop_pixel_size=crop_pixel_size,
                                                   pyramid_top_level=pyramid_top_level
                                                   )



            
            
          if kind=='train':
            
            img_b,mask_b=next(iter(dl_train))
            #pdb.set_trace()

          if kind=='val':
            
             img_b,mask_b=next(iter(dl_val))

        
          H,W=crop_pixel_size
          img_b=img_b.reshape(batch_size*num_pyramid_levels,3,H,W)

          #pdb.set_trace()
          mask_b=mask_b.reshape(batch_size*num_pyramid_levels,1,H,W)

        
          fig = plt.figure(figsize=(32,24))
          color_b=mask_b*mask_color.unsqueeze(0).unsqueeze(2).unsqueeze(3)
          overlay_b=img_b+color_b
          grid_img=make_grid(overlay_b,nrow=num_pyramid_levels)
          
          if save:
            plt.imshow(grid_img.permute(1,2,0))
              
          plt.imshow(grid_img.permute(1,2,0))
         

         
    
       
    
    def run_training(self,name,epochs=10,
                     num_pyramid_levels=4,
                     num_pyramid_mask_levels=1,
                     crop_pixel_size=(512,512),
                     pyramid_top_level={0: 1.0},
                     batch_size=32,
                     **kwargs):

        dl_train,dl_val=self.get_dls(batch_size=batch_size,
                                     num_pyramid_levels=num_pyramid_levels,
                                     num_pyramid_mask_levels=num_pyramid_mask_levels,
                                     crop_pixel_size=crop_pixel_size,
                                     pyramid_top_level=pyramid_top_level)

        self.pl_module=SegLightningModule(in_channels=3*num_pyramid_levels,**kwargs)
                                                   

        
        
   
        accelerator='gpu' if torch.cuda.is_available() else 'cpu'
        
        self.trainer = pl.Trainer(max_epochs=epochs, accelerator=accelerator,callbacks=[EarlyStopping(monitor='Val_loss',
                                                             patience=5,  verbose=True, mode='min' ),
                                       TQDMProgressBar(refresh_rate=1)],
                                logger=CSVLogger(flush_logs_every_n_steps=10,save_dir='runs',name=name))
                                                                         
                                                        
                                                              
                                               
        
        ## automatically saves model
        self.trainer.fit(model=self.pl_module, train_dataloaders=dl_train,val_dataloaders=dl_val)
            

            
            

In [11]:
runner=experiment_runner()

In [12]:
runner.run_training(name='test_04',epochs=1,batch_size=4)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name               | Type             | Params
--------------------------------------------------------
0 | train_metrics      | MetricCollection | 0     
1 | val_metrics        | MetricCollection | 0     
2 | segmentation_model | UnetPlusPlus     | 26.1 M
3 | loss_fn            | CrossEntropyLoss | 0     
--------------------------------------------------------
26.1 M    Trainable params
0         Non-trainable params
26.1 M    Total params
104.428   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


                                                                                                                       

  rank_zero_warn(


Epoch 0: 100%|████████████████████████████████████████████████████████████████| 98/98 [02:14<00:00,  1.38s/it, v_num=0]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                               | 0/25 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                  | 0/25 [00:00<?, ?it/s][A
Validation DataLoader 0:   4%|██▎                                                       | 1/25 [00:00<00:05,  4.33it/s][A
Validation DataLoader 0:   8%|████▋                                                     | 2/25 [00:01<00:13,  1.72it/s][A
Validation DataLoader 0:  12%|██████▉                                                   | 3/25 [00:02<00:18,  1.16it/s][A
Validation DataLoader 0:  16%|█████████▎                                                | 4/25 [00:03<00:19,  1.09it/s][A
Validation DataLoader 0:  20%|███████████▌                                              | 5/25 [00:04<00:18,

Metric Val_loss improved. New best score: 0.574
`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|████████████████████████████████████████████████████████████████| 98/98 [02:47<00:00,  1.71s/it, v_num=0]


In [12]:
runner.run_inference(checkpoint_path=Path('test_02')/'version_0'/'checkpoints'/'epoch=3-step=392.ckpt',wsi_img_name='4520_Phi172_S3T1R1.tiff',batch_size=4)

  0%|▎                                                                            | 10/2114 [07:53<27:40:11, 47.34s/it]


KeyboardInterrupt: 