# Linear Layers


In [None]:
import torch
import torchvision
from torch.utils.data import DataLoader
from torch.nn import Linear, Module
from torch.utils.tensorboard import SummaryWriter

# 1. 准备数据集
dataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=64, drop_last=True) # drop_last防止最后一批不满64报错

# 2. 搭建网络 
class MyModel(Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear1 = Linear(32*32*3, 10)  # 输入是32*32的RGB图像，展平后是3072维，输出10类

    def forward(self, x):
        output = self.linear1(x)
        return output

my_model = MyModel()

# 3. 设置 TensorBoard
writer = SummaryWriter("logs")

step = 0
for data in dataloader:
    imgs, targets = data
    # 在 TensorBoard 里看看原始图像
    writer.add_images("input_imgs", imgs, step)

    output_input = torch.flatten(imgs, start_dim=1) 
    
    # 经过线性层
    output = my_model(output_input)
    
    print(f"输出形状: {output.shape}") 
    
    step += 1

writer.close()