In [None]:
import pandas as pd
import requests
import os
import numpy as np
from io import StringIO

from PIL import Image
import matplotlib.pyplot as plt


from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import torch
import torchvision
from torchvision.models import ResNet34_Weights


In [None]:
def load_labeled_features(path,drop_non_labelled=True, potential_features=['rods','clumped','planktonic','filaments','positive','negative','intermediate']):
    '''
    we assume that the server is locally hosted as described in the settings, but the user can also
    write sql code to read from the database directly
    '''
    response = requests.get(f"http://localhost:8080/download_csv_by_path?path={path}")
    df = pd.read_csv(StringIO(response.text),index_col=0)
    df.drop(df[df.trash==True].index,inplace=True) # drop all trash
    
    actual_feataures=list(df.columns)
    features=[f for f in potential_features if f in actual_feataures]

    # we add if its labelled
    foo=pd.DataFrame(df.groupby(['chip','label']).apply(lambda x: x[features].any(axis=1),include_groups=False))
    foo.reset_index(inplace=True)
    foo=foo.drop(columns=['level_2'])
    foo.rename(columns={0:'labeled'},inplace=True)
    
    df=df.merge(foo,on=['chip','label'])
    if(drop_non_labelled):
        df.drop(df[df.labeled==False].index,inplace=True) # drop all non labelled
    df.reset_index(inplace=True,drop=True)
    return df

In [None]:
def prepare_data_for_classification(path,columns_to_keep=['chip','label','concentration','positive','clumped','planktonic','filaments','rods','n_cells_1','n_cells_2','labeled','filename']):

    data=load_labeled_features(path,drop_non_labelled=False)
    
    path_split=path.split('/')   # some datasets for Cipro are incosistently labelled and have a folder 1stexp and 2ndexp, we catch that here directly   
    is_match = path_split[-2] in ['1stexp', '2ndexp', 'set-1', 'set-2', 'set1', 'set2']

    if(is_match):
        data['path_short']= path_split[-3] +'_'+ path_split[-2]
    else:    
        data['path_short']=path_split[-2]
            
    data['path']=path
    data['filename']=''
    prefix=data.path_short.unique()[0]
    for c in data.chip.unique():
        foo=data[data.chip==c].copy()
        for i in foo.label.unique():
            data.loc[(data.chip==c) & (data.label==i),'filename']=f'{prefix}_Crop_{c}_{i}.tiff'

    # we need to convert 'negative' labeles into entries for positve if no column does not exist and add 0 entries if column does not exist
    if ('negative' in data.columns) and ('positive' not in data.columns):
        print('adding positive!')
        data['positive']=1-data['negative'].values

    for c in ['rods','filaments','planktonic','clumped']:
            if(c not in data.columns):
                data[c]=0.0
    data.drop(columns=[c for c in data.columns if c not in columns_to_keep],inplace=True)
    
    return data
    

In [None]:
class BacteriaEndPointDatasetApplication(Dataset): # here we use loc because we want to keept the index structure to insert later the labeles int othe orignal dataframe

    def __init__(self, df, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.info_table= df.copy()
        self.ids= self.info_table.index.values

        self.root_dir = root_dir
        self.transform = transform

        
    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx): 
        '''
        loads image
        '''
        global_idx=self.ids[idx]
        
        img_name = os.path.join(self.root_dir,self.info_table.loc[global_idx].filename)
        image=Image.open(img_name)
        
        if self.transform:
            sample={'image': self.transform(image),'globel_id':global_idx}
        else:
            sample = {'image': image,'globel_id':global_idx}

        return sample

# Set up the trained model

In [None]:
num_classes=5
model = torchvision.models.resnet34(weights=ResNet34_Weights.DEFAULT)
model.fc = torch.nn.Linear(512, num_classes) # affine linear transformation for last layer 

In [None]:
MODEL_SAVE_PATH = 'trained_networks/' # where the trained network weights are stored
modelname = 'bacteria_trained_model_resnet34'
model.load_state_dict(torch.load(os.path.join(MODEL_SAVE_PATH, modelname), map_location=torch.device('cpu')))

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load relevant data

