In [2]:
import sys
IN_COLAB = 'google.colab' in sys.modules
#if IN_COLAB:
#!git clone https://github.com/EyupBunlu/QViT_HEP_ML4Sci

Installations

In [3]:
!pip install tensorcircuit
!pip install pennylane



In [4]:
# extra installs due to local git clone
!pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124
!pip install jax
!pip install optax

Looking in indexes: https://download.pytorch.org/whl/nightly/cu124


In [5]:
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
import math
from tqdm.notebook import tqdm
import time
from torchvision.datasets import MNIST
from torchvision import transforms
import h5py
#if IN_COLAB: from QViT_HEP_ML4Sci.QViT import *
from QViT_HEP_ML4Sci.QViT import *
#from sklearn.metrics import roc_auc_score,roc_curve,confusion_matrix

torch.set_default_dtype(torch.float32)
torch.set_num_threads(8)

Please first ``pip install -U qiskit`` to enable related functionality in translation module
Please first ``pip install -U cirq`` to enable related functionality in translation module


In [6]:
# set device
device = 'cpu' # FOR CUDA VERSION USE SEPARATE FILE WITH CONFIGURED BATCH SIZES

Models

In [7]:
mnist_trainset = MNIST(root='./data', train=True, download=True)
n,d= 7,28

transform = transforms.Compose([ transforms.Resize((14,14)), transforms.ConvertImageDtype(torch.float64),transforms.Normalize(0,1)])
data = mnist_trainset.data  # size = (60000, 28, 28_)
data = transform(data)      # size = (60000, 14, 14)
print(data.device)
data_patched = patcher(data,[2,14])

mnist_trainset = simple_dataset(data_patched,mnist_trainset.targets)
tr_len = 4000
val_len = 1000
tr_set,val_set,test_set = torch.utils.data.random_split(mnist_trainset,[tr_len,val_len,mnist_trainset.target.shape[0]-tr_len-val_len])
tr_dl = DataLoader(tr_set,batch_size=32,shuffle=True)
val_dl = DataLoader(val_set,batch_size=32,shuffle=True)

cpu


In [8]:
transformer_dims = {'Token_Dim':data_patched.shape[-2],'Image_Dim':data_patched.shape[-1]}
transformer_hyper = {'n_layers':2,'FC_layers':[10],
                    'head_dimension':8,'Embed_Dim': 16,'ff_dim':32}
transformer_type = {'classifying_type':'max','pos_embedding':True}

Initialization

In [9]:
print(torch.cuda.is_available())
print(f"PyTorch is using: {torch.cuda.get_device_name(torch.cuda.current_device())}")

True
PyTorch is using: NVIDIA GeForce RTX 4070 Ti SUPER


In [10]:
# initialized models
classical_model = HViT(**transformer_dims,**transformer_hyper,**transformer_type,attention_type='classic').to(device)
hybrid2_model = HViT(**transformer_dims,**transformer_hyper,**transformer_type,attention_type='hybrid2').to(device)

In [11]:
# history
reset_classical=True
reset_hybrid2=True

if reset_classical:
    classical_history=None
else:
    classical_history=torch.load('classical history')
    classical_model.load_state_dict(torch.load('classical state dict'))

if reset_hybrid2:
    hybrid2_history=None
else:
    hybrid2_history=torch.load('hybrid2 history')
    hybrid2_model.load_state_dict(torch.load('hybrid2 state dict'))

Training

In [12]:
classical_optim = torch.optim.Adam(classical_model.parameters(),lr=1e-3)
hybrid2_optim = torch.optim.Adam(hybrid2_model.parameters(),lr=1e-3)
n_epochs = 80
loss_fn = nn.CrossEntropyLoss(reduction='none')

