In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
import math, time
import matplotlib.pyplot as plt
import copy
import os
import csv

In [2]:
def show_imgs(imgs, l1=4, l2=5, s1=6, s2=6, name=""):
    plt.rcParams['figure.figsize'] = (s1, s2)
    imgs = imgs.cpu().reshape([-1, 28, 28])
    g, ax = plt.subplots(l1, l2)
    for i in range(l1):
        for j in range(l2):
            a = i * l2 + j
            if (a >= imgs.shape[0]):
                break
            ax[i][j].imshow(imgs[a, :, :], cmap='gray')
            ax[i][j].set_xticks([])
            ax[i][j].set_yticks([])
    if name != "":
        plt.savefig(path + str(name) + ".png")
    #plt.show()
    plt.clf()

def show_loss(train_loss_s, test_loss_s):
    plt.rcParams['figure.figsize'] = (10, 10)
    plt.plot(train_loss_s, "o-", label='train loss')
    plt.plot(test_loss_s, "o-", label='test loss')
    plt.ylabel("loss")
    plt.xlabel("epoch")
    plt.legend(loc=1)
    plt.savefig(path + "loss.png")
    plt.clf()

def show_Acc(train_Acc_s, test_Acc_s):
    plt.rcParams['figure.figsize'] = (10, 10)
    plt.plot(train_Acc_s, "o-", label='train Acc')
    plt.plot(test_Acc_s, "o-", label='test Acc')
    plt.ylabel("Acc")
    plt.xlabel("epoch")
    plt.legend(loc=1)
    plt.savefig(path + "Acc.png")
    plt.clf()

In [3]:
class CMPS(nn.Module):
    def __init__ (self, Dmax, n, c, mydevice=torch.device('cpu')):
        super(CMPS,self).__init__()
        self.Dmax = Dmax
        self.n = n
        self.c = c
        self.bond_dims = [self.Dmax for i in range(n - 1)] + [1]
        self.tensors = []
        for i in range(self.n - 1):
            t = torch.randn(self.bond_dims[i - 1], 2, self.bond_dims[i], device=mydevice)
            t = Variable(t, requires_grad = True)
            self.tensors.append(t)
        t = torch.rand(self.bond_dims[self.n - 2], 2, self.bond_dims[self.n - 1], c, device=mydevice)
        t = Variable(t, requires_grad = True)
        self.tensors.append(t)
        self.normalize()
    
    def getNorm(self):
        result = torch.tensordot(self.tensors[0], self.tensors[0], dims=([1], [1]))
        for i in range(1,self.n - 1):
            result = torch.einsum("niol,ijk,ljm->nkom", result, self.tensors[i],self.tensors[i])
        result = torch.einsum("niol,ijkp,ljmp->nkomp", result, self.tensors[self.n - 1], self.tensors[self.n - 1])
        return result
    
    
    def normalize(self):
        result = torch.tensordot(self.tensors[0], self.tensors[0], dims=([1], [1]))
        res_max = result.max()
        tensor = self.tensors[0].detach()
        tensor = tensor / torch.sqrt(res_max)
        self.tensors[0] = Variable(tensor, requires_grad = True)
        result /= res_max
        for i in range(1,self.n - 1):
            result = torch.einsum("niol,ijk,ljm->nkom", result, self.tensors[i],self.tensors[i])
            res_max = result.max()
            tensor = self.tensors[i].detach()
            tensor = tensor / torch.sqrt(res_max)
            self.tensors[i] = Variable(tensor, requires_grad = True)
            result /= res_max           
        result = torch.einsum("niol,ijkp,ljmp->nkomp", result, self.tensors[self.n - 1], self.tensors[self.n - 1])
        tensor = self.tensors[self.n - 1].detach()
        for i in range(self.c):
            tensor[:,:,:,i] = tensor[:,:,:,i] / torch.sqrt(result[0,0,0,0,i])
        self.tensors[self.n - 1] = Variable(tensor, requires_grad = True)
    
    def forward(self,x):
        result = torch.einsum("ijk,lmn,bj,bm->bkn", self.tensors[0], self.tensors[0], x[:,0], x[:,0])
        for i in range(1, 784-1):
            result = torch.einsum("bil,ijk,lmn,bj,bm->bkn",result, self.tensors[i], self.tensors[i], x[:,i], x[:,i])
            result *= 2.5
        result = torch.einsum("bil,ijkc,lmnc,bj,bm->bc",result, self.tensors[-1], self.tensors[-1], x[:,-1], x[:,-1])
        result = result/(self.getNorm().view(self.c))
        return result

In [4]:
def time_now():
    return time.strftime("%Y%m%d_%H%M%S", time.localtime())

def mapData2Spin(data):
    newData = torch.tensor(
        np.concatenate([
            np.cos(np.pi / 2 * data).reshape(len(data), n, 1),
            np.sin(np.pi / 2 * data).reshape(len(data), n, 1)
        ],
                       axis=2))
    return newData

