In [1]:
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import seaborn as sns
from scipy.linalg import sqrtm


from google.colab import drive
drive.mount('/content/gdrive')


if torch.cuda.is_available():
  print("hello GPU")
else:
  print("sadge")


class BaseModel(nn.Module, ABC):
    def __init__(self):
        super().__init__()

    @abstractmethod
    def forward(self, x):
        pass

    @property
    def device(self):
        return next(self.parameters()).device

    def restore_checkpoint(self, ckpt_file, optimizer=None):
        if not ckpt_file:
            raise ValueError("No checkpoint file to be restored.")

        try:
            ckpt_dict = torch.load(ckpt_file)
        except RuntimeError:
            ckpt_dict = torch.load(ckpt_file,
                                   map_location=lambda storage, loc: storage)

        # Restore model weights
        self.load_state_dict(ckpt_dict['model_state_dict'])

        # Restore optimizer status if existing. Evaluation doesn't need this
        if optimizer:
            optimizer.load_state_dict(ckpt_dict['optimizer_state_dict'])

        # Return global step
        return ckpt_dict['global_step']

    def save_checkpoint(self,
                        directory,
                        global_step,
                        optimizer=None,
                        name=None):
        # Create directory to save to
        if not os.path.exists(directory):
            os.makedirs(directory)

        # Build checkpoint dict to save.
        ckpt_dict = {
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': optimizer.state_dict() if optimizer is not None else None,
            'global_step': global_step
        }

        # Save the file with specific name
        if name is None:
            name = "{}_{}_steps.pth".format(
                os.path.basename(directory),  # netD or netG
                global_step)

        torch.save(ckpt_dict, os.path.join(directory, name))

    def count_params(self):
        num_total_params = sum(p.numel() for p in self.parameters())
        num_trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)

        return num_total_params, num_trainable_params




class Discriminator(BaseModel):
    def __init__(self, **kwargs):
        super().__init__()
        self.count = 0
        self.errD_array = []
        self.bce = nn.BCELoss()
        self.ndf = 100

        self.sblock1 = nn.Linear(3152,2700)
        self.relu = nn.ReLU(True)
        self.sblock2 = nn.Linear(2700,2000)
        self.sblock3 = nn.Linear(2000,1500)
        self.sblock4 = nn.Linear(1500,1100)
        self.sblock5 = nn.Linear(1100,512)
        self.sblock8 = nn.Linear(512,128)
        self.sblock9 = nn.Linear(128,64)
        self.sblock10 = nn.Linear(64,1)
        self.dropout = nn.Dropout(0.2)

        # Final classification layer
        self.sigmoid = nn.Sigmoid()


    def forward(self, x):
        h = self.dropout(self.relu(self.sblock1(x)))
        h = self.dropout(self.relu(self.sblock2(h)))
        h = self.dropout(self.relu(self.sblock3(h)))
        h = self.dropout(self.relu(self.sblock4(h)))
        h = self.dropout(self.relu(self.sblock5(h)))
        h = self.dropout(self.relu(self.sblock8(h)))
        h = self.relu(self.sblock9(h))
        h = self.sblock10(h)
        h = self.sigmoid(h)
        return h.view(h.shape[0], 64)

    def compute_loss(self, output, actual):
        return self.bce(output, actual)

