In [22]:
import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder
from torchvision.transforms import v2
from torch.utils.data import DataLoader
device = torch.device('cuda:0')

In [23]:
train_path = r'../dataset/train'
valid_path = r'../dataset/valid'

## 数据加载器

In [24]:
size = 128
transforms = v2.Compose([
    v2.ToImage(),
    v2.Resize(size),
    v2.CenterCrop(size),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [25]:
train_dataset = ImageFolder(train_path,transforms)

In [26]:
train_data_loader = DataLoader(train_dataset,batch_size=32,shuffle=True)

## 残差神经网络

In [27]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, hid_channels):
        super().__init__()
        self.in_channels = in_channels
        self.hid_channels = hid_channels
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels,hid_channels,kernel_size=3,padding=1),
            nn.BatchNorm2d(hid_channels),
            nn.SiLU(),
            nn.Conv2d(hid_channels,hid_channels,kernel_size=7,padding=3),
            nn.BatchNorm2d(hid_channels),
            nn.SiLU(),
            nn.Conv2d(hid_channels,in_channels,kernel_size=3,padding=1)
        )

    def forward(self, x):
        return x + self.cnn(x)
    
class ResNet(nn.Module):
    def __init__(self, size, in_channels, num_classes, res_channels, hid_channels, num_res):
        super().__init__()
        self.size = size
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.res_channels = res_channels
        self.hid_channels = hid_channels
        self.num_res = num_res
        self.encoder = nn.Conv2d(in_channels,res_channels,kernel_size=3,padding=1)
        self.decoder = nn.Sequential(
            nn.SiLU(),
            nn.Conv2d(res_channels,num_classes,kernel_size=3,padding=1)
        )
        self.hid_res_layers = nn.Sequential(
            *[ResBlock(res_channels,hid_channels) for _ in range(num_res)]
        )
    
    def forward(self, x):
        y = self.encoder(x)
        y = self.hid_res_layers(y)
        y = self.decoder(y)
        return y.mean(dim=[-2,-1])

In [28]:
x = torch.randn([1,3,size,size],device=device)
resnet = ResNet([size,size],3,20,32,64,5).to(device)
y = resnet(x)

In [29]:
def train(model,dataloader,epoch,lr):
    optim = torch.optim.Adam(model.parameters(),lr=lr)
    loss_fun = nn.CrossEntropyLoss()

    for i in range(epoch):
        model.train()
        for data, label in dataloader:
            optim.zero_grad()
            pred = model(data.to(device))
            loss = loss_fun(pred,label.to(device))
            loss.backward()
            optim.step()
            print(i, loss.item())

In [30]:
train(resnet,train_data_loader,10,1e-3)

0 3.055065631866455
0 2.860116481781006
0 2.9382009506225586
0 3.1308631896972656
0 2.9581871032714844
0 2.8232150077819824
0 2.824767589569092
0 2.7621848583221436
0 2.8927369117736816
0 2.636106252670288
0 2.858337879180908
0 2.7278406620025635
0 2.894214391708374
0 2.837451934814453
0 2.7823691368103027
0 2.8448100090026855
0 2.56420636177063
0 2.865365505218506
0 2.8040270805358887
0 2.639960289001465
0 2.9402823448181152
0 2.786759614944458
0 2.688908100128174
0 2.677626609802246
0 2.5823912620544434
0 2.876296281814575
0 2.750661611557007
0 2.826972007751465
0 2.7340121269226074
0 2.6149139404296875
0 2.9647927284240723
0 2.8132948875427246
0 2.7355782985687256
0 2.6962785720825195
0 2.523383617401123
0 2.6128854751586914
0 2.7519588470458984
0 2.610339641571045
0 2.8236851692199707
0 2.6183505058288574
0 2.5226759910583496
0 2.6314849853515625
0 2.7119803428649902
0 2.9558873176574707
0 2.8779993057250977
0 2.5914134979248047
0 2.6089534759521484
0 2.869424343109131
0 2.63371467

KeyboardInterrupt: 