In [1]:
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import seaborn as sns


from google.colab import drive
drive.mount('/content/gdrive')


if torch.cuda.is_available():
  print("hello GPU")
else:
  print("sadge")


class BaseModel(nn.Module, ABC):
    def __init__(self):
        super().__init__()

    @abstractmethod
    def forward(self, x):
        pass

    @property
    def device(self):
        return next(self.parameters()).device

    def restore_checkpoint(self, ckpt_file, optimizer=None):
        if not ckpt_file:
            raise ValueError("No checkpoint file to be restored.")

        try:
            ckpt_dict = torch.load(ckpt_file)
        except RuntimeError:
            ckpt_dict = torch.load(ckpt_file,
                                   map_location=lambda storage, loc: storage)

        # Restore model weights
        self.load_state_dict(ckpt_dict['model_state_dict'])

        # Restore optimizer status if existing. Evaluation doesn't need this
        if optimizer:
            optimizer.load_state_dict(ckpt_dict['optimizer_state_dict'])

        # Return global step
        return ckpt_dict['global_step']

    def save_checkpoint(self,
                        directory,
                        global_step,
                        optimizer=None,
                        name=None):
        # Create directory to save to
        if not os.path.exists(directory):
            os.makedirs(directory)

        # Build checkpoint dict to save.
        ckpt_dict = {
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': optimizer.state_dict() if optimizer is not None else None,
            'global_step': global_step
        }

        # Save the file with specific name
        if name is None:
            name = "{}_{}_steps.pth".format(
                os.path.basename(directory),  # netD or netG
                global_step)

        torch.save(ckpt_dict, os.path.join(directory, name))

    def count_params(self):
        num_total_params = sum(p.numel() for p in self.parameters())
        num_trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)

        return num_total_params, num_trainable_params


class DBlock(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels):
        super().__init__()
        self.rnn = nn.GRU(in_channels, out_channels, num_layers=1, batch_first=True)
        self.activation = nn.LeakyReLU(0.2)

    def forward(self, x):
        h, _ = self.rnn(x)
        h = self.activation(h)
        h = h.permute(1, 2, 0)
        h = F.avg_pool1d(h, 2)
        h = h.permute(2, 0, 1)
        return h

