In [32]:
import torch
from torch.nn import (
    Module,
    Linear,
    ReLU,
    Conv2d,
    Sequential,
    MaxPool2d,
    Flatten,
    Dropout
)


In [33]:
class AlexNet(Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.conv = Sequential(
            Conv2d(1,96,kernel_size=11,stride=4,padding=1),
            ReLU(),
            MaxPool2d(kernel_size=3,stride=2),

            Conv2d(96,256,kernel_size=3,padding=1),
            ReLU(),

            Conv2d(256,384,kernel_size=3,padding=1),
            ReLU(),

            Conv2d(384,384,kernel_size=3,padding=1),
            ReLU(),

            Conv2d(384,384,kernel_size=3,padding=1),
            ReLU(),

            MaxPool2d(kernel_size=3,stride=2)

        )
        self.linear = Sequential(
            Flatten(),

            Linear(1536,4096),
            ReLU(),
            Dropout(0.5),

            Linear(4096,4096),
            ReLU(),
            Dropout(0.5),

            Linear(4096,10)
            
        )

    def forward(self,input):
        output = self.conv(input)
        output = self.linear(output)

        return output
    
net = AlexNet()
def init_net_parpmter(layer):
    if isinstance(layer,(Linear,Conv2d)):
        torch.nn.init.kaiming_normal_(layer.weight,mode="fan_in",nonlinearity='relu')
net.apply(init_net_parpmter)


AlexNet(
  (conv): Sequential(
    (0): Conv2d(1, 96, kernel_size=(11, 11), stride=(4, 4), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(96, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU()
    (7): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU()
    (9): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): ReLU()
    (11): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (linear): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=1536, out_features=4096, bias=True)
    (2): ReLU()
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU()
    (6): Dropout(p=0.5, inplace=False)
    (7): Linear(in_features=

In [34]:
def get_device(chose_device = 0 ):
    return f'cuda:{chose_device}' if torch.cuda.is_available() else "cpu"

device = get_device()

net.to(device)

AlexNet(
  (conv): Sequential(
    (0): Conv2d(1, 96, kernel_size=(11, 11), stride=(4, 4), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(96, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU()
    (7): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU()
    (9): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): ReLU()
    (11): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (linear): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=1536, out_features=4096, bias=True)
    (2): ReLU()
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU()
    (6): Dropout(p=0.5, inplace=False)
    (7): Linear(in_features=

In [35]:
from torchvision import transforms
import torchvision

trans = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((64,64))
    ]
)
mninst_train = torchvision.datasets.FashionMNIST(
    root='FashionMINIST',
    train=True,
    download=False,
    transform=trans
    )
mninst_text = torchvision.datasets. FashionMNIST(
    root="FashionMINIST",
    train=False,
    download=False,
    transform=trans
)


len(mninst_train),len(mninst_text)

(60000, 10000)

In [36]:
from torch.utils.data import DataLoader

def get_dataloader(dataset,mode,batch_size=128):
    
    return DataLoader(
        dataset=dataset,
        shuffle= ('train' == mode),
        drop_last= ('train' == mode),
        batch_size=batch_size
    )

train_dataloader = get_dataloader(mninst_train,'train')
test_dataloader = get_dataloader(mninst_text,'test')


In [37]:
n_epoch = 50
lossfunction = torch.nn.CrossEntropyLoss()
optimer = getattr(torch.optim,'Adam')(net.parameters(),lr=0.01,weight_decay=1e-4)
from tqdm.auto import  tqdm
def val(val_dataloader , model , device):
    model.eval()
    with torch.no_grad():
        acc = 0
        run =0 
        for val_feature , val_label in tqdm(val_dataloader):
            val_feature = val_feature.to(device)
            run +=1
            val_label = val_label.to(device)

            y_predict = net(val_feature)
            max_index = torch.argmax(y_predict,1)

            acc += (max_index == val_label).float().mean().item()

    return acc/run * 100



net.train()
for epoch in tqdm(range(n_epoch)):
    acc = 0
    run =0 
    for train_feature,train_label in tqdm(train_dataloader):
        run += 1
        train_feature = train_feature.to(device)
        train_label = train_label.to(device)
        y_hat = net(train_feature)

        max_index = torch.argmax(y_hat,1)
        optimer.zero_grad()
        loss = lossfunction(y_hat,train_label).to(device)
        loss.backward()
        optimer.step()
        acc += (max_index == train_label).float().mean().item()
    
    print("train:",acc/run*100)

    accuracy = val(test_dataloader,net,device)
    print("val:",accuracy)

    


100%|██████████| 468/468 [00:15<00:00, 29.33it/s]


train: 52.448918269230774


100%|██████████| 79/79 [00:00<00:00, 98.32it/s]
  2%|▏         | 1/50 [00:16<13:41, 16.76s/it]

val: 65.82278481012658


100%|██████████| 468/468 [00:16<00:00, 29.24it/s]


train: 74.2988782051282


100%|██████████| 79/79 [00:00<00:00, 91.53it/s]
  4%|▍         | 2/50 [00:33<13:27, 16.83s/it]

val: 75.07911392405063


100%|██████████| 468/468 [00:16<00:00, 28.75it/s]


train: 76.59922542735043


100%|██████████| 79/79 [00:00<00:00, 85.37it/s]
  6%|▌         | 3/50 [00:50<13:19, 17.00s/it]

val: 77.48219936708861


100%|██████████| 468/468 [00:16<00:00, 28.58it/s]


train: 77.65257745726495


100%|██████████| 79/79 [00:00<00:00, 87.66it/s]
  8%|▊         | 4/50 [01:08<13:07, 17.11s/it]

val: 76.76028481012658


100%|██████████| 468/468 [00:16<00:00, 28.39it/s]


train: 78.22182158119658


100%|██████████| 79/79 [00:00<00:00, 87.47it/s]
 10%|█         | 5/50 [01:25<12:54, 17.21s/it]

val: 78.25356012658227


100%|██████████| 468/468 [00:16<00:00, 28.38it/s]


train: 79.34027777777779


100%|██████████| 79/79 [00:00<00:00, 86.34it/s]
 12%|█▏        | 6/50 [01:42<12:40, 17.28s/it]

val: 77.90743670886076


100%|██████████| 468/468 [00:16<00:00, 28.27it/s]


train: 79.80435363247864


100%|██████████| 79/79 [00:00<00:00, 86.94it/s]
 14%|█▍        | 7/50 [02:00<12:25, 17.34s/it]

val: 79.3117088607595


100%|██████████| 468/468 [00:16<00:00, 28.45it/s]


train: 80.06810897435898


100%|██████████| 79/79 [00:00<00:00, 85.94it/s]
 16%|█▌        | 8/50 [02:17<12:08, 17.35s/it]

val: 79.49960443037975


100%|██████████| 468/468 [00:16<00:00, 28.58it/s]


train: 80.75754540598291


100%|██████████| 79/79 [00:00<00:00, 87.05it/s]
 18%|█▊        | 9/50 [02:35<11:50, 17.33s/it]

val: 80.17207278481013


100%|██████████| 468/468 [00:16<00:00, 29.18it/s]


train: 80.62566773504274


100%|██████████| 79/79 [00:00<00:00, 87.74it/s]
 20%|██        | 10/50 [02:51<11:28, 17.21s/it]

val: 79.16337025316456


100%|██████████| 468/468 [00:16<00:00, 28.54it/s]


train: 81.03298611111111


100%|██████████| 79/79 [00:00<00:00, 86.84it/s]
 22%|██▏       | 11/50 [03:09<11:12, 17.24s/it]

val: 79.94462025316456


100%|██████████| 468/468 [00:16<00:00, 28.81it/s]


train: 81.35516826923077


100%|██████████| 79/79 [00:00<00:00, 86.41it/s]
 24%|██▍       | 12/50 [03:26<10:54, 17.22s/it]

val: 81.26977848101265


100%|██████████| 468/468 [00:16<00:00, 28.49it/s]


train: 81.51876335470085


100%|██████████| 79/79 [00:00<00:00, 85.63it/s]
 26%|██▌       | 13/50 [03:43<10:38, 17.26s/it]

val: 81.35878164556962


100%|██████████| 468/468 [00:16<00:00, 28.79it/s]


train: 81.63895566239316


100%|██████████| 79/79 [00:00<00:00, 85.33it/s]
 28%|██▊       | 14/50 [04:01<10:20, 17.24s/it]

val: 80.35007911392405


100%|██████████| 468/468 [00:16<00:00, 28.70it/s]


train: 82.0713141025641


100%|██████████| 79/79 [00:00<00:00, 86.93it/s]
 30%|███       | 15/50 [04:18<10:03, 17.23s/it]

val: 81.40822784810126


100%|██████████| 468/468 [00:15<00:00, 29.29it/s]


train: 82.33673878205127


100%|██████████| 79/79 [00:00<00:00, 85.78it/s]
 32%|███▏      | 16/50 [04:35<09:42, 17.13s/it]

val: 81.6257911392405


100%|██████████| 468/468 [00:16<00:00, 29.24it/s]


train: 82.60550213675214


100%|██████████| 79/79 [00:00<00:00, 85.30it/s]
 34%|███▍      | 17/50 [04:52<09:23, 17.07s/it]

val: 81.05221518987342


100%|██████████| 468/468 [00:16<00:00, 27.97it/s]


train: 82.13641826923077


100%|██████████| 79/79 [00:01<00:00, 72.77it/s]
 36%|███▌      | 18/50 [05:09<09:13, 17.30s/it]

val: 82.22903481012658


100%|██████████| 468/468 [00:16<00:00, 28.65it/s]


train: 82.98444177350427


100%|██████████| 79/79 [00:00<00:00, 86.15it/s]
 38%|███▊      | 19/50 [05:27<08:55, 17.29s/it]

val: 83.2871835443038


100%|██████████| 468/468 [00:16<00:00, 29.10it/s]


train: 82.88094284188034


100%|██████████| 79/79 [00:00<00:00, 87.28it/s]
 40%|████      | 20/50 [05:44<08:35, 17.20s/it]

val: 82.51582278481013


100%|██████████| 468/468 [00:16<00:00, 28.92it/s]


train: 83.17307692307693


100%|██████████| 79/79 [00:00<00:00, 85.60it/s]
 42%|████▏     | 21/50 [06:01<08:17, 17.17s/it]

val: 83.13884493670885


100%|██████████| 468/468 [00:16<00:00, 28.72it/s]


train: 83.35002670940172


100%|██████████| 79/79 [00:00<00:00, 87.30it/s]
 44%|████▍     | 22/50 [06:18<08:01, 17.18s/it]

val: 81.98180379746836


100%|██████████| 468/468 [00:16<00:00, 28.93it/s]


train: 83.3717280982906


100%|██████████| 79/79 [00:00<00:00, 87.52it/s]
 46%|████▌     | 23/50 [06:35<07:43, 17.15s/it]

val: 81.93235759493672


 43%|████▎     | 199/468 [00:07<00:09, 28.11it/s]
 46%|████▌     | 23/50 [06:42<07:52, 17.50s/it]


KeyboardInterrupt: 