<a href="https://colab.research.google.com/github/KrajShuffle/ML_Audio_Models/blob/main/Nov_MaleWavFileClassification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Setup Modules & Datasets

In [1]:
!pip install praatio

Collecting praatio
  Downloading praatio-6.0.1-py3-none-any.whl (79 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.2/79.2 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: praatio
Successfully installed praatio-6.0.1


In [2]:
%cd /content/drive/MyDrive/Audio_SR_22050/
%pwd

/content/drive/MyDrive/Audio_SR_22050


'/content/drive/MyDrive/Audio_SR_22050'

In [3]:
import pandas as pd
import numpy as np
import librosa
import torch

### Converting 10/31 Male Model to Torchscript

In [4]:
from torch import nn
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv = nn.ModuleList([nn.Conv2d(in_channels= 1, out_channels= 16,kernel_size = 5, stride = 1, padding= 0),
                                   nn.Conv2d(in_channels= 16, out_channels= 32,kernel_size = 3, stride = 1, padding= 0),
                                   nn.Conv2d(in_channels=32, out_channels=48, kernel_size = 3, stride = 1, padding = 0)])
        self.bn = nn.ModuleList([nn.BatchNorm2d(16), nn.BatchNorm2d(32), nn.BatchNorm2d(48)])
        self.fc = nn.ModuleList([nn.Linear(528, 200), nn.Linear(200, 1)])
        self.pooling = nn.AvgPool2d(kernel_size = 3)
        self.activation = nn.ReLU()
        self.conv_dp = nn.Dropout2d(p = 0.2)
        self.fc_dp = nn.Dropout(p= 0.15)

    def forward(self, x):
        x = self.pooling(self.conv_dp(self.activation(self.bn[0](self.conv[0](x)))))
        x = self.pooling(self.conv_dp(self.activation(self.bn[1](self.conv[1](x)))))
        x = self.pooling(self.conv_dp(self.activation(self.bn[2](self.conv[2](x)))))
        x = x.view(x.size(0), -1) # Equivalent to x = nn.Flatten() (x) Flattening output neurons of final conv layer to be input to fc
        x = self.fc[0](x)
        x = self.fc_dp(self.activation(x))
        x = self.fc[1](x)
        return x

In [5]:
from torchsummary import summary
# Specify the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
updated_cnn = CNN().to(device)

# Input is the (number of channels, image height, image width) if input is 2d, essentially the 2 dimensions of it
s = summary(updated_cnn, (1,64,345))

cpu
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 16, 60, 341]             416
       BatchNorm2d-2          [-1, 16, 60, 341]              32
              ReLU-3          [-1, 16, 60, 341]               0
         Dropout2d-4          [-1, 16, 60, 341]               0
         AvgPool2d-5          [-1, 16, 20, 113]               0
            Conv2d-6          [-1, 32, 18, 111]           4,640
       BatchNorm2d-7          [-1, 32, 18, 111]              64
              ReLU-8          [-1, 32, 18, 111]               0
         Dropout2d-9          [-1, 32, 18, 111]               0
        AvgPool2d-10            [-1, 32, 6, 37]               0
           Conv2d-11            [-1, 48, 4, 35]          13,872
      BatchNorm2d-12            [-1, 48, 4, 35]              96
             ReLU-13            [-1, 48, 4, 35]               0
        Dropout2d-14            [-1

In [6]:
updated_cnn.load_state_dict(torch.load("Male_1sec_2FC_3conv16_32_48_psz3_64mels_64hoplen_1024nfft_5000fmax.pt", map_location= device))
model_scripted = torch.jit.script(updated_cnn)
model_scripted.save('11_1_Male_3conv2fc_CNN_model.pt')

### Importing in 1 Second Male Dataset


In [7]:
df_1sec = pd.read_csv("all_spectrify_SR_22050_slen_0_1_clen_1.csv", index_col = 0)
df_1sec_male = df_1sec[df_1sec['sex'] == "M"].copy()
df_1sec_male['class'] = df_1sec_male['class'].apply(lambda x: 1 if x == 0 else 0);
df_1sec_male

Unnamed: 0,correct_filename,ds_type,begin_time,end_time,class,sex,session
0,Train/0171017001_h_00.TextGrid,Train,1.453696,2.463696,1,M,ses1017
1,Train/0171017001_h_00.TextGrid,Train,2.463696,3.563696,1,M,ses1017
2,Train/0171017001_h_00.TextGrid,Train,3.563696,4.763696,1,M,ses1017
3,Train/0171017002_h_00.TextGrid,Train,0.776871,1.796871,1,M,ses1017
4,Train/0171017002_h_00.TextGrid,Train,1.796871,2.936871,1,M,ses1017
...,...,...,...,...,...,...,...
59019,Test/5824078030_h_00.TextGrid,Test,16.660000,17.800000,0,M,ses4078
59020,Test/5824078030_h_00.TextGrid,Test,18.430000,19.460000,0,M,ses4078
59021,Test/5824078030_h_00.TextGrid,Test,20.820000,21.840000,0,M,ses4078
59022,Test/5824078030_h_00.TextGrid,Test,21.840000,22.880000,0,M,ses4078


In [8]:
dict_labels = {"S" : 0, "I" : 1}

In [9]:
def create_df_ds(df_all_maps, ds_type):
   """ Options include "Train", "D1", "Test" & Code for Labels -> {"S" : 0, "I" : 1}
   """
   return df_all_maps[df_all_maps['ds_type'] == ds_type].set_index('correct_filename')

In [10]:
df_1sec_train = create_df_ds(df_1sec_male, "Train")
print(df_1sec_train.shape)
df_1sec_train.head()

(12128, 6)


Unnamed: 0_level_0,ds_type,begin_time,end_time,class,sex,session
correct_filename,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Train/0171017001_h_00.TextGrid,Train,1.453696,2.463696,1,M,ses1017
Train/0171017001_h_00.TextGrid,Train,2.463696,3.563696,1,M,ses1017
Train/0171017001_h_00.TextGrid,Train,3.563696,4.763696,1,M,ses1017
Train/0171017002_h_00.TextGrid,Train,0.776871,1.796871,1,M,ses1017
Train/0171017002_h_00.TextGrid,Train,1.796871,2.936871,1,M,ses1017


In [11]:
df_1sec_train['class'].value_counts()

0    9426
1    2702
Name: class, dtype: int64

In [12]:
df_1sec_val = create_df_ds(df_1sec_male, "D1")
print(df_1sec_val.shape)
df_1sec_val.head()

(8332, 6)


Unnamed: 0_level_0,ds_type,begin_time,end_time,class,sex,session
correct_filename,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Val/0261027001_h_01.TextGrid,D1,1.428073,2.508073,1,M,ses1027
Val/0261027001_h_01.TextGrid,D1,2.508073,3.648073,1,M,ses1027
Val/0261027001_h_01.TextGrid,D1,3.648073,4.698073,1,M,ses1027
Val/0261027016_h_00.TextGrid,D1,1.706508,2.776508,1,M,ses1027
Val/0261027016_h_00.TextGrid,D1,2.776508,3.896508,1,M,ses1027


In [13]:
df_1sec_val[df_1sec_val['class'] == 1].index.unique()

Index(['Val/0261027001_h_01.TextGrid', 'Val/0261027016_h_00.TextGrid',
       'Val/0261027002_h_00.TextGrid', 'Val/0261027017_h_00.TextGrid',
       'Val/0261027003_h_00.TextGrid', 'Val/0261027018_h_00.TextGrid',
       'Val/0261027019_h_00.TextGrid', 'Val/0261027004_h_00.TextGrid',
       'Val/0261027005_h_00.TextGrid', 'Val/0261027020_h_00.TextGrid',
       ...
       'Val/5963097009_h_00.TextGrid', 'Val/5963097017_h_00.TextGrid',
       'Val/5963097018_h_00.TextGrid', 'Val/5963097010_h_00.TextGrid',
       'Val/5963097019_h_00.TextGrid', 'Val/5963097020_h_00.TextGrid',
       'Val/5963097021_h_00.TextGrid', 'Val/5963097022_h_00.TextGrid',
       'Val/5963097024_h_00.TextGrid', 'Val/5963097025_h_00.TextGrid'],
      dtype='object', name='correct_filename', length=387)

In [14]:
df_1sec_val['class'].value_counts()

0    6254
1    2078
Name: class, dtype: int64

In [15]:
df_1sec_test = create_df_ds(df_1sec_male, "Test")
print(df_1sec_test.shape)
df_1sec_test.head()

(9118, 6)


Unnamed: 0_level_0,ds_type,begin_time,end_time,class,sex,session
correct_filename,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Test/0321033001_h_00.TextGrid,Test,0.566939,1.596939,1,M,ses1033
Test/0321033001_h_00.TextGrid,Test,1.596939,2.616939,1,M,ses1033
Test/0321033015_h_00.TextGrid,Test,0.536939,1.586939,1,M,ses1033
Test/0321033015_h_00.TextGrid,Test,1.586939,2.656939,1,M,ses1033
Test/0321033002_h_01.TextGrid,Test,0.35,1.6,1,M,ses1033


#### Spectrify Class & Size Normalization Function

In [16]:
def equal_specs(input_ar, des_shape):
  """ Since num_rows = num_mels, need to ensure consistent time chunks or equal num cols
  """
  if input_ar.shape[1] > des_shape[1]:
    input_ar = input_ar[:, :des_shape[1]]
  elif input_ar.shape[1] < des_shape[1]:
    # First tuple assigns padding along rows, which is not required
    # Second tuple assigns padding along columns, which is needed to reach 264 columns
    pad_width = [(0, 0), (0, des_shape[1] - input_ar.shape[1])]
    input_ar =  np.pad(input_ar, pad_width, mode='constant', constant_values=0)
  return input_ar

In [17]:
import librosa
from praatio import textgrid

#Define Spectrify class with parameters
class Spectrify:
    def __init__(self, fmin, fmax, nmels, hop_length, n_fft, silence_len, chunk_len, des_shape, nml_tech01):
        self.fmin = fmin
        self.fmax = fmax
        self.nmels = nmels
        self.hop_length = hop_length
        self.n_fft = n_fft
        self.silence_len = silence_len
        self.chunk_len = chunk_len
        self.desired_shape = des_shape # Tuple of required shape
        self.normal_tech_01 = nml_tech01


    # Define planner which extracts start and end times for each interval
    def planner(self, filename):
        tg = textgrid.openTextgrid(filename, False)
        entries = tg.tiers[0].entries
        entries = [(start, end, label) for start, end, label in entries]
        return self.phraser(entries, filename)

    # Define phraser, which creates target-len chunks that do not contain silence exceeding specified silence_len
    def phraser(self, entries, filename):
        phrases = []
        phrase_duration = 0
        current_phrase = []

        for start, end, label in entries:
            duration = end - start
            if label == "<p:>" and duration > self.silence_len: # pause
              current_phrase = []
              phrase_duration = 0
            # From original ALC, these are noise: ["<\"ah>", "<hm>", "<\"ahm>", "<hes>", "[sta]", "[int]", "[spk]", "<P>", "<PP>"]
            # In my copy, just need to make sure "<usb>" is counted as noise
            elif label == "<usb>": # noise (specific to KRAJ ALC Version)
              current_phrase = []
              phrase_duration = 0
            else: # Phoneme detected
              phrase_duration += duration # phoneme added to it
              current_phrase.append((start, end, label))

              if phrase_duration >= self.chunk_len: # See if accumulated phonemes exceed limit
                phrases.append(current_phrase)
                current_phrase = [] # Reset current phrase and try other potential phrases in textgrid
                phrase_duration = 0

        return phrases

    # Return spectrogram for chunk specified by parameters
    def spectrify(self, filename, beginning, end):
        filename = filename.replace("TextGrid", "wav")
        length = end - beginning
        y, sr = librosa.load(filename, offset=beginning, duration=length, sr = 22_050)

        S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=self.nmels, n_fft=self.n_fft, hop_length=self.hop_length,
                                               fmin=self.fmin, fmax=self.fmax)
        S_db = librosa.power_to_db(S)
        if self.normal_tech_01 == True:
          normalized_spec = ((S_db - np.min(S_db)) / (np.max(S_db) - np.min(S_db))) # normalizing between 0 and 1
        else:
          normalized_spec = 2*((S_db - np.min(S_db)) / (np.max(S_db) - np.min(S_db))) + -1 # normalizing between -1 and 1
        normalized_spec = equal_specs(normalized_spec, self.desired_shape) # normalizing size of retrieved chunks
        return normalized_spec

