In [1]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as np

In [2]:
# torch.set_default_dtype(torch.float32)
if torch.cuda.is_available():
    device = torch.device("cuda:0")  # you can continue going on here, like cuda:1 cuda:2....etc. 
    print("Running on the GPU")
    torch.set_default_device(device)

Running on the GPU


In [3]:
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)
train_dataloader = torch.utils.data.DataLoader(training_data, batch_size=512, shuffle=True, generator=torch.Generator(device), num_workers=16)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=True, generator=torch.Generator(device), num_workers=8)

In [5]:
from torch import nn
model = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10))

In [7]:
lr = 0.01
max_epochs = 10

In [8]:
loss = torch.nn.CrossEntropyLoss().to(device)
# optimizer = torch.optim.Adam(params=model.parameters(), lr=lr, weight_decay=0.02, betas=(0.9, 0.99))
optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)

In [9]:
history = []
for epoch in range(max_epochs):
    for x, y in train_dataloader:
        x = x.to(device)
        y = y.to(device)
        l = loss(model(x), y)
        l.backward()
        optimizer.step()
        optimizer.zero_grad()
        print(f"\r epoch {epoch} loss = {l.item()}", end='')
    # history.append(l.clone().cpu().detach().numpy())
    print("")

 epoch 0 loss = 1.1998394727706918
 epoch 1 loss = 0.5566245913505554
 epoch 2 loss = 0.49025988578796387
 epoch 3 loss = 0.48947295546531684
 epoch 4 loss = 0.36543229222297673
 epoch 5 loss = 0.40794387459754944
 epoch 6 loss = 0.36911487579345703
 epoch 7 loss = 0.27616426348686223
 epoch 8 loss = 0.22874332964420319
 epoch 9 loss = 0.38741970062255865


In [10]:
true_flag = 0
with torch.no_grad():
    for x, y in test_dataloader:
        result = torch.argmax(model(x.to(device)), 1).cpu()
        true_flag += torch.where(result == y, 1, 0).sum()
true_flag

tensor(8512)

In [11]:
acc = true_flag / len(test_data)
acc

tensor(0.8512)