In [25]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import FashionMNIST,CIFAR10
from torchvision.transforms import ToTensor
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from datetime import datetime


In [26]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [27]:
train_data = FashionMNIST(
    root="../data",
    train=True,
    download=True,
    transform=ToTensor()
)
test_data = FashionMNIST(
    root="../data",
    train=False,
    download=True,
    transform=ToTensor()
)

In [28]:
def data_proc(data):
    imgs, labels = [],[]
    for img, label in data:
        img = img.reshape(-1)
        imgs.append(img)
        labels.append(label)
    return torch.stack(imgs), torch.tensor(labels, dtype=torch.long)

In [29]:
train_dl = DataLoader(train_data, batch_size=16, shuffle=True)
next(iter(train_dl))[0].shape

torch.Size([16, 1, 28, 28])

In [30]:
def train_model(model, batch_size, train_data, test_data,loss_fn, optimizer, epoch, device=None):
    log_dir = f"../source/logs/{datetime.now().strftime('%Y%m%d_%H%M')}"
    writer = SummaryWriter(log_dir)
    
    
    train_loss, valid_loss, valid_acc = 0, 0, 0
    train_dl = DataLoader(train_data, batch_size=batch_size, shuffle=True)# , collate_fn=data_proc
    test_dl = DataLoader(test_data, batch_size=batch_size)

    shape = next(iter(train_dl))[0].shape

    if device is not None:
        model = model.to(device)
        loss_fn = loss_fn.to(device)
        writer.add_graph(model, torch.rand(batch_size,shape[1],shape[2],shape[3],device=device)) 
    
    for e in range(epoch):
        model.train()
        process_bar = tqdm(train_dl)
        train_total_loss = 0
        for i, (img, lbl) in enumerate(process_bar, start=1):
            img, lbl = img.to(device), lbl.to(device)
            optimizer.zero_grad()
            y_hat = model(img)
            loss = loss_fn(y_hat, lbl)
            train_total_loss += loss
            loss.backward()
            optimizer.step()
            process_bar.set_description(f"epoch: {e+1}, avg_train_loss: {(train_total_loss.item() / i):.4f}")
        writer.add_scalar("train_loss", loss.item(), train_loss)
        train_loss += 1
        
        model.eval()
        with torch.no_grad():
            correct = 0
            valid_total_loss = 0
            for img, lbl in test_dl:
                img, lbl = img.to(device), lbl.to(device)
                y_hat = model(img)
                loss = loss_fn(y_hat, lbl)
                correct += (y_hat.argmax(-1) == lbl).sum()
                valid_total_loss += loss
            print("validation accuracy:", correct.item() / len(test_data))
            print("validation per_batch loss:", valid_total_loss.item() / len(test_dl))
            print("*" * 50)
            
            writer.add_scalar("validation_loss", valid_total_loss.item() / len(test_dl), valid_loss)
            valid_loss += 1
            writer.add_scalar("validation_acc", correct.item() / len(test_data), valid_acc)
            valid_acc += 1

In [31]:
class ImgClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.seq = nn.Sequential(
                nn.Flatten(),
                nn.Linear(784, 256),
                nn.ReLU(),
                nn.Linear(256, 10)
            )
                
    def forward(self, X):
        return self.seq(X)


model = ImgClassifier()
model = model.to(device)
batch_size = 16
epoch = 20
lr = 1e-3
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [34]:
train_model(model, batch_size, train_data, test_data, loss_fn, optimizer, epoch,device)

epoch: 1, avg_train_loss: 0.4162: 100%|██████████| 3750/3750 [00:13<00:00, 279.22it/s]


validation accuracy: 0.8215
validation per_batch loss: 0.491342626953125
**************************************************


epoch: 2, avg_train_loss: 0.3495: 100%|██████████| 3750/3750 [00:13<00:00, 278.75it/s]


validation accuracy: 0.8576
validation per_batch loss: 0.3870287841796875
**************************************************


epoch: 3, avg_train_loss: 0.3179: 100%|██████████| 3750/3750 [00:13<00:00, 275.97it/s]