In [18]:
# Creating Spectrify Obj Dependent on Whether User is Male/Female:
def create_spectrify_obj(Gender):
  # Unique Spectrogram Generation hyperparameters for Male & Female
  if Gender == "M":
    spec_params = {'fmin' : 100, 'fmax' : 5000, 'nfft' : 1024, 'hoplen' : 64, 'nmels' : 64,
               'silen': 0.1, 'clen' : 1, 'nml_tech01' : False}
  else:
    spec_params = {'fmin' : 100, 'fmax' : 5000, 'nfft' : 512, 'hoplen' : 64, 'nmels' : 64,
               'silen': 0.1, 'clen' : 1, 'nml_tech01' : False}

  spec_params['des_shape'] = (spec_params['nmels'], int(22_050 / spec_params['hoplen']) + 1)
  spectrify_obj = Spectrify(fmin=spec_params['fmin'], fmax=spec_params['fmax'], nmels= spec_params['nmels'],
                              hop_length= spec_params['hoplen'], n_fft=spec_params['nfft'], silence_len=spec_params['silen'],
                          chunk_len= spec_params['clen'], des_shape = spec_params['des_shape'], nml_tech01= spec_params['nml_tech01'])
  return spectrify_obj
spect_male = create_spectrify_obj("M")
spect_male

