In [None]:
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import torch.nn as nn
import torchquantum as tq
import random
from torch.optim.lr_scheduler import CosineAnnealingLR

import pandas as pd
from collections import OrderedDict
from torchquantum.encoding import encoder_op_list_name_dict
from torchquantum.layers import U3CU3Layer0, RandomLayer
import os 

import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import datasets, transforms

from torch.distributions.bernoulli import Bernoulli
from torchquantum.encoding import encoder_op_list_name_dict as enc_dict
from torchquantum.layers import U3CU3Layer0 
from models import Dataset
from torch.utils.data import DataLoader

os.environ["CUDA_VISIBLE_DEVICES"]="2"
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
num_client  = 10

num_class = 4
EPOCH = 100
TrainLoader = DataLoader(Dataset(np.load('data/mnist/train_x.npy')[:18623],np.load('data/mnist/train_y.npy')[:18623],0),batch_size=256,shuffle=True)

Data = []
Labels = []
for i, (data, labels) in enumerate(TrainLoader):
    Data.append(data)
    Labels.append(labels)
    if i==2:
        break
        
TrainLoader = DataLoader(Dataset(Data[0],Labels[0],0),batch_size=32,shuffle=True)
TestLoader  = DataLoader(Dataset(Data[1],Labels[1],0),batch_size=32,shuffle=True)
ValidLoader = DataLoader(Dataset(Data[2][:32],Labels[2][:32],0), batch_size=1, shuffle=True)
device = torch.device('cuda:0')
criterion = nn.CrossEntropyLoss()
q_device = tq.QuantumDevice(n_wires=4).to(device)

class QNN(tq.QuantumModule):
    def __init__(self):
        super().__init__()
        self.n_wires = 4
        self.encoder = tq.GeneralEncoder(enc_dict['4x4_ryzxy'])
        self.pqc     = tq.RandomLayer(n_ops=50, wires=[0,1,2,3])
        
    def forward(self, x,q_device=q_device):
        batchsize = x.shape[0]
        x = x.reshape(batchsize,-1).to(dtype=torch.complex64)
        self.encoder(q_device , x)
        self.pqc(q_device)
        x = tq.expval(q_device,
                      [i for i in range(num_class)], 
                      [tq.PauliZ() for _ in range(num_class)]
                     ).squeeze() 
        return x
    
model = QNN().to(device)
opt   = torch.optim.Adam(model.parameters(), lr=5e-3)



from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

today = datetime.today().strftime("%m%d%H%M%S")
# writer = SummaryWriter(f'runs/{today}')
writer = SummaryWriter(f'ICSE_2023')
def train(ep,
          train_loader,
          test_loader,
          valid_loader, 
          model, 
          device, 
          criterion,
          optimizer):
    
    Train_Loss = 0 
    Test_Acc   = 0
    # Train #
    for niter, (data, labels) in enumerate(train_loader):
        inputs  = data.to(device,dtype=torch.float32)
        targets = labels.to(device,dtype=torch.long)

        outputs = model(inputs)
        loss    = criterion(torch.softmax(outputs,dim=-1), targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        Train_Loss = loss.item()
        
    # Test #
    with torch.no_grad():
        Size = 0
        Corrects = 0
        for _, (x,y) in enumerate(test_loader):
            x = x.to(device,dtype=torch.float32)
            y = y.to(device,dtype=torch.long)
            y_hat = model(x) 
            _, indices = y_hat.topk(1, dim=1)
            masks = indices.eq(y.view(-1, 1).expand_as(indices))
            Size += y.shape[0]
            Corrects += masks.sum().item()
        Test_Acc = Corrects / Size
            
    # Barren Plateaus # 
    grad_bp,mean, var = {},{},{}
    
    for i,(name, params) in enumerate(model.pqc.named_parameters()):
        grad_bp[name] = []
    
    for niter, (data, labels) in enumerate(valid_loader):
        inputs  = data.to(device,dtype=torch.float32)
        targets = labels.to(device,dtype=torch.long)
        outputs = model(inputs)
        loss    = criterion(torch.softmax(outputs,dim=-1).unsqueeze(0), targets)
        optimizer.zero_grad()
        loss.backward()
        for i, (name, params) in enumerate(model.pqc.named_parameters()):
            grad_bp[name].append(params.grad.clone().detach().cpu().numpy())
        optimizer.zero_grad()
        
        if niter==31:
            break

    
    for key in grad_bp.keys():
        grads     = grad_bp[key]
        grads     = np.array(grads)
        mean[key] = np.mean(grads)
        var[key]  = np.var(grads)
    
    return Train_Loss, Test_Acc, var

def Helper(ep,var):
    Event = []
    for key in  var.keys():
        gate  = key.split('.')[2].split('_params')[0]
        order = int(key.split('.')[1]) + 1
        bp_value = var[key]
        if bp_value <= 1e-5:
            event = f"[Epoch {ep}] {order}-th params ({gate} Gate) has barren plateaus (BP value: {var[key]})"
            Event.append(event)
    return '<br>'.join(Event)

for ep in range(EPOCH):
    Train_Loss, Test_Acc, var = train(  ep, 
                                        TrainLoader, 
                                        TestLoader, 
                                        ValidLoader,  
                                        model,  
                                        device,  
                                        criterion, 
                                        opt
                                     )
    
    writer.add_scalars(f'Metric/Loss', {'loss': Train_Loss} ,ep+1)
    writer.add_scalars(f'Metric/Accuracy', {'acc': Test_Acc} ,ep+1)
    writer.add_scalars(f'Metric/BarrenPlateaus', var,ep+1)
    writer.add_text('Event',Helper(ep+1,var), ep+1)