In [6]:
import re
import torch 
import torch.utils.data as tor_utils
import os
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as torchmodels
import random
import string
import numpy as np
from torch.utils.data import Dataset, DataLoader
from fastai import *
from fastai.vision.all import *
from fastai.data.core import DataLoaders as fast_dataloaders
from scipy.ndimage import zoom
import orix
from orix.quaternion import Rotation, Symmetry
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
from scipy.ndimage import median_filter
import cv2
import torch.utils

In [7]:
#data related functions


def getMps(steel_type):
    """
    A wrapper function to load the master pattern stored in steelH5s'

    Args:
        steel_type (str): The type of steel to be loaded.

    Returns:
        mp (kikuchipy masterpattern): The loaded master pattern in default projection.
        mp.as_lambert(kikuchipy masterpattern): The loaded master pattern in lambert projection.

    """
    mp = kp.load("steelH5s/" + steel_type + ".h5")
    return mp, mp.as_lambert()

def quatBreakdown(fPath):
    """
    A function to apply a regex search to extract the rotation quaternions from a filename.
    
    Args:
        fPath (str): The filepath which is the be broken down.

    Returns:
        signleTarget (list of floats): A list of the numbers between underscores in the filepath.
    """
    #create and apply the pattern which looks inbetween underscores and extracts the strings inbetween
    pat = r'[^\\_]+(?=_)+.+[^_.jpeg]'  
    pat = re.compile(pat)  
    extracted = pat.search(str(fPath)).group()

    #cast the extracted strings to floats
    singleTarget = [float(lbl) for lbl in extracted.split("_")[1: -1]]
    return singleTarget

def getNpsDataset(path, regexFunc):  

    """
    A function to load all NPS files in a specified folder into a dataset and use the regexFunc to generate labels from filenames.

    Args:
        path (str): Directory path where NPS files are stored.
        regexFunc (regex pattern): A function to apply regex operations on filenames to generate labels.

    Returns:
        NpsDataset: A dataset containing loaded NPS files and their labels.
    """

    #get all files in a folder
    files = os.listdir(path)
    fileList = [path + f for f in files if os.path.isfile(os.path.join(path, f))]

    #load all the files to an Nps dataset
    npsDataset = NpsDataset(fileList, regexFunc)
    return npsDataset

def loadNpsFile(filePath):
    """
    Load an NPS file.

    Args:
        filePath (str): Path to the NPS file.

    Returns:
        numpy array: Loaded data from the NPS file.
    """
    return np.load(filePath)

class NpsDataset(tor_utils.Dataset):

    """
    A pytorch dataset class for loading NPS files. It also breaks down their filenames, using a regex function, to generate labels.
    
    Attributes:
        fileList (list): List of paths to NPS files.
        regexFunc (function): A function to apply regex operations on filenames to generate labels.
    """
    
    def __init__(self, fileList, regexFunc):
        #initialise with the file list of the nps files and the regex function for breakdown
        self.fileList = fileList
        self.regexFunc = regexFunc

    def __len__(self):
        """
        Return the total number of items in the dataset.
        
        Returns:
            int: Total number of items in the dataset.
        """
        return len(self.fileList)

    def __getitem__(self, idx):

        """
        Get the data and label for the item at the specified index.
        
        Args:
            idx (int): Index of the desired item.
        
        Returns:
            torch tensor: Image data 
            torch tensor: The label with the regex function applied
        """

        filePath = self.fileList[idx]
        data = loadNpsFile(filePath)
        return torch.tensor(data).unsqueeze(0), torch.tensor(self.regexFunc(filePath))

    def subset(self, indices):

        """
        Get a subset of items from the dataset based on the specified indices.
        
        Args:
            indices (list): List of indices specifying the desired subset.
        
        Returns:
            list: A list of items corresponding to the specified indices.
        """
        return [self[i] for i in indices]

def resize(pat, size):
    """
    Resize a given image array to the specified size.
    
    Args:
        pat (numpy array): Input image array to be resized.
        size (int): Desired size for both width and height of the output image.
    
    Returns:
        numpy array: Resized image array.
    """

    img = Image.fromarray(pat)
    img = img.resize((size, size))
    data = np.asarray(img)
    return data

def resize_batch(pat, size):
    """
    Resize a batch of images using the zoom function to a given size.
    
    Args:
        pat (numpy array): Input 3D array representing a batch of images (number_of_images, height, width).
        size (int): Desired size for both width and height of each output image in the batch.
    
    Returns:
        numpy array: Resized batch of images.
    """
    #calculate zoom factor needed
    zoomFactors = (1, size / pat.shape[1], size / pat.shape[2])

    resizedArray = zoom(pat, zoomFactors)
    return resizedArray


