In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torchvision.transforms import v2
import matplotlib.pyplot as plt
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
train_path = './dataset/train'
valid_path = './dataset/valid'

In [3]:
# 数据预处理
transforms = v2.Compose([
    v2.ToImage(),
    v2.CenterCrop((256, 256)),
    v2.RandomHorizontalFlip(p = 0.5),
    v2.ToDtype(torch.float32, scale = True),
    v2.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
]) 

In [4]:
train_dataset = ImageFolder(train_path, transforms)

In [5]:
train_dataset[0][0].max()

tensor(2.6400)

In [6]:
data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)

In [7]:
class ResBock(nn.Module):
    def __init__(self, in_chanels,hid_chanels):
        super().__init__()
        self.in_chanels = in_chanels
        self.hid_chanels = hid_chanels
        self.cnn = nn.Sequential(
          nn.Conv2d(in_chanels, hid_chanels, kernel_size=3, padding=1),
          nn.BatchNorm2d(hid_chanels),
          nn.SiLU(),
          nn.Conv2d(in_chanels, hid_chanels, kernel_size=3, padding=1),
          nn.BatchNorm2d(hid_chanels),
          nn.SiLU(),
        )
    def forward(self, x):    
        return x + self.cnn(x)

class ResNet(nn.Module):
    def __init__(self, in_chanels, out_chanels, hid_chanels, res_chanels, num_res):
        super().__init__()
        self.in_chanels = in_chanels
        self.out_chanels = out_chanels
        self.res_chanels = res_chanels
        self.hid_chanels = hid_chanels
        self.num_res = num_res
        self.encoder = nn.Sequential(
            nn.Conv2d(in_chanels, hid_chanels, kernel_size=3, padding=1),
            nn.SiLU()
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(res_chanels, out_chanels, kernel_size=3, padding=1)
        )    
        self.hid_chanels_layers = nn.Sequential(
            *[ResBock(hid_chanels, res_chanels) for _ in range(num_res)]
        )
    def forward(self, x):
        y = self.encoder(x)
        y = self.hid_chanels_layers(y)
        y = self.decoder(y)
        return y.mean(dim=[-2,-1])

In [8]:
x = torch.randn((1, 3, 256, 256), device = device)
resnet = ResNet(3, 20, 32, 32, 20).to(device)
y = resnet(x)
print(y.shape)

torch.Size([1, 20])


In [9]:
def train_dataset(model,data_loader,epoch,batch,lr):
    optim = torch.optim.Adam(model.parameters(),lr=lr)
    loss_fun = nn.BCELoss()

    for ep in range(epoch):
        model.train()
        for i in range(batch):
            optim.zero_grad()
            data = next(iter(data_loader))