In [1]:
import torch
import torch.nn as nn

import numpy as np

from tqdm import tqdm

import cv2

import math

In [2]:
class VGG_Schmidtea(nn.Module):
    
        def __init__(self, n_classes=72):
            super(VGG_Schmidtea, self).__init__()
        
            self.conv_1 = nn.Sequential(                                                      # 32 * 32
                nn.Conv2d( 1,   64, kernel_size = 3, padding = 1),  
                nn.Conv2d( 64,  64, kernel_size = 3, padding = 1),
                nn.BatchNorm2d(64),
                nn.MaxPool2d(kernel_size = 2)) 

            self.conv_2 = nn.Sequential(                                                     # 16 * 16 
                nn.Conv2d( 64, 128, kernel_size = 3, padding = 1),   
                nn.Conv2d(128, 128, kernel_size = 3, padding = 1),
                nn.BatchNorm2d(128),
                nn.MaxPool2d(kernel_size = 2))

            self.conv_3 = nn.Sequential(                                                     # 8 * 8
                nn.Conv2d(128, 256, kernel_size = 3, padding = 1),
                nn.Conv2d(256, 256, kernel_size = 3, padding = 1),
                nn.Conv2d(256, 256, kernel_size = 3, padding = 1),
                nn.BatchNorm2d(256),
                nn.MaxPool2d(kernel_size = 2))

            self.conv_4 = nn.Sequential(                                                     # 4 * 4 
                nn.Conv2d(256, 512, kernel_size = 3, padding = 1),
                nn.Conv2d(512, 512, kernel_size = 3, padding = 1),
                nn.Conv2d(512, 512, kernel_size = 3, padding = 1),
                nn.BatchNorm2d(512),
                nn.MaxPool2d(kernel_size = 2))

            self.conv_5 = nn.Sequential(                                                     # 2 * 2
                nn.Conv2d(512, 512, kernel_size = 3, padding = 1),
                nn.Conv2d(512, 512, kernel_size = 3, padding = 1),
                nn.Conv2d(512, 512, kernel_size = 3, padding = 1),
                nn.BatchNorm2d(512),
                nn.MaxPool2d(kernel_size = 2))                                               # 1 * 1 


            self.classifier = nn.Sequential(
                nn.Dropout(p = 0.25),
                nn.Linear(512, 4096),
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(4096),
                nn.Dropout(p = 0.25),
                nn.Linear(4096, 4096),
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(4096),
                nn.Linear(4096, n_classes))

             
        def forward(self, x):
            x = self.conv_1(x)
            x = self.conv_2(x)
            x = self.conv_3(x)
            x = self.conv_4(x)
            x = self.conv_5(x)
            x = torch.flatten(x, 1)
            x = self.classifier(x)
                    
            return x

