In [16]:
import torch
from torch.nn import (
    Module,
    Linear,
    ReLU,
    Conv2d,
    Sequential,
    MaxPool2d,
    Flatten,
    Dropout
)

In [17]:
def vgg_block(num_convs,in_channels,out_channels):
    layers = []
    for i in range(num_convs):
        layers.append(
                Conv2d(in_channels,out_channels,kernel_size=3,padding=1)
            )
        in_channels = out_channels
        layers.append(ReLU())
    layers.append(MaxPool2d(kernel_size=2,stride=2))

    return Sequential(*layers)

conv_arch = ((1,64),(1,128),(2,256),(2,512),(2,512))

def vgg(conv_arch):
    conv_blks = []
    in_channels = 1
    for (num_convs,out_channels) in conv_arch:
        conv_blks.append(
            vgg_block(
                num_convs,
                in_channels,
                out_channels
            )
        )
        in_channels = out_channels

    return Sequential(
        *conv_blks,
        Flatten(),

        Linear(out_channels * 2 * 2,4096),
        ReLU(),
        Dropout(0.5),

        Linear(4096,4096),
        ReLU(),
        Dropout(0.5),

        Linear(4096,10)
    )

smmall_conv_vgg = [(pair[0],pair[1]//4) for pair in conv_arch]
net = vgg(smmall_conv_vgg)
def init_net_parpmter(layer):
    if isinstance(layer,(Linear,Conv2d)):
        torch.nn.init.kaiming_normal_(layer.weight,mode="fan_in",nonlinearity='relu')
net.apply(init_net_parpmter)


Sequential(
  (0): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (1): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (3): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (4): Seque

In [18]:
def get_device(chose_device = 0 ):
    return f'cuda:{chose_device}' if torch.cuda.is_available() else "cpu"

device = get_device()

net.to(device)
from torchvision import transforms
import torchvision

trans = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((64,64))
    ]
)
mninst_train = torchvision.datasets.FashionMNIST(
    root='FashionMINIST',
    train=True,
    download=False,
    transform=trans
    )
mninst_text = torchvision.datasets. FashionMNIST(
    root="FashionMINIST",
    train=False,
    download=False,
    transform=trans
)


len(mninst_train),len(mninst_text)
from torch.utils.data import DataLoader

def get_dataloader(dataset,mode,batch_size=128):
    
    return DataLoader(
        dataset=dataset,
        shuffle= ('train' == mode),
        drop_last= ('train' == mode),
        batch_size=batch_size
    )

train_dataloader = get_dataloader(mninst_train,'train')
test_dataloader = get_dataloader(mninst_text,'test')


In [19]:
n_epoch = 10
lossfunction = torch.nn.CrossEntropyLoss()
optimer = getattr(torch.optim,'Adam')(net.parameters(),lr=0.005,weight_decay=1e-4)
from tqdm.auto import  tqdm
def val(val_dataloader , model , device):
    model.eval()
    with torch.no_grad():
        acc = 0
        run =0 
        for val_feature , val_label in tqdm(val_dataloader):
            val_feature = val_feature.to(device)
            run +=1
            val_label = val_label.to(device)

            y_predict = net(val_feature)
            max_index = torch.argmax(y_predict,1)

            acc += (max_index == val_label).float().mean().item()

    return acc/run * 100



net.train()
for epoch in tqdm(range(n_epoch)):
    acc = 0
    run =0 
    for train_feature,train_label in tqdm(train_dataloader):
        run += 1
        train_feature = train_feature.to(device)
        train_label = train_label.to(device)
        y_hat = net(train_feature)

        max_index = torch.argmax(y_hat,1)
        optimer.zero_grad()
        loss = lossfunction(y_hat,train_label).to(device)
        loss.backward()
        optimer.step()
        acc += (max_index == train_label).float().mean().item()
    
    print("train:",acc/run*100)

    accuracy = val(test_dataloader,net,device)
    print("val:",accuracy)

    


100%|██████████| 468/468 [00:13<00:00, 34.37it/s]


train: 74.65110844017094


100%|██████████| 79/79 [00:00<00:00, 93.40it/s] 
 10%|█         | 1/10 [00:14<02:10, 14.47s/it]

val: 83.59375


100%|██████████| 468/468 [00:13<00:00, 34.35it/s]


train: 86.11111111111111


100%|██████████| 79/79 [00:01<00:00, 74.22it/s]
 20%|██        | 2/10 [00:29<01:56, 14.60s/it]

val: 85.68037974683544


100%|██████████| 468/468 [00:14<00:00, 32.36it/s]


train: 87.7220219017094


100%|██████████| 79/79 [00:01<00:00, 67.18it/s]
 30%|███       | 3/10 [00:44<01:45, 15.08s/it]

val: 86.50118670886076


100%|██████████| 468/468 [00:13<00:00, 34.14it/s]


train: 88.21614583333334


100%|██████████| 79/79 [00:00<00:00, 105.49it/s]
 40%|████      | 4/10 [00:59<01:28, 14.83s/it]

val: 86.85719936708861


100%|██████████| 468/468 [00:12<00:00, 36.56it/s]


train: 88.96400908119658


100%|██████████| 79/79 [00:01<00:00, 68.23it/s]
 50%|█████     | 5/10 [01:13<01:12, 14.52s/it]

val: 87.30221518987342


100%|██████████| 468/468 [00:13<00:00, 34.60it/s]


train: 89.36965811965813


100%|██████████| 79/79 [00:01<00:00, 70.49it/s]
 60%|██████    | 6/10 [01:27<00:58, 14.56s/it]

val: 89.22072784810126


100%|██████████| 468/468 [00:12<00:00, 36.35it/s]


train: 89.81537126068376


100%|██████████| 79/79 [00:00<00:00, 94.87it/s]
 70%|███████   | 7/10 [01:41<00:42, 14.28s/it]

val: 88.37025316455697


100%|██████████| 468/468 [00:13<00:00, 35.18it/s]


train: 90.11585202991454


100%|██████████| 79/79 [00:00<00:00, 80.13it/s]
 80%|████████  | 8/10 [01:55<00:28, 14.29s/it]

val: 89.70530063291139


100%|██████████| 468/468 [00:13<00:00, 35.04it/s]


train: 90.24439102564102


100%|██████████| 79/79 [00:00<00:00, 101.78it/s]
 90%|█████████ | 9/10 [02:10<00:14, 14.24s/it]

val: 89.58662974683544


100%|██████████| 468/468 [00:14<00:00, 33.23it/s]


train: 90.54821047008546


100%|██████████| 79/79 [00:00<00:00, 96.29it/s]
100%|██████████| 10/10 [02:24<00:00, 14.49s/it]

val: 89.80419303797468



