In [1]:
# 使用伪码表生成wave从而得到对比训练集，防止恒正运算
# server给subject提供一个伪码表，subject利用伪码表生成对应的对照集
# 伪码与最终的hash key或者BCH code无关
# 判别器需要保留一定的分辨能力
# data index range [2, 66]
import torch
import torch.nn as nn
from torch.nn import *
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import os

from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader

from preprocess import Process

In [2]:
#定义生成器
class Generator(nn.Module):
    def __init__(self, subject_num):
        super(Generator,self).__init__()
        self.linear1 = nn.Linear(100,512)
        self.bn1=nn.BatchNorm1d(512)
        self.subject_num = subject_num
        self.linear2 = nn.Linear(subject_num,512)
        self.bn2=nn.BatchNorm1d(512)
        
        self.deconv1 = nn.Conv1d(1, 3, kernel_size=2, padding='same')
        self.bn3=nn.BatchNorm1d(3)
        self.deconv2 = nn.Conv1d(3, 6, kernel_size=2, padding='same')
        self.bn4=nn.BatchNorm1d(6)
        self.deconv3 = nn.Conv1d(6, 1, kernel_size=2, padding='same')
        
    def forward(self,x1,x2):
        x1=F.relu(self.linear1(x1.to(torch.float64)))
        x1=self.bn1(x1)
        x2=F.relu(self.linear2(x2))
        x2=self.bn2(x2)
        x=torch.cat([x1,x2],axis=1)
        x=F.relu(self.deconv1(torch.reshape(x, (x.size(0), 1, x.size(1)))))
        x=self.bn3(x)
        x=F.relu(self.deconv2(x))
        x=self.bn4(x)
        x=torch.tanh(self.deconv3(x))
        return x

In [3]:
#设备的配置
device='cuda' if torch.cuda.is_available() else 'cpu'
gen = torch.load("generator").to(device)

In [4]:
labels = np.zeros((54, 54), dtype='double')
for i in range(54):
    labels[i][i] = 1.0
labels = torch.tensor(labels).to(device)

In [5]:
random_seed = torch.randn(54,100,device=device)