class SymQuatBreakdown:
    """
    A class to find symmetrical points and return all distinguished pounts from a given quaternion
    
    Attributes:
        sym: The given symmetry for quaternion breakdown.
        singleTarget: A single target quaternion that needs to be broken down.
        allTargetsData: Tensor containing the distinguished points of the symmetry breakdown.
    """

    def __init__(self, sym):
        self.sym = sym
        self.singleTarget = None
        self.allTargetsData = None
    
    def breakdown(self, singleTarget):
        """
        Break down the given target quaternion into distinguished points based on the provided symmetry.
        
        Args:
            singleTarget: The quaternion that needs to be broken down.
        
        Returns:
            torch tensor: Tensor containing the distinguished points of the quaternion breakdown.
        """

        #generate all distinguished points with similarity to the target
        self.singleTarget = Symmetry(singleTarget)
        self.allTargetsData = torch.tensor(orix.quaternion.symmetry.get_distinguished_points(self.sym, self.singleTarget).data).float()
        
        #fill out the tensor with zeros in order for all labels to have the same shape
        numSlots = 96 - self.allTargetsData.size(0)
        firstValue = self.allTargetsData[0, :]
        filledData = firstValue.repeat(numSlots, 1)
        self.allTargetsData = torch.cat((self.allTargetsData, filledData), dim=0)
        return self.allTargetsData

class greyDataset(Dataset):
    """
    A dataset class to handle grayscale images and their corresponding labels. 
    Allows optional transformations to be applied to the images.
    
    Attributes:
        data (torch tensor): Greyscale image data 
        labels (torch tensor: Labels corresponding to each image.
        tmfs (list, optional): List of transformations to be applied on the images.
        transform (torch transform, optional): A single transform composed of all the transforms in tmfs.
    """
    def __init__(self, data, labels, tmfs=None):
        self.data = data
        self.labels = labels
        self.tmfs = tmfs
        self.transform = transforms.Compose(self.tmfs) if self.tmfs else None

    def __len__(self):
        """
        Return the total number of items in the dataset.
        
        Returns:
            int: Total number of items in the dataset.
        """
        return len(self.data)

    def __getitem__(self, idx):
        """
        Retrieve the image and its corresponding label at the specified index.
        Applies transformations on the image if specified.
        
        Args:
            idx (int): Index of the desired item.
        
        Returns:
            image (pytorch tensor): Transformed, if specified, greyscale image data.
            label (pytorch tensor): Labels corresponding to the image.
        """
        image = self.data[idx]
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label
    
def closestQuaternion(reference, quaternions):
    """
    Finds the quaternion closest to the reference quaternion based on dot product.
    
    Parameters:
        reference (torch.Tensor): Reference quaternion.
        quaternions (torch.Tensor): Tensor of quaternions to search through.
    
    Returns:
        torch.Tensor: The closest quaternion to the reference.
    """
    reference = reference.unsqueeze(0)
    #Calculate dot products between the reference and each quaternion in the list
    dots = torch.sum(reference * quaternions, dim = 1)
    #Find the index of the maximum dot product
    closestInd = np.argmax(dots)
    return quaternions[closestInd]
    
class oneHotSymQuatBreakdown:
    """
    A class to find symmetrical points and return all distinguished pounts from a given quaternion and include a portion of the tensor which represents the one hot encoding of the material type
    
    Attributes:
        sym (kikuchipy phase): The given symmetry for quaternion breakdown.
        singleTarget (torch tensor): A single target quaternion that needs to be broken down.
        allTargetsData (torch tensor): Tensor containing the distinguished points of the symmetry breakdown.
        maxSize (int): Maximum number of symmetrical points in all symmetries, will cast all outputs to have this shape
    """
    def __init__(self, numMaterial, maxSize = 96):
        self.numMaterial = numMaterial
        self.maxSize = maxSize
        
    def breakdown(self, single_target, oneHotLoc, sym):
        self.sym = sym

        #Get symmetrical distinguished points for the target
        self.singleTarget = Symmetry(single_target)
        self.allTargetsData = torch.tensor(orix.quaternion.symmetry.get_distinguished_points(self.sym, self.singleTarget).data).float()
        
        # Extend the target data tensor to have maxSize rows
        num_slots_to_fill = 96 - self.allTargetsData.size(0)
        first_value = self.allTargetsData[0, :]
        filled_data = first_value.repeat(num_slots_to_fill, 1)
        self.allTargetsData = torch.cat((self.allTargetsData, filled_data), dim=0)
        oneHot = torch.zeros((self.maxSize, self.numMaterial))
        oneHot[:, oneHotLoc] = 1
        return torch.cat([self.allTargetsData, oneHot], axis = 1)
    
def grainSplit(img, thresh):
    """
    Splits the grains in an image based on differences in adjacent pixels' values.
    
    Args:
        img (torch.Tensor): The input image, assumed to be in RGB format.
        thresh (float): Threshold for determining grain boundaries. 
    
    Returns:
        split (torch.Tensor): A binary image indicating the splits between grains. Pixels marked `True` are on grain boundaries, while those marked `False` are not.
    """
    # Calculate the difference between each pixel and the one below it
    downDif = torch.sum(torch.abs(img[1:, :] - img[:-1, :]), dim=2) > thresh

    # Calculate the difference between each pixel and the one to its right
    rightDif = torch.sum(torch.abs(img[:, 1:] - img[:, :-1]), dim=2) > thresh
    #print(torch.abs(img[:, 1:] - img[:, :-1]))
    # Initialize a binary image with False values
    split = torch.zeros_like(img[:, :, 0], dtype=torch.bool)

    # Set the boundary pixels to True based on the differences computed
    split[:-1, :] |= downDif
    split[:, :-1] |= rightDif

    return split

