In [1]:
import time
import torch
import torchvision
import tqdm
import matplotlib.pyplot as plt

In [2]:
import wandb
wandb.login()
wandb.init(project="Fashion-Mnist pytorch")

ModuleNotFoundError: No module named 'wandb'

In [None]:
class MyModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.fc1 = torch.nn.Linear(784, 128)
        self.fc2 = torch.nn.Linear(128, 64)
        self.fc3 = torch.nn.Linear(64, 10)

    def forward(self, x):
        # input shape batch*28*28
        x = x.reshape((x.shape[0], 784))

        x = self.fc1(x)
        x = torch.relu(x)
        # x = torch.dropout(x, 0.2, train=True)

        x = self.fc2(x)
        x = torch.relu(x)

        x = self.fc3(x)
        x = torch.softmax(x, dim=1)
        return x
        

In [None]:
model = MyModel()
device = torch.device('cuda')
model = model.to(device)
model.train(True)

In [None]:
batch_size = 128
epochs = 20
lr = 0.001

In [None]:
def calc_accuracy(preds, labels):
    _, pred_max = torch.max(preds, 1)
    return torch.sum(pred_max == labels.data, dtype=torch.float64) / len(preds)

Data preparation

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=0, std=1)
])

In [None]:
dataset_train = torchvision.datasets.FashionMNIST(root='datasets', train=True, transform=transform, download=True)
dataset_test = torchvision.datasets.FashionMNIST(root='datasets', train=False, transform=transform, download=True)

In [None]:
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=True)

In [None]:
@torch.no_grad()
def test_model(model):
    test_loss = 0.0
    test_acc = 0.0
    for images, labels in test_loader:

        images, labels = images.to(device), labels.to(device)
        preds = model(images)
        loss = loss_function(preds, labels)

        test_loss += loss
        test_acc += calc_accuracy(preds, labels)
    total_loss = test_loss / len(test_loader)
    total_acc = test_acc / len(test_loader)
    
    print(f"\nloss_test: {total_loss}, Accuracy_test: {total_acc}")

Compile

In [None]:
# optimizer = torch.optim.SGD(params=model.parameters(), lr=lr)
optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
loss_function = torch.nn.CrossEntropyLoss()

Training

In [None]:
total_loss = []
total_acc = []
for epoch in (range(epochs)):
    train_loss = 0
    train_acc = 0
    
    for images, labels in (tqdm.tqdm(train_loader)):

        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        preds = model(images)

        loss = loss_function(preds, labels)
        loss.backward()

        optimizer.step()

        train_loss += loss
        train_acc += calc_accuracy(preds, labels)

        

    total_loss.append(train_loss / len(train_loader))
    total_acc.append(train_acc / len(train_loader))

    wandb.log({'accuracy': total_acc[-1], 'loss': total_loss[-1]})
    
    print(f"Epoch: {epoch}, loss: {total_loss[-1]}, Accuracy: {total_acc[-1]}")
    test_model(model)

plt.plot(total_loss,'g*',label="loss")
plt.plot(total_acc, 'ro', label="accuracy")
plt.legend(loc="upper right")
plt.show()

Save Model

In [None]:
torch.save(model.state_dict(), "mnist.pth")

Inference

In [None]:
import cv2
import numpy as np

model.train(False) # model.eval()

img = cv2.imread('test.png')
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img = cv2.resize(img, (28, 28))

tensor = transform(img).unsqueeze(0).to(device)

preds = model(tensor)

preds = preds.cpu().detach().numpy()
output = np.argmax(preds)
output

Load Weights

In [None]:
new_model = MyModel()
weight = torch.load('mnist.pth')
new_model.load_state_dict(weight)

In [None]:
new_model.parameters

In [None]:
import cv2
import numpy as np

new_model.train(False) # model.eval()

img = cv2.imread('test.png')
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img = cv2.resize(img, (28, 28))

device = torch.device('cpu')
tensor = transform(img).unsqueeze(0).to(device)

preds = new_model(tensor)

preds = preds.cpu().detach().numpy()
output = np.argmax(preds)
output