In [5]:
def train(cmps,
          SpinData_train,
          train_label,
          SpinData_test,
          test_label,
          batch_size,
          epochs=5,
          learning_rate=0.001,
          mydevice=torch.device('cpu')):
    train_loss_s = []
    train_Acc_s = []
    test_loss_s = []
    test_Acc_s = []
    n = cmps.n
    start_time = time.time()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(cmps.tensors,lr=learning_rate)
    for epoch in range(epochs):
        train_loss = []
        train_Acc = []
        test_loss = []
        test_Acc = []
        maxnum = 0
        for batch_now in range(len(SpinData_train) // batch_size):
            SpinDatas = SpinData_train[(batch_now) *batch_size:(batch_now + 1) * batch_size]
            label = train_label[(batch_now) *batch_size:(batch_now + 1) * batch_size]
            out = cmps(SpinDatas)
            loss = criterion(out,label)
            loss.backward()
            optimizer.step()
            cmps.normalize()
            maxnum = max(maxnum,torch.max(out.data))
            pred_label = torch.max(out.data, 1).indices
            correct_cnt = (pred_label == label).sum()
            train_loss.append(loss.data)
            train_Acc.append(correct_cnt)
        train_loss_s.append(np.mean(train_loss))
        train_Acc_s.append(np.sum(train_Acc)/len(SpinData_train))
        for batch_now in range(len(SpinData_test) // batch_size):
            SpinDatas = SpinData_test[(batch_now) *batch_size:(batch_now + 1) * batch_size]
            label = test_label[(batch_now) *batch_size:(batch_now + 1) * batch_size]
            out = cmps(SpinDatas)
            loss = criterion(out,label)
            pred_label = torch.max(out.data, 1).indices
            correct_cnt = (pred_label == label).sum()
            test_loss.append(loss.data)
            test_Acc.append(correct_cnt)
        test_loss_s.append(np.mean(test_loss))
        test_Acc_s.append(np.sum(test_Acc)/len(SpinData_test))
        print(epoch + 1,train_loss_s[-1],train_Acc_s[-1],test_loss_s[-1],test_Acc_s[-1],maxnum,time.time()-start_time)
        show_loss(train_loss_s, test_loss_s)
        show_Acc(train_Acc_s, test_Acc_s)

In [6]:
torch.set_default_tensor_type(torch.DoubleTensor)
plt.rcParams['figure.dpi'] = 300
mydevice = torch.device('cpu')
n = 784
m = 1000
tm = 500
Dmax = 5
c = 10
batch_size = 50
learning_rate = 0.005
path = time_now() + "_" + str(m) + "_" + str(batch_size) + "_" + str(
    tm) + "_" + str(Dmax) + "_" + str(learning_rate)
print(path)
os.makedirs(path)

20220724_214907_1000_50_500_5_0.005


In [7]:
path = "./" + path + "/"
raw_data = np.load("data_bin.npz")
train_data = raw_data['train_data']
train_data = torch.Tensor(train_data[:m, :])
show_imgs(train_data, 5, 15, 15, 5, "train_fig")
SpinData_train = mapData2Spin(train_data)
print(SpinData_train.shape)
test_data = raw_data['train_data']
test_data = torch.Tensor(test_data[:tm, :])
SpinData_test = mapData2Spin(test_data)
print(SpinData_test.shape)
train_label = torch.tensor(raw_data['train_label'][:m], dtype = torch.long)
test_label = torch.tensor(raw_data['test_label'][:m], dtype = torch.long)
print(train_label.shape)
print(test_label.shape)
torch.manual_seed(1)
np.random.seed(1)
SpinData_train = SpinData_train.to(mydevice)
SpinData_test = SpinData_test.to(mydevice)

cmps = CMPS(Dmax, n, c, mydevice)

torch.Size([1000, 784, 2])
torch.Size([500, 784, 2])
torch.Size([1000])
torch.Size([1000])


<Figure size 4500x1500 with 0 Axes>

In [8]:
#train(cmps,SpinData_train,train_label,SpinData_test,test_label,batch_size,epochs = 2,learning_rate = 1e-3)
train(cmps,SpinData_train,train_label,SpinData_test,test_label,batch_size,epochs = 10,learning_rate = 1e-5)

1 5.400017113666363e+24 0.092 1.3937024997432444e+22 0.102 tensor(9.0501e+27) 28.32327675819397
2 5.377613917886157e+24 0.092 1.3755582588963722e+22 0.102 tensor(9.0368e+27) 59.473910331726074
3 5.355206870709274e+24 0.092 1.357416796033969e+22 0.102 tensor(9.0236e+27) 91.62163615226746
4 5.332795353068174e+24 0.092 1.339278127097473e+22 0.102 tensor(9.0104e+27) 123.08135485649109
5 5.310379376563772e+24 0.092 1.3211422680166227e+22 0.102 tensor(8.9971e+27) 151.09310936927795
6 5.287958952796341e+24 0.092 1.3030092347118576e+22 0.102 tensor(8.9839e+27) 178.99333000183105
7 5.265534093370896e+24 0.092 1.2848790430940209e+22 0.102 tensor(8.9707e+27) 206.9028639793396
8 5.243104809890814e+24 0.092 1.266751709062411e+22 0.102 tensor(8.9575e+27) 235.43025088310242
9 5.220671113961638e+24 0.092 1.2486272485073564e+22 0.102 tensor(8.9443e+27) 273.92173171043396
10 5.198233017191393e+24 0.092 1.2305056773082362e+22 0.102 tensor(8.9311e+27) 310.42026138305664


<Figure size 3000x3000 with 0 Axes>