In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
batch_size = 100
training_data = datasets.FashionMNIST(root='./fashion_mnist', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.FashionMNIST(root='./fashion_mnist', train=False, download=True, transform=transforms.ToTensor())

train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

In [3]:
sequence_len = 28
input_len = 28
hidden_size = 128
num_layers = 2
num_classes =10
num_epochs = 2
learning_rate = 0.01

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
class LSTM(nn.Module):
    def __init__(self, input_len, hidden_size, num_class, num_layers, device):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_len, hidden_size, num_layers, batch_first=True)
        self.output = nn.Linear(hidden_size, num_class)
        self.device = device
    def forward(self, x):
        hidden_states = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(self.device)
        cell_states = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(self.device)
        out, _ = self.lstm(x, (hidden_states, cell_states))
        out = self.output(out[:, -1, :])
        return out


In [5]:
model = LSTM(input_len, hidden_size, num_classes, num_layers, device)
print(model)
model.to(device)

LSTM(
  (lstm): LSTM(28, 128, num_layers=2, batch_first=True)
  (output): Linear(in_features=128, out_features=10, bias=True)
)


LSTM(
  (lstm): LSTM(28, 128, num_layers=2, batch_first=True)
  (output): Linear(in_features=128, out_features=10, bias=True)
)

In [6]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [7]:
def train(num_epochs, model, train_dataloader, loss_fn, device):
    total_steps = len(train_dataloader)

    for epoch in range(num_epochs):
        for step, (images, labels) in enumerate(train_dataloader):
            images = images.reshape(-1, sequence_len, input_len)
            images = images.to(device)
            labels = labels.to(device)
            output = model(images)
            loss = loss_fn(output, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (step+1) % 100 == 0:
                print(f'{epoch=}, steps={step+1} / {total_steps=}, {loss=:.4f}')

In [8]:
train(num_epochs, model, train_dataloader, loss_fn, device)

epoch=0, steps=100 / total_steps=600, loss=1.0140
epoch=0, steps=200 / total_steps=600, loss=0.6746
epoch=0, steps=300 / total_steps=600, loss=0.5460
epoch=0, steps=400 / total_steps=600, loss=0.5031
epoch=0, steps=500 / total_steps=600, loss=0.5928
epoch=0, steps=600 / total_steps=600, loss=0.3812
epoch=1, steps=100 / total_steps=600, loss=0.4537
epoch=1, steps=200 / total_steps=600, loss=0.3349
epoch=1, steps=300 / total_steps=600, loss=0.3734
epoch=1, steps=400 / total_steps=600, loss=0.3632
epoch=1, steps=500 / total_steps=600, loss=0.5282
epoch=1, steps=600 / total_steps=600, loss=0.2743


In [9]:
# eval
with torch.no_grad():
    losses = []
    correct = 0
    for x, y in test_dataloader:
        x = x.reshape(-1, 28, 28)
        x = x.to(device)
        y = y.to(device)
        y_hat = model(x)
        loss = loss_fn(y_hat, y)
        losses.append(loss.item())
        correct += (y_hat.argmax(1) == y).sum().item()
    batches = len(test_dataloader)
    final_loss = sum(losses) / batches
    acc = correct / (batch_size * batches)
    print(f'{final_loss=} {acc=}')

final_loss=0.3772645643353462 acc=0.8601
