In [None]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
from torch.nn import functional as F
from torch import optim
from torch import nn
from torch.utils.data import DataLoader
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import sys

In [None]:
print(torch.__version__)
print(torch.cuda.is_available())

In [None]:
train_data = datasets.FashionMNIST("data", train=True, download=True, transform=ToTensor())
test_data = datasets.FashionMNIST("data", train=False, download=True, transform=ToTensor())
train_loader = DataLoader(train_data, batch_size=64)
test_loader = DataLoader(test_data, batch_size=64)

In [None]:
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(cols*rows):
    idx = torch.randint(len(train_data), size=(1,)).item()
    img, label = train_data[idx]
    figure.add_subplot(rows, cols, i+1)
    # plt.axis("off")
    plt.imshow(img.squeeze())
plt.show()

In [None]:
class Network(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.flatten = nn.Flatten()
        self.net = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        return self.net(x)

model = Network().to("cpu")
print(model)

In [None]:
from torchsummary import summary

summary(model, (1, 28, 28))

In [None]:
loss_fn = nn.CrossEntropyLoss()
sgd = optim.SGD(model.parameters(), lr=1e-3)

def train(data: DataLoader, model: Network, lossfn, sgd):
    size = len(data.dataset)
    model.train()
    for batch, (X, y) in enumerate(data):
        pred = model(X)
        loss = lossfn(pred, y)
        sgd.zero_grad()
        loss.backward()
        sgd.step()

        if batch % 100 == 0:
            print(f"batch: {batch} loss: {loss.item()}")

In [None]:
def test(data: DataLoader, model: Network, lossfn):
    model.eval()
    correct = 0
    with torch.no_grad():
        for X, y in data:
            pred = model(X)
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    print(f"Test accuracy: {correct / len(data.dataset)}")


In [None]:
for i in range(5):
    print(f"epoch {i} ------------------")
    train(train_loader, model, loss_fn, sgd)
    test(test_loader, model, loss_fn)
print("Done!")

In [None]:
def plot_curve(data):
    plt.figure()
    plt.plot(range(len(data)), data)
    plt.legend(['value'], loc="upper right")
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()

def plot_image(img, label, name):
    plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i+1)
        plt.tight_layout()
        plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')
        plt.title(f"{name}: {label[i].item()}")
        plt.xticks([])
        plt.yticks([])
    plt.show()

def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

In [None]:
batch_size = 512
tf_train = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
tf_test = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
# train_loader = DataLoader("mnist_data", train=True, download=True, transform=tf, batch_size=512, shuffle=True)
# test_loader = DataLoader("mnist_data", train=False, download=True, transform=tf, batch_size=512, shuffle=False)
mnist_train = datasets.MNIST("mnist_data", train=True, transform=tf_train, download=True)
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
mnist_test = datasets.MNIST("mnist_data", train=False, transform=tf_test, download=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)

In [None]:
class MnistNet(nn.Module):
    def __init__(self) -> None:
        super(MnistNet, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
net = MnistNet()
sgd = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

In [None]:
def train(data: DataLoader, net: MnistNet, sgd):
    losses = []
    net.train()
    for epoch in range(3):
        for batch, (x, y) in enumerate(data):
            # x = x.view(x.size(0), 28*28)
            out = net(x)
            y_onehot = one_hot(y)
            loss = F.mse_loss(out, y_onehot)
            sgd.zero_grad()
            loss.backward()
            sgd.step()
            losses.append(loss.item())

            if batch % 10 == 0:
                print(f"\repoch {epoch} ===========> batch {batch} loss {loss.item()}")
                # sys.stdout.write(f"\repoch {epoch} ===========> batch {batch} loss {loss.item()}")
                # sys.stdout.flush()
        #print("")
    plot_curve(losses)

In [None]:
train(train_loader, net, sgd)

In [None]:
def validate(net):
    total = 0
    for x, y in test_loader:
        # x = x.view(x.size(0), 28*28)
        out = net(x)
        pred = out.argmax(dim=1)
        correct = pred.eq(y).sum().float().item()
        total += correct
    print(f"test acc: {total / len(test_loader.dataset)}")

In [None]:
validate(net)

In [None]:
x, y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
out = net(x)
pred = out.argmax(dim=1)
plot_image(x, pred, 'train')

In [None]:
class CnnNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 8, 3, 1, 1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(8, 16, 3, 1, 1)
        self.fc1 = nn.Linear(16 * 28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
net = CnnNet()
lossfn = nn.CrossEntropyLoss()
adam = optim.Adam(net.parameters(), lr=0.001)

In [None]:
x, y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())

In [None]:
def traincnn(train_loader, net, adam):
    losses = []
    net.train()
    for i in range(3):
        for batch, (x, y) in enumerate(train_loader):
            print(x.shape)
            out = net(x)
            y_onehot = one_hot(y)
            print(f'out: {out.shape}')
            print(f'y_onehot: {y_onehot.shape}')
            loss = F.mse_loss(out, y_onehot)
            adam.zero_grad()
            loss.backward()
            adam.step()
            losses.append(loss.item())

            if batch % 10 == 0:
                print(f"\repoch {i} ===========> batch {batch} loss {loss.item()}")
    plot_curve(losses)

In [None]:
traincnn(train_loader, net, adam)