In [8]:
#models
class DeepGreyResnet101(nn.Module):
    """
    Custom deep neural network architecture based on ResNet101, adapted for grayscale (single-channel) input images.

    Attributes:
        resnet (nn.Sequential): Custom ResNet101 backbone.
        avgpool (nn.AdaptiveAvgPool2d): Adaptive average pooling layer.
        flatten (nn.Flatten): Layer to flatten the output tensor.
        dropoutn (nn.Dropout): Dropout layers, up to 3.
        fcn (nn.Linear): Fully connected layers, up to 4.
        activation (nn.LeakyReLU): Activation function layer.
    """
    
    def __init__(self, n_classes, p_conv=0.1, p_lin = 0.2):
        """
        Args:
            n_classes (int): Number of classes for the final output.
            p_conv (float, optional): Dropout probability for convolutional layers. Defaults to 0.1.
            p_lin (float, optional): Dropout probability for linear layers. Defaults to 0.2.
        """
        super().__init__()
        #Define custom ResNet architecture for single-channel input by replacing the first layer
        layers = list(resnet101(pretrained=True).children())[:-2]
        in_channels = 1
        first_conv_layer = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        first_conv_layer.weight.data = layers[0].weight.sum(1, keepdim=True)
        self.resnet = nn.Sequential(first_conv_layer)
        
        #Add dropout and activation to convolutional layers
        for layer in layers[1:]:
            if isinstance(layer, nn.Conv2d):
                self.resnet.add_module("conv_dropout", nn.Dropout2d(p_conv))
            self.resnet.add_module(str(len(self.resnet)), layer)

        #Define linear section of the model
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.flatten = nn.Flatten()
        self.dropout1 = nn.Dropout(p_lin)
        self.fc1 = nn.Linear(2048, 512)
        self.dropout2 = nn.Dropout(p_lin)
        self.fc2 = nn.Linear(512, 128)
        self.dropout3 = nn.Dropout(p_lin)
        self.fc3 = nn.Linear(128, 32)
        self.fc4 = nn.Linear(32, n_classes)
        self.activation = nn.LeakyReLU(0.1)

    def forward(self, x):
        """
        Forward pass of the model.
        
        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            x (torch.Tensor): Output tensor after passing through the network.
        """
        x = self.resnet(x)
        x = self.avgpool(x)
        x = self.flatten(x)

        if self.training:
            x = self.dropout1(x)

        x = self.fc1(x)
        x = self.activation(x)

        if self.training:
            x = self.dropout2(x)

        x = self.fc2(x)
        x = self.activation(x)

        x = self.fc3(x)
        x = self.activation(x)

        x = self.fc4(x)
        return x
    
#models
class DeepGreyResnet152(nn.Module):
    """
    Custom deep neural network architecture based on ResNet152, adapted for grayscale (single-channel) input images.

    Attributes:
        resnet (nn.Sequential): Custom ResNet101 backbone.
        avgpool (nn.AdaptiveAvgPool2d): Adaptive average pooling layer.
        flatten (nn.Flatten): Layer to flatten the output tensor.
        dropoutn (nn.Dropout): Dropout layers, up to 3.
        fcn (nn.Linear): Fully connected layers, up to 4.
        activation (nn.LeakyReLU): Activation function layer.
    """
    
    def __init__(self, n_classes, p_conv=0.1, p_lin = 0.2):
        """
        Args:
            n_classes (int): Number of classes for the final output.
            p_conv (float, optional): Dropout probability for convolutional layers. Defaults to 0.1.
            p_lin (float, optional): Dropout probability for linear layers. Defaults to 0.2.
        """
        super().__init__()
        #Define custom ResNet architecture for single-channel input by replacing the first layer
        self.n_classes = n_classes
        layers = list(resnet152(pretrained=True).children())[:-2]
        in_channels = 1
        first_conv_layer = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        first_conv_layer.weight.data = layers[0].weight.sum(1, keepdim=True)
        self.resnet = nn.Sequential(first_conv_layer)
        
        #Add dropout and activation to convolutional layers
        for layer in layers[1:]:
            if isinstance(layer, nn.Conv2d):
                self.resnet.add_module("conv_dropout", nn.Dropout2d(p_conv))
            self.resnet.add_module(str(len(self.resnet)), layer)

        #Define linear section of the model
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.flatten = nn.Flatten()
        self.dropout1 = nn.Dropout(p_lin)
        self.fc1 = nn.Linear(2048, 2048)
        self.dropout2 = nn.Dropout(p_lin)
        self.fc2 = nn.Linear(2048, 1024)
        self.dropout3 = nn.Dropout(p_lin)
        self.fc3 = nn.Linear(1024, 512)
        self.fc4 = nn.Linear(512, n_classes)
        self.activation = nn.LeakyReLU(0.1)

    def forward(self, x):
        """
        Forward pass of the model.
        
        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            x (torch.Tensor): Output tensor after passing through the network.
        """
        x = self.resnet(x)
        x = self.avgpool(x)
        x = self.flatten(x)

        if self.training:
            x = self.dropout1(x)

        x = self.fc1(x)
        x = self.activation(x)

        if self.training:
            x = self.dropout2(x)

        x = self.fc2(x)
        x = self.activation(x)

        x = self.fc3(x)
        x = self.activation(x)

        x = self.fc4(x)
        """if self.n_classes > 4:
            ind = self.n_classes - 4
            x[:, -ind:] = F.softmax(x[:, -ind:], dim=1)"""
        return x
    

