In [2]:
import glob
import os
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.optim import lr_scheduler
import scipy

import torch
import torchvision
from torchvision.models import ResNet34_Weights
from IPython.display import display # to display images
from sklearn.metrics import accuracy_score


In [3]:
class BacteriaEndPointDataset(Dataset):

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.info_table= pd.read_csv(csv_file) # we we load and suppres the loading of the indexing by index_col=0 our indices don't match!
        self.info_table.drop(columns=['Unnamed: 0'],inplace=True) 
        self.ids= self.info_table.index.values

        self.train_ids=[]
        self.val_ids=[]
        self.test_ids=[]

        self.root_dir = root_dir
        self.transform = transform

        self.mode='train'
        
    def __len__(self):
        if( self.mode =='train'):
            return len(self.train_ids) # number of elements
        elif(self.mode=='val'):
            return len(self.val_ids) # number of elements
        else:   #test 
            return len(self.test_ids) # number of elements


    def train_val_split(self,split_div=3):
        # split percentages can be adapted
    
        self.val_ids=np.random.choice(self.ids,size=len(self.ids)//split_div,replace=False)
        self.train_ids=np.array([x for x in self.ids if x not in self.val_ids]) # not already choosen
    
        
    def __getitem__(self, idx): 
        '''
        loads image and labels into sample
        warning: I use global indices so they match the indices in the table! Hence once needs to use iloc.
        '''
        if(self.mode =='train'):
            global_idx=self.train_ids[idx]
        elif(self.mode=='val'):
            global_idx=self.val_ids[idx]
        else: #test
            global_idx=self.test_ids[idx]
        
        img_name = os.path.join(self.root_dir,self.info_table.iloc[global_idx, 0])
        image=Image.open(img_name)

        labels = self.info_table.iloc[global_idx, 1:].values
        labels = labels.astype('float')
    
        
        if self.transform:
            sample={'image': self.transform[self.mode](image), 'labels': labels,'global_id':global_idx}
        else:
            sample = {'image': image, 'labels': labels,'globel_id':global_idx}

        return sample  
        

In [4]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Grayscale(3),  #resnet requirement
        transforms.Resize(224), #resnet requirement
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),  #resnet requirement, but scales also to [0,1]
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])   #resnet requirement
    ]),
    'val':transforms.Compose([
        transforms.Grayscale(3),
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Grayscale(3),
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Load Datasets 

Crops are precomuted and stored as tifs for faster performance.

In [None]:
data_set=BacteriaEndPointDataset('LabelingSetAll.csv','Crops/',data_transforms) 


In [None]:
data_set.train_val_split(3) # internally split, otherwise Dataloader will fail currently

In [None]:
dataloader=DataLoader(data_set,shuffle=True,batch_size=32,drop_last=True) # drops incomplete sets that don't match batch_size

In [None]:
data_set.mode='val'

labels_val = []
labels_train = []

for i,D in enumerate(dataloader):
    labels_val.append(D['labels'])

data_set.mode='train'

for i,D in enumerate(dataloader):
    labels_train.append(D['labels'])
      

labels_val=np.concatenate(labels_val, axis=0)
labels_train=np.concatenate(labels_train, axis=0)

In [None]:
np.bincount(np.all(labels_train==0,axis=1))

In [None]:
np.bincount(np.all(labels_val==0,axis=1))

In [None]:
classes=['positive','planktonic','clumped','rods','filaments']

In [None]:
fig, ax = plt.subplots(1,2,figsize=(20, 10), dpi=200)
ax[1].pie(labels_val.sum(axis=0), labels=classes, autopct='%1.0f%%')
ax[1].set_title('validation')
ax[0].pie(labels_train.sum(axis=0), labels=classes, autopct='%1.0f%%')
ax[0].set_title('train')

plt.rcParams.update({'font.size': 30})

In [None]:
#dg=pd.DataFrame({'labels_train':labels_train.sum(axis=0),'N_train': labels_train.shape[0] * np.ones_like(labels_train.sum(axis=0)),'labels_val':labels_val.sum(axis=0),'N_val': labels_val.shape[0] * np.ones_like(labels_val.sum(axis=0))})

In [None]:
#dg.to_csv('LabelSplitInfoAll.csv') # save for later use

# Define model and run training

In [5]:
# Define the model
num_classes=5
model = torchvision.models.resnet34(weights=ResNet34_Weights.DEFAULT)
model.fc = torch.nn.Linear(512, num_classes) # affine linear transformation for last layer 

Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /Users/erikmaikranz/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 83.3M/83.3M [00:01<00:00, 46.5MB/s]


In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define the loss function and optimizer
criterion = torch.nn.BCEWithLogitsLoss()

optimizer = torch.optim.Adam(model.parameters(),lr=0.00001) #my default 0.0001
# Decay LR by a factor of 0.1 every 7 epochs
scheduler = lr_scheduler.StepLR(optimizer, step_size=50) #step_size25
epochs = 100

In [None]:
# Train the model and store metrics
epoch_train_losses=[]
epoch_val_losses=[]

epoch_train_accuracy=[]
epoch_val_accuracy=[]


for epoch in range(epochs):
    train_losses=[]
    train_accuracy=[]
    train_precision=[]
    train_recall=[]

    data_set.mode='train'
    model.train()
    for D in dataloader:
        optimizer.zero_grad()
        data=D['image'].to(device)
        labels=D['labels'].to(device)

        # Forward pass
        predictions = model(data)
        loss = criterion(predictions, labels)

        # Backward pass
        loss.backward()
        optimizer.step()

        #collect performance metrics
        train_losses.append(loss.item())
        prob=1/(1+ np.exp(-predictions.detach().numpy())) #sigmoid to get probability
        threshold=0.5
        y_pred=np.zeros_like(prob)
        y_pred[prob>threshold]=1

        train_accuracy.append(accuracy_score(labels,y_pred))

    epoch_train_losses.append(np.mean(train_losses))  
    epoch_train_accuracy.append(np.mean(train_accuracy))  
    
    scheduler.step()
    
#----validation--------------------------
    val_losses=[]
    val_accuracy=[]
    val_precision=[]
    val_recall=[]

    data_set.mode='val'
    model.eval()
    with torch.no_grad():
        for D in dataloader:
            data=D['image'].to(device)
            labels=D['labels'].to(device)
    
            # Forward pass
            predictions = model(data)
            loss = criterion(predictions, labels)
            val_losses.append(loss.item())
             
            prob=1/(1+ np.exp(-predictions.detach().numpy())) #sigmoid to get probability
            threshold=0.5
            y_pred=np.zeros_like(prob)
            y_pred[prob>threshold]=1

            val_accuracy.append(accuracy_score(labels,y_pred))
  
    epoch_val_losses.append(np.mean(val_losses))  
    epoch_val_accuracy.append(np.mean(val_accuracy))  
    
    print(f' Trained {epoch} with average loss {np.mean(train_losses)}, {np.mean(val_losses)}')

    

In [None]:
dg=pd.DataFrame({'train_losses':np.array(epoch_train_losses),'val_losses': np.array(epoch_val_losses),
                 'train_acc': np.array(epoch_train_accuracy),'val_acc': np.array(epoch_val_accuracy))

In [None]:
dg.to_csv('TrainingInfo_All.csv')

In [None]:
#initial visualisation: proper visualisation in a different notebook

fig= plt.figure()
ax=plt.gca()

ax.plot(epoch_train_losses, color = 'red', label = 'train')
ax.plot(epoch_val_losses, color = 'blue', label = 'val')

plt.legend()
ax.set_xlabel('epoch')
ax.set_ylabel('average loss per epoch')

In [None]:
fig= plt.figure()
ax=plt.gca()

ax.plot(epoch_train_accuracy, color = 'red', label = 'train')
ax.plot(epoch_val_accuracy, color = 'blue', label = 'val')

plt.legend()
ax.set_xlabel('epoch')
ax.set_ylabel('average accuracy per epoch')

# Evaluate on Validation Set

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_set.mode='val'

pred_logit = []
y_true = []
id_list=[]

model.eval()
with torch.no_grad():
    for i,D in enumerate(dataloader):
        data=D['image'].to(device)
        labels=D['labels'].to(device)
        pred_logit.append(model(data))
        y_true.append(labels)
        id_list.append(D['global_id'].to(device))

y_true=np.concatenate(y_true, axis=0)
id_list=np.concatenate(id_list, axis=0)           

pred_logit=np.concatenate(pred_logit, axis=0) 
prob=1/(1+ np.exp(-pred_logit)) #sigmoid to get probability

In [None]:
dg=pd.DataFrame({'pred_logit_0':pred_logit[:,0],'pred_logit_1':pred_logit[:,1],'pred_logit_2':pred_logit[:,2],'pred_logit_3':pred_logit[:,3],'pred_logit_4':pred_logit[:,4],'id_list':id_list,'positive':y_true[:,0],'rods': y_true[:,1],'planktonic':y_true[:,2],'filaments':y_true[:,3],'clumped':y_true[:,4]})


In [None]:
dg.to_csv('ValidationPredictions_All.csv')

In [None]:
classes=['positive','planktonic','clumped','rods','filaments'] 


In [None]:
# illustrate logits
x=np.arange(-5,5,0.1)
fig, axs = plt.subplots(1, len(classes), figsize=[15, 4.5] )
for i,ax in enumerate(axs):
    ax.scatter(pred_logit[:,i],y_true[:,i])
    ax.plot(x,1/(1+ np.exp(-x)),color='r')
    ax.plot([0,0],[0,1],color='gray')
    ax.set_title(classes[i])

## precision and recall calculation

In [None]:
threshold=0.5
y_pred=np.zeros_like(prob)
y_pred[prob>threshold]=1

In [None]:
precision=precision_score(y_true,y_pred,average=None)
recall=recall_score(y_true,y_pred,average=None)
F1score=2 * precision*recall/(precision+recall)

# precision true_positve/(true_positve + false_positive)  how many of prediced class labels are correct ?
# -> The precision is intuitively the ability of the classifier not to label as positive a sample that is negative.
#recall : true_positve/(true_positve + false_negative) : how well is this class detected ?
#->  The recall is intuitively the ability of the classifier to find all the positive samples.

In [None]:
dg=pd.DataFrame({'precision': precision,'recall': recall,'F1score': F1score})

In [None]:
dg.to_csv('Scores_All.csv')

In [None]:
#Illustrate
fig= plt.figure()

plt.scatter(classes,precision,label='precision')
plt.scatter(classes,recall,label='recall')
plt.scatter(classes,F1score,label='F1')

plt.legend()
plt.yticks(np.arange(0,1.1,0.1))
plt.ylim([0,1])

# Visualise examples

The idea is to look at the logits and identify  images of interest

In [None]:
df=pd.read_csv('LabelingSetAll.csv')
path='Crops/'
#classes=['positive','rods','planktonic','filaments','clumped']

In [None]:
x=np.arange(-5,5,0.1)
fig, axs = plt.subplots(1, len(classes), figsize=[15, 4.5] )
for i,ax in enumerate(axs):
    ax.scatter(pred_logit[:,i],y_true[:,i])
    ax.plot(x,1/(1+ np.exp(-x)),color='r')
    ax.plot([0,0],[0,1],color='gray')
    ax.set_title(classes[i])

In [None]:
class_index=3
# examples with low logit but positive: 
select_thres=-1
indices_fneg= [i for i in range(len(y_true[:,class_index])) if (y_true[i,class_index]==1) & (pred_logit[i,class_index]<=select_thres)]
indices_tneg= [i for i in range(len(y_true[:,class_index])) if (y_true[i,class_index]==0) & (pred_logit[i,class_index]<=select_thres)]

# examples with high logit but negative
select_thres=1
indices_fpos= [i for i in range(len(y_true[:,class_index])) if (y_true[i,class_index]==0) & (pred_logit[i,class_index]>=select_thres)]
indices_tpos= [i for i in range(len(y_true[:,class_index])) if (y_true[i,class_index]==1) & (pred_logit[i,class_index]>=select_thres)]

## Selection

In [None]:
I=2
col=[indices_fneg[I],indices_tneg[I],indices_fpos[I],indices_tpos[I]]

In [None]:
for i in col:

    fig = plt.figure()
    image=Image.open(os.path.join(path,df.iloc[id_list[i], 1]))
    display(image)
    print(df.iloc[id_list[i],1:])
    print('--------------------------------------------------')

## False positve

In [None]:
for i in indices_fpos:

    fig = plt.figure()
    image=Image.open(os.path.join(path,df.iloc[id_list[i], 1]))
    display(image)
    print(df.iloc[id_list[i],1:])
    print('--------------------------------------------------')

## False negative

In [None]:
for i in indices_fneg:

    fig = plt.figure()
    image=Image.open(os.path.join(path,df.iloc[id_list[i], 1]))
    display(image)
    print(df.iloc[id_list[i],1:])
    print('--------------------------------------------------')

# Save model

In [None]:
model = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)


In [None]:
MODEL_SAVE_PATH = 'trained_networks/'
modelname = 'bacteria_trained_model_resnet34'

torch.save(model.state_dict(), os.path.join(MODEL_SAVE_PATH, modelname))