<__main__.Spectrify at 0x7f2d4aee7070>

Extracting chunk mappings from a random user-provided TextGrid

In [19]:
import random
# file string with only one chunk: Val/0261027004_h_00.TextGrid
rand_file_str = random.choice(df_1sec_test.index)
rand_file_str

'Test/5264027010_h_00.TextGrid'

In [20]:
def create_df_chunk_mapping(file_str, spect_obj):
  """ Expecting an input Textgrid filename string & a male/female spectrify obj
  """
  list_chunks = []
  #for file_str in df_ds.index:
  phrases = spect_obj.planner(file_str)
  if len(phrases) > 0:
    for phrase in phrases:
      begin_pt, end_pt = phrase[0][0], phrase[-1][1]
      list_chunks.append((file_str, begin_pt, end_pt))
  df_chunks_ds = pd.DataFrame(list_chunks, columns = ['filename', 'begin_time', 'end_time'])
  return df_chunks_ds

In [21]:
df_chunks_file = create_df_chunk_mapping(rand_file_str, spect_male)
df_chunks_file

Unnamed: 0,filename,begin_time,end_time
0,Test/5264027010_h_00.TextGrid,3.550181,4.580181
1,Test/5264027010_h_00.TextGrid,4.930181,5.930181
2,Test/5264027010_h_00.TextGrid,9.220181,10.280181
3,Test/5264027010_h_00.TextGrid,10.280181,11.320181
4,Test/5264027010_h_00.TextGrid,13.490181,14.510181
5,Test/5264027010_h_00.TextGrid,16.560181,17.850181
6,Test/5264027010_h_00.TextGrid,22.500159,23.720181
7,Test/5264027010_h_00.TextGrid,23.720181,24.760181
8,Test/5264027010_h_00.TextGrid,27.480181,28.480181
9,Test/5264027010_h_00.TextGrid,28.480181,29.570181