#models
class DeepGreyResnet50(nn.Module):
    """
    Custom deep neural network architecture based on ResNet152, adapted for grayscale (single-channel) input images.

    Attributes:
        resnet (nn.Sequential): Custom ResNet101 backbone.
        avgpool (nn.AdaptiveAvgPool2d): Adaptive average pooling layer.
        flatten (nn.Flatten): Layer to flatten the output tensor.
        dropoutn (nn.Dropout): Dropout layers, up to 3.
        fcn (nn.Linear): Fully connected layers, up to 4.
        activation (nn.LeakyReLU): Activation function layer.
    """
    
    def __init__(self, n_classes, p_conv=0.1, p_lin = 0.2):
        """
        Args:
            n_classes (int): Number of classes for the final output.
            p_conv (float, optional): Dropout probability for convolutional layers. Defaults to 0.1.
            p_lin (float, optional): Dropout probability for linear layers. Defaults to 0.2.
        """
        super().__init__()
        #Define custom ResNet architecture for single-channel input by replacing the first layer
        self.n_classes = n_classes
        layers = list(resnet50(pretrained=True).children())[:-2]
        in_channels = 1
        first_conv_layer = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        first_conv_layer.weight.data = layers[0].weight.sum(1, keepdim=True)
        self.resnet = nn.Sequential(first_conv_layer)
        
        #Add dropout and activation to convolutional layers
        for layer in layers[1:]:
            if isinstance(layer, nn.Conv2d):
                self.resnet.add_module("conv_dropout", nn.Dropout2d(p_conv))
            self.resnet.add_module(str(len(self.resnet)), layer)

        #Define linear section of the model
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.flatten = nn.Flatten()
        self.dropout1 = nn.Dropout(p_lin)
        self.fc1 = nn.Linear(2048, 2048)
        self.dropout2 = nn.Dropout(p_lin)
        self.fc2 = nn.Linear(2048, 1024)
        self.dropout3 = nn.Dropout(p_lin)
        self.fc3 = nn.Linear(1024, 512)
        self.fc4 = nn.Linear(512, n_classes)
        self.activation = nn.LeakyReLU(0.1)

    def forward(self, x):
        """
        Forward pass of the model.
        
        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            x (torch.Tensor): Output tensor after passing through the network.
        """
        x = self.resnet(x)
        x = self.avgpool(x)
        x = self.flatten(x)

        if self.training:
            x = self.dropout1(x)

        x = self.fc1(x)
        x = self.activation(x)

        if self.training:
            x = self.dropout2(x)

        x = self.fc2(x)
        x = self.activation(x)

        x = self.fc3(x)
        x = self.activation(x)

        x = self.fc4(x)
        """if self.n_classes > 4:
            ind = self.n_classes - 4
            x[:, -ind:] = F.softmax(x[:, -ind:], dim=1)"""
        return x
    

    
class oneGreyResnet50(nn.Module):
    """
    Custom deep neural network architecture based on ResNet152, adapted for grayscale (single-channel) input images.

    Attributes:
        resnet (nn.Sequential): Custom ResNet101 backbone.
        avgpool (nn.AdaptiveAvgPool2d): Adaptive average pooling layer.
        flatten (nn.Flatten): Layer to flatten the output tensor.
        dropoutn (nn.Dropout): Dropout layers, up to 3.
        fcn (nn.Linear): Fully connected layers, up to 4.
        activation (nn.LeakyReLU): Activation function layer.
    """
    
    def __init__(self, n_classes, p_conv=0.1, p_lin = 0.2):
        """
        Args:
            n_classes (int): Number of classes for the final output.
            p_conv (float, optional): Dropout probability for convolutional layers. Defaults to 0.1.
            p_lin (float, optional): Dropout probability for linear layers. Defaults to 0.2.
        """
        super().__init__()
        #Define custom ResNet architecture for single-channel input by replacing the first layer
        self.n_classes = n_classes

        layers = list(resnet50(pretrained=True).children())[:-2]
        
        in_channels = 1
        first_conv_layer = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        first_conv_layer.weight.data = layers[0].weight.sum(1, keepdim=True)
        
        self.resnet = nn.Sequential(first_conv_layer)
        
        #Add dropout and activation to convolutional layers
        for layer in layers[1:]:
            if isinstance(layer, nn.Conv2d):
                self.resnet.add_module("conv_dropout", nn.Dropout2d(p_conv))
            self.resnet.add_module(str(len(self.resnet)), layer)

        #Define linear section of the model
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.flatten = nn.Flatten()
        
        self.fc1s = nn.Linear(2048, 2)
        
        self.fc1q = nn.Linear(2048, 1024)
        self.fc2q = nn.Linear(1024, 512)
        self.fc3q = nn.Linear(512, 256)
        self.fc4q = nn.Linear(256, 4)
        self.activation = nn.LeakyReLU(0.1)
        self.final = False
    def forward(self, x):
        """
        Forward pass of the model.
        
        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            x (torch.Tensor): Output tensor after passing through the network.
        """
        x = self.resnet(x)
        #print(x.shape)
        x = self.avgpool(x)
        x = self.flatten(x)
        
        xs = self.fc1s(x)
        
        if self.final:
            xs = self.final(xs)
        #xs = self.activation(xs)
        #xs = self.fc2s(xs)
        #xs = self.activation(xs)
        #xs = self.fc3s(xs)
        #xs = self.activation(xs)
        #xs = self.fc4s(xs)
        #xs = self.catActivation(xs)
        
        xq = self.fc1q(x)
        xq = self.activation(xq)
        xq = self.fc2q(xq)
        xq = self.activation(xq)
        xq = self.fc3q(xq)
        xq = self.activation(xq)
        xq = self.fc4q(xq)
        
        x = torch.cat([xq, xs], dim = -1)
        #print(x.shape)
        
        return x