In [None]:
# Cipro
Cipro_paths=['Cirpofloxacin/20230131-ecoli-cipro-1/day2',
'Cirpofloxacin/20230131-ecoli-cipro-2/day2',
'Cirpofloxacin/20220531-MIC-e.coli-cipro/2ndexp/day2',
'Cirpofloxacin/20220531-MIC-e.coli-cipro/1stexp/day2']
# Genta
Genta_paths=['Gentamicin/20230110-e.coli-genta/day2', 
'Gentamicin/20230110-e.coli-genta-2/day2', 
'Gentamicin/20221101-ecoli-genta1/day2',
'Gentamicin/20221101-ecoli-genta2/day2'] 

#Tetra
Tetra_paths=['Tetracycline/20230404-ecoli-Tetracycline/set2/day2', 
'Tetracycline/20230404-ecoli-Tetracycline/set1/day2', 
'Tetracycline/20230315-ecoli/set-2/day2', 
'Tetracycline/20230315-ecoli/set-1/day2'] 

#CHP
CHP_paths=['Chloramphenicol/20230313-ecoli-chp-2/day2',
'Chloramphenicol/20230313-ecoli-chp-1/day2', 
'Chloramphenicol/20230221-ecoli-chp-1/day2', 
'Chloramphenicol/20230111-ecoli-chp-2/day2', 
'Chloramphenicol/20230111-ecoli-chp/day2', 
'Chloramphenicol/20221122-ecoli-chp/day2',
'Chloramphenicol/20221031-ecoli-chp2/day2', 
'Chloramphenicol/20221031-ecoli-chp1/day2', 
'Chloramphenicol/20221013-ecoli-chp/day2', 
'Chloramphenicol/20221012-ecoli-chp/day2',
'Chloramphenicol/20220628-MIC-e.coli-chp-LB-2/day2', 
'Chloramphenicol/20220628-MIC-e.coli-chp-LB-1/day2', 
'Chloramphenicol/20220602-MIC-e.coli-chp-LB/day2', 
'Chloramphenicol/20220524-MIC-e.coli-chp-LB/day2'] 


AMP_paths=['Ampicillin/20220614-MIC-e.coli-amp-LB-2/day2','Ampicillin/20220614-MIC-e.coli-amp-LB-1/day2']

all_paths= Cipro_paths + Genta_paths + Tetra_paths + CHP_paths + AMP_paths

In [None]:
mounting_folder='/home/your_username/Crops/' # where the crops are

In [None]:
# this cell will take a while to run: Docker slows the performance

#predict label for not labelled data and fuse back into dataframe
os.makedirs('PredictedLabelsTables', exist_ok=True) # creates folder to store the results if it does not exist
for p in all_paths:
    print(p)
    data=prepare_data_for_classification(p)
    data_to_process=data[data.labeled==False].copy() # only non labelled data
    data_set_application=BacteriaEndPointDatasetApplication(data_to_process,mounting_folder,transforms.Compose([
        transforms.Grayscale(3),
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]))
    dataloader_application=DataLoader(data_set_application,shuffle=False,batch_size=32,drop_last=False) # drops incomplete sets that don't match batch_size


    pred_logit = []
    id_list=[]

    model.eval()
    with torch.no_grad():
        for i,D in enumerate(dataloader_application):
            images_to_clasify=D['image'].to(device)
            pred_logit.append(model(images_to_clasify))
            id_list.append(D['globel_id'].to(device))
    
    id_list=np.concatenate(id_list, axis=0)           
    
    pred_logit=np.concatenate(pred_logit, axis=0) 
    prob=1/(1+ np.exp(-pred_logit)) #sigmoid to get probability
    threshold=0.5
    y_pred=np.zeros_like(prob)
    y_pred[prob>threshold]=1

    for c,i in enumerate(id_list):
        data.loc[i,'positive']=y_pred[c,0]
        data.loc[i,'planktonic']=y_pred[c,1]
        data.loc[i,'clumped']=y_pred[c,2]
        data.loc[i,'rods']=y_pred[c,3]
        data.loc[i,'filaments']=y_pred[c,4]
     

    data.to_csv(f'PredictedLabelsTables/{data.filename.unique()[0].split('_Crop')[0]}_trained.csv')

    print(f'{data.filename.unique()[0].split('_Crop')[0]}_trained.csv done')
    del data