In [1]:
import os
import torch
import random
import numpy as np
import matplotlib.pyplot as plt

# Identify Largest Class

In [2]:
def additional_samples_to_be_created(base_dir):
    '''
    Test/DF - 10
    Test/PF - 27
    Test/Prick - 5
    
    sizes = additional_samples_to_be_created("Test")
    
    Output
    {'DF': 30, 'PF': 13, 'Prick': 35}
    
    Counts new samples to generate. Augments to make all classes 1.5*the largest classes orginal size
    '''
    
    largest_class_size = 0
    sizes = {}
    # Iterate through folders to identify the largest class and it's size
    for i in ["DF", "PF", "Prick"]:
        current_class_size = len(list(os.listdir(base_dir + i)))
        if largest_class_size < current_class_size:
            largest_class = i*1.5
            largest_class_size = current_class_size
        sizes[i] = current_class_size
    
    ""
    for i in ["DF", "PF", "Prick"]:
        sizes[i] = largest_class_size - sizes[i]
    return sizes

# Augmentation

 - Adding Gaps
 - Mixing and Matching Signals (Permutation)

In [3]:
def add_gaps(input_tensor):
    '''
    Adds 1 or 2 gaps to a tensor to create new
    Adapted for 1500 sample signal can be adjusted easily
    '''

    # Randomly decide whether to add one or two gaps to a signal
    number_of_gaps = random.randint(1,2)

    if number_of_gaps == 1:

        # Randoly determine where to put the gap and how large the gap is (between 50-150 samples)
        gap_length = random.randint(50,150)
        gap_start = random.randint(0,1499 - gap_length)

        # Add Gaps to make a new augmented signal
        new_file_1 = input_tensor.detach().clone()
        new_file_1[gap_start:gap_start + gap_length,:] = torch.zeros((gap_length,56))
    else:

        # Randomly decide the length of both gaps such that the signal will be 10% gaps
        gap_length_1 = random.randint(50,100)
        gap_length_2 = 150 - gap_length_1

        # Randomly determine where to put gaps such that they don't overlap
        gap_start_1 = random.randint(100,1499 - gap_length_1)
        gap_start_2 = random.randint(0,gap_start_1 - gap_length_2)

        # Add Gaps to make a new augmented signal
        new_file_1 = input_tensor.detach().clone()
        new_file_1[gap_start_1:gap_start_1 + gap_length_1,:] = torch.zeros((gap_length_1,56))
        new_file_1[gap_start_2:gap_start_2 + gap_length_2,:] = torch.zeros((gap_length_2,56))
        
    return new_file_1

In [4]:
def shift_signal(signal):
    '''
    Shifts the signal back or forth a few time steps filling in the unknown wiht a gap
    Adapted for 1500 sample signal can be adjusted easily
    '''

    # Decide amount of shift and direction
    shift = random.randint(50,200)
    gap = torch.zeros((shift,56))
    front_or_back = [-1,1][random.randint(0,1)]
    shift = front_or_back*shift

    # Create Augmented Shifted Signal
    if shift < 0: 
        augmented_signal = signal.detach().clone()
        augmented_signal = torch.cat((augmented_signal[-1*shift:],gap))
    else:
        augmented_signal = signal.detach().clone()
        augmented_signal = torch.cat((gap,augmented_signal[:1500-shift]))  
        
    return augmented_signal

In [5]:
def Generate_Augmented_Dataset(ratnum, output_dir):

    for i in ["DF", "PF", "Prick"]:

        # Base Directory for files
        base_dir = "Rat " + ratnum + "\\"

        # Determine the Augmented class size
        sizes = additional_samples_to_be_created(base_dir)

        # List of Files before augmentation
        files = list(os.listdir(base_dir + i))

        # Tracking 
        new_no = len(files)

        # Tracking to prevent repeated mix and matches
        combos = []

        for j in range((sizes[i])):

            new_no += 1
            
            if new_no%3 == 0: # Mix and Match Method

                file_1 = random.choices(files)
                file_2 = random.choices(files)

                if file_1 == file_2: # Ensure File_1 and File_2 are not the same
                    file_1 = random.choices(files)
                    file_2 = random.choices(files)

                if (file_1,file_2) in combos: # Check if combo has been used before
                    file_1 = random.choices(files)
                    file_2 = random.choices(files)

                combos.append((file_1,file_2))

                file_1 = torch.load(base_dir + i + "/" + file_1[0])
                file_2 = torch.load(base_dir + i + "/" + file_2[0])

                # Mix the first and last third of one signal with the second third of another
                # Signals are the same type and same rat
                new_file = torch.cat((file_2[:500],file_1[500:1000],file_2[1000:]),0)
                
            elif new_no%3 == 1: # Adding Gaps Method
                
                file_1 = random.choices(files)
                
                file_1 = torch.load(base_dir + i + "/" + file_1[0])
                
                new_file = add_gaps(file_1)
                
            elif new_no%3 == 2: # Adding Gaps Method
                
                file_1 = random.choices(files)
                
                file_1 = torch.load(base_dir + i + "/" + file_1[0])
                
                new_file = shift_signal(file_1)
            
            # Output File name
            file_name = "Rat" + ratnum + "_" + str(i) + "_" + str(new_no) + ".pt"

            # Ensure directory exists to save file
            if not os.path.isdir(output_dir):
                os.mkdir(output_dir)
            if not os.path.isdir(output_dir + "\\Rat " + ratnum):
                os.mkdir(output_dir + "\\Rat " + ratnum)
            if not os.path.isdir(output_dir + "\\Rat " + ratnum + "\\" + i):
                os.mkdir(output_dir + "\\Rat " + ratnum + "\\" + i)

            # Save New Augmented File
            torch.save(new_file, output_dir + "\\Rat " + ratnum + "\\" + i + "\\" + file_name)

In [7]:
for i in [2,3,4,5,6,7,8,9,10]:
    Generate_Augmented_Dataset(str(i), "Augmented_Dataset")

# Generate_Augmented_Dataset("Test", "Augmented_Dataset")