def train(open, closed, dupe=False):
  device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
  netD = Discriminator().to(device)
  optimizer = torch.optim.Adam(netD.parameters(), 0.0001, (0.5, 0.99))

  train_lossD = []
  train_accuracyD = []
  test_lossD = []
  test_accuracyD = []

  # Split data into training and testing sets
  train_open, test_open = train_test_split(open, test_size=0.2, random_state=42)
  train_closed, test_closed = train_test_split(closed, test_size=0.2, random_state=42)

  if (dupe):
      train_open = np.concatenate((train_open, train_open))
      train_closed = np.concatenate((train_closed, train_closed))

  # Move data to tensors and to device
  train_open = torch.tensor(train_open).float().to(device)
  train_closed = torch.tensor(train_closed).float().to(device)
  test_open = torch.tensor(test_open).float().to(device)
  test_closed = torch.tensor(test_closed).float().to(device)

  for epoch in range(20):
    # Training
    netD.zero_grad()

    train_output_open = netD.forward(train_open)
    train_output_closed = netD.forward(train_closed)

    train_vals_open = torch.full(train_output_open.shape, 1.0, dtype=torch.float, device=device)
    train_vals_closed = torch.full(train_output_closed.shape, 0.0, dtype=torch.float, device=device)

    train_pred_labels = torch.cat((train_output_open, train_output_closed))
    train_actual_labels = torch.cat((train_vals_open, train_vals_closed))

    train_accuracy = ((train_pred_labels > 0.5) == train_actual_labels).float().mean()
    train_accuracyD.append(train_accuracy.item())

    # Compute training loss
    train_errD_real = netD.compute_loss(train_output_open, train_vals_open)
    train_errD_fake = netD.compute_loss(train_output_closed, train_vals_closed)
    train_errD = train_errD_real + train_errD_fake
    train_errD.backward()
    optimizer.step()

    train_lossD.append(train_errD.item())

    # Testing
    with torch.no_grad():
      test_output_open = netD.forward(test_open)
      test_output_closed = netD.forward(test_closed)

      # Compute classification accuracy
      test_vals_open = torch.full(test_output_open.shape, 1.0, dtype=torch.float, device=device)
      test_vals_closed = torch.full(test_output_closed.shape, 0.0, dtype=torch.float, device=device)

      test_pred_labels = torch.cat((test_output_open, test_output_closed))
      test_actual_labels = torch.cat((test_vals_open, test_vals_closed))

      test_accuracy = ((test_pred_labels > 0.5) == test_actual_labels).float().mean()

      # Bin guess into T/F for confusion matrix
      test_output_open_cm = (test_output_open.mean(dim=1, keepdim=True) > 0.5).float()
      test_output_closed_cm = (test_output_closed.mean(dim=1, keepdim=True) > 0.5).float()
      test_vals_open_cm = (test_vals_open.mean(dim=1, keepdim=True) > 0.5).float()
      test_vals_closed_cm = (test_vals_closed.mean(dim=1, keepdim=True) > 0.5).float()

      # Compute testing loss
      test_errD_real = netD.compute_loss(test_output_open, test_vals_open)
      test_errD_fake = netD.compute_loss(test_output_closed, test_vals_closed)
      test_errD = test_errD_real + test_errD_fake

      test_accuracyD.append(test_accuracy.item())
      test_lossD.append(test_errD.item())

    # if epoch % 20 == 0:
    #   print('Epoch', epoch)
    #   print('Train Loss:', train_errD.item(), 'Test Loss:', test_errD.item())
    #   print('Train Accuracy:', train_accuracy.item(), 'Test Accuracy:', test_accuracy.item())
    #   plt.figure(figsize=(10,4))
    #   plt.subplot(1, 2, 1)
    #   plt.plot(train_lossD, label='Training Loss')
    #   plt.plot(test_lossD, label='Testing Loss')
    #   plt.legend()
    #   plt.subplot(1, 2, 2)
    #   plt.plot(train_accuracyD, label='Training Accuracy')
    #   plt.plot(test_accuracyD, label='Testing Accuracy')
    #   plt.legend()
    #   plt.show()

  print('Final Test Accuracy:', test_accuracy.item())
  ##return torch.cat((test_output_open_cm, test_output_closed_cm)), torch.cat((test_vals_open_cm, test_vals_closed_cm)), netD
  return test_accuracy.item()


Mounted at /content/gdrive
hello GPU


In [None]:
open = np.load("/content/gdrive/My Drive/Research_Paper/Training/normalized-training-open-64ch.npy")
closed = np.load("/content/gdrive/My Drive/Research_Paper/Training/normalized-training-closed-64ch.npy")

test_100_accuracies = []

for i in range(100):
  accuracy = train(open, closed)
  test_100_accuracies.append(accuracy)

print("Avg test accuracy: ", np.mean(test_100_accuracies))