In [9]:
#losses

def l1Simple(q1, q2):
    """
    Compute the L1 loss between the closest point in q2 to a reference and q1.
    
    Args:
        q1 (torch.Tensor): Target tensor.
        q2 (torch.Tensor): Labels tensor.
        
    Returns:
        loss (torch.Tensor): Computed L1 loss.
    """
        
    ref = torch.tensor([1, 1e-8, 1e-16, 1e-24]).cuda()
    ref = ref / torch.norm(ref, dim=-1, keepdim=True)
    
    diff = ref - q2
    distance = torch.norm(diff, dim=-1)
    ind = torch.argmin(distance, dim=1)
    closest = q2[torch.arange(q1.shape[0]), ind, :]
    loss = F.l1_loss(closest, q1)
    return loss

def l1Sym(q1, q2):
    """
    Compute the L1 loss between the closest point in q2 to q1.
    
    Args:
        q1 (torch.Tensor): Target tensor.
        q2 (torch.Tensor): Labels tensor.

    Returns:
    - torch.Tensor: Computed L1 loss.
    """
    ref = q1 / torch.norm(q1, dim=-1, keepdim=True)
    diff = ref.unsqueeze(1) - q2
    distance = torch.norm(diff, dim=-1)
    ind = torch.argmin(distance, dim=1)
    closest = q2[torch.arange(q1.shape[0]), ind, :]
    loss = F.l1_loss(closest, q1)
    return loss

def trueRotDistance(q1, q2):

    q1Norm = q1 / torch.norm(q1, dim=-1, keepdim=True)
    q2Norm = q2 / torch.norm(q2, dim=-1, keepdim=True)

    dotProduct = torch.sum(q1Norm.unsqueeze(1) * q2Norm, dim=-1)
    
    dotProduct = torch.clamp(dotProduct, -1.0 + 1e-7, 1.0 - 1e-7)

    # Compute the geodesic distance between the quaternions
    distance = torch.acos(dotProduct) * 360 / 3.14159
    closestTheta, _ = torch.min(distance, dim = 1)
    return torch.mean(closestTheta)

def mixedLoss(q1, q2):
    """
    Compute a mixed loss using geodesic distance and L1 symmetrical loss.
    
    Args:
        q1 (torch.Tensor): Target tensor.
        q2 (torch.Tensor): Labels tensor.
    
    Returns:
    - torch.Tensor: Computed mixed loss.
    """
    return trueRotDistance(q1, q2) + 10 * l1Sym(q1, q2)

import numpy
import numpy.matlib as npm

def avgRotDistance(q1, q2):

    q1Norm = q1 / torch.norm(q1, dim=-1, keepdim=True)
    q2Norm = q2 / torch.norm(q2, dim=-1, keepdim=True)
    
    q1Norm = q1Norm.unsqueeze(1)
    
    dotProduct = torch.abs(torch.sum(q1Norm * q2Norm, dim=-1))


    distance = torch.acos(dotProduct) * 360 / 3.14159
    return torch.mean(distance)

def averageQuaternions(Q):
    """
    Compute the average quaternion of an array of quaternions.
    
    This function calculates the average by determining the eigenvector 
    associated with the largest eigenvalue of the quaternion cross-product matrix.
    Reference: https://github.com/christophhagen/averaging-quaternions/blob/master/averageQuaternions.py
    
    Args:
        Q (numpy.ndarray): An Mx4 matrix, where M is the number of quaternions and each quaternion is represented by a row of 4 elements.
    
    Returns:
        numpy.ndarray: A quaternion representing the average.
    """

    M = Q.shape[0]
    A = npm.zeros(shape=(4,4))

    for i in range(0,M):
        q = Q[i,:]
        # multiply q with its transposed version q' and add A
        A = numpy.outer(q,q) + A

    # scale
    A = (1.0/M)*A
    # compute eigenvalues and -vectors
    eigenValues, eigenVectors = numpy.linalg.eig(A)
    # Sort by largest eigenvalue
    eigenVectors = eigenVectors[:,eigenValues.argsort()[::-1]]
    # return the real part of the largest eigenvector (has only real part)
    return numpy.real(eigenVectors[:,0].A1)

