In [1]:
import torch
from torch.nn import (
    Linear,
    Conv2d,
    AvgPool2d,
    Module,
    Sequential,
    ReLU,
    Flatten,
    MaxPool2d
)

class LetNet(Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv = Sequential(
            Conv2d(1,6,kernel_size=5,padding=2),
            ReLU(),
            AvgPool2d(2),

            Conv2d(6,16,kernel_size=5),
            ReLU(),
            AvgPool2d(2),

            Flatten(),
        )
        self.linear = Sequential(
            Linear(400,120),
            ReLU(),

            Linear(120,84),
            ReLU(),

            Linear(84,10)
        )

    def forward(self,input):
        input = input.reshape((-1,1,28,28))
        input = self.conv(input)
        output = self.linear(input)

        return output

net = LetNet()
def init_net_parpmter(layer):
    if isinstance(layer,(Linear,Conv2d)):
        torch.nn.init.kaiming_normal_(layer.weight,mode="fan_in",nonlinearity='relu')
net.apply(init_net_parpmter)


LetNet(
  (conv): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (6): Flatten(start_dim=1, end_dim=-1)
  )
  (linear): Sequential(
    (0): Linear(in_features=400, out_features=120, bias=True)
    (1): ReLU()
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): ReLU()
    (4): Linear(in_features=84, out_features=10, bias=True)
  )
)

In [2]:
def get_device(chose_device = 0 ):
    return f'cuda:{chose_device}' if torch.cuda.is_available() else "cpu"

device = get_device()

net.to(device)


LetNet(
  (conv): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (6): Flatten(start_dim=1, end_dim=-1)
  )
  (linear): Sequential(
    (0): Linear(in_features=400, out_features=120, bias=True)
    (1): ReLU()
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): ReLU()
    (4): Linear(in_features=84, out_features=10, bias=True)
  )
)

In [3]:
from torchvision import transforms
import torchvision

trans = transforms.Compose(
    [
        transforms.ToTensor()
    ]
)
mninst_train = torchvision.datasets.FashionMNIST(
    root='FashionMINIST',
    train=True,
    download=True,
    transform=trans
    )
mninst_text = torchvision.datasets. FashionMNIST(
    root="FashionMINIST",
    train=False,
    download=True,
    transform=trans
)


len(mninst_train),len(mninst_text)

(60000, 10000)

In [4]:
from torch.utils.data import DataLoader

def get_dataloader(dataset,mode,batch_size=256):
    
    return DataLoader(
        dataset=dataset,
        shuffle= ('train' == mode),
        drop_last= ('train' == mode),
        batch_size=batch_size
    )

train_dataloader = get_dataloader(mninst_train,'train')
test_dataloader = get_dataloader(mninst_text,'test')


In [5]:
n_epoch = 50
lossfunction = torch.nn.CrossEntropyLoss()
optimer = getattr(torch.optim,'Adam')(net.parameters(),lr=0.005,weight_decay=1e-4)
from tqdm.auto import  tqdm
def val(val_dataloader , model , device):
    model.eval()
    with torch.no_grad():
        acc = 0
        run =0 
        for val_feature , val_label in tqdm(val_dataloader):
            val_feature = val_feature.to(device)
            run +=1
            val_label = val_label.to(device)

            y_predict = net(val_feature)
            max_index = torch.argmax(y_predict,1)

            acc += (max_index == val_label).float().mean().item()

    return acc/run * 100



net.train()
for epoch in tqdm(range(n_epoch)):
    acc = 0
    run =0 
    for train_feature,train_label in tqdm(train_dataloader):
        run += 1
        train_feature = train_feature.to(device)
        train_label = train_label.to(device)
        y_hat = net(train_feature)

        max_index = torch.argmax(y_hat,1)
        optimer.zero_grad()
        loss = lossfunction(y_hat,train_label).to(device)
        loss.backward()
        optimer.step()
        acc += (max_index == train_label).float().mean().item()
    
    print("train:",acc/run*100)

    accuracy = val(test_dataloader,net,device)
    print("val:",accuracy)

    


  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 234/234 [00:02<00:00, 115.33it/s]


train: 80.13989049145299


100%|██████████| 40/40 [00:00<00:00, 157.56it/s]
  2%|▏         | 1/50 [00:02<01:52,  2.29s/it]

val: 85.380859375


100%|██████████| 234/234 [00:01<00:00, 139.85it/s]


train: 87.0693108974359


100%|██████████| 40/40 [00:00<00:00, 163.65it/s]
  4%|▍         | 2/50 [00:04<01:39,  2.07s/it]

val: 87.265625


100%|██████████| 234/234 [00:01<00:00, 137.62it/s]


train: 88.73030181623932


100%|██████████| 40/40 [00:00<00:00, 162.66it/s]
  6%|▌         | 3/50 [00:06<01:34,  2.02s/it]

val: 88.53515625


100%|██████████| 234/234 [00:01<00:00, 136.27it/s]


train: 89.47148771367522


100%|██████████| 40/40 [00:00<00:00, 161.63it/s]
  8%|▊         | 4/50 [00:08<01:31,  2.00s/it]

val: 87.607421875


 12%|█▏        | 29/234 [00:00<00:01, 131.75it/s]
  8%|▊         | 4/50 [00:08<01:35,  2.09s/it]


KeyboardInterrupt: 