In [10]:
def train(model, loss_vector, accuracy_vector,  confusion_matrix, train_loader, device, problem, criterion, optimizer, epoch, n_class, log_interval=700000):
    # Set model to training mode
    model.train()
    
    
    if problem == 'classification':
        train_loss, correct = 0, 0
        conf_mat = torch.zeros(n_class, n_class)
        
    else:
        train_loss, correct_5, correct_2_5, correct_1 = 0, 0, 0, 0
        conf_mat = torch.zeros(n_class, n_class)
        
    
    # Loop over each batch from the training set
    for batch_idx, batch in enumerate(train_loader):
        
        # Copy data to GPU if needed
        img = batch['image'].float().to(device)
        if problem == 'classification':
            angle = batch['angle'].long().to(device)
        else: 
            angle = batch['angle'].float().to(device)         ### Not sure it's a float

        # Zero gradient buffers
        optimizer.zero_grad() 
        
        # Pass data through the network
        output = model(img)
        
        #pred = output.max(1)[1]     
        #correct += pred.eq(angle).cpu().sum()

        # Calculate loss      
        if problem == 'classification':
            loss = criterion(output, angle)
            pred = output.max(1)[1]     
            correct += pred.eq(angle).cpu().sum()
            
            train_loss += loss
            
            for a, p in zip(angle.view(-1), pred.view(-1)):
                conf_mat[a.long(), p.long()] +=1

        elif problem == 'regression_1':
            loss = criterion(output, angle)
            
        elif problem == 'regression_2' or problem == 'regression_3':
            angle = torch.deg2rad(angle.float())
            loss = criterion(output, angle)
        
        elif problem == 'regression_4':
            loss = criterion(output, angle)
            """
                outs = torch.sin(output[:, 1])  # Prediction of the sinus
                outc = torch.cos(output[:, 0])  # Prediction of the cosine
                pred_angle = torch.atan2(outs, outc) # Get the pred_angle in Radians
                angle = torch.deg2rad(angle.float())  # Transform the real angle in radians
            
                loss = criterion(output, angle)
            """
                                  
            #output = output.reshape(len(output))
            #cos_output = torch.cos(torch.deg2rad(output.float()))
            #cos_angle  = torch.cos(torch.deg2rad(angle.float()))
            #sin_output = torch.sin(torch.deg2rad(output.float()))
            #sin_angle  = torch.sin(torch.deg2rad(angle.float()))
            #loss = criterion(cos_output, cos_angle) + criterion(sin_output, sin_angle)
            
        """train_loss += loss  
            
            angle_5 = angle//72
            output_5 = output//72
            correct_5 += output_5.eq(angle_5).cpu().sum()
            
            angle_2_5 = angle//36
            output_2_5 = output//36
            correct_2_5 += output_2_5.eq(angle_2_5).cpu().sum()
            
            angle_1 = angle//1
            output_1 = output//1
            correct_1 += output_1.eq(angle_1).cpu().sum()
            
            for a, p in zip(angle_5.view(-1), output_5.view(-1)):
                conf_mat[a.long(), p.long()] +=1"""

                
        # Backpropagate
        loss.backward()
        
        # Update weights
        optimizer.step()
        
        # Print advancement of the code
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(img), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data.item()))
               
    train_loss /= len(train_loader)
    loss_vector.append(train_loss)
    