In [6]:
comp_data = gen(random_seed, labels)

  return F.conv1d(input, weight, bias, self.stride,


In [7]:
comp_data.shape

torch.Size([54, 1, 1024])

In [8]:
series = []
items = 3
step = 3
subject_num = 0

try:
    series += Process(2).prepro(1024, step, items)
    subject_num += 1
except:
    print(f'subject {num} abandoned')

In [9]:
target = torch.reshape(torch.tensor(series).to(device), (items, 1, 1024))

In [10]:
in_data = torch.cat([comp_data, target], 0)

In [11]:
in_data.shape

torch.Size([57, 1, 1024])

In [12]:
comp_label = np.zeros((54,1), dtype='double')
comp_label = torch.tensor(comp_label).to(device)

In [13]:
true_label = np.ones((3,1), dtype='double')
true_label = torch.tensor(true_label).to(device)

In [14]:
in_label = torch.cat([comp_label, true_label], 0)

In [15]:
class MyDataset(Dataset):
    def __init__(self, sig, label):
        self.sig = sig
        self.label = label

    # need to overload
    def __len__(self):
        return self.sig.size(0)

    # need to overload
    def __getitem__(self, idx):
        return self.sig[idx], self.label[idx]

In [16]:
dataset = MyDataset(in_data, in_label)
dataloader = torch.utils.data.DataLoader(dataset,batch_size=4)

In [17]:
class ConvolutionBlock(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(ConvolutionBlock, self).__init__()
        self.block = nn.Sequential(nn.Conv1d(input_size, hidden_size, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
            nn.Conv1d(hidden_size, hidden_size, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU()
        )
        
    def forward(self, x):
        x = self.block(x)
        return x

In [18]:
class TransformerClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes, num_layers, num_heads, dropout, num_conv_blocks):
        super(TransformerClassifier, self).__init__()
        self.conv1 = ConvolutionBlock(input_size, hidden_size)
        self.conv_blocks = nn.ModuleList([
            ConvolutionBlock(hidden_size, hidden_size) for _ in range(num_conv_blocks)
        ])
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(hidden_size, num_heads, dim_feedforward=hidden_size, dropout=dropout),
            num_layers
        )
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        for conv_block in self.conv_blocks:
            x = conv_block(x)
        x = x.permute(0, 2, 1)  # Reshape to (batch_size, hidden_size, seq_len)
        x = self.transformer(x)
        x = x.mean(dim=1)  # Average the sequence dimension
        x = torch.sigmoid(self.fc(x))
        return x

In [19]:
# Example usage
input_size = 1
hidden_size = 128
num_classes = 1
num_layers = 6
num_heads = 4
dropout = 0.2
batch_size = 4
num_epochs = 10
num_conv_blocks = 2

In [20]:
cls = TransformerClassifier(input_size, hidden_size, num_classes, num_layers, num_heads, dropout, num_conv_blocks).to(device)
cls = cls.double()

In [21]:
criterion = nn.BCELoss()
optimizer = optim.Adam(cls.parameters(), lr=1e-6)

In [22]:
C_loss = []
C_acc = []

In [23]:
for epoch in range(200):
    epoch_iterator = tqdm(dataloader, desc="Training Epoch %d" % (epoch + 1), ncols = 100)
    #初始化损失值
    c_epoch_loss = 0
    acc_num = 0
    num = 0
    count = len(dataloader) #返回批次数
    #对数据集进行迭代
    for step, (subject, label) in enumerate(epoch_iterator):
        subject = subject.to(device) #把数据放到设备上
        label = label.to(device)
        size = subject.size(0)
        num += size
        
        class_train = cls(subject.data)
        c_loss = criterion(class_train, label.data)
        c_loss.backward()
        optimizer.step()
        
        for i in range(subject.size(0)):
            idx1 = torch.argmax(class_train[i])
            idx2 = torch.argmax(label[i])
            if torch.abs(class_train[i] - label[i]) < 0.5:
                acc_num += 1
        
        #累计每一个批次的loss
        with torch.no_grad():
            c_epoch_loss += c_loss
        epoch_iterator.set_postfix({"c_loss": '{0:1.5f}'.format(c_epoch_loss), "accuracy": '{0:1.3f}'.format(acc_num / num)})
        epoch_iterator.update(1)
            
    #求平均损失
    with torch.no_grad():
        c_epoch_loss /= count
        acc = acc_num / num
        C_loss.append(c_epoch_loss)
        C_acc.append(acc)

Training Epoch 1: 100%|█████████████| 15/15 [00:01<00:00, 13.97it/s, c_loss=9.51179, accuracy=0.947]
Training Epoch 2: 100%|█████████████| 15/15 [00:00<00:00, 16.14it/s, c_loss=9.20270, accuracy=0.947]
Training Epoch 3: 100%|█████████████| 15/15 [00:00<00:00, 16.06it/s, c_loss=8.85879, accuracy=0.947]
Training Epoch 4: 100%|█████████████| 15/15 [00:00<00:00, 16.09it/s, c_loss=8.48047, accuracy=0.947]
Training Epoch 5: 100%|█████████████| 15/15 [00:00<00:00, 16.04it/s, c_loss=8.10233, accuracy=0.947]
Training Epoch 6: 100%|█████████████| 15/15 [00:00<00:00, 16.08it/s, c_loss=7.75870, accuracy=0.947]
Training Epoch 7: 100%|█████████████| 15/15 [00:00<00:00, 16.02it/s, c_loss=7.41715, accuracy=0.947]
Training Epoch 8: 100%|█████████████| 15/15 [00:00<00:00, 15.92it/s, c_loss=7.10382, accuracy=0.947]
Training Epoch 9: 100%|█████████████| 15/15 [00:00<00:00, 15.97it/s, c_loss=6.81637, accuracy=0.947]
Training Epoch 10: 100%|████████████| 15/15 [00:00<00:00, 15.98it/s, c_loss=6.53481, accura

Training Epoch 82: 100%|████████████| 15/15 [00:00<00:00, 15.91it/s, c_loss=2.33070, accuracy=0.947]
Training Epoch 83: 100%|████████████| 15/15 [00:00<00:00, 15.86it/s, c_loss=2.30158, accuracy=0.947]
Training Epoch 84: 100%|████████████| 15/15 [00:00<00:00, 15.81it/s, c_loss=2.26870, accuracy=0.965]
Training Epoch 85: 100%|████████████| 15/15 [00:00<00:00, 15.87it/s, c_loss=2.26314, accuracy=0.965]
Training Epoch 86: 100%|████████████| 15/15 [00:00<00:00, 15.75it/s, c_loss=2.23005, accuracy=0.965]
Training Epoch 87: 100%|████████████| 15/15 [00:00<00:00, 15.87it/s, c_loss=2.19819, accuracy=0.965]
Training Epoch 88: 100%|████████████| 15/15 [00:00<00:00, 15.79it/s, c_loss=2.17725, accuracy=0.965]
Training Epoch 89: 100%|████████████| 15/15 [00:00<00:00, 15.79it/s, c_loss=2.13981, accuracy=0.965]
Training Epoch 90: 100%|████████████| 15/15 [00:00<00:00, 15.74it/s, c_loss=2.10843, accuracy=0.965]
Training Epoch 91: 100%|████████████| 15/15 [00:00<00:00, 15.85it/s, c_loss=2.07790, accura

Training Epoch 163: 100%|███████████| 15/15 [00:00<00:00, 15.88it/s, c_loss=0.73566, accuracy=1.000]
Training Epoch 164: 100%|███████████| 15/15 [00:00<00:00, 15.89it/s, c_loss=0.72603, accuracy=1.000]
Training Epoch 165: 100%|███████████| 15/15 [00:00<00:00, 15.66it/s, c_loss=0.71964, accuracy=1.000]
Training Epoch 166: 100%|███████████| 15/15 [00:00<00:00, 15.87it/s, c_loss=0.71036, accuracy=1.000]
Training Epoch 167: 100%|███████████| 15/15 [00:00<00:00, 15.88it/s, c_loss=0.70068, accuracy=1.000]
Training Epoch 168: 100%|███████████| 15/15 [00:00<00:00, 15.87it/s, c_loss=0.68968, accuracy=1.000]
Training Epoch 169: 100%|███████████| 15/15 [00:00<00:00, 15.90it/s, c_loss=0.68397, accuracy=1.000]
Training Epoch 170: 100%|███████████| 15/15 [00:00<00:00, 15.82it/s, c_loss=0.67724, accuracy=1.000]
Training Epoch 171: 100%|███████████| 15/15 [00:00<00:00, 15.74it/s, c_loss=0.66852, accuracy=1.000]
Training Epoch 172: 100%|███████████| 15/15 [00:00<00:00, 15.81it/s, c_loss=0.65760, accura