In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
import math, time
import matplotlib.pyplot as plt
import copy
import os
import csv

In [2]:
def show_imgs(imgs, l1=4, l2=5, s1=6, s2=6, name=""):
    plt.rcParams['figure.figsize'] = (s1, s2)
    imgs = imgs.cpu().reshape([-1, 28, 28])
    g, ax = plt.subplots(l1, l2)
    for i in range(l1):
        for j in range(l2):
            a = i * l2 + j
            if (a >= imgs.shape[0]):
                break
            ax[i][j].imshow(imgs[a, :, :], cmap='gray')
            ax[i][j].set_xticks([])
            ax[i][j].set_yticks([])
    if name != "":
        plt.savefig(path + str(name) + ".png")
    #plt.show()
    plt.clf()

def show_loss(train_loss_s, test_loss_s):
    plt.rcParams['figure.figsize'] = (10, 10)
    plt.plot(train_loss_s, "o-", label='train loss')
    plt.plot(test_loss_s, "o-", label='test loss')
    plt.ylabel("loss")
    plt.xlabel("epoch")
    plt.legend(loc=1)
    plt.savefig(path + "loss.png")
    plt.clf()

def show_Acc(train_Acc_s, test_Acc_s):
    plt.rcParams['figure.figsize'] = (10, 10)
    plt.plot(train_Acc_s, "o-", label='train Acc')
    plt.plot(test_Acc_s, "o-", label='test Acc')
    plt.ylabel("Acc")
    plt.xlabel("epoch")
    plt.legend(loc=1)
    plt.savefig(path + "Acc.png")
    plt.clf()

In [3]:
class CMPS(nn.Module):
    def __init__ (self, Dmax, n, c, mydevice=torch.device('cpu')):
        super(CMPS,self).__init__()
        self.Dmax = Dmax
        self.n = n
        self.c = c
        self.bond_dims = [self.Dmax for i in range(n - 1)] + [1]
        self.tensors = []
        for i in range(self.n - 1):
            t = torch.randn(self.bond_dims[i - 1], 2, self.bond_dims[i], device=mydevice)
            t = Variable(t, requires_grad = True)
            self.tensors.append(t)
        t = torch.rand(self.bond_dims[self.n - 2], 2, self.bond_dims[self.n - 1], c, device=mydevice)
        t = Variable(t, requires_grad = True)
        self.tensors.append(t)
        self.normalize()
    
    def getNorm(self):
        result = torch.tensordot(self.tensors[0], self.tensors[0], dims=([1], [1]))
        for i in range(1,self.n - 1):
            result = torch.einsum("niol,ijk,ljm->nkom", result, self.tensors[i],self.tensors[i])
        result = torch.einsum("niol,ijkp,ljmp->nkomp", result, self.tensors[self.n - 1], self.tensors[self.n - 1])
        return result
    
    
    def normalize(self):
        result = torch.tensordot(self.tensors[0], self.tensors[0], dims=([1], [1]))
        res_max = result.max()
        tensor = self.tensors[0].detach()
        tensor = tensor / torch.sqrt(res_max)
        self.tensors[0] = Variable(tensor, requires_grad = True)
        result /= res_max
        for i in range(1,self.n - 1):
            result = torch.einsum("niol,ijk,ljm->nkom", result, self.tensors[i],self.tensors[i])
            res_max = result.max()
            tensor = self.tensors[i].detach()
            tensor = tensor / torch.sqrt(res_max)
            self.tensors[i] = Variable(tensor, requires_grad = True)
            result /= res_max           
        result = torch.einsum("niol,ijkp,ljmp->nkomp", result, self.tensors[self.n - 1], self.tensors[self.n - 1])
        tensor = self.tensors[self.n - 1].detach()
        for i in range(self.c):
            tensor[:,:,:,i] = tensor[:,:,:,i] / torch.sqrt(result[0,0,0,0,i])
        self.tensors[self.n - 1] = Variable(tensor, requires_grad = True)
    
    def forward(self,x,label):
        result = torch.einsum("ijk,bj->bk", self.tensors[0], x[:,0])
        for i in range(1, 784-1):
            result = torch.einsum("bi,ijk,bj->bk",result, self.tensors[i], x[:,i])
        result = torch.einsum("bi,ijkc,bj,bc->b",result, self.tensors[-1], x[:,-1],label)
        return result

In [4]:
def time_now():
    return time.strftime("%Y%m%d_%H%M%S", time.localtime())

def mapData2Spin(data):
    newData = torch.tensor(
        np.concatenate([
            np.cos(np.pi / 2 * data).reshape(len(data), n, 1),
            np.sin(np.pi / 2 * data).reshape(len(data), n, 1)
        ],
                       axis=2))
    return newData