In [1]:
def validate(model, loss_vector, accuracy_vector, confusion_matrix, validation_loader, device, problem, criterion, n_class, list_target):
    '''
    Input of the function:
        model: neural network model in Pytorch
        loss_vector: empty array with is assigned by the function
        accuracy_vector: empty array with is assigned by the function
    '''
    
    
    if problem == 'classification':
        val_loss, correct = 0, 0
        conf_mat = torch.zeros(n_class, n_class)
    else:
        val_loss, correct_5, correct_2_5, correct_1 = 0, 0, 0, 0
        conf_mat = torch.zeros(n_class, n_class)
        
    i = 0
    for batch_idx, batch in enumerate(validation_loader):
        i +=1
        # Copy data to GPU if needed
        img = batch['image'].float().to(device)
        if problem == 'classification':
            angle = batch['angle'].long().to(device)
        else: 
            angle = batch['angle'].float().to(device)       ### Not sure it's a float
            
        output = model(img)
        
        toto = angle.flatten().cpu()
        tete = output.flatten().cpu()
        #list_target.append((toto, tete))
        #list_target.append((angle.flatten().to('cpu'), output.flatten().to('cpu')))
        
        # Pass data through the network
        with torch.no_grad():
            if problem == 'classification':
                pred = output.max(1)[1]
                correct += pred.eq(angle).cpu().sum()
                
                val_loss += criterion(pred, angle) 
                
                for a, p in zip(angle.view(-1), pred.view(-1)):
                    conf_mat[a.long(), p.long()] +=1
                    
            elif problem == 'regression_1':
                loss = criterion(output, angle)
            
            elif problem == 'regression_2' or problem == 'regression_3':
                angle = torch.deg2rad(angle.float())
                loss = criterion(output, angle)
        
            elif problem == 'regression_4':
                loss = criterion(output, angle)
            """elif problem == 'regression_4':
                outs = torch.sin(output[:, 1])  # Prediction of the sinus
                outc = torch.cos(output[:, 0])  # Prediction of the cosine
                pred_angle = torch.atan2(outs, outc) # Get the pred_angle in Radians
                angle = torch.deg2rad(angle.float())  # Transform the real angle in radians
            
                loss = criterion(output, angle)"""
                
            """else:
                output = output.reshape(len(output))
                cos_output = torch.cos(torch.deg2rad(output.float()))
                cos_angle  = torch.cos(torch.deg2rad(angle.float()))
                sin_output = torch.sin(torch.deg2rad(output.float()))
                sin_angle  = torch.sin(torch.deg2rad(angle.float()))

                val_loss += criterion(cos_output, cos_angle) + criterion(sin_output, sin_angle)                
                
                angle_5 = angle//72
                output_5 = output//72
                correct_5 += output_5.eq(angle_5).cpu().sum()

                angle_2_5 = angle//36
                output_2_5 = output//36
                correct_2_5 += output_2_5.eq(angle_2_5).cpu().sum()

                angle_1 = angle//1
                output_1 = output//1
                correct_1 += output_1.eq(angle_1).cpu().sum()
            
                for a, p in zip(angle_5.view(-1), output_5.view(-1)):
                    conf_mat[a.long(), p.long()] +=1
            """
    
    val_loss /= len(validation_loader)
    loss_vector.append(val_loss)
    confusion_matrix.append(conf_mat)
                
    if problem == 'classification':
        accuracy = 100. * correct.to(torch.float32) / len(validation_loader.dataset)
        accuracy_vector.append(accuracy)

        print('Validation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
                    .format(val_loss, correct, len(validation_loader.dataset), accuracy))
        

    else:
        """
        accuracy_5 = 100. * correct_5.to(torch.float32) / len(validation_loader.dataset)
        accuracy_2_5 = 100. * correct_2_5.to(torch.float32) / len(validation_loader.dataset)
        accuracy_1 = 100. * correct_1.to(torch.float32) / len(validation_loader.dataset)

        accuracy_vector.append(accuracy_5)
        accuracy_vector.append(accuracy_2_5)
        accuracy_vector.append(accuracy_1)        
        
        print('Validation set: Average loss: {:.4f}, Accuracy 5 : {:.1f}%, Accuracy 2.5 : {:.1f}, Accuracy 1 : {:.1f}\n'
                    .format(val_loss, accuracy_5, accuracy_2_5, accuracy_1))
                    
        """
        print(f"Validation loss: {val_loss}")
        
        
    '''   
    accuracy = 100. * correct.to(torch.float32) / len(validation_loader.dataset)
    accuracy_vector.append(accuracy)
    
    
    # At the end of an epoch, print the precision of the current model weight
    print('Validation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
            .format(val_loss, correct, len(validation_loader.dataset), accuracy))
    '''

In [1]:
def angle_prediction_from_img(path, model, device, problem = 'classification'):
    """Function that take a image of detected centriole"""
    
    detectionpath = path[:-4] + '_centriole_detected.tif'
    
    #Read the image
    img = cv2.imread(detectionpath, cv2.IMREAD_UNCHANGED)
    
    #Extract the coordinate of detected centrioles
    ypts, xpts = np.where(img == 1)
    
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    
    a_list_of_centriole = []
    x_shape, y_shape = img.shape[1], img.shape[0] 
    
    xlim, ylim  = x_shape - 16 , y_shape - 16
    for i in tqdm(range(len(ypts))):
        x, y = xpts[i], ypts[i]

        if y > 16 and x > 16 and y < ylim and x < xlim:
            #centriole_extracted = img.crop((xpts[i], ypts[i], xpts[i] + 32, ypts[i] + 32))
            centriole = img[y-16:y+16, x-16:x+16]
            centriole = np.asarray(centriole, dtype = "uint8")
            centriole = centriole.reshape(1 , 1, 32, 32)
            # Inside predictor:
            centriole = torch.from_numpy(centriole)
            centriole = centriole.float().to(device)

            with torch.no_grad():
                output = model(centriole)

            angle = output.max(1)[1]
            angle = angle.cpu().numpy()

            #print(centriole_extracted)
            #angle = predictor(model, centriole_extracted, device, problem = 'classification')
            a_list_of_centriole.append(((xpts[i], ypts[i]), angle[0]))
            
    return a_list_of_centriole