In [1]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import  SummaryWriter
import einops
writer=SummaryWriter('logs')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
transform=torchvision.transforms.Compose([torchvision.transforms.RandAugment(num_ops=2,magnitude=6),torchvision.transforms.ToTensor()])
cifar_trainset=torchvision.datasets.CIFAR10('CIFAR10',train=True,transform=transform,download=True)
cifar_testset=torchvision.datasets.CIFAR10('CIFAR10',train=False,transform=torchvision.transforms.ToTensor(),download=True)
print(len(cifar_trainset))
cifar_trainset,cifar_validateset=torch.utils.data.random_split(cifar_trainset,[40000,10000])
load_trainset=DataLoader(cifar_trainset,batch_size=64,shuffle=True,drop_last=True)
load_validateset=DataLoader(cifar_validateset,batch_size=64,shuffle=True,drop_last=True)
load_testset=DataLoader(cifar_testset,batch_size=64,drop_last=False)
print(cifar_testset[0][0].shape)

50000
torch.Size([3, 32, 32])


In [None]:
class residual_connection(nn.Module):
    def __init__(self,in_channel,hidden_channel,out_channel,kernel_size=3,stride=1,padding=1):
        super().__init__()
        self.conv1=nn.Conv2d(in_channel,hidden_channel,kernel_size=1,stride=1,padding=0)
        self.conv2=nn.Conv2d(hidden_channel,hidden_channel,kernel_size,padding=padding)
        self.conv3=nn.Conv2d(hidden_channel,out_channel,kernel_size=1,padding=0)
        self.silu=nn.SiLU()
        self.batchnorm1=nn.BatchNorm2d(in_channel)
        self.batchnorm2=nn.BatchNorm2d(hidden_channel)
        self.batchnorm3=nn.BatchNorm2d(hidden_channel)
    def forward(self,x):
        y=x.clone()
        x=self.batchnorm1(x)
        x=self.conv1(x)
        x=self.silu(x)
        x=self.batchnorm2(x)
        x=self.conv2(x)
        x=self.silu(x)
        x=self.batchnorm3(x)
        x=self.conv3(x)
        x=self.silu(x)
        x=x+y
        return x
class resnetblock(nn.Module):
    def __init__(self,in_channel,hidden_channel,out_channel,num_residual_connection=2,kernel_size=3,stride=1,padding=1):
        super().__init__()
        self.residual_connections=nn.ModuleList([residual_connection(in_channel,hidden_channel,out_channel,kernel_size,stride,padding) for _ in range(num_residual_connection)])
    def forward(self,x):
        for residual_connection in self.residual_connections:
            x=residual_connection(x)
        return x
class resnet(nn.Module):
    def __init__(self):
        super().__init__()
        self.batchnorm1 = nn.BatchNorm2d(3)
        self.conv1 = nn.Conv2d(3, 32, 5, padding=2)
        self.resnetblock1=resnetblock(32,16,32,2,5,1,padding=2)
        self.pool1 = nn.MaxPool2d(2)
        self.silu1 = nn.SiLU()

        # self.batchnorm2 = nn.BatchNorm2d(32)
        # self.conv2 = nn.Conv2d(32, 32, 5, padding=2)
        self.resnetblock2=resnetblock(32,16,32,2,5,1,padding=2)
        self.pool2 = nn.MaxPool2d(2)
        self.silu2 = nn.SiLU()

        self.batchnorm3 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, 5, padding=2)
        self.resnetblock3=resnetblock(64,16,64,2,5,1,padding=2)
        self.pool3 = nn.MaxPool2d(2)
        self.silu3 = nn.SiLU()

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 16, 64)
        self.silu4 = nn.SiLU()
        self.fc2 = nn.Linear(64, 10)
        
    #     # 初始化权重为 Xavier
    #     self._initialize_weights()
    
    # def _initialize_weights(self):
    #     """使用 Xavier 初始化所有卷积层和全连接层的权重"""
    #     for m in self.modules():
    #         if isinstance(m, (nn.Conv2d, nn.Linear)):
    #             nn.init.xavier_normal_(m.weight)
    #             if m.bias is not None:
    #                 nn.init.constant_(m.bias, 0)
    
    def forward(self,x):
        x=self.batchnorm1(x)
        x=self.conv1(x)
        x=self.resnetblock1(x)
        x=self.pool1(x)
        x=self.silu1(x)
        # x=self.batchnorm2(x)
        x=self.resnetblock2(x)
        x=self.pool2(x)
        x=self.silu2(x)
        x=self.batchnorm3(x)
        x=self.conv3(x)
        x=self.resnetblock3(x)
        x=self.pool3(x)
        x=self.silu3(x)
        x=self.flatten(x)
        x=self.silu4(self.fc1(x))
        x=self.fc2(x)
        return x
model=resnet()
model=model.to(device)
test=torch.randn(2,3,32,32,device=device)
print(model(test))