#### Audio Dataset Class

In [22]:
def cnn_reshape(input_arr):
  return input_arr.reshape(1, input_arr.shape[0], input_arr.shape[1])

In [23]:
from torch.utils.data import Dataset

class Audio_DS(Dataset):
    def __init__(self, data, spectrify_obj, device = "cuda"):

        self.df_mapping = data # Loading in dataframe of filenames as index and associated class
        self.device = device
        self.spectrify_obj = spectrify_obj # Using predefined spectrify object with audio parameters
        self.file_strs = self.df_mapping['filename'].values
        self.begin_pts = self.df_mapping['begin_time'].values
        self.end_pts = self.df_mapping['end_time'].values

    def __len__(self):
      # index is derived from Weighted Random Sampler which is using max number as number of chunks
      return len(self.file_strs)

    def __getitem__(self, idx):
      chunk_filename = self.file_strs[idx]
      chunk_start_pt = self.begin_pts[idx]
      chunk_end_pt = self.end_pts[idx]
      # Transform filename input into spectrogram & convert spectrogram shape into shape (num_channels, img_height, img_width)
      arr_spec = self.spectrify_obj.spectrify(chunk_filename, chunk_start_pt, chunk_end_pt) # Creating Standardized spectrogram array
      arr_spec = cnn_reshape(arr_spec) # Transform into (num_channels, img_height, img_width)

      return arr_spec