validation accuracy: 0.8529
validation per_batch loss: 0.3963774658203125
**************************************************


epoch: 4, avg_train_loss: 0.2968: 100%|██████████| 3750/3750 [00:13<00:00, 278.49it/s]


validation accuracy: 0.8677
validation per_batch loss: 0.360337451171875
**************************************************


epoch: 5, avg_train_loss: 0.2822: 100%|██████████| 3750/3750 [00:13<00:00, 283.64it/s]


validation accuracy: 0.8773
validation per_batch loss: 0.34071748046875
**************************************************


epoch: 6, avg_train_loss: 0.2673: 100%|██████████| 3750/3750 [00:13<00:00, 283.46it/s]


validation accuracy: 0.877
validation per_batch loss: 0.356355224609375
**************************************************


epoch: 7, avg_train_loss: 0.2563: 100%|██████████| 3750/3750 [00:13<00:00, 282.98it/s]


validation accuracy: 0.8775
validation per_batch loss: 0.3431870361328125
**************************************************


epoch: 8, avg_train_loss: 0.2467: 100%|██████████| 3750/3750 [00:12<00:00, 290.36it/s]


validation accuracy: 0.8832
validation per_batch loss: 0.347034912109375
**************************************************


epoch: 9, avg_train_loss: 0.2379: 100%|██████████| 3750/3750 [00:12<00:00, 293.08it/s]


validation accuracy: 0.8775
validation per_batch loss: 0.3660745849609375
**************************************************


epoch: 10, avg_train_loss: 0.2276: 100%|██████████| 3750/3750 [00:12<00:00, 302.65it/s]


validation accuracy: 0.8759
validation per_batch loss: 0.3677202392578125
**************************************************


epoch: 11, avg_train_loss: 0.2214: 100%|██████████| 3750/3750 [00:12<00:00, 295.76it/s]


validation accuracy: 0.8863
validation per_batch loss: 0.341362841796875
**************************************************


epoch: 12, avg_train_loss: 0.2146: 100%|██████████| 3750/3750 [00:12<00:00, 296.88it/s]


validation accuracy: 0.8884
validation per_batch loss: 0.34800673828125
**************************************************


epoch: 13, avg_train_loss: 0.2051: 100%|██████████| 3750/3750 [00:12<00:00, 297.09it/s]


validation accuracy: 0.8777
validation per_batch loss: 0.3713918701171875
**************************************************


epoch: 14, avg_train_loss: 0.1994: 100%|██████████| 3750/3750 [00:12<00:00, 297.66it/s]


validation accuracy: 0.879
validation per_batch loss: 0.39333115234375
**************************************************


epoch: 15, avg_train_loss: 0.1954: 100%|██████████| 3750/3750 [00:12<00:00, 296.73it/s]


validation accuracy: 0.8914
validation per_batch loss: 0.3642329345703125
**************************************************


epoch: 16, avg_train_loss: 0.1889: 100%|██████████| 3750/3750 [00:12<00:00, 301.83it/s]


validation accuracy: 0.8655
validation per_batch loss: 0.488431201171875
**************************************************


epoch: 17, avg_train_loss: 0.1838: 100%|██████████| 3750/3750 [00:12<00:00, 300.56it/s]


validation accuracy: 0.8885
validation per_batch loss: 0.3623718505859375
**************************************************


epoch: 18, avg_train_loss: 0.1794: 100%|██████████| 3750/3750 [00:12<00:00, 298.55it/s]


validation accuracy: 0.885
validation per_batch loss: 0.3887227783203125
**************************************************


epoch: 19, avg_train_loss: 0.1748: 100%|██████████| 3750/3750 [00:12<00:00, 298.95it/s]


validation accuracy: 0.8795
validation per_batch loss: 0.4017836669921875
**************************************************


epoch: 20, avg_train_loss: 0.1696: 100%|██████████| 3750/3750 [00:12<00:00, 296.16it/s]


validation accuracy: 0.8847
validation per_batch loss: 0.4087773681640625
**************************************************


卷积