Final Test Accuracy: 0.6510416865348816
Final Test Accuracy: 0.6579861044883728
Final Test Accuracy: 0.6571180820465088
Final Test Accuracy: 0.65625
Final Test Accuracy: 0.6215277910232544
Final Test Accuracy: 0.6684027910232544
Final Test Accuracy: 0.6605902910232544
Final Test Accuracy: 0.6380208134651184
Final Test Accuracy: 0.6267361044883728
Final Test Accuracy: 0.6449652910232544
Final Test Accuracy: 0.6354166865348816
Final Test Accuracy: 0.6640625
Final Test Accuracy: 0.6484375
Final Test Accuracy: 0.6362847089767456
Final Test Accuracy: 0.65625
Final Test Accuracy: 0.6701388955116272
Final Test Accuracy: 0.6215277910232544
Final Test Accuracy: 0.663194477558136
Final Test Accuracy: 0.647569477558136
Final Test Accuracy: 0.6501736044883728
Final Test Accuracy: 0.6605902910232544
Final Test Accuracy: 0.6284722089767456
Final Test Accuracy: 0.6414930820465088
Final Test Accuracy: 0.6579861044883728
Final Test Accuracy: 0.6449652910232544
Final Test Accuracy: 0.6449652910232544
Fi

In [3]:
open = np.load("/content/gdrive/My Drive/Research_Paper/Training/normalized-training-open-64ch.npy")
closed = np.load("/content/gdrive/My Drive/Research_Paper/Training/normalized-training-closed-64ch.npy")
open_generated1 = np.load("/content/gdrive/My Drive/Research_Paper/generated-data/generated-open-1.npy")
closed_generated1 = np.load("/content/gdrive/My Drive/Research_Paper/generated-data/generated-closed-1.npy")
open_generated2 = np.load("/content/gdrive/My Drive/Research_Paper/generated-data/generated-open-2.npy")
closed_generated2 = np.load("/content/gdrive/My Drive/Research_Paper/generated-data/generated-closed-2.npy")

open = np.concatenate((open, open_generated1))
closed = np.concatenate((closed, closed_generated1))
open = np.concatenate((open, open_generated2))
closed = np.concatenate((closed, closed_generated2))
test_100_accuracies = []

for i in range(100):
  accuracy = train(open, closed)
  test_100_accuracies.append(accuracy)

print("Avg test accuracy: ", np.mean(test_100_accuracies))



Final Test Accuracy: 0.8307783007621765
Final Test Accuracy: 0.8307783007621765
Final Test Accuracy: 0.8605542778968811
Final Test Accuracy: 0.8346108794212341
Final Test Accuracy: 0.8337264060974121
Final Test Accuracy: 0.8354952931404114
Final Test Accuracy: 0.8555424809455872
Final Test Accuracy: 0.8428655862808228
Final Test Accuracy: 0.829009473323822
Final Test Accuracy: 0.8307783007621765
Final Test Accuracy: 0.8337264060974121
Final Test Accuracy: 0.8325471878051758
Final Test Accuracy: 0.838738203048706
Final Test Accuracy: 0.8534787893295288
Final Test Accuracy: 0.8266509771347046
Final Test Accuracy: 0.8419811725616455
Final Test Accuracy: 0.8584905862808228
Final Test Accuracy: 0.8357900977134705
Final Test Accuracy: 0.8275353908538818
Final Test Accuracy: 0.8293042778968811
Final Test Accuracy: 0.839327871799469
Final Test Accuracy: 0.8295990824699402
Final Test Accuracy: 0.8354952931404114
Final Test Accuracy: 0.8416863083839417
Final Test Accuracy: 0.8605542778968811
Fin

In [2]:
open = np.load("/content/gdrive/My Drive/Research_Paper/Training/normalized-training-open-64ch.npy")
closed = np.load("/content/gdrive/My Drive/Research_Paper/Training/normalized-training-closed-64ch.npy")

test_100_accuracies = []

for i in range(100):
  accuracy = train(open, closed, True)
  test_100_accuracies.append(accuracy)

print("Avg test accuracy: ", np.mean(test_100_accuracies))