def quatL1Simple(q1, q2):
    ref = torch.tensor([1, 1e-8, 1e-16, 1e-24]).cuda()
    ref = ref / torch.norm(ref, dim=-1, keepdim=True)
    
    diff = ref - q2[:, :, :4]
    distance = torch.norm(diff, dim=-1)
    ind = torch.argmin(distance, dim=1)
    closest = q2[torch.arange(q1.shape[0]), ind, :]
    loss = F.l1_loss(closest[:, :4], q1[:, :4])
    return loss

def quatTrueRotDistance(q1, q2):
    q1Norm = q1[:, :4] / torch.norm(q1[:, :4], dim=-1, keepdim=True)
    q2Norm = q2[:, :, :4]  / torch.norm(q2[:, :, :4] , dim=-1, keepdim=True)

    dotProduct = torch.sum(q1Norm.unsqueeze(1) * q2Norm, dim=-1)
    
    dotProduct = torch.clamp(dotProduct, -1.0 + 1e-7, 1.0 - 1e-7)

    # Compute the geodesic distance between the quaternions
    distance = torch.acos(dotProduct) * 360 / 3.14159
    #print(distance.shape)
    closestTheta, _ = torch.min(distance, dim = 1)
    #print(closestTheta.shape)
    return torch.mean(closestTheta)

In [10]:
#transforms

class zoom:
    def __init__(self, zoomFactor = 2.0, imgSize = 144):
        self.zoomFactor = zoomFactor
        self.imgSize = imgSize
        
    def __call__(self, img):
        initialImg = img
        
        zoomFactor =  (self.zoomFactor - 1) * random.random() + 1
        zoomSize = max(self.imgSize, int(zoomFactor * self.imgSize))
        img = cv2.resize(img[0, :, :].numpy(), (zoomSize, zoomSize), interpolation = cv2.INTER_LANCZOS4)
        edges = int((zoomSize - self.imgSize) / 2)
        img = img[edges: -edges, edges: -edges]
        img = img[:self.imgSize, :self.imgSize]
        if img.shape[1] == self.imgSize:
            return torch.Tensor(img).unsqueeze(0)
        else:
            return initialImg
        
def createCircularArray(rows, cols, center, radius):
    """
    Creates a 2D binary array of given dimensions where cells within a given radius
    from a specified center are set to 1, and others are set to 0.
    
    Parameters:
        rows (int): Number of rows for the 2D array.
        cols (int): Number of columns for the 2D array.
        center (tuple): Coordinates of the center.
        radius (int): Radius for the circle.
    
    Returns:
        numpy.ndarray: 2D binary array.
    """
    #Initialize a 2D array with zeros of specified rows and cols.
    arr = np.zeros((rows, cols), dtype=int)
    
    #Iterate through each cell in the 2D array.
    for i in range(rows):
        for j in range(cols):
            #Compute the distance from the current cell to the center.
            distance = np.sqrt((i - center[0])**2 + (j - center[1])**2)
            
            #If the distance is less than or equal to the given radius, set that cell to 1.
            if distance <= radius:
                arr[i, j] = 1
 
    return arr

class AddNoise:
    """
    Callable class that adds random noise to an image.
    
    Attributes:
        noise_factor (float): Measure of how much noise to introduce.
    """
    def __init__(self, noiseFactor):
        #Set the noise factor (a measure of how much noise to introduce).
        self.noiseFactor = noiseFactor
        
    def __call__(self, img):
        """
        Add noise to the given image based on the noise factor.
        
        Args:
        - img (torch.Tensor): Input image.
        
        Returns:
        - torch.Tensor: Noisy image.
        """

        #Add noise with a 50 percent chance
        chance = random.randint(1, 10) 
        if chance >= 6:
            return img
        
        #Add random noise to the image based on the noise factor.
        noisyImg = img + random.uniform(0, self.noiseFactor) * torch.randn(*img.shape)
        return noisyImg