In [8]:
# 从 metric.py 导入 get_metrics 函数
# 如果没有 metric.py，可以使用下面的函数定义
def get_metrics(model,load_testset,beta=1,fault_tolerance=0):
    acc=torch.zeros(10,device=device)
    predict=torch.zeros(10,device=device)
    total=torch.zeros(10,device=device)
    if fault_tolerance:
        for data in load_testset:
            imgs,label=data
            imgs=imgs.to(device)
            label=label.to(device)
            ans=model(imgs)
            for i in range(10):
                acc[i]+=torch.sum(torch.sum(torch.topk(ans,1+fault_tolerance,dim=1).indices==label.unsqueeze(1),dim=-1)*(label==i)).item()
                total[i]+=torch.sum(label==i).item()
                predict[i]+=torch.sum(torch.topk(ans,1+fault_tolerance,dim=1).indices==i).item()
    else:
        for data in load_testset:
            imgs,label=data
            imgs=imgs.to(device)
            label=label.to(device)
            ans=model(imgs)
            for i in range(10):
                acc[i]+=torch.sum((torch.argmax(ans,axis=1)==label)*(label==i)).item()
                total[i]+=torch.sum(label==i).item()
                predict[i]+=torch.sum(torch.argmax(ans,axis=1)==i).item()
    precision=acc/predict
    recall=acc/total
    f1=precision*recall/(beta**0.5*precision+recall)*(1+beta**0.5)
    accuracy=torch.sum(acc)/torch.sum(total)
    return f1,accuracy,precision,recall

print(get_metrics(model,load_testset,fault_tolerance=0))


(tensor([0.1302, 0.0539, 0.0815, 0.1872,    nan,    nan, 0.0057, 0.0472,    nan,
        0.1805], device='cuda:0'), tensor(0.1171, device='cuda:0'), tensor([0.1423, 0.0989, 0.1429, 0.1138, 0.0000,    nan, 0.0612, 0.0990,    nan,
        0.1168], device='cuda:0'), tensor([0.1200, 0.0370, 0.0570, 0.5260, 0.0000, 0.0000, 0.0030, 0.0310, 0.0000,
        0.3970], device='cuda:0'))


In [7]:
class cifarmodel(nn.Module):
    def __init__(self):
        super().__init__()
        self.batchnorm1 = nn.BatchNorm2d(3)
        self.conv1 = nn.Conv2d(3,32,5,padding=2)
        self.silu1 = nn.SiLU()
        self.pool1 = nn.MaxPool2d(2)
        
        self.batchnorm2 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32,32,5,padding=2)
        self.silu2 = nn.SiLU()
        self.pool2 = nn.MaxPool2d(2)
        
        self.batchnorm3 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32,64,5,padding=2)
        self.silu3 = nn.SiLU()
        self.pool3 = nn.MaxPool2d(2)
        
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64*16,64)
        self.silu4 = nn.SiLU()
        self.fc2 = nn.Linear(64,10)
    
    def forward(self,x):
        x = self.batchnorm1(x)
        x = self.conv1(x)
        x = self.silu1(x)
        x = self.pool1(x)
        
        x = self.batchnorm2(x)
        x = self.conv2(x)
        x = self.silu2(x)
        x = self.pool2(x)
        
        x = self.batchnorm3(x)
        x = self.conv3(x)
        x = self.silu3(x)
        x = self.pool3(x)
        
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.silu4(x)
        x = self.fc2(x)
        return x
model=cifarmodel()
model=model.to(device)

In [14]:
class ViT(nn.Module):
    def __init__(self,patch1=4,patch2=4,channel=3):
        super().__init__()
        self.patch1=patch1
        self.patch2=patch2
        self.tokenize=nn.Linear(patch1*patch2*channel,64)
        self.embedding=nn.parameter.Parameter(torch.randn(65,64,device=device),requires_grad=True)
        self.classembedding=nn.parameter.Parameter(torch.randn(64,device=device),requires_grad=True)
        self.transformer=nn.TransformerEncoder(nn.TransformerEncoderLayer(64,4,dim_feedforward=128,batch_first=True),num_layers=3)
        self.fc=nn.Linear(64,10)
    def forward(self,x):
        x=einops.rearrange(x,'batch channel (h patch1) (w patch2) -> batch channel h w (patch1 patch2)',patch1=self.patch1,patch2=self.patch2)
        x=einops.rearrange(x,'batch channel h w patch -> batch (h w) (channel patch)')
        x=self.tokenize(x)
        x=torch.cat((self.classembedding.unsqueeze(0).unsqueeze(0).expand(x.shape[0],-1,-1),x),dim=1)
        x=x+self.embedding.unsqueeze(0).expand(x.shape[0],-1,-1)
        x=self.transformer(x)
        x=self.fc(x[:,0])
        return x
model=ViT()
model=model.to(device)

In [9]:
optimizer=torch.optim.AdamW(model.parameters(),lr=0.01)
loss=nn.CrossEntropyLoss()
loss=loss.to(device)
for epoch in range(5):
    model.train()
    totalloss=0
    for data in load_trainset:
        imgs,label=data
        imgs=imgs.to(device)
        label=label.to(device)
        result=model(imgs)
        result_loss=loss(result,label)
        optimizer.zero_grad()
        result_loss.backward()
        optimizer.step()
        totalloss+=result_loss.item()
    print(f"epoch={epoch},loss={totalloss},",end='')
    model.eval()
    writer.add_scalar(tag='loss',scalar_value=totalloss,global_step=epoch)
    with torch.no_grad():
        f1,acc,precision,recall=get_metrics(model,load_validateset)
        print(f"f1={f1.mean()},acc={acc}")
        writer.add_scalar(tag='f1',scalar_value=f1.mean(),global_step=epoch)
        writer.add_scalar(tag='acc',scalar_value=acc,global_step=epoch)
writer.close()


epoch=0,loss=1291.5514106750488,f1=nan,acc=0.18699920177459717
epoch=1,loss=1113.1976212263107,f1=0.36242586374282837,acc=0.38982370495796204
epoch=2,loss=925.9437798261642,f1=0.514957070350647,acc=0.5205328464508057
epoch=3,loss=810.2734661102295,f1=0.510861337184906,acc=0.528245210647583
epoch=4,loss=729.5858160853386,f1=0.5992826819419861,acc=0.6025640964508057