Final Test Accuracy: 0.6371527910232544
Final Test Accuracy: 0.6440972089767456
Final Test Accuracy: 0.514756977558136
Final Test Accuracy: 0.6493055820465088
Final Test Accuracy: 0.6388888955116272
Final Test Accuracy: 0.6032986044883728
Final Test Accuracy: 0.6597222089767456
Final Test Accuracy: 0.6032986044883728
Final Test Accuracy: 0.65625
Final Test Accuracy: 0.6432291865348816
Final Test Accuracy: 0.6605902910232544
Final Test Accuracy: 0.6597222089767456
Final Test Accuracy: 0.6336805820465088
Final Test Accuracy: 0.6371527910232544
Final Test Accuracy: 0.6423611044883728
Final Test Accuracy: 0.6597222089767456
Final Test Accuracy: 0.6215277910232544
Final Test Accuracy: 0.6380208134651184
Final Test Accuracy: 0.6545138955116272
Final Test Accuracy: 0.6432291865348816
Final Test Accuracy: 0.6510416865348816
Final Test Accuracy: 0.6736111044883728
Final Test Accuracy: 0.6293402910232544
Final Test Accuracy: 0.6467013955116272
Final Test Accuracy: 0.6432291865348816
Final Test A

In [1]:
from scipy import stats