### Model Architecture for 1 Second

In [24]:
CNN_1sec = torch.jit.load("11_1_Male_3conv2fc_CNN_model.pt")
CNN_1sec

RecursiveScriptModule(
  original_name=CNN
  (conv): RecursiveScriptModule(
    original_name=ModuleList
    (0): RecursiveScriptModule(original_name=Conv2d)
    (1): RecursiveScriptModule(original_name=Conv2d)
    (2): RecursiveScriptModule(original_name=Conv2d)
  )
  (bn): RecursiveScriptModule(
    original_name=ModuleList
    (0): RecursiveScriptModule(original_name=BatchNorm2d)
    (1): RecursiveScriptModule(original_name=BatchNorm2d)
    (2): RecursiveScriptModule(original_name=BatchNorm2d)
  )
  (fc): RecursiveScriptModule(
    original_name=ModuleList
    (0): RecursiveScriptModule(original_name=Linear)
    (1): RecursiveScriptModule(original_name=Linear)
  )
  (pooling): RecursiveScriptModule(original_name=AvgPool2d)
  (activation): RecursiveScriptModule(original_name=ReLU)
  (conv_dp): RecursiveScriptModule(original_name=Dropout2d)
  (fc_dp): RecursiveScriptModule(original_name=Dropout)
)

### Code for Generating Predictions on Chunks

In [25]:
from torch.utils.data import DataLoader
# Val Datasets & DataLoaders
def gen_chunk_preds(spect_obj, df_val, device, model):
  """ Input is Male/Female Spectrify Obj, dataframe of chunk time mappings, device: CPU or GPU, Male/Female Model
  """
  val_ds = Audio_DS(df_val,spect_obj, device = device)
  val_data_loader = DataLoader(val_ds, 32, shuffle = False, num_workers = 2, prefetch_factor= 4, drop_last = False)
  # Setting model to eval to switch off dropout
  all_preds = []
  all_probs = []
  dict_probs_preds = {}
  model.eval()
  for batch in val_data_loader:
    with torch.no_grad():
      batch_logits = model(batch)
      batch_probs = torch.squeeze(torch.sigmoid(batch_logits))
      batch_preds = (batch_probs >= 0.5).float()

    if torch.numel(batch_preds) == 1: # In case only 1 chunk can be extracted from audio file
        all_preds.append(batch_preds.item())
        all_probs.append(batch_probs.item())
    else:                             # If more than 1 chunk can be extracted from audio file
        all_preds.extend(batch_preds.tolist())
        all_probs.extend(batch_probs.tolist())
  dict_probs_preds['preds'] = np.array(all_preds)
  dict_probs_preds['probs'] = np.array(all_probs)
  return dict_probs_preds


In [26]:
dict_predobs = gen_chunk_preds(spect_male, df_chunks_file, device, CNN_1sec)

In [27]:
# Chunk prediction made by seeing
print("All chunk predictions from file:", dict_predobs['preds'])

# All probabilities should be interpreted as probability of being intoxicated
print("All chunk probs of Intoxicated:", dict_predobs['probs'])


All chunk predictions from file: [1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 1. 1. 0. 0. 1. 0. 0. 0. 0.]
All chunk probs of Intoxicated: [0.56128865 0.35295668 0.20208196 0.29619214 0.26295063 0.24045829
 0.52506196 0.17962949 0.13694027 0.24404563 0.19101156 0.64199513
 0.49553871 0.54675531 0.52353907 0.49115264 0.40180776 0.55739623
 0.49504736 0.45286164 0.30234396 0.1592734 ]