class Discriminator(BaseModel):
    def __init__(self, **kwargs):
        super().__init__()
        self.count = 0
        self.errD_array = []
        self.bce = nn.BCELoss()
        self.ndf = 100

        self.sblock1 = DBlock(64, self.ndf)
        self.sblock2 = DBlock(self.ndf, self.ndf)
        self.sblock3 = DBlock(self.ndf, self.ndf)
        self.sblock4 = DBlock(self.ndf, self.ndf)
        self.sblock5 = DBlock(self.ndf, self.ndf)
        self.c = nn.GRU(self.ndf, 64, num_layers=1, batch_first=True)
        self.end = nn.Linear(197, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = x.float()
        h = self.sblock1(x)
        h = self.sblock2(h)
        h = self.sblock3(h)
        h = self.sblock4(h)
        h, _ = self.c(h)
        h = h.permute(1,2,0)
        h = self.end(h)
        h = self.sigmoid(h)
        return h.view(h.shape[0], 64)

    def compute_loss(self, output, actual):
        return self.bce(output, actual)


def train(open, closed, dupe=False):
  device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
  netD = Discriminator().to(device)
  optimizer = torch.optim.Adam(netD.parameters(), 0.0001, (0.5, 0.99))

  train_lossD = []
  train_accuracyD = []
  test_lossD = []
  test_accuracyD = []

  # Split data into training and testing sets
  train_open, test_open = train_test_split(open, test_size=0.2, random_state=42)
  train_closed, test_closed = train_test_split(closed, test_size=0.2, random_state=42)

  if (dupe):
      train_open = np.concatenate((train_open, train_open))
      train_closed = np.concatenate((train_closed, train_closed))

  # Move data to tensors and to device
  train_open = torch.tensor(train_open).float().to(device)
  train_closed = torch.tensor(train_closed).float().to(device)
  test_open = torch.tensor(test_open).float().to(device)
  test_closed = torch.tensor(test_closed).float().to(device)

  train_open = train_open.permute((2,0,1))
  train_closed = train_closed.permute((2,0,1))
  test_open = test_open.permute((2,0,1))
  test_closed = test_closed.permute((2,0,1))

  for epoch in range(60):
    # Training
    netD.zero_grad()

    train_output_open = netD.forward(train_open)
    train_output_closed = netD.forward(train_closed)

    train_vals_open = torch.full(train_output_open.shape, 1.0, dtype=torch.float, device=device)
    train_vals_closed = torch.full(train_output_closed.shape, 0.0, dtype=torch.float, device=device)

    train_pred_labels = torch.cat((train_output_open, train_output_closed))
    train_actual_labels = torch.cat((train_vals_open, train_vals_closed))

    train_accuracy = ((train_pred_labels > 0.5) == train_actual_labels).float().mean()
    train_accuracyD.append(train_accuracy.item())

    # Compute training loss
    train_errD_real = netD.compute_loss(train_output_open, train_vals_open)
    train_errD_fake = netD.compute_loss(train_output_closed, train_vals_closed)
    train_errD = train_errD_real + train_errD_fake
    train_errD.backward()
    optimizer.step()

    train_lossD.append(train_errD.item())

    # Testing
    with torch.no_grad():
      test_output_open = netD.forward(test_open)
      test_output_closed = netD.forward(test_closed)

      # Compute classification accuracy
      test_vals_open = torch.full(test_output_open.shape, 1.0, dtype=torch.float, device=device)
      test_vals_closed = torch.full(test_output_closed.shape, 0.0, dtype=torch.float, device=device)

      test_pred_labels = torch.cat((test_output_open, test_output_closed))
      test_actual_labels = torch.cat((test_vals_open, test_vals_closed))

      test_accuracy = ((test_pred_labels > 0.5) == test_actual_labels).float().mean()

      # Bin guess into T/F for confusion matrix
      test_output_open_cm = (test_output_open.mean(dim=1, keepdim=True) > 0.5).float()
      test_output_closed_cm = (test_output_closed.mean(dim=1, keepdim=True) > 0.5).float()
      test_vals_open_cm = (test_vals_open.mean(dim=1, keepdim=True) > 0.5).float()
      test_vals_closed_cm = (test_vals_closed.mean(dim=1, keepdim=True) > 0.5).float()

      # Compute testing loss
      test_errD_real = netD.compute_loss(test_output_open, test_vals_open)
      test_errD_fake = netD.compute_loss(test_output_closed, test_vals_closed)
      test_errD = test_errD_real + test_errD_fake

      test_accuracyD.append(test_accuracy.item())
      test_lossD.append(test_errD.item())

    # if epoch % 20 == 0:
      # print('Epoch', epoch)
      # print('Train Loss:', train_errD.item(), 'Test Loss:', test_errD.item())
      # print('Train Accuracy:', train_accuracy.item(), 'Test Accuracy:', test_accuracy.item())
      # plt.figure(figsize=(10,4))
      # plt.subplot(1, 2, 1)
      # plt.plot(train_lossD, label='Training Loss')
      # plt.plot(test_lossD, label='Testing Loss')
      # plt.legend()
      # plt.subplot(1, 2, 2)
      # plt.plot(train_accuracyD, label='Training Accuracy')
      # plt.plot(test_accuracyD, label='Testing Accuracy')
      # plt.legend()
      # plt.show()

  print('Final Test Accuracy:', test_accuracy.item())
  return test_accuracy.item()
  ##return torch.cat((test_output_open_cm, test_output_closed_cm)), torch.cat((test_vals_open_cm, test_vals_closed_cm)), netD


Mounted at /content/gdrive
hello GPU


In [8]:
open = np.load("/content/gdrive/My Drive/Research_Paper/Training/normalized-training-open-64ch.npy")
closed = np.load("/content/gdrive/My Drive/Research_Paper/Training/normalized-training-closed-64ch.npy")

test_100_accuracies = []

for i in range(100):
  accuracy = train(open, closed)
  test_100_accuracies.append(accuracy)

print("Avg test accuracy: ", np.mean(test_100_accuracies))



Final Test Accuracy: 0.7578125
Final Test Accuracy: 0.7482638955116272
Final Test Accuracy: 0.8012152910232544
Final Test Accuracy: 0.7690972089767456
Final Test Accuracy: 0.7934027910232544
Final Test Accuracy: 0.7717013955116272
Final Test Accuracy: 0.7699652910232544
Final Test Accuracy: 0.5607638955116272
Final Test Accuracy: 0.7673611044883728
Final Test Accuracy: 0.7682291865348816
Final Test Accuracy: 0.7760416865348816
Final Test Accuracy: 0.5130208134651184
Final Test Accuracy: 0.6970486044883728
Final Test Accuracy: 0.7907986044883728
Final Test Accuracy: 0.7734375
Final Test Accuracy: 0.7604166865348816
Final Test Accuracy: 0.8098958134651184
Final Test Accuracy: 0.7560763955116272
Final Test Accuracy: 0.7560763955116272
Final Test Accuracy: 0.7777777910232544
Final Test Accuracy: 0.7170138955116272
Final Test Accuracy: 0.733506977558136
Final Test Accuracy: 0.756944477558136
Final Test Accuracy: 0.7395833134651184
Final Test Accuracy: 0.6953125
Final Test Accuracy: 0.745659

In [15]:
open = np.load("/content/gdrive/My Drive/Research_Paper/Training/normalized-training-open-64ch.npy")
closed = np.load("/content/gdrive/My Drive/Research_Paper/Training/normalized-training-closed-64ch.npy")
open_generated1 = np.load("/content/gdrive/My Drive/Research_Paper/generated-data/generated-open-1.npy")
closed_generated1 = np.load("/content/gdrive/My Drive/Research_Paper/generated-data/generated-closed-1.npy")
open_generated2 = np.load("/content/gdrive/My Drive/Research_Paper/generated-data/generated-open-2.npy")
closed_generated2 = np.load("/content/gdrive/My Drive/Research_Paper/generated-data/generated-closed-2.npy")

open = np.concatenate((open, open_generated1))
closed = np.concatenate((closed, closed_generated1))
open = np.concatenate((open, open_generated2))
closed = np.concatenate((closed, closed_generated2))
test_100_accuracies = []

for i in range(100):
  accuracy = train(open, closed)
  test_100_accuracies.append(accuracy)

print("Avg test accuracy: ", np.mean(test_100_accuracies))



Final Test Accuracy: 0.8903301954269409
Final Test Accuracy: 0.8257665038108826
Final Test Accuracy: 0.8679245710372925
Final Test Accuracy: 0.8629127740859985
Final Test Accuracy: 0.8977004885673523
Final Test Accuracy: 0.8808962106704712
Final Test Accuracy: 0.895931601524353
Final Test Accuracy: 0.8808962106704712
Final Test Accuracy: 0.895636796951294
Final Test Accuracy: 0.8814858794212341
Final Test Accuracy: 0.8947523832321167
Final Test Accuracy: 0.8428655862808228
Final Test Accuracy: 0.8847287893295288
Final Test Accuracy: 0.9139150977134705
Final Test Accuracy: 0.8932783007621765
Final Test Accuracy: 0.8691037893295288
Final Test Accuracy: 0.8732311725616455
Final Test Accuracy: 0.8788325786590576
Final Test Accuracy: 0.8181014060974121
Final Test Accuracy: 0.8897405862808228
Final Test Accuracy: 0.9059551954269409
Final Test Accuracy: 0.8744103908538818
Final Test Accuracy: 0.8838443756103516
Final Test Accuracy: 0.8882665038108826
Final Test Accuracy: 0.8764740824699402
Fi

In [2]:
open = np.load("/content/gdrive/My Drive/Research_Paper/Training/normalized-training-open-64ch.npy")
closed = np.load("/content/gdrive/My Drive/Research_Paper/Training/normalized-training-closed-64ch.npy")

test_100_accuracies = []

for i in range(100):
  accuracy = train(open, closed, True)
  test_100_accuracies.append(accuracy)

print("Avg test accuracy: ", np.mean(test_100_accuracies))



Final Test Accuracy: 0.7395833134651184
Final Test Accuracy: 0.7152777910232544
Final Test Accuracy: 0.7352430820465088
Final Test Accuracy: 0.5737847089767456
Final Test Accuracy: 0.7769097089767456
Final Test Accuracy: 0.7083333134651184
Final Test Accuracy: 0.7855902910232544
Final Test Accuracy: 0.7942708134651184
Final Test Accuracy: 0.5078125
Final Test Accuracy: 0.796875
Final Test Accuracy: 0.796875
Final Test Accuracy: 0.6857638955116272
Final Test Accuracy: 0.7708333134651184
Final Test Accuracy: 0.7465277910232544
Final Test Accuracy: 0.764756977558136
Final Test Accuracy: 0.4982638955116272
Final Test Accuracy: 0.7586805820465088
Final Test Accuracy: 0.7204861044883728
Final Test Accuracy: 0.788194477558136
Final Test Accuracy: 0.7873263955116272
Final Test Accuracy: 0.6875
Final Test Accuracy: 0.8046875
Final Test Accuracy: 0.7673611044883728
Final Test Accuracy: 0.8090277910232544
Final Test Accuracy: 0.7690972089767456
Final Test Accuracy: 0.7204861044883728
Final Test A

In [2]:
from scipy import stats

real = [0.7578125, 0.7482638955116272, 0.8012152910232544, 0.7690972089767456, 0.7934027910232544, 0.7717013955116272, 0.7699652910232544, 0.5607638955116272, 0.7673611044883728, 0.7682291865348816, 0.7760416865348816, 0.5130208134651184, 0.6970486044883728, 0.7907986044883728, 0.7734375, 0.7604166865348816, 0.8098958134651184, 0.7560763955116272, 0.7560763955116272, 0.7777777910232544, 0.7170138955116272, 0.733506977558136, 0.756944477558136, 0.7395833134651184, 0.6953125, 0.7456597089767456, 0.7925347089767456, 0.741319477558136, 0.8142361044883728, 0.772569477558136, 0.8211805820465088, 0.7916666865348816, 0.8168402910232544, 0.7751736044883728, 0.7465277910232544, 0.7074652910232544, 0.7673611044883728, 0.7274305820465088, 0.741319477558136, 0.7526041865348816, 0.7777777910232544, 0.7907986044883728, 0.734375, 0.8177083134651184, 0.7213541865348816, 0.7534722089767456, 0.7751736044883728, 0.7317708134651184, 0.7630208134651184, 0.65625, 0.5963541865348816, 0.7934027910232544, 0.8203125, 0.7161458134651184, 0.796875, 0.7135416865348816, 0.7769097089767456, 0.734375, 0.6935763955116272, 0.8333333134651184, 0.6953125, 0.8177083134651184, 0.8107638955116272, 0.7222222089767456, 0.7699652910232544, 0.6996527910232544, 0.7951388955116272, 0.8151041865348816, 0.78125, 0.7638888955116272, 0.7552083134651184, 0.803819477558136, 0.7552083134651184, 0.7864583134651184, 0.7526041865348816, 0.8255208134651184, 0.7795138955116272, 0.8107638955116272, 0.7769097089767456, 0.8177083134651184, 0.7222222089767456, 0.65625, 0.8237847089767456, 0.796006977558136, 0.7907986044883728, 0.8385416865348816, 0.803819477558136, 0.7213541865348816, 0.7838541865348816, 0.7578125, 0.7387152910232544, 0.8003472089767456, 0.7352430820465088, 0.717881977558136, 0.7682291865348816, 0.6979166865348816, 0.7517361044883728, 0.8168402910232544, 0.7517361044883728, 0.7560763955116272]
real_generated = [0.8903301954269409, 0.8257665038108826, 0.8679245710372925, 0.8629127740859985, 0.8977004885673523, 0.8808962106704712, 0.895931601524353, 0.8808962106704712, 0.895636796951294, 0.8814858794212341, 0.8947523832321167, 0.8428655862808228, 0.8847287893295288, 0.9139150977134705, 0.8932783007621765, 0.8691037893295288, 0.8732311725616455, 0.8788325786590576, 0.8181014060974121, 0.8897405862808228, 0.9059551954269409, 0.8744103908538818, 0.8838443756103516, 0.8882665038108826, 0.8764740824699402, 0.8832547068595886, 0.9056603908538818, 0.8484669923782349, 0.8770636916160583, 0.8717570900917053, 0.911261796951294, 0.8920990824699402, 0.9000589847564697, 0.8581957817077637, 0.8770636916160583, 0.8864976763725281, 0.8590801954269409, 0.8764740824699402, 0.8974056839942932, 0.8923938870429993, 0.864386796951294, 0.834316074848175, 0.8596698045730591, 0.8929834961891174, 0.8720518946647644, 0.8929834961891174, 0.8688089847564697, 0.859375, 0.8455188870429993, 0.9053655862808228, 0.8944575786590576, 0.886202871799469, 0.9033018946647644, 0.9015330076217651, 0.885613203048706, 0.8829599022865295, 0.8599646091461182, 0.8953419923782349, 0.8920990824699402, 0.8888561725616455, 0.865566074848175, 0.8938679099082947, 0.9265919923782349, 0.8744103908538818, 0.8649764060974121, 0.8885613083839417, 0.8399174809455872, 0.8673349022865295, 0.9077240824699402, 0.8752948045730591, 0.859375, 0.8770636916160583, 0.8791273832321167, 0.8879716992378235, 0.8909198045730591, 0.8552476763725281, 0.8879716992378235, 0.916568398475647, 0.8882665038108826, 0.8920990824699402, 0.8879716992378235, 0.8752948045730591, 0.8947523832321167, 0.900943398475647, 0.8923938870429993, 0.849941074848175, 0.8561320900917053, 0.8688089847564697, 0.8867924809455872, 0.8478773832321167, 0.8817806839942932, 0.8900353908538818, 0.8219339847564697, 0.834316074848175, 0.9035966992378235, 0.8820754885673523, 0.8688089847564697, 0.854952871799469, 0.8808962106704712, 0.9103773832321167]
real_real = [0.7594339847564697, 0.803066074848175, 0.802181601524353, 0.8546580076217651, 0.8042452931404114, 0.7954009771347046, 0.834316074848175, 0.8378537893295288, 0.7435141801834106, 0.818691074848175, 0.7326061725616455, 0.8083726763725281, 0.8157429099082947, 0.860259473323822, 0.8723466992378235, 0.8581957817077637, 0.7800707817077637, 0.823113203048706, 0.7426297068595886, 0.8319575786590576, 0.8558372855186462, 0.8083726763725281, 0.8517099022865295, 0.7184551954269409, 0.7473466992378235, 0.7010613083839417, 0.8139740824699402, 0.8779481053352356, 0.7974646091461182, 0.8564268946647644, 0.8620283007621765, 0.8390330076217651, 0.8054245114326477, 0.8714622855186462, 0.7685731053352356, 0.8652712106704712, 0.8493514060974121, 0.7933372855186462, 0.8637971878051758, 0.8077830076217651, 0.7868514060974121, 0.7771226763725281, 0.880306601524353, 0.8068985939025879, 0.8242924809455872, 0.7959905862808228, 0.8920990824699402, 0.8077830076217651, 0.7620872855186462, 0.7839033007621765, 0.8319575786590576, 0.7880306839942932, 0.7576650977134705, 0.8576061725616455, 0.854068398475647, 0.849056601524353, 0.8051297068595886, 0.7600235939025879, 0.8410966992378235, 0.8729363083839417, 0.8352004885673523, 0.817511796951294, 0.8352004885673523, 0.8325471878051758, 0.7809551954269409, 0.8378537893295288, 0.7930424809455872, 0.7780070900917053, 0.7895047068595886, 0.7948113083839417, 0.8416863083839417, 0.7243514060974121, 0.739386796951294, 0.8122051954269409, 0.7535377740859985, 0.8098466992378235, 0.770931601524353, 0.8145636916160583, 0.8611438870429993, 0.8127948045730591, 0.7797759771347046, 0.7423349022865295, 0.8419811725616455, 0.8245872855186462, 0.7632665038108826, 0.885613203048706, 0.8767688870429993, 0.8201650977134705, 0.8163325786590576, 0.6188089847564697, 0.8508254885673523, 0.8316627740859985, 0.7438089847564697, 0.8372641801834106, 0.7785966992378235, 0.8304834961891174, 0.7724056839942932, 0.7774174809455872, 0.735259473323822, 0.8782429099082947]

# real vs generated/real
t_statistic, p_value = stats.ttest_rel(real, real_generated)
print(p_value)

# real vs realx2
t_statistic, p_value = stats.ttest_rel(real_real, real_generated)
print(p_value)


3.2656397012126445e-40
7.833380606096976e-26
