# 📚 Import Libraries

In [1]:
!cp -r ../input/timm-pytorch-image-models . 
#!cp -r ../input/openslide .

In [2]:
!pip install -qU ./timm-pytorch-image-models/pytorch-image-models-master
#!pip install -qU ./openslide

[0m

In [3]:
!conda install ../input/how-to-use-pyvips-offline/*.tar.bz2 


Downloading and Extracting Packages
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
###########################################################

In [4]:
import pandas as pd
import numpy as np
from glob import glob
from collections import defaultdict
from tqdm import tqdm
import time
import os 
import copy
import gc
from openslide import OpenSlide
from PIL import Image
# visualization
import cv2
import matplotlib.pyplot as plt

# Sklearn
from sklearn.model_selection import StratifiedKFold, KFold, StratifiedGroupKFold,GroupKFold 

# PyTorch 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp

import timm

import zipfile
import pyvips
# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Metrics 
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score


# For colored terminal text
from colorama import Fore, Back, Style
c_  = Fore.GREEN
sr_ = Style.RESET_ALL

import warnings
warnings.filterwarnings('ignore')

# ⚙️ Configuration

In [5]:
fake_inf = len(glob("/kaggle/input/mayo-clinic-strip-ai/test/*")) == 4 
print(fake_inf)

True


In [6]:
class CFG:
    seed          = 2307
    debug         = False # set debug=False for Full Training
    comment       = "eff b7 more more 0.15 satruration on data"
    n_flods       = 5
    backbone      = "convnext_tiny"
    train_bs      = 1
    valid_bs      = 1
    epochs        = 25
    lr            = 1e-4
    scheduler     = 'CosineAnnealingLR'
    min_lr        = 1e-6
    T_max         = int(30000/train_bs*epochs)+50
    T_0           = 25
    warmup_epochs = 1
    wd            = 1e-6
    n_accumulate  = 2#max(1, 32//train_bs)
    device        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    tile_size     = (1024,1024)
    t_thr         = 0.4 # percentage of required color pixles to |keep the pic 
    faster_inf    = True #only take N_slides slides from each image 
    N_slides      = 45

In [7]:
def set_seed(seed = 42):
    np.random.seed(seed)
    #random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    print('> SEEDING DONE')
    
set_seed(CFG.seed)

> SEEDING DONE


# ❗ Data

In [8]:
def get_data_info(paths,df):
    img_prop = defaultdict(list)
    
    for i, path in enumerate(paths):
        img_path = paths[i]
        slide = OpenSlide(img_path)    
        img_prop['image_id'].append(img_path[-12:-4])
        img_prop['width'].append(slide.dimensions[0])
        img_prop['height'].append(slide.dimensions[1])
        img_prop['size'].append(round(os.path.getsize(img_path) / 1e6, 2))
        img_prop['path'].append(img_path)
    
    image_data = pd.DataFrame(img_prop)
    image_data['img_aspect_ratio'] = image_data['width']/image_data['height']
    image_data.sort_values(by='image_id', inplace=True)
    image_data.reset_index(inplace=True, drop=True)
    
    image_data = image_data.merge(df, on='image_id')
    return image_data

In [9]:
test_images = glob("/kaggle/input/mayo-clinic-strip-ai/test/*")
test_df = pd.read_csv('../input/mayo-clinic-strip-ai/test.csv')
test_df= get_data_info(test_images,test_df)
test_df.head()

Unnamed: 0,image_id,width,height,size,path,img_aspect_ratio,center_id,patient_id,image_num
0,006388_0,34007,60797,1312.94,/kaggle/input/mayo-clinic-strip-ai/test/006388...,0.559353,11,006388,0
1,008e5c_0,5946,29694,109.57,/kaggle/input/mayo-clinic-strip-ai/test/008e5c...,0.200242,11,008e5c,0
2,00c058_0,15255,61801,351.76,/kaggle/input/mayo-clinic-strip-ai/test/00c058...,0.246841,11,00c058,0
3,01adc5_0,55831,26553,679.17,/kaggle/input/mayo-clinic-strip-ai/test/01adc5...,2.102625,11,01adc5,0


In [10]:
!mkdir ./test

In [11]:


def tile(img, w,h, N=16):
    shape = img.shape
    diff_a0 = np.sum(np.diff(img,axis=0),axis = (0,2))/(3*shape[0]) < 3 # if the avg diffrannce in that line is less than 20 just delet it 
    diff_a1 = np.sum(np.diff(img,axis=1),axis = (1,2))/(3*shape[1]) < 3
    img = np.delete(img, diff_a0, axis=1)
    img = np.delete(img, diff_a1, axis=0)
    shape = img.shape
    pad0,pad1 = (w - shape[0]%w)%w, (h - shape[1]%h)%h
    img = np.pad(img,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2],[0,0]],constant_values=255)
    
    img = img.reshape(img.shape[0]//w,w,img.shape[1]//h,h,3)
    
    img = img.transpose(0,2,1,3,4).reshape(-1,w,h,3)
    #print(shape)

    print(img.shape)
    diff = np.diff(img,axis=1)
    idxs = np.argsort(diff.reshape(diff.shape[0],-1).sum(-1))[-N:]
    img = img[idxs]
    return img


def make_tiles(img_info , save_dir):

    img_info = img_info[1]
    image_id ,img_path ,center_id ,w ,h = img_info.image_id ,img_info.path ,img_info.center_id ,img_info.width ,img_info.height
    
    save_counter = 0
    #img_obj = OpenSlide(img_path)
    tiles_num = 20 #n*n tiles 
    ds_factor =  max(h,w) / 25000 # biggest image we can load in mem? 
    print(1/ds_factor)
    #print(1/ds_factor)
    img_obj = pyvips.Image.new_from_file(img_path).resize(1/ds_factor).numpy()

    patch_size = (384,384)
    
    img_obj = tile(img_obj,patch_size[0],patch_size[1],50)

    for i in range(img_obj.shape[0]):
        img_name = f"{image_id}_{center_id}_{i}.jpg"
        Image.fromarray(img_obj[i,:,:,:]).save(save_dir + img_name)
    del img_obj
    gc.collect()
    #h*=1/ds_factor
    #w*=1/ds_factor
    #print(ds_factor)
    

In [12]:
for image_info in tqdm(test_df.iterrows(), total = len(test_df), desc = "splliting images"):

    make_tiles(image_info, "./test/")

splliting images:   0%|          | 0/4 [00:00<?, ?it/s]

0.4112045002220504
(1836, 384, 384, 3)


splliting images:  25%|██▌       | 1/4 [01:11<03:35, 71.81s/it]

0.8419209267865563
(348, 384, 384, 3)


splliting images:  50%|█████     | 2/4 [01:25<01:15, 37.60s/it]

0.40452419863756256
(560, 384, 384, 3)


splliting images:  75%|███████▌  | 3/4 [01:51<00:32, 32.42s/it]

0.44777990722000327
(2046, 384, 384, 3)


splliting images: 100%|██████████| 4/4 [02:50<00:00, 42.56s/it]


In [13]:
def get_data_info(paths , train = False):
    img_prop = defaultdict(list)
    
    for i, path in tqdm(enumerate(paths), total = len(paths),desc = "making dataframe"):
        img_info =  path.split('/')[-1]
        if len(img_info.split("_")) == 4:
            patient_id , image_num  , centre_id ,slice_num= img_info.split("_")
        else:
            patient_id , image_num   ,slice_num= img_info.split("_")
        #tl_pixel = tl_pixel.split('.')[0]
        #centre_id = centre_id.split('.jpg')[0]
        
        img_prop['image_id'].append(f"{patient_id}_{image_num}")
        img_prop['patient_id'].append(patient_id)
        img_prop['image_num'].append(image_num)
        img_prop['slice_num'].append(slice_num.split('.jpg')[0])

        img_prop['path'].append(path)
        #img_prop['tl_pixel'].append(tl_pixel)
        
        if train:
            label = train_data[train_data["image_id"]==f"{patient_id}_{image_num}"].label.item()
            
            img_prop['label'].append(label)
            
        
        #img_prop['density'].append(extra_info)
    
    image_data = pd.DataFrame(img_prop)

    image_data.sort_values(by='image_id', inplace=True)
    image_data.reset_index(inplace=True, drop=True)
    #image_data['density'] = image_data['density'].astype(np.float16)
    
    return image_data

In [14]:
test_images = glob("./test/*")
df = get_data_info(test_images)
df.head()

making dataframe: 100%|██████████| 200/200 [00:00<00:00, 443841.69it/s]


Unnamed: 0,image_id,patient_id,image_num,slice_num,path
0,006388_0,6388,0,33,./test/006388_0_11_33.jpg
1,006388_0,6388,0,28,./test/006388_0_11_28.jpg
2,006388_0,6388,0,12,./test/006388_0_11_12.jpg
3,006388_0,6388,0,7,./test/006388_0_11_7.jpg
4,006388_0,6388,0,48,./test/006388_0_11_48.jpg


# 🔨 Utility

In [15]:
def load_img(path):
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)

    img = img.astype('float32') 
    return img

def show_img(img, ground_truth, pred = "", conf = ""):
    plt.imshow(img)
    plt.title(f'true: {"CE" if ground_truth else "LAA"} | predicted: {pred} | conf: {conf}')
    plt.axis('off')
    

# ❗ DataLoaders

In [16]:
class StripAiDataset(Dataset):
    def __init__(self, df, N_slides, label=False ,transforms = None):
        self.df = df
        self.transforms = transforms
        self.image_ids = df['image_id'].unique().tolist()
        self.file_names = df['path'].tolist()
        self.patient_id = df['patient_id'].tolist()
        self.label = None
        self.N_slides = N_slides
    def __len__(self):
        return len(self.image_ids)
  
    def __getitem__(self,index):
        img_id = self.image_ids[index]
        tempdf =self.df[self.df["image_id"] == img_id]
        patient_id = tempdf['patient_id'].tolist()
        img_paths = tempdf.path.sample(self.N_slides,replace=True).tolist()
        img = []
        for path in img_paths:  
            temp = self.transforms(image=load_img(path))["image"]
            temp = np.transpose(temp, (2, 0, 1))
            img.append(temp)
        #print(np.array(img).shape)

        img = np.stack(img , axis = 0).astype(np.float32)

        #print(img.shape)
        
    
    
        return torch.tensor(img), patient_id[0]

In [17]:
data_transforms = {
    "train": A.Compose([A.Normalize()], p=1.0),
    "tta": A.Compose([A.Flip(p=1.0),
                      A.Normalize()], p=1.0),
    "valid": A.Compose([A.Normalize()], p=1.0)
}

In [18]:
def prepare_loaders(test_df, N_slides,aug,debug=False):
 

    test_dataset = StripAiDataset(test_df, N_slides,transforms=aug)

    test_loader = DataLoader(test_dataset, batch_size=CFG.valid_bs, 
                              num_workers=1, shuffle=False, pin_memory=False)
    
    return test_loader

In [19]:
class Flatten(nn.Module):
    def __init__(self, dim=1):
        super().__init__()
        self.dim = dim

    def forward(self, x): 
        input_shape = x.shape
        output_shape = [input_shape[i] for i in range(self.dim)] + [-1]
        return x.view(*output_shape)


In [20]:
class StripModel(nn.Module):

    def __init__(self, model_name, num_classes=2, pretrained=True ,num_instances=CFG.N_slides , path=""):
        super().__init__()
        self.num_instances = num_instances
        self.encoder = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)

        
        
        feature_dim = self.encoder.get_classifier().in_features
        self.encoder.head.fc = nn.Identity()
        self.feature_dim = feature_dim
        print(feature_dim)
        
        self.head = nn.Sequential(
            nn.Conv3d(self.num_instances ,1,(1,12,12)), nn.ReLU(inplace=True), Flatten(),
            nn.Linear(feature_dim, 256), nn.ReLU(inplace=True), 
            nn.Linear(256, 64), nn.ReLU(inplace=True), 
            nn.Linear(64, num_classes)
        )


    def forward(self, x):
        # x: bs x N x C x W x W
        bs, _, ch, w, h = x.shape
        x = x.view(bs*self.num_instances, ch, w, h) # x: N bs x C x W x W
        x = self.encoder.forward_features(x) # x: N bs x C' x W' x W'

        # Concat and pool
        #bs2, ch2, w2, h2 = x.shape
        #x = x.view(-1, self.num_instances, ch2, w2, h2).permute(0, 2, 1, 3, 4)\
            #.contiguous().view(bs, ch2, self.num_instances*w2, h2) # x: bs x C' x N W'' x W''
        emb = self.head(x)

        return emb,self.encoder.head(x)
    

In [21]:
def get_model(path):
    model = StripModel(CFG.backbone ,pretrained = False)
    model.to(CFG.device)
    model.load_state_dict(torch.load(path,map_location=torch.device(CFG.device)))
    model.eval()
    return model

# 🔧 Loss Function

# 🚄 Training Function

In [22]:
@torch.no_grad()
def infer(model_ptah, test_loader, preds):

    model     = get_model(model_ptah) 
    soft = nn.Softmax()
    for idx, (img,pid) in enumerate(tqdm(test_loader, total=len(test_loader), desc='Infer ')):
        img = img.to(CFG.device, dtype=torch.float)# .squeeze(0)
        try:
            out,_ = model(img)#.squeeze()
            out = soft(out).cpu().detach().numpy()
                
        except:
            out = torch.tensor([[0.5 ,0.5]])
            print("except")
            #print(out)
        preds["patient_id"].extend(pid)
        
        preds["CE"].extend(out[:,1])
        preds["LAA"].extend(out[:,0])
            #out = model(img)#.squeeze()
        
        del img,  out
        gc.collect()
        torch.cuda.empty_cache()

    return preds

In [23]:
model_ptahs = glob("../input/all5fold/*")

In [24]:
#pd.concat([df]*100)

In [25]:
preds = defaultdict(list)
for model_ptah in model_ptahs:
    test_loader  = prepare_loaders(df,N_slides=45,aug = data_transforms['valid'])
    preds = infer(model_ptah ,test_loader, preds)
    test_loader  = prepare_loaders(df,N_slides=45,aug = data_transforms['tta'])
    preds = infer(model_ptah ,test_loader, preds)


768


Infer : 100%|██████████| 4/4 [00:09<00:00,  2.33s/it]


768


Infer : 100%|██████████| 4/4 [00:03<00:00,  1.20it/s]


768


Infer : 100%|██████████| 4/4 [00:03<00:00,  1.26it/s]


768


Infer : 100%|██████████| 4/4 [00:03<00:00,  1.11it/s]


768


Infer : 100%|██████████| 4/4 [00:03<00:00,  1.26it/s]


768


Infer : 100%|██████████| 4/4 [00:03<00:00,  1.17it/s]


768


Infer : 100%|██████████| 4/4 [00:03<00:00,  1.20it/s]


768


Infer : 100%|██████████| 4/4 [00:03<00:00,  1.16it/s]


768


Infer : 100%|██████████| 4/4 [00:03<00:00,  1.25it/s]


768


Infer : 100%|██████████| 4/4 [00:03<00:00,  1.05it/s]


In [26]:
preds = pd.DataFrame(preds)

In [27]:
preds = preds.groupby(by="patient_id").mean().reset_index()

In [28]:
preds

Unnamed: 0,patient_id,CE,LAA
0,006388,0.443729,0.556271
1,008e5c,0.618345,0.381655
2,00c058,0.576882,0.423118
3,01adc5,0.435397,0.564603


In [29]:
preds

Unnamed: 0,patient_id,CE,LAA
0,006388,0.443729,0.556271
1,008e5c,0.618345,0.381655
2,00c058,0.576882,0.423118
3,01adc5,0.435397,0.564603


In [30]:
preds.to_csv("submission.csv", index = False)