real = [0.6510416865348816, 0.6579861044883728, 0.6571180820465088, 0.65625, 0.6215277910232544, 0.6684027910232544, 0.6605902910232544, 0.6380208134651184, 0.6267361044883728, 0.6449652910232544, 0.6354166865348816, 0.6640625, 0.6484375, 0.6362847089767456, 0.65625, 0.6701388955116272, 0.6215277910232544, 0.663194477558136, 0.647569477558136, 0.6501736044883728, 0.6605902910232544, 0.6284722089767456, 0.6414930820465088, 0.6579861044883728, 0.6449652910232544, 0.6449652910232544, 0.6440972089767456, 0.5998263955116272, 0.6571180820465088, 0.6510416865348816, 0.6510416865348816, 0.6371527910232544, 0.647569477558136, 0.6223958134651184, 0.6388888955116272, 0.6258680820465088, 0.6145833134651184, 0.624131977558136, 0.6432291865348816, 0.655381977558136, 0.6493055820465088, 0.6293402910232544, 0.5703125, 0.5390625, 0.6493055820465088, 0.6467013955116272, 0.655381977558136, 0.6493055820465088, 0.6336805820465088, 0.6328125, 0.6458333134651184, 0.6701388955116272, 0.640625, 0.6579861044883728, 0.6232638955116272, 0.609375, 0.6345486044883728, 0.6284722089767456, 0.5711805820465088, 0.6588541865348816, 0.6684027910232544, 0.6588541865348816, 0.6501736044883728, 0.6432291865348816, 0.65625, 0.6657986044883728, 0.6258680820465088, 0.65625, 0.6032986044883728, 0.639756977558136, 0.6493055820465088, 0.6692708134651184, 0.6597222089767456, 0.6336805820465088, 0.6432291865348816, 0.6467013955116272, 0.6501736044883728, 0.5885416865348816, 0.6440972089767456, 0.6545138955116272, 0.6414930820465088, 0.592881977558136, 0.6336805820465088, 0.6423611044883728, 0.6458333134651184, 0.6614583134651184, 0.6519097089767456, 0.6232638955116272, 0.640625, 0.6458333134651184, 0.6440972089767456, 0.6302083134651184, 0.6362847089767456, 0.6432291865348816, 0.639756977558136, 0.6458333134651184, 0.6345486044883728, 0.6536458134651184, 0.6293402910232544, 0.631944477558136]
real_generated = [0.8307783007621765, 0.8307783007621765, 0.8605542778968811, 0.8346108794212341, 0.8337264060974121, 0.8354952931404114, 0.8555424809455872, 0.8428655862808228, 0.829009473323822, 0.8307783007621765, 0.8337264060974121, 0.8325471878051758, 0.838738203048706, 0.8534787893295288, 0.8266509771347046, 0.8419811725616455, 0.8584905862808228, 0.8357900977134705, 0.8275353908538818, 0.8293042778968811, 0.839327871799469, 0.8295990824699402, 0.8354952931404114, 0.8416863083839417, 0.8605542778968811, 0.8307783007621765, 0.8452240824699402, 0.8405070900917053, 0.8461084961891174, 0.8434551954269409, 0.8349056839942932, 0.8354952931404114, 0.8605542778968811, 0.8316627740859985, 0.8328419923782349, 0.8360849022865295, 0.8363797068595886, 0.8378537893295288, 0.8278301954269409, 0.834316074848175, 0.833136796951294, 0.8316627740859985, 0.8304834961891174, 0.8723466992378235, 0.8522995114326477, 0.8449292778968811, 0.8337264060974121, 0.8340212106704712, 0.8328419923782349, 0.8328419923782349, 0.8434551954269409, 0.8410966992378235, 0.8301886916160583, 0.8531839847564697, 0.8316627740859985, 0.833431601524353, 0.823113203048706, 0.8301886916160583, 0.8316627740859985, 0.8360849022865295, 0.8528891801834106, 0.8537735939025879, 0.8502358794212341, 0.833431601524353, 0.8434551954269409, 0.8328419923782349, 0.8390330076217651, 0.8390330076217651, 0.8319575786590576, 0.8390330076217651, 0.828125, 0.8337264060974121, 0.833136796951294, 0.8363797068595886, 0.8452240824699402, 0.8375589847564697, 0.8416863083839417, 0.84375, 0.829009473323822, 0.8611438870429993, 0.8307783007621765, 0.8328419923782349, 0.8307783007621765, 0.8378537893295288, 0.8310731053352356, 0.8399174809455872, 0.8455188870429993, 0.8337264060974121, 0.8390330076217651, 0.8381485939025879, 0.8431603908538818, 0.833431601524353, 0.8363797068595886, 0.8319575786590576, 0.8313679099082947, 0.8284198045730591, 0.8629127740859985, 0.8328419923782349, 0.8325471878051758, 0.8354952931404114]
real_real = [0.8130896091461182, 0.8157429099082947, 0.8136792778968811, 0.8139740824699402, 0.8083726763725281, 0.8116155862808228, 0.797759473323822, 0.8328419923782349, 0.7385023832321167, 0.7535377740859985, 0.8181014060974121, 0.8284198045730591, 0.7573702931404114, 0.8039504885673523, 0.822818398475647, 0.813384473323822, 0.8027712106704712, 0.817806601524353, 0.8027712106704712, 0.8310731053352356, 0.8151533007621765, 0.8051297068595886, 0.8293042778968811, 0.7735849022865295, 0.8254716992378235, 0.8222287893295288, 0.8248820900917053, 0.5386202931404114, 0.8402122855186462, 0.813384473323822, 0.803066074848175, 0.739681601524353, 0.8219339847564697, 0.8104363083839417, 0.8095518946647644, 0.823113203048706, 0.8269457817077637, 0.7556014060974121, 0.8275353908538818, 0.8110259771347046, 0.8493514060974121, 0.8113207817077637, 0.822818398475647, 0.8163325786590576, 0.698113203048706, 0.801886796951294, 0.8057193756103516, 0.8015919923782349, 0.8481721878051758, 0.828125, 0.7402712106704712, 0.7582547068595886, 0.7614976763725281, 0.7511792778968811, 0.7455778121948242, 0.8057193756103516, 0.8092570900917053, 0.8104363083839417, 0.8337264060974121, 0.8045400977134705, 0.8145636916160583, 0.8045400977134705, 0.8169221878051758, 0.823113203048706, 0.8095518946647644, 0.822818398475647, 0.807193398475647, 0.7986438870429993, 0.8310731053352356, 0.78125, 0.8033608794212341, 0.8216391801834106, 0.8319575786590576, 0.8352004885673523, 0.8183962106704712, 0.8125, 0.8172169923782349, 0.8107311725616455, 0.7956957817077637, 0.8119103908538818, 0.8104363083839417, 0.8225235939025879, 0.8119103908538818, 0.8166273832321167, 0.8313679099082947, 0.8275353908538818, 0.8248820900917053, 0.8360849022865295, 0.8172169923782349, 0.829009473323822, 0.8045400977134705, 0.8139740824699402, 0.8066037893295288, 0.8095518946647644, 0.7594339847564697, 0.8057193756103516, 0.8242924809455872, 0.8101415038108826, 0.8172169923782349, 0.8272405862808228]

# real vs generated/real
t_statistic, p_value = stats.ttest_rel(real, real_generated)
print(p_value)

# real vs realx2
t_statistic, p_value = stats.ttest_rel(real_real, real_generated)
print(p_value)



9.997838508142085e-94
1.924721611940723e-13
