In [1]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as np

In [2]:
# torch.set_default_dtype(torch.float32)
if torch.cuda.is_available():
    device = torch.device("cuda:0")  # you can continue going on here, like cuda:1 cuda:2....etc. 
    print("Running on the GPU")
    torch.set_default_device(device)

Running on the GPU


In [3]:
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [4]:
train_dataloader = torch.utils.data.DataLoader(training_data, batch_size=512, shuffle=True, generator=torch.Generator(device), num_workers=16)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=True, generator=torch.Generator(device), num_workers=8)

In [5]:
class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.f = torch.nn.Flatten()
        self.l1 = torch.nn.Linear(784, 256)
        self.a1 = torch.nn.ReLU()
        self.d1 = torch.nn.Dropout(0.2)

        self.l2 = torch.nn.Linear(256, 64)
        self.a2 = torch.nn.ReLU()
        self.d2 = torch.nn.Dropout(0.2)

        self.out = torch.nn.Linear(64, 10)

    def forward(self, X):
        X = self.f(X)
        X = self.l1(X)
        X = self.a1(X)
        X = self.d1(X)

        X = self.l2(X)
        X = self.a2(X)
        X = self.d2(X)

        return self.out(X)

In [6]:
model = MyModel()
#model = torch.nn.Sequential(
#    torch.nn.Flatten(),
#    nn.Linear(28*28, 512),
#    nn.ReLU(),
#    nn.Linear(512, 512),
#    nn.ReLU(),
#    nn.Linear(512, 10),
#)
def init_weights(m):
    if type(m) == torch.nn.Linear:
        torch.nn.init.normal_(m.weight, std=0.01)
model.apply(init_weights)

MyModel(
  (f): Flatten(start_dim=1, end_dim=-1)
  (l1): Linear(in_features=784, out_features=256, bias=True)
  (a1): ReLU()
  (d1): Dropout(p=0.2, inplace=False)
  (l2): Linear(in_features=256, out_features=64, bias=True)
  (a2): ReLU()
  (d2): Dropout(p=0.2, inplace=False)
  (out): Linear(in_features=64, out_features=10, bias=True)
)

In [7]:
lr = 0.01
max_epochs = 10

In [8]:
loss = torch.nn.CrossEntropyLoss().to(device)
# optimizer = torch.optim.Adam(params=model.parameters(), lr=lr, weight_decay=0.02, betas=(0.9, 0.99))
optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)

In [9]:
history = []
for epoch in range(max_epochs):
    for x, y in train_dataloader:
        x = x.to(device)
        y = y.to(device)
        l = loss(model(x), y)
        l.backward()
        optimizer.step()
        optimizer.zero_grad()
        print(f"\r epoch {epoch} loss = {l.item()}", end='')
    # history.append(l.clone().cpu().detach().numpy())
    print("")

 epoch 0 loss = 0.94375991821289067
 epoch 1 loss = 0.42107251286506653
 epoch 2 loss = 0.43251430988311773
 epoch 3 loss = 0.54981791973114017
 epoch 4 loss = 0.40007331967353826
 epoch 5 loss = 0.37136054039001465
 epoch 6 loss = 0.55628317594528254
 epoch 7 loss = 0.39071992039680486
 epoch 8 loss = 0.34586739540100173
 epoch 9 loss = 0.51932793855667115
 epoch 10 loss = 0.45277574658393866
 epoch 11 loss = 0.24700701236724854
 epoch 12 loss = 0.37118294835090637
 epoch 13 loss = 0.49035087227821353
 epoch 14 loss = 0.47379252314567566
 epoch 15 loss = 0.51320648193359384
 epoch 16 loss = 0.25035589933395386
 epoch 17 loss = 0.51172327995300295
 epoch 18 loss = 0.32590240240097046
 epoch 19 loss = 0.41547405719757087
 epoch 20 loss = 0.23498380184173584
 epoch 21 loss = 0.52293843030929574
 epoch 22 loss = 0.46691003441810614
 epoch 23 loss = 0.51272577047348025
 epoch 24 loss = 0.51913297176361086
 epoch 25 loss = 0.43857869505882263
 epoch 26 loss = 0.37982630729675293
 epoch 27 l

KeyboardInterrupt: 

In [10]:
true_flag = 0
with torch.no_grad():
    for x, y in test_dataloader:
        result = torch.argmax(model(x.to(device)), 1).cpu()
        true_flag += torch.where(result == y, 1, 0).sum()
true_flag

tensor(8354)

In [11]:
acc = true_flag / len(test_data)
acc

tensor(0.8354)