In [35]:
class ImgClassifier_conv(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding="valid")
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding="valid")
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(64 * 24 * 24, 10)  # 注意这里的输入尺寸需要根据卷积层的输出调整

    def forward(self, X):
        X = self.relu(self.conv1(X))
        X = self.relu(self.conv2(X))
        X = self.flatten(X)
        return self.fc(X)

model_conv = ImgClassifier_conv()
optimizer_conv = torch.optim.Adam(model_conv.parameters(), lr=lr)
train_model(model_conv, batch_size, train_data, test_data, loss_fn, optimizer_conv, epoch, device)

epoch: 1, avg_train_loss: 0.3720: 100%|██████████| 3750/3750 [00:14<00:00, 261.97it/s]


validation accuracy: 0.8822
validation per_batch loss: 0.318248095703125
**************************************************


epoch: 2, avg_train_loss: 0.2421: 100%|██████████| 3750/3750 [00:14<00:00, 265.21it/s]


validation accuracy: 0.9074
validation per_batch loss: 0.2603870849609375
**************************************************


epoch: 3, avg_train_loss: 0.1855: 100%|██████████| 3750/3750 [00:13<00:00, 269.03it/s]


validation accuracy: 0.9114
validation per_batch loss: 0.261960498046875
**************************************************


epoch: 4, avg_train_loss: 0.1446: 100%|██████████| 3750/3750 [00:13<00:00, 268.08it/s]


validation accuracy: 0.9118
validation per_batch loss: 0.273916357421875
**************************************************


epoch: 5, avg_train_loss: 0.1147: 100%|██████████| 3750/3750 [00:13<00:00, 269.55it/s]


validation accuracy: 0.9109
validation per_batch loss: 0.3017988037109375
**************************************************


epoch: 6, avg_train_loss: 0.0914: 100%|██████████| 3750/3750 [00:14<00:00, 266.07it/s]


validation accuracy: 0.9074
validation per_batch loss: 0.34546748046875
**************************************************


epoch: 7, avg_train_loss: 0.0740: 100%|██████████| 3750/3750 [00:14<00:00, 267.15it/s]


validation accuracy: 0.9062
validation per_batch loss: 0.3726767578125
**************************************************


epoch: 8, avg_train_loss: 0.0598: 100%|██████████| 3750/3750 [00:13<00:00, 267.94it/s]


validation accuracy: 0.9073
validation per_batch loss: 0.4082210693359375
**************************************************


epoch: 9, avg_train_loss: 0.0503: 100%|██████████| 3750/3750 [00:14<00:00, 264.29it/s]


validation accuracy: 0.9026
validation per_batch loss: 0.43705986328125
**************************************************


epoch: 10, avg_train_loss: 0.0418: 100%|██████████| 3750/3750 [00:13<00:00, 269.57it/s]


validation accuracy: 0.9037
validation per_batch loss: 0.464376123046875
**************************************************


epoch: 11, avg_train_loss: 0.0354: 100%|██████████| 3750/3750 [00:13<00:00, 268.10it/s]


validation accuracy: 0.904
validation per_batch loss: 0.505641943359375
**************************************************


epoch: 12, avg_train_loss: 0.0318: 100%|██████████| 3750/3750 [00:14<00:00, 265.26it/s]


validation accuracy: 0.9038
validation per_batch loss: 0.53517099609375
**************************************************


epoch: 13, avg_train_loss: 0.0276: 100%|██████████| 3750/3750 [00:13<00:00, 272.13it/s]


validation accuracy: 0.8998
validation per_batch loss: 0.59549306640625
**************************************************


epoch: 14, avg_train_loss: 0.0245: 100%|██████████| 3750/3750 [00:14<00:00, 266.80it/s]


validation accuracy: 0.9003
validation per_batch loss: 0.6547228515625
**************************************************


epoch: 15, avg_train_loss: 0.0222: 100%|██████████| 3750/3750 [00:14<00:00, 266.23it/s]


validation accuracy: 0.8983
validation per_batch loss: 0.67363046875
**************************************************


