In [1]:
import numpy as np
import torch as th
import matplotlib.pylab as plt
from optical_nn import *
import complex_torch_var as ct
from mnist import *
import os
from time import time
from functools import partial
from glob import glob
from default_params import *

Using CUDA
torch.Size([8, 784])


In [2]:
def load_complex(f=F_COMPLEX_TRAIN):
    print(f)
    net = mnist_complex()
    print(net)
    net.load_state_dict(th.load(f, map_location=DEVICE))
    return net.to(DEVICE)

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)



In [4]:
def load_cnn(f=F_CNN_TRAIN):
    print(f)
    net = Net()
    print(net)
    net.load_state_dict(th.load(f, map_location=DEVICE))
    return net.to(DEVICE)

In [7]:
def load_grid(f=os.path.join(DIR_TRAINED_MODELS, 'grid_net_cd2.pth'), rand_S=True, report_acc=True):
    if f is None:
        f = F_GRID_TRAIN if rand_S else F_GRID_ORD_TRAIN
    net = mnist_ONN()
    net.load_state_dict(th.load(f, map_location=DEVICE))
    acc, confusion_matrix = get_acc(net)
    print(f'GridNetOrdered loaded from {f} with accuracy {acc}.')
    print(confusion_matrix)
    return net.to(DEVICE)

In [11]:
import numpy as np
import torch
from scipy import interp
import matplotlib.pyplot as plt
from itertools import cycle
from sklearn.metrics import roc_curve, auc
import math

model = load_grid(f=os.path.join(DIR_TRAINED_MODELS, 'grid_net_cd2.pth'))
print(model)
for data, target in mnist_test_loader(train=False, batch_size=100, shuffle=False):
    continue
#data = data.view(-1, 28**2)
data, target = data.to(DEVICE), target.to(DEVICE)
print(data.shape)
Y_pred = model(data)
out = Y_pred.data.max(1, keepdim=True)[1]
print(target)
print(out)
#print(get_acc(model))
Y_pred = torch.exp(Y_pred)
Y_pred.detach()
Y_real = torch.zeros((8, 2), requires_grad = False)
Y_pred = Y_pred.detach().numpy()
print(out.eq(target.data.view_as(out)).sum())
for i in range(8):
    Y_real[i][target[i]] = 1
# Plot linewidth.
Y_real = Y_real.numpy()
lw = 2

# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(2):
    temp = []
    fpr[i], tpr[i], _ = roc_curve(Y_real[:, i], Y_pred[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(Y_real.ravel(), Y_pred.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

# Compute macro-average ROC curve and ROC area

# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(2)]))

# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(2):
    mean_tpr += interp(all_fpr, fpr[i], tpr[i])

# Finally average it and compute AUC
mean_tpr /= 2

fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr



roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

# Plot micro/macro ROC curves
plt.figure(1)
plt.plot(fpr["micro"], tpr["micro"],
         label='micro-average ROC curve (area = {0:0.2f})'
               ''.format(roc_auc["micro"]),
         color='deeppink', linestyle=':', linewidth=4)

plt.plot(fpr["macro"], tpr["macro"],
         label='macro-average ROC curve (area = {0:0.2f})'
               ''.format(roc_auc["macro"]),
         color='navy', linestyle=':', linewidth=4)

plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic for multi-class')
plt.legend(loc="lower right")
plt.show()

# Plot class ROC curves
plt.figure(2)
colors = cycle(['aqua', 'darkorange', 'cornflowerblue','deeppink','navy','yellow','purple','gray','red','green','brown','blue'])
for i, color in zip(range(2), colors):
    plt.plot(fpr[i], tpr[i], color=color, lw=lw,
             label='Class {0} (AUC = {1:0.6f})'
             ''.format(i, roc_auc[i]))

plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.01, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('')
plt.legend(loc="lower right")
plt.savefig('CNN_ROC.png')
plt.show()

GridNetOrdered loaded from E:\nit dgp\2nd sem\Quantum-Optical-ConvNet\trained_models\grid_net_cd2.pth with accuracy 0.747757613658905.
[[3271  133]
 [1020  147]]
NoisySequential(
  (0): Linear(
    (VH): Unitary()
    (S): Diagonal()
    (U): Unitary()
  )
  (1): ModNonlinearity(
    (f): ShiftedSoftplus()
  )
  (2): Linear(
    (VH): Unitary()
    (S): Diagonal()
    (U): Unitary()
  )
  (3): ModNonlinearity(
    (f): ShiftedSoftplus()
  )
  (4): Linear(
    (VH): Unitary()
    (S): Diagonal()
    (U): Unitary()
  )
  (5): ComplexNorm()
  (6): LogSoftmax(dim=1)
)
torch.Size([8, 1, 28, 28])


ValueError: too many values to unpack (expected 2)