In [32]:
import torch
from torch.nn import (
    Module,
    Linear,
    ReLU,
    Conv2d,
    Sequential,
    MaxPool2d,
    Flatten,
    Dropout,
    AdaptiveAvgPool2d
)

In [33]:
def nin_block(in_channel,out_channel,kernel_size, strides,padding):
    return Sequential(
        Conv2d(in_channel,out_channel,kernel_size=kernel_size,stride=strides,padding=padding),
        ReLU(),
        Conv2d(out_channel,out_channel,kernel_size=1),
        ReLU(),
        Conv2d(out_channel,out_channel,kernel_size=1),
        ReLU(),
    )

In [34]:
net = Sequential(
    nin_block(1,96,11,4,0),
    MaxPool2d(3,stride=2),
    nin_block(96,256,5,1,2),
    MaxPool2d(3,stride=2),
    nin_block(256,384,3,1,1),
    MaxPool2d(3,stride=2),
    Dropout(0.5),
    nin_block(384,10,3,1,1),
    AdaptiveAvgPool2d((1,1)),
    Flatten()
)
x = torch.rand(size=(1,1,224,224))
try:
    for index,layer in enumerate(net):
        x = layer(x)
        print(layer.__class__.__name__ , "output_size",x.shape)
except Exception as e :
    print(index)

Sequential output_size torch.Size([1, 96, 54, 54])
MaxPool2d output_size torch.Size([1, 96, 26, 26])
Sequential output_size torch.Size([1, 256, 26, 26])
MaxPool2d output_size torch.Size([1, 256, 12, 12])
Sequential output_size torch.Size([1, 384, 12, 12])
MaxPool2d output_size torch.Size([1, 384, 5, 5])
Dropout output_size torch.Size([1, 384, 5, 5])
Sequential output_size torch.Size([1, 10, 5, 5])
AdaptiveAvgPool2d output_size torch.Size([1, 10, 1, 1])
Flatten output_size torch.Size([1, 10])


In [37]:
def init_net_parpmter(layer):
    if isinstance(layer,(Linear,Conv2d)):
        torch.nn.init.kaiming_normal_(layer.weight,mode="fan_in",nonlinearity='relu')
net.apply(init_net_parpmter)
def get_device(chose_device = 0 ):
    return f'cuda:{chose_device}' if torch.cuda.is_available() else "cpu"

device = get_device()

net.to(device)
from torchvision import transforms
import torchvision

trans = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((224,224))
    ]
)
mninst_train = torchvision.datasets.FashionMNIST(
    root='FashionMINIST',
    train=True,
    download=False,
    transform=trans
    )
mninst_text = torchvision.datasets. FashionMNIST(
    root="FashionMINIST",
    train=False,
    download=False,
    transform=trans
)


len(mninst_train),len(mninst_text)
from torch.utils.data import DataLoader

def get_dataloader(dataset,mode,batch_size=128):
    
    return DataLoader(
        dataset=dataset,
        shuffle= ('train' == mode),
        drop_last= ('train' == mode),
        batch_size=batch_size
    )

train_dataloader = get_dataloader(mninst_train,'train')
test_dataloader = get_dataloader(mninst_text,'test')
n_epoch = 10
lossfunction = torch.nn.CrossEntropyLoss()
optimer = getattr(torch.optim,'Adam')(net.parameters(),lr=0.0001,weight_decay=1e-4)
from tqdm.auto import  tqdm
def val(val_dataloader , model , device):
    model.eval()
    with torch.no_grad():
        acc = 0
        run =0 
        for val_feature , val_label in tqdm(val_dataloader):
            val_feature = val_feature.to(device)
            run +=1
            val_label = val_label.to(device)

            y_predict = net(val_feature)
            max_index = torch.argmax(y_predict,1)

            acc += (max_index == val_label).float().mean().item()

    return acc/run * 100



net.train()
for epoch in tqdm(range(n_epoch)):
    acc = 0
    run =0 
    for train_feature,train_label in tqdm(train_dataloader):
        run += 1
        train_feature = train_feature.to(device)
        train_label = train_label.to(device)
        y_hat = net(train_feature)

        max_index = torch.argmax(y_hat,1)
        optimer.zero_grad()
        loss = lossfunction(y_hat,train_label).to(device)
        loss.backward()
        optimer.step()
        acc += (max_index == train_label).float().mean().item()
    
    print("train:",acc/run*100)

    accuracy = val(test_dataloader,net,device)
    print("val:",accuracy)

    

100%|██████████| 468/468 [01:09<00:00,  6.73it/s]


train: 54.8828125


100%|██████████| 79/79 [00:05<00:00, 15.78it/s]
 10%|█         | 1/10 [01:14<11:10, 74.55s/it]

val: 74.1495253164557


  4%|▍         | 21/468 [00:03<01:06,  6.72it/s]
 10%|█         | 1/10 [01:17<11:39, 77.68s/it]


KeyboardInterrupt: 