epoch: 16, avg_train_loss: 0.0205: 100%|██████████| 3750/3750 [00:14<00:00, 264.30it/s]


validation accuracy: 0.9001
validation per_batch loss: 0.69286455078125
**************************************************


epoch: 17, avg_train_loss: 0.0190: 100%|██████████| 3750/3750 [00:14<00:00, 267.17it/s]


validation accuracy: 0.8999
validation per_batch loss: 0.766913232421875
**************************************************


epoch: 18, avg_train_loss: 0.0176: 100%|██████████| 3750/3750 [00:14<00:00, 267.20it/s]


validation accuracy: 0.8996
validation per_batch loss: 0.803832080078125
**************************************************


epoch: 19, avg_train_loss: 0.0164: 100%|██████████| 3750/3750 [00:14<00:00, 265.05it/s]


validation accuracy: 0.8961
validation per_batch loss: 0.79076650390625
**************************************************


epoch: 20, avg_train_loss: 0.0162: 100%|██████████| 3750/3750 [00:14<00:00, 266.19it/s]


validation accuracy: 0.8984
validation per_batch loss: 0.93686328125
**************************************************


迁移学习

In [36]:
from torchvision.models import efficientnet_b0
model_efficientnet = efficientnet_b0(weights="DEFAULT")

In [37]:
print(model_efficientnet)

EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormActivat

In [38]:
custom_classifier = nn.Sequential(
    nn.Dropout(0.2),
    nn.Linear(1280, 10)
)

# 冻结底层网络
# model_efficientnet.features.trainable = False # torch.no_grad()
# 冻结所有参数
model_efficientnet.requires_grad_(False)

# 解冻自定义分类器的参数
model_efficientnet.classifier.requires_grad_(True)



model_efficientnet.classifier = custom_classifier

In [39]:
print(model_efficientnet)

EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormActivat

In [40]:
model_efficientnet.features[0][0].weight[0]

tensor([[[ 0.1216,  0.6563,  0.4567],
         [-0.1109, -0.6100, -0.3334],
         [ 0.0280, -0.1031, -0.1032]],

        [[ 0.0636,  1.6552,  1.7436],
         [-0.1365, -1.5367, -1.5937],
         [ 0.0502, -0.1136, -0.1260]],

        [[ 0.0873,  0.3613,  0.2695],
         [-0.1197, -0.2812, -0.2188],
         [ 0.0367, -0.0708, -0.0819]]])

In [41]:
train_data = CIFAR10(
    root="../data",
    train=True,
    download=True,
    transform=ToTensor()
)
test_data = CIFAR10(
    root="../data",
    train=False,
    download=True,
    transform=ToTensor()
)

Files already downloaded and verified
Files already downloaded and verified


In [42]:
optimizer_efficientnet = torch.optim.Adam(model_efficientnet.parameters(), lr = lr)

In [43]:
train_model(model_efficientnet, batch_size, train_data, test_data, loss_fn, optimizer_efficientnet, epoch, device)

epoch: 1, avg_train_loss: 2.0064: 100%|██████████| 3125/3125 [00:31<00:00, 98.89it/s] 


validation accuracy: 0.3145
validation per_batch loss: 2.026816015625
**************************************************


epoch: 2, avg_train_loss: 2.0054: 100%|██████████| 3125/3125 [00:33<00:00, 94.37it/s] 


validation accuracy: 0.3195
validation per_batch loss: 1.9668712890625
**************************************************


epoch: 3, avg_train_loss: 2.0054: 100%|██████████| 3125/3125 [00:31<00:00, 98.29it/s] 


validation accuracy: 0.314
validation per_batch loss: 2.014690625
**************************************************


epoch: 4, avg_train_loss: 1.9970: 100%|██████████| 3125/3125 [00:30<00:00, 100.94it/s]


validation accuracy: 0.3151
validation per_batch loss: 1.99359140625
**************************************************


epoch: 5, avg_train_loss: 2.0019: 100%|██████████| 3125/3125 [00:31<00:00, 98.72it/s] 


validation accuracy: 0.3127
validation per_batch loss: 1.9954884765625
**************************************************