In [13]:
# ADD MISSING TRAIN()
def train(model,tr_dl,val_dl,loss_fn,optim,n_epochs,device='cuda'):
    try:
        min_loss = np.inf
        bar_epoch = tqdm(range(n_epochs))
        history = {'tr':[],'val':[],'tr_acc':[],'val_acc':[]}
        for epoch in bar_epoch:
            loss =0
            val_loss = 0

            total_samples = 0
            bar_batch = tqdm(tr_dl)
            model.train()
            pred_tr = []
            real_tr = []
            pred_val = []
            real_val = []
            for i in bar_batch:
                optim.zero_grad()
                yhat = model(i['input'].to(device))
                y = i['output']
                loss_ = loss_fn(yhat,y.to(device))

                loss_.sum().backward()

                optim.step()
                loss += loss_.sum().item()
                total_samples += y.shape[0]
                if len(yhat.shape)==1 or yhat.shape[-1]==1:
                    pred_tr.append((torch.sigmoid(yhat.detach())>.5).cpu())
                    real_tr.append(y.detach().cpu().unsqueeze(-1))
                else:
                    pred_tr.append(yhat.detach().argmax(axis=-1).cpu())
                    real_tr.append(y.detach().cpu())

                bar_batch.set_postfix_str(f'loss:{loss/total_samples}')



            model.eval()
            for i in val_dl:
                with torch.no_grad():
                    yhat = model(i['input'].to(device))
                    y = i['output']
                    val_loss_ = loss_fn(yhat,y.to(device))
                    val_loss += val_loss_.sum().item()
                    if len(yhat.shape)==1 or yhat.shape[-1]==1:
                        pred_val.append((torch.sigmoid(yhat.detach())>.5).cpu())
                        real_val.append(y.detach().cpu().unsqueeze(-1))
                    else:
                        pred_val.append(yhat.detach().argmax(axis=-1).cpu())
                        real_val.append(y.detach().cpu())

            history['tr_acc'].append((torch.cat(pred_tr)==torch.cat(real_tr)).sum()/total_samples )
            history['val_acc'].append((torch.cat(pred_val)==torch.cat(real_val)).sum()/len(val_dl.dataset) )
            history['val'].append(val_loss/len(val_dl.dataset))
            history['tr'].append(loss/total_samples)
            bar_epoch.set_postfix_str(f'loss:{loss/total_samples}, v.loss:{val_loss/len(val_dl.dataset)},\
            tr_acc:{history["tr_acc"][-1] }, val_acc:{ history["val_acc"][-1] }')
            if history['val'][-1]<min_loss:
                min_loss = history['val'][-1]
                torch.save(model.state_dict(),'best_state_on_training_loss')
            if history['val_acc'][-1]==max(history['val_acc']):
                min_loss = history['val'][-1]
                torch.save(model.state_dict(),'best_state_on_training_acc')
            torch.save(history,'temp_history')
        return history
    except KeyboardInterrupt:
        return history

In [14]:
# classical training
#classical_history = train(classical_model,tr_dl,val_dl,loss_fn,classical_optim,n_epochs,
                          #history=classical_history,save_path="classical state dict",device=device)
classical_history = train(classical_model,tr_dl,val_dl,loss_fn,classical_optim,n_epochs,device=device)
torch.save(classical_history, "classical history")

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

In [15]:
# hybrid2 training
#hybrid2_history = train(hybrid2_model,tr_dl,val_dl,loss_fn,hybrid2_optim,n_epochs,
#                       history=hybrid2_history,save_path='hybrid2 state dict',device=device)
hybrid2_history = train(hybrid2_model,tr_dl,val_dl,loss_fn,hybrid2_optim,n_epochs, device=device)
torch.save(hybrid2_history, 'hybrid2 history')

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

Plotting

In [1]:
#model(data_patched[[0]].to(device))
#print(f'# of parameters: {sum([np.prod(i.shape) for i in model.parameters()])}')
#model.load_state_dict(torch.load('best_state_on_training_acc'))

plt.plot(classical_history['val'],label='val_loss')
plt.plot(classical_history['tr'],label='tr_loss')
plt.title('Classical Model Loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()

plt.figure()
plt.plot(classical_history['val_acc'],label='val_acc')
plt.plot(classical_history['tr_acc'],label='tr_acc')
plt.title('Classical Model Accuracy')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.legend()

#pred = model(test_set.dataset.data[test_set.indices[:1000]].to(device)).cpu().argmax(axis=1)
"""
plt.figure()
plt.imshow(confusion_matrix(pred.cpu(),test_set.dataset.target[test_set.indices[:1000]],normalize='true'),
          vmin=0,vmax=1)
plt.colorbar()
plt.title('The normalized confusion matrix')
print(f'Wrongly Predicted Ratio:{ (pred != test_set.dataset.target[test_set.indices[:1000]]).sum()/pred.shape[0]}')
"""

NameError: name 'plt' is not defined

In [None]:
plt.plot(hybrid2_history['val'],label='val_loss')
plt.plot(hybrid2_history['tr'],label='tr_loss')
plt.title('Hybrid2 Model Loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()

plt.figure()
plt.plot(hybrid2_history['val_acc'],label='val_acc')
plt.plot(hybrid2_history['tr_acc'],label='tr_acc')
plt.title('Hybrid2 Model Accuracy')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.legend()

Proof of Concept

In [None]:
data[0]

In [None]:
plt.imshow(data[0], cmap='gray')

In [None]:
plt.imshow(data[2], cmap='gray')

In [None]:
softmax = torch.nn.Softmax()

In [None]:
def predict(model, data_point):
    pred_probs = softmax(model(data_point))
    return pred_probs

In [None]:
predict(classical_model, data_patched[0])

In [None]:
softmax(classical_model(data_patched[0]))

In [None]:
softmax(hybrid2_model(data_patched[0]).squeeze())