class CircularCrop:
    """
    Callable class that crops an image in a circular shape.
    
    Attributes:
        centre (tuple): Coordinates of the center for cropping.
        rand (bool): Flag to determine if random cropping is enabled.
        radius (int): Radius for the circular crop.
        size (int): Size of the side for cropping.
        circular_array (numpy.ndarray): 2D binary array for cropping.
    """
    def __init__(self, radius = 30, size = 60, rand = False):
        self.centre = (size // 2, size // 2) 
        self.rand = rand
        self.radius = radius
        self.size = size
        
        #Generate a circular array of ones inside the specified radius.
        circularArray = createCircularArray(self.size, self.size, self.centre, self.radius)
        #Expand the dimensions of the circular array.
        self.circularArray = np.expand_dims(circularArray, 0)
        
    def __call__(self, img):
        """
        Crop the given image in a circular shape.
        
        Args:
            img (torch.Tensor): Input image.
        
        Returns:
            torch.Tensor: Cropped image.
        """

        #If random cropping is enabled:
        if self.rand:
            #Generate a random radius between the specified radius and 30.
            randRadius = random.randint(self.radius, self.size / 2)

            #Recreate the circular array based on the new radius.
            self.circularArray = createCircularArray(self.size, self.size, self.centre, randRadius)
            self.circularArray = np.expand_dims(self.circularArray, 0)

        #Multiply the image with the circular array to get the cropped area.
        data = img * self.circularArray
        return data.half()
    
class Bright:
    """
    Brightens an image by applying a brightness gradient at a random offset.
    
    Attributes:
        brightFactor (float): Factor to control the intensity of brightness.
        size (int): Size of the image.
    """

    def __init__(self, brightFactor, size = 60):
        self.brightFactor = brightFactor
        self.size = size
        
    def __call__(self, img):
        chance = random.randint(1, 10) 
        if chance >= 6:
            return img
        centerOffset = np.random.randint(-self.size  // 4, self.size  // 4, size=2)
        y, x = np.mgrid[-self.size  // 2 + centerOffset[0]:self.size  // 2 + centerOffset[0], 
                        -self.size  // 2 + centerOffset[1]:self.size  // 2 + centerOffset[1]]

        radius = np.sqrt(x ** 2 + y ** 2)
        gradient = np.exp(-2e-3 * random.uniform(0.1, 4) * radius)

        normalizedGradient = gradient / gradient.max()
        return img + normalizedGradient * random.uniform(-self.brightFactor, self.brightFactor)

class Normalise: 
    """
    Normalises an image so its values range between 0 and 1.
    """
    def __call__(self, img):
        return (img - torch.min(img)) / (torch.max(img) - torch.min(img))

class randomBlur:
    """
    Applies a Gaussian blur to an image with a random kernel size.
    
    Attributes:
        maxKernelSize (int): Maximum size for the Gaussian blur kernel.
    """
    def __init__(self, maxKernelSize = 5):
        self.maxKernelSize = maxKernelSize

    def __call__(self, img): 
        chance = random.randint(1, 10) 
        if chance >= 6:
            return img
        
        else:
            img = img.to(torch.float32)
            kernelSize = random.randint(1, self.maxKernelSize) 
            
            #Ensure an odd kernel size
            if kernelSize % 2 == 0:
                kernelSize -= 1  

            blurTransform = transforms.GaussianBlur(kernelSize, sigma=(0.1, 2.0))
            blurImg = blurTransform(img)
            return blurImg  
    
class BrightNoise:
    """
    Adds a random bright noise within a circular region in the image.
    
    Attributes:
        bright_factor (float): Intensity factor of brightness noise.
        max_radius (int): Maximum possible radius for the noise circle.
    """
    def __init__(self, bright_factor, max_radius = 30):
        self.bright_factor = bright_factor
        self.max_radius = max_radius
        
    def __call__(self, img):
        chance = random.randint(1, 10) 
        if chance >= 6:
            return img
        h, w = img.shape[1], img.shape[2]
        circle_center = (random.randint(self.max_radius, w - self.max_radius),
                         random.randint(self.max_radius, h - self.max_radius))
        circle_radius = random.randint(5,  self.max_radius)
        
        y, x = np.ogrid[:h, :w]
        
        mask = ((x - circle_center[0]) ** 2 + (y - circle_center[1]) ** 2) >= circle_radius ** 2
        noise = torch.Tensor(np.random.normal(size=(1, h, w))) * mask * self.bright_factor
        return noise + img
    
class addTo:
    """
    Increases or decreases the pixel values of an image by a random amount.
    
    Attributes:
        max_add (float): Maximum value to add or subtract from the image.
    """
    def __init__(self, max_add = 0.25):
        self.max_add = max_add
    
    def __call__(self, img):
        return img + random.uniform(-self.max_add, self.max_add)

class scaleBy: 
    """
    Scales the pixel values of an image by a random factor.
    
    Attributes:
        maxScale (float): Maximum scaling factor.
    """
    def __init__(self, maxScale = 0.5):
        self.maxScale = maxScale
    
    def __call__(self, img):
        scaleFactor = random.uniform(1 - self.maxScale, 1 / (1 - self.maxScale))
        return img * scaleFactor
    
class mixedTransform:
    """
    Applies a series of transformations to an image, converting it to a PIL image 
    and then back to a tensor after transformations.
    
    Attributes:
        transformList (list): List of transforms to apply to the image.
    """
    def __init__(self, transformList):
        self.transformList = transformList
    
    def __call__(self, img):
        mixedPilImage = to_pil_image(img)
        
        for transform in self.transformList:
            mixedPilImage = transform(mixedPilImage)
         
        mixed_tensor = transforms.ToTensor()(mixedPilImage)  
        return mixed_tensor

class setNaNToZero:
    """
    Transform that sets NaN values in a tensor to a specific replacement value.
    
    Attributes:
        replacementValue (float): Value to replace NaN values with. Defaults to 0.0.
    """
    def __init__(self, replacementValue = 0.0):
        """Initialize setNaNToZero class with the given replacement value."""
        self.replacementValue = replacementValue

    def __call__(self, img):
        """
        Call method to process the image tensor and replace NaN values.
        
        Args:
            img (torch.Tensor): Input image tensor.
        
        Returns:
            torch.Tensor: Processed image tensor.
        """
        # Check for NaN values in the image tensor
        mask = torch.isnan(img)
        
        # Replace NaN values with the specified replacement value
        img[mask] = self.replacementValue
        
        return img.float()
    

class StretchAndCropTransform:
    """
    Transform that applies random stretching to a tensor in both x and y axes, 
    then resizes to a standard size.
    
    Attributes:
        xStretch (int): Maximum factor to stretch in the x-axis.
        yStretch (int): Maximum factor to stretch in the y-axis.
        standard_size (int): Size to resize the image after stretching. Defaults to 144.
    """
    def __init__(self, xStretch = 2, yStretch = 2, standard_size = 144):
        """Initialize StretchAndCropTransform class with given stretch factors and standard size."""
        self.xStretch = xStretch
        self.yStretch = yStretch
        self.standard_size = standard_size
        
    def __call__(self, tensor):
        """
        Call method to process the tensor and apply random stretching and resizing.
        
        Args:
            tensor (torch.Tensor): Input tensor.
        
        Returns:
            torch.Tensor: Processed tensor.
        """
        chance = random.randint(1, 10) 
        if chance >= 6:
            return tensor
        originalWidth, originalHeight = tensor.shape[1], tensor.shape[2]
        
        xStretch = random.uniform(1,  self.xStretch) * originalWidth
        yStretch = random.uniform(1,  self.yStretch) * originalHeight
        
        padHeight = int(max((yStretch - xStretch), 0) / 2)
        padWidth= int(max((xStretch - yStretch), 0) / 2)
        
        padded = torch.nn.functional.pad(tensor.unsqueeze(0), (padWidth, padWidth, padHeight, padHeight), value=0)
                                         
        resized = F.interpolate(padded, self.standard_size, mode='bilinear', align_corners=True)
        resized = resized.squeeze(0)

        return resized


class CircularCrop:
    """
    Transform that crops an image tensor into a circle.
    
    Attributes:
        centre (tuple): Centre of the circle to crop.
        radius (int): Radius of the circle.
        size (int): Size of the circular mask.
        circular_array (torch.Tensor): The mask to apply circular cropping.
    """
    def __init__(self, radius = 30, size = 60):
        """Initialize CircularCrop class with given radius and size."""
        self.centre = (size // 2, size // 2) 
        self.radius = radius
        self.size = size
        
        self.circular_array = createCircularArray(self.size, self.size, self.centre, self.radius)
        
    def __call__(self, img):
        """
        Call method to process the tensor and apply the circular crop.
        
        Args:
            img (torch.Tensor): Input tensor.
        
        Returns:
            data (torch.Tensor): Cropped tensor.
        """
        data = img * self.circular_array
        return data.float()

class RandCircularCrop:
    """
    Transform that crops an image tensor into a circle with random radius.
    
    Attributes:
        centre (tuple): Centre of the circle to crop.
        radius (int): Base radius for the circle.
        size (int): Size of the circular mask.
        circular_array_list (list of torch.Tensor): List of masks for different radii.
    """
    
    def __init__(self, radius = 30, size = 60):
        self.centre = (size // 2, size // 2) 
        self.radius = radius
        self.size = size
        self.circular_array_list = []
        
        for i in range(radius, size // 2):
            circular_array = createCircularArray(self.size, self.size, self.centre, i)
            self.circular_array_list.append(circular_array)
            
    def __call__(self, img):
        """
        Call method to apply a circular crop of random radius.
        
        Args:
            img (torch.Tensor): Input tensor.
        
        Returns:
            data (torch.Tensor): Cropped tensor.
        """
        randRadius = random.randint(0, self.size // 2 - self.radius - 1)

        data = img * self.circular_array_list[randRadius]
        return data.float()

def randomTransform(transform, p=0.5):
    """
    A function to make a given transform apply randomly with a specified probability.
    
    Args:
        transform: The transformation function to be applied.
        p (float, optional): Probability of applying the transform. Default is 0.5.
    
    Returns:
        A randomly applied transformation object.
    """
    return transforms.RandomApply([transform], p=p)

def noiseReductionMedian(image, filterSize):
    """
    Apply a median filter to an image for noise reduction.
    
    Parameters:
        image (torch tensor): Input image that noise reduction will be applied to.
        filterSize (int): Size of the median filter to be applied.
    
    Returns:
        denoisedImage: Image after noise reduction using median filtering.
    """
    denoisedImage = median_filter(image, size = filterSize)
    return denoisedImage
