In [None]:
import torch
import torchvision
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils import data
from torchvision import transforms
from tqdm import tqdm

myseed = 1314  # set a random seed for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(myseed)
torch.manual_seed(myseed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(myseed)

trans = [transforms.ToTensor()]
trans.insert(0,transforms.Resize(224))
trans = transforms.Compose(trans)
train_data = torchvision.datasets.FashionMNIST(
    root="./MNIST", train=True, transform=trans, download=True)
test_data  = torchvision.datasets.FashionMNIST(
    root="./MNIST", train=False, transform=trans, download=True)
device  = torch.device('cuda' if torch.cuda.is_available else 'cpu')

BATCH_SIZE = 256  

train_dataloader = data.DataLoader(dataset = train_data,batch_size = BATCH_SIZE,shuffle = True)
test_dataloader  = data.DataLoader(dataset = test_data,batch_size = 1000)

class_str = ["T-shirt","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankel boot"]

def get_class(indices):
    return class_str[int(indices)]

def show_img(img,label,nrow,ncol,scale=1.5):
    fig,axes = plt.subplots(nrow,ncol, figsize=(ncol*scale,nrow*scale)) 
    axes = axes.flatten()
    for i,(img,ax) in enumerate(zip(img,axes)):
        ax.imshow(img.squeeze(),cmap="Greys")
        ax.set_title(get_class(label[i]))
        ax.axis('off')
    plt.show()

x, y = next(iter(train_dataloader))
show_img(x,y,2,9)
print(x.shape)

In [None]:
class AlexNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(1,96,kernel_size=11,stride=4,padding=1),nn.ReLU(),
            nn.MaxPool2d(kernel_size=3,stride=2),
            nn.Conv2d(96,256,kernel_size=5,stride=1,padding=2),nn.ReLU(),
            nn.MaxPool2d(kernel_size=3,stride=2),
            nn.Conv2d(256,384,kernel_size=3,stride=1,padding=1),nn.ReLU(),
            nn.Conv2d(384,384,kernel_size=3,stride=1,padding=1),nn.ReLU(),
            nn.Conv2d(384,256,kernel_size=3,stride=1,padding=1),nn.ReLU(),
            nn.MaxPool2d(kernel_size=3,stride=2),nn.Flatten(1),
            nn.Linear(256*5*5,4096),nn.ReLU(),nn.Dropout(p=0.5),
            nn.Linear(4096,4096),nn.ReLU(),nn.Dropout(p=0.5),
            nn.Linear(4096,10),
        )

    def forward(self,x):
        x = self.layers(x)
        return x
    
x = torch.randn((2,1,224,224))
alexnet = AlexNet()
for layer in alexnet.layers:
    x = layer(x)
    print("Name: %-10s"%(layer.__class__.__name__),'output shape: ',x.shape)

def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)
alexnet.apply(init_weights)

In [None]:
def accuracy(y_pred,y):
    y_pred = y_pred.softmax(1).argmax(1)
    total = y.shape[0]
    right = (y_pred == y).sum()
    print("Acc:%.2f%%"%(right/total*100))
    return (right,total)


def train(train_loader,test_loader,model,n_epoch,lr,device):
    model.to(device)
    optim = torch.optim.SGD(model.parameters(),lr,weight_decay=0.01,momentum=0.001)
    criterion = torch.nn.CrossEntropyLoss()
    for epoch in range(n_epoch):
        bar = tqdm(train_loader,total=len(train_loader))
        for x,y in bar:
            optim.zero_grad()
            x,y = x.to(device),y.to(device)
            pred = model(x)
            loss = criterion(pred,y)
            loss.backward()
            optim.step()
            bar.set_description(f'Epoch:[{epoch+1:02d}/{n_epoch}]')
            bar.set_postfix(loss=float(loss))
        with torch.no_grad():
            for x,y in test_loader:
                x,y = x.to(device),y.to(device)
                pred = model(x)
                accuracy(pred,y)
                break
            
train(train_dataloader,test_dataloader,model=alexnet,n_epoch=20,lr=0.005,device=device)
    

In [None]:
with torch.no_grad():
    for x,y in test_dataloader:
        x,y = x.to(device),y.to(device)
        pred = alexnet(x)
        accuracy(pred,y)
        break