In [5]:
def train(cmps,
          SpinData_train,
          train_label,
          SpinData_test,
          test_label,
          batch_size,
          epochs=5,
          learning_rate=0.001,
          mydevice=torch.device('cpu')):
    train_loss_s = []
    train_Acc_s = []
    test_loss_s = []
    test_Acc_s = []
    n = cmps.n
    start_time = time.time()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(cmps.tensors,lr=learning_rate)
    for epoch in range(epochs):
        train_loss = []
        train_Acc = []
        test_loss = []
        test_Acc = []
        maxnum = 0
        for batch_now in range(len(SpinData_train) // batch_size):
            SpinDatas = SpinData_train[(batch_now) *batch_size:(batch_now + 1) * batch_size]
            label = train_label[(batch_now) *batch_size:(batch_now + 1) * batch_size]
            out = cmps(SpinDatas,label)
            #print(out)
            loss = -2*torch.mean(torch.log(torch.abs(out)))
            #print(out)
            #print(torch.log(torch.abs(out)))
            #print(loss)
            loss.backward()
            optimizer.step()
            cmps.normalize()
            maxnum = max(maxnum,torch.max(out.data))
            #pred_label = torch.max(out.data, 1).indices
            #correct_cnt = (pred_label == label).sum()
            train_loss.append(loss.data)
            #train_Acc.append(correct_cnt)
        train_loss_s.append(np.mean(train_loss))
        train_Acc_s.append(np.sum(train_Acc)/len(SpinData_train))
        for batch_now in range(len(SpinData_test) // batch_size):
            SpinDatas = SpinData_test[(batch_now) *batch_size:(batch_now + 1) * batch_size]
            label = test_label[(batch_now) *batch_size:(batch_now + 1) * batch_size]
            out = cmps(SpinDatas,label)
            loss = -2*torch.mean(torch.log(torch.abs(out)))
            #pred_label = torch.max(out.data, 1).indices
            #correct_cnt = (pred_label == label).sum()
            test_loss.append(loss.data)
            #test_Acc.append(correct_cnt)
        test_loss_s.append(np.mean(test_loss))
        test_Acc_s.append(np.sum(test_Acc)/len(SpinData_test))
        print(epoch + 1,train_loss_s[-1],train_Acc_s[-1],test_loss_s[-1],test_Acc_s[-1],maxnum,time.time()-start_time)
        show_loss(train_loss_s, test_loss_s)
        show_Acc(train_Acc_s, test_Acc_s)

In [6]:
torch.set_default_tensor_type(torch.DoubleTensor)
plt.rcParams['figure.dpi'] = 300
mydevice = torch.device('cpu')
n = 784
m = 60000
tm = 10000
Dmax = 10
c = 10
batch_size = 10000
learning_rate = 0.005
path = time_now() + "_" + str(m) + "_" + str(batch_size) + "_" + str(
    tm) + "_" + str(Dmax) + "_" + str(learning_rate)
print(path)
os.makedirs(path)

20220724_234024_60000_10000_10000_10_0.005


In [7]:
path = "./" + path + "/"
raw_data = np.load("data_bin.npz")
train_data = raw_data['train_data']
train_data = torch.Tensor(train_data[:m, :])
show_imgs(train_data, 5, 15, 15, 5, "train_fig")
SpinData_train = mapData2Spin(train_data)
print(SpinData_train.shape)
test_data = raw_data['train_data']
test_data = torch.Tensor(test_data[:tm, :])
SpinData_test = mapData2Spin(test_data)
print(SpinData_test.shape)
train_label = raw_data['train_label'][:m]
train_label = np.eye(c)[train_label]
train_label = torch.Tensor(train_label)
test_label = raw_data['test_label'][:m]
test_label = np.eye(c)[test_label]
test_label = torch.Tensor(test_label)
print(train_label.shape)
print(test_label.shape)
torch.manual_seed(1)
np.random.seed(1)
SpinData_train = SpinData_train.to(mydevice)
SpinData_test = SpinData_test.to(mydevice)

cmps = CMPS(Dmax, n, c, mydevice)

torch.Size([60000, 784, 2])
torch.Size([10000, 784, 2])
torch.Size([60000, 10])
torch.Size([10000, 10])


<Figure size 4500x1500 with 0 Axes>

In [8]:
#train(cmps,SpinData_train,train_label,SpinData_test,test_label,batch_size,epochs = 2,learning_rate = 1e-3)
train(cmps,SpinData_train,train_label,SpinData_test,test_label,batch_size,epochs = 10,learning_rate = 1e-2)

1 514.1412229624588 0.0 494.3837938948187 0.0 tensor(1.3109e-99) 14.589141607284546
2 493.9373686663661 0.0 493.33029501119034 0.0 tensor(1.9704e-99) 31.76314640045166
3 493.1309675310192 0.0 492.768547823417 0.0 tensor(2.5560e-99) 55.26916980743408
4 492.6693832737619 0.0 492.42375134656373 0.0 tensor(3.0115e-99) 74.94097971916199
5 492.38000260523177 0.0 492.20196819854107 0.0 tensor(3.3529e-99) 95.35163402557373
6 492.193125562249 0.0 492.0575995647587 0.0 tensor(3.6021e-99) 115.58974885940552
7 492.07104525172014 0.0 491.96331517315934 0.0 tensor(3.7809e-99) 131.67711758613586
8 491.9910321491236 0.0 491.90098750106563 0.0 tensor(3.9080e-99) 151.4003245830536
9 491.9386196266396 0.0 491.8602296740993 0.0 tensor(3.9979e-99) 167.0203239917755
10 491.9043914282334 0.0 491.8336759768667 0.0 tensor(4.0615e-99) 182.52707719802856


<Figure size 3000x3000 with 0 Axes>