In [28]:
df_acfile_values = df_1sec_male.set_index('correct_filename').loc[rand_file_str]
if isinstance(df_acfile_values, pd.Series):
  df_acfile_values = pd.DataFrame(df_acfile_values).T
df_acfile_values

Unnamed: 0_level_0,ds_type,begin_time,end_time,class,sex,session
correct_filename,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Test/5264027010_h_00.TextGrid,Test,3.550181,4.580181,0,M,ses4027
Test/5264027010_h_00.TextGrid,Test,4.930181,5.930181,0,M,ses4027
Test/5264027010_h_00.TextGrid,Test,9.220181,10.280181,0,M,ses4027
Test/5264027010_h_00.TextGrid,Test,10.280181,11.320181,0,M,ses4027
Test/5264027010_h_00.TextGrid,Test,13.490181,14.510181,0,M,ses4027
Test/5264027010_h_00.TextGrid,Test,16.560181,17.850181,0,M,ses4027
Test/5264027010_h_00.TextGrid,Test,22.500159,23.720181,0,M,ses4027
Test/5264027010_h_00.TextGrid,Test,23.720181,24.760181,0,M,ses4027
Test/5264027010_h_00.TextGrid,Test,27.480181,28.480181,0,M,ses4027
Test/5264027010_h_00.TextGrid,Test,28.480181,29.570181,0,M,ses4027


In [29]:
# Check if number of chunk predictions matches expected number of chunks extracted from file
assert(df_acfile_values.shape[0] == len(dict_predobs['preds']))

In [30]:
unique_vals, unique_counts = np.unique(dict_predobs['preds'], return_counts = True)
print(unique_vals, unique_counts)

[0. 1.] [16  6]


In [31]:
np.where(dict_predobs['preds'] == 1)[0]
dict_predobs['probs'][np.where(dict_predobs['preds'] == 1)[0]]  # Probs of Intoxicated


array([0.56128865, 0.52506196, 0.64199513, 0.54675531, 0.52353907,
       0.55739623])

In [32]:
dict_labels

{'S': 0, 'I': 1}

In [33]:
def tg_class_pred(dict_preds_probs):
  """ Returns predicted class for a given textgrid file based on chunk class votes outputted by CNN
      Also, returns pseudo probabilities if file prediction is intoxicated (1)
      Labeling_scheme: 0 -> Sober, 1 -> Intox"""

  unique_vals, unique_counts = np.unique(dict_preds_probs['preds'], return_counts = True)
  if len(unique_vals) == 2: # If there is at least 1 vote for both sober and intoxicated classes
    num_zeros, num_ones = unique_counts[0], unique_counts[1]
    if num_zeros > num_ones: # More sober votes than intox votes, so pred = sober
      pred = 0
      return pred
    elif num_ones == num_zeros: # Equal sober and intox votes, have to side with caution so pred = intox
      pred = 1
      all_probs_intoxicated = dict_preds_probs['probs'][np.where(dict_preds_probs['preds'] == 1)[0]]
      average_prob_intoxicated = np.mean(all_probs_intoxicated)
      return pred, average_prob_intoxicated
    else: # Final case of more intox than sober votes, pred = intox
      pred = 1
      all_probs_intoxicated = dict_preds_probs['probs'][np.where(dict_preds_probs['preds'] == 1)[0]]
      average_prob_intoxicated = np.mean(all_probs_intoxicated)
      return pred, average_prob_intoxicated # Return prediction of intox and average probability of it
  elif len(unique_vals) == 1: # If there is only 1 vote for either sober or intoxicated classes
    pred = unique_vals[0] # Capture predicted class for sole chunk
    if pred == 1: # If prediction is intoxicated
      all_probs_intoxicated = dict_preds_probs['probs'].item()
      return pred, all_probs_intoxicated # Return pred of 1 & probability of intoxicated
    else:
      return pred # Return sober (0) if sole chunk outputted sober (0)

In [34]:
tg_predicted_class = tg_class_pred(dict_predobs)
tg_predicted_class

0

In [35]:
actual_class = df_acfile_values['class'][0]
actual_class

0