In [8]:
import os
import re
import imageio as iio
from pathlib import Path
import numpy as np
from glob import glob
import matplotlib.cm as cm
import matplotlib.pyplot as plt

#AIRXD model import

from airxd_cnn.cnn import ARIXD_CNN as cmodel
from airxd.model import ARIXD
from airxd.dataset import Dataset, parse_imctrl
from torchvision.transforms import v2
import torch

In [40]:
import re
import airxd_cnn
from airxd_cnn.transforms import powder_normalize
from airxd_cnn.powder_dataset_v2 import powder_dset
from sklearn.metrics import confusion_matrix as CM
import shutil
#In the line below, filter out for any names with 'normalized'

# Helper functions

In [63]:
def valid_pred(model, train_classes, valid_path,
               results, repeat_idx):
    '''
    This runs a simple model validation for one trained model across each of the different
    experiments (classes). 

    Outputs:
        valid_acc: (n) array of validation accuracies for n classes
    '''

    for i in range(len(train_classes)):
        class_valid_paths = [f for f in valid_path if f.split('/')[1] == train_classes[i]]
        class_valid_masks = [os.path.join(os.path.dirname(path), "mask",
                                          os.path.basename(path).replace(".tif", "_mask.tif")) for path in class_valid_paths]
        
        
        for image_idxs in range(len(class_valid_paths)):
            #Get image and label. The predict method normalizes
            image_file = class_valid_paths[image_idxs]
            label_file = class_valid_masks[image_idxs]
            labels = np.array(iio.v2.volread(label_file))

            #Predict (noprmalize built in)
            pred = model.predict_old(image_file)

            #Calculate CM and append to results
            single_cm = calculate_CM(pred, labels)

            #Add to results
            results[i,repeat_idx,:] += single_cm

def calculate_CM(pred, labels):
    matrix = CM(labels.ravel(), pred.ravel())
    matrix_flat = matrix.ravel()
    return matrix_flat

In [57]:
def create_train_and_val_subsets(directory_train, class_paths, subset_size):
    #Path to normalized image data for training. Normalization was done beforehand
    train_image_paths = []
    #Path to mask data for training, in a separate location than the normalized image data
    train_mask_paths = []
    #Path to raw image data for validation
    val_paths = []

    for path_to_class in class_paths:
        path_identifier = path_to_class + '/' + '*.tif'

        #Find the normalized sample path
        sample_paths = glob(path_identifier)
        sample_norm_paths = [os.path.join(directory_train,
                                          os.path.basename(path).replace(".tif", "_norm.tif")) for path in sample_paths]

        #Find mask path
        mask_paths = [f for f in glob(path_to_class + '/masks/*.tif')]

        #Sort sample_paths and mask_paths to ensure uniformity
        sample_paths.sort(key=lambda x: x.split('/')[-1])
        sample_norm_paths.sort(key=lambda x: x.split('/')[-1])
        mask_paths.sort(key=lambda x: x.split('/')[-1].split('_mask')[0])
        
        #Generate random index subset
        n_images = len(sample_paths)
        train_subset_indices = list(np.random.choice(n_images, subset_size, replace=False))
        val_subset_indices = list(set(range(n_images)) - set(train_subset_indices))

        #Append correct paths to train image and masks, as well as validation paths
        train_image_paths.extend([sample_norm_paths[i] for i in train_subset_indices])
        train_mask_paths.extend([mask_paths[i] for i in train_subset_indices])
        val_paths.extend([sample_paths[i] for i in val_subset_indices])

    return train_image_paths, train_mask_paths, val_paths


In [58]:

#Creates normalized images if they don't exist, does nothing if they do
train_image_paths, train_mask_paths, val_paths = create_train_and_val_subsets(normalized_train_dir,class_paths,subset_size = 3)


In [34]:
from airxd_cnn.transforms import RandomRotation, RandomFlip
from torch.utils.data import Subset, DataLoader
from airxd_cnn.model import ARIXD_CNN

In [67]:
#Define transforms
from airxd_cnn.powder_dataset_v2 import powder_dset

#List of acceptable classes for training
train_classes = ['Nickel', 'battery1', 'battery2', 'battery3', 'battery4']

#Define class_paths as any directories in classes

class_paths = [f for f in glob('data/*') if f.split('/')[-1] in train_classes]
mask_paths = [f for f in glob('data/*/masks/*.tif') if f.split('/')[1] in train_classes]

#Normalize images within each class folder and save in separate subfolder
normalized_train_dir = 'data/normalized_train'

#Transforms
transform_pipeline = v2.Compose([
    RandomRotation(),
    RandomFlip()
])

#Other relevant params
other_params = {"transforms": transform_pipeline,
          "input_map_path": 'data/input_mmap',
          "target_map_path": 'data/target_mmap',
          "device": 'cuda:0',
          "minority_threshold": 30,
          "create_memmap": True}

#Model parameters
training_params = {'device': 'cuda:0',
                   'amp': False,
                   'clip_value': None,
                   'epoch': 20,
                   'batch_size': batch_size,
                   'shuffle': True,
                   'drop_last': True,
                   'lr_rate': 1e-2,
                   'weights': [1.0, 10.0],
                   'save_path': 'models_scratch'}


In [68]:
#Hyperparameters for paper replication
N_sizes = [128, 256]
N_epochs = [10, 20, 30]
results_master = {}
n_repeats = 5
batch_size = 64

Loop everything here and save results temporarily

In [None]:
for N in N_sizes:
    for epoch in N_epochs:
        
        #Quilter parameters
        side_crop = 20
        image_d = 2880
        M = N // 2
        B = M // 4
        quilter_params = {'Y': image_d - 2*side_crop,
                        'X': image_d - 2*side_crop,
                        'window': (N, N),
                        'step': (M, M),
                        'border': (B, B),
                        'border_weight': 0,
                        'crop': side_crop}
        
        #Validation predictions
        results = np.zeros((len(train_classes),n_repeats,4))
        
        for i in range(n_repeats):
        
            #Dataset randomization for training
            train_image_paths, train_mask_paths, val_paths = create_train_and_val_subsets(normalized_train_dir,class_paths,subset_size = 3) 
            #Sorting to ensure 1:1 correspondence
            train_image_paths.sort(key=lambda x: x.split('/')[-1].split('_norm')[0])
            train_mask_paths.sort(key=lambda x: x.split('/')[-1].split('_mask')[0])
            
            #Create dataset object
            train_dataset = powder_dset(train_image_paths, train_mask_paths, quilter_params, **other_params)

            #Create dataloader
            train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True, drop_last=True)

            #Model specifications
            #Define TUNet params (from dlsia)
            model_params = {'image_shape': (N, N),
                    'in_channels': 1,
                    'out_channels': 2,
                    'base_channels': 8,
                    'growth_rate': 2,
                    'depth': 4}
            
            #Setting model params
            training_params['epoch'] = epoch

            #Model training
            cnn_model = ARIXD_CNN(model_params, training_params, quilter_params)

            #Don't care about validation here, because we evaluate afterwards
            cnn_model.train(train_loader, train_loader)

            #Validation predictions
            valid_pred(cnn_model, train_classes, val_paths,
                       results, i)
            
        tn = results[:,:,0] / (results[:,:,0] + results[:,:,1] + 1)
        tp = results[:,:,3] / (results[:,:,3] + results[:,:,2] + 1)

        results_master[(N, epoch)] = np.stack((tn, tp), axis = 2)


        
        