epoch: 6, avg_train_loss: 2.0029: 100%|██████████| 3125/3125 [00:31<00:00, 100.32it/s]


validation accuracy: 0.3217
validation per_batch loss: 1.966137890625
**************************************************


epoch: 7, avg_train_loss: 2.0073: 100%|██████████| 3125/3125 [00:31<00:00, 98.39it/s] 


validation accuracy: 0.3197
validation per_batch loss: 1.961698046875
**************************************************


epoch: 8, avg_train_loss: 2.0008: 100%|██████████| 3125/3125 [00:31<00:00, 98.66it/s] 


validation accuracy: 0.3154
validation per_batch loss: 2.0008875
**************************************************


epoch: 9, avg_train_loss: 2.0034: 100%|██████████| 3125/3125 [00:32<00:00, 96.59it/s] 


validation accuracy: 0.3173
validation per_batch loss: 1.99473359375
**************************************************


epoch: 10, avg_train_loss: 1.9980: 100%|██████████| 3125/3125 [00:29<00:00, 104.63it/s]


validation accuracy: 0.3133
validation per_batch loss: 1.9955732421875
**************************************************


epoch: 11, avg_train_loss: 1.9992: 100%|██████████| 3125/3125 [00:30<00:00, 103.18it/s]


validation accuracy: 0.3083
validation per_batch loss: 1.9851837890625
**************************************************


epoch: 12, avg_train_loss: 2.0029: 100%|██████████| 3125/3125 [00:31<00:00, 100.64it/s]


validation accuracy: 0.3141
validation per_batch loss: 2.041709375
**************************************************


epoch: 13, avg_train_loss: 1.9958: 100%|██████████| 3125/3125 [00:30<00:00, 103.43it/s]


validation accuracy: 0.3315
validation per_batch loss: 1.962728515625
**************************************************


epoch: 14, avg_train_loss: 2.0049: 100%|██████████| 3125/3125 [00:30<00:00, 100.88it/s]


validation accuracy: 0.3167
validation per_batch loss: 1.9671353515625
**************************************************


epoch: 15, avg_train_loss: 1.9991: 100%|██████████| 3125/3125 [00:32<00:00, 95.84it/s] 


validation accuracy: 0.3203
validation per_batch loss: 1.959537109375
**************************************************


epoch: 16, avg_train_loss: 2.0080: 100%|██████████| 3125/3125 [00:31<00:00, 98.42it/s] 


validation accuracy: 0.3152
validation per_batch loss: 1.9769986328125
**************************************************


epoch: 17, avg_train_loss: 1.9999: 100%|██████████| 3125/3125 [00:32<00:00, 97.33it/s] 


validation accuracy: 0.3083
validation per_batch loss: 1.994339453125
**************************************************


epoch: 18, avg_train_loss: 2.0050: 100%|██████████| 3125/3125 [00:30<00:00, 101.34it/s]


validation accuracy: 0.3097
validation per_batch loss: 2.0061033203125
**************************************************


epoch: 19, avg_train_loss: 2.0033: 100%|██████████| 3125/3125 [00:30<00:00, 101.10it/s]


validation accuracy: 0.3168
validation per_batch loss: 1.9621154296875
**************************************************


epoch: 20, avg_train_loss: 1.9968: 100%|██████████| 3125/3125 [00:30<00:00, 102.85it/s]


validation accuracy: 0.3098
validation per_batch loss: 1.9883439453125
**************************************************


In [44]:
model_efficientnet.features[0][0].weight[0]

tensor([[[ 0.1216,  0.6563,  0.4567],
         [-0.1109, -0.6100, -0.3334],
         [ 0.0280, -0.1031, -0.1032]],

        [[ 0.0636,  1.6552,  1.7436],
         [-0.1365, -1.5367, -1.5937],
         [ 0.0502, -0.1136, -0.1260]],

        [[ 0.0873,  0.3613,  0.2695],
         [-0.1197, -0.2812, -0.2188],
         [ 0.0367, -0.0708, -0.0819]]], device='cuda:0')