In [None]:
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 2048)
        self.fc2 = nn.Linear(2048, 2048)
        self.fc3 = nn.Linear(2048, 10)

    def forward(self, x):
        x = nn.Flatten()(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        output = F.log_softmax(x, dim=1)
        return output


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [None]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.device(device)

lr = 0.00001
epochs = 10

transform = transforms.Compose([
    transforms.ToTensor()
])

dataset1 = datasets.MNIST('../data', train=True, download=True, transform = transform)
dataset2 = datasets.MNIST('../data', train=False, transform = transform)


train_loader = torch.utils.data.DataLoader(dataset1, batch_size = 64, shuffle= True)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size = 64)

In [None]:

model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)


In [None]:
from tqdm import tqdm

for epoch in range(1, epochs):

    model.train()

    for batch_idx, (data, target) in tqdm(enumerate(train_loader)):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)

        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 200 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


25it [00:00, 122.56it/s]



215it [00:01, 111.99it/s]



415it [00:03, 113.67it/s]



615it [00:05, 86.73it/s]



815it [00:07, 112.87it/s]



938it [00:09, 102.88it/s]
19it [00:00, 91.86it/s]



224it [00:02, 117.04it/s]



414it [00:03, 119.91it/s]



622it [00:05, 121.06it/s]



824it [00:07, 117.33it/s]



938it [00:08, 112.88it/s]
25it [00:00, 118.82it/s]



214it [00:01, 120.04it/s]



409it [00:03, 90.99it/s]



619it [00:05, 117.59it/s]



822it [00:07, 120.97it/s]



938it [00:08, 110.46it/s]
24it [00:00, 114.29it/s]



215it [00:01, 119.71it/s]



416it [00:03, 116.78it/s]



621it [00:05, 121.81it/s]



816it [00:07, 94.38it/s]



938it [00:08, 110.40it/s]
19it [00:00, 96.22it/s]



221it [00:01, 108.93it/s]



416it [00:03, 118.27it/s]



617it [00:05, 116.17it/s]



813it [00:07, 110.85it/s]



938it [00:08, 114.34it/s]
25it [00:00, 119.68it/s]



219it [00:01, 120.39it/s]



419it [00:04, 90.36it/s]



619it [00:05, 117.66it/s]



818it [00:07, 121.11it/s]



938it [00:08, 108.79it/s]
25it [00:00, 119.57it/s]



216it [00:01, 118.82it/s]



415it [00:03, 114.24it/s]



621it [00:05, 121.06it/s]



823it [00:07, 94.48it/s]



938it [00:08, 108.97it/s]
24it [00:00, 115.10it/s]



218it [00:01, 118.32it/s]



419it [00:03, 120.59it/s]



615it [00:05, 118.58it/s]



819it [00:06, 121.19it/s]



938it [00:07, 118.89it/s]
13it [00:00, 120.86it/s]



214it [00:02, 93.04it/s]



414it [00:04, 118.69it/s]



620it [00:05, 121.81it/s]



815it [00:07, 121.51it/s]



938it [00:08, 110.32it/s]


In [None]:
test(model, 'cuda', test_loader)


Test set: Average loss: 0.1531, Accuracy: 9553/10000 (96%)



In [None]:
import numpy as np

for i in model.state_dict():
  file_name = i.replace('.', '_')
  np.save(open(f'/content/model_weights/{file_name}.npy', 'wb'),
          model.state_dict()[i].cpu().numpy())