In [None]:
# imports

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import matplotlib.pyplot as plt

In [None]:
# import load_data function from loadData.py
from loadData import load_data

train_data, train_labels, test_data, test_labels = load_data()


# type of train_data
print("train_data")
print(type(train_data))
print(train_data.shape)

# type of train_labels
print("train_labels")
print(type(train_labels))
print(train_labels.shape)

# type of test_data
print("test_data")
print(type(test_data))
print(test_data.shape)

# type of test_labels
print("test_labels")
print(type(test_labels))
print(test_labels.shape)

In [None]:
# Custom dataset class
class CSVDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]
        return x, y


In [None]:
# define the network class

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = torch.flatten(x, -1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
# define the training function

def train(model, train_loader, criterion, optimizer, epochs):
    loss_list = []
    for epoch in range(epochs):
        for input, data in enumerate(train_loader):
            input, labels = data
            optimizer.zero_grad()
            output = model(input)
            loss = criterion(output, labels)
            loss_list.append(loss.item())
            loss.backward()
            optimizer.step()
            if epoch % 10 == 0:
                print("Epoch: " + str(epoch) + ", Loss: " + str(loss.item()))

    # use matplotlib to plot the loss
    plt.plot(loss_list)
                
    print("Finished Training")

In [None]:
# define the testing function

def test(model, test_loader):
    correct = 0
    total = 0

    with torch.no_grad():
        for data, labels in test_loader:
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    print(f"Accuracy: {correct / total}")

In [None]:
# data loader
train_dataset = CSVDataset(train_data, train_labels)
test_dataset = CSVDataset(test_data, test_labels)

train_loader = DataLoader(train_dataset, batch_size=1000, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=True)

In [None]:
# create an instance of the network
model = Net()

# define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# run one forward and backward pass on the first item in the train_loader
model.forward(train_data[0])


In [None]:
# train the network
train(model, train_loader, criterion, optimizer, 500)

# test the network
test(model, test_loader)

In [None]:
print(test_labels[0].item())
model.forward(test_data[0])
print(torch.argmax(model.forward(test_data[0])).item())


In [None]:
# save the model
torch.save(model.state_dict(), "model.pt")

In [None]:
# load the model if needed
model = Net()
model.load_state_dict(torch.load("model.pt"))
model.eval()

In [36]:
import csv

user_input = []
with open('user_input.csv', 'r') as csvfile:
    reader = csv.reader(csvfile, delimiter=',')
    for row in reader:
        for i in range(0, len(row)):
            row[i] = float(row[i])/255
        user_input.append(row)

user_input = torch.tensor(user_input)
print(user_input.shape)

prediction = model.forward(user_input[0])
print(prediction)
# print index of max value
print(torch.argmax(prediction).item())


torch.Size([1, 784])
tensor([-3.9840,  1.3469, -2.7200, -2.4575,  3.2632, -0.8208, -2.0065,  0.9444,
        -0.5887,  3.5824], grad_fn=<AddBackward0>)
9
