In [1]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from torch import nn,optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms

In [2]:
batchsize=100

training_data=datasets.FashionMNIST(root="../fashion_mnist",train=True,transform=transforms.ToTensor(),download=True)
test_data=datasets.FashionMNIST(root="../fashion_mnist",train=False,transform=transforms.ToTensor(),download=True)

train_dataloader=DataLoader(training_data,batch_size=batchsize)
test_dataloader=DataLoader(test_data,batch_size=batchsize)

100%|██████████| 26.4M/26.4M [00:30<00:00, 873kB/s] 
100%|██████████| 29.5k/29.5k [00:00<00:00, 158kB/s]
100%|██████████| 4.42M/4.42M [00:02<00:00, 2.20MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 17.8MB/s]


In [3]:
#hyperparameters lwk
sequence_len=28
input_len=28
hidden_size=128
num_layer=2
num_classes=10
num_epochs=10
learning_rate=0.01

In [4]:
class LSTM(nn.Module):
   def __init__(self,input_len,hidden_size,num_classes,num_layer):
      super(LSTM,self).__init__()
      self.hidden_size=hidden_size
      self.num_layer=num_layer
      self.lstm=nn.LSTM(input_len,hidden_size,num_layer,batch_first=True)
      self.output_layer=nn.Linear(hidden_size,num_classes)

   def forward(self,X):
      hidden_states=torch.zeros(self.num_layer,X.size(0),self.hidden_size)
      cell_states=torch.zeros(self.num_layer,X.size(0),self.hidden_size)

      out,_ = self.lstm(X,(hidden_states,cell_states))
      out=self.output_layer(out[:,-1,:])
      return out

In [5]:
model=LSTM(input_len,hidden_size,num_classes,num_layer)
print(model)

LSTM(
  (lstm): LSTM(28, 128, num_layers=2, batch_first=True)
  (output_layer): Linear(in_features=128, out_features=10, bias=True)
)


In [6]:
loss_func=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(),lr=learning_rate)

In [7]:
#training loop

def train(num_epochs,model,train_dataloader,loss_func,optimizer):
    for epoch in range(num_epochs):
        for batch,(images,labels) in enumerate(train_dataloader):
            images=images.reshape(-1,sequence_len,input_len)
            outputs=model(images)
            loss=loss_func(outputs,labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if(batch+1)%100==0:
                print(f"Epoch [{epoch+1}/{num_epochs}],Step [{batch+1}/{len(train_dataloader)}],Loss:{loss.item():.4f}")


In [8]:
train(num_epochs,model,train_dataloader,loss_func,optimizer)

Epoch [1/10],Step [100/600],Loss:0.9308
Epoch [1/10],Step [200/600],Loss:0.8574
Epoch [1/10],Step [300/600],Loss:0.5480
Epoch [1/10],Step [400/600],Loss:0.4933
Epoch [1/10],Step [500/600],Loss:0.6332
Epoch [1/10],Step [600/600],Loss:0.4495
Epoch [2/10],Step [100/600],Loss:0.4347
Epoch [2/10],Step [200/600],Loss:0.3629
Epoch [2/10],Step [300/600],Loss:0.3426
Epoch [2/10],Step [400/600],Loss:0.4892
Epoch [2/10],Step [500/600],Loss:0.5684
Epoch [2/10],Step [600/600],Loss:0.2672
Epoch [3/10],Step [100/600],Loss:0.3030
Epoch [3/10],Step [200/600],Loss:0.2879
Epoch [3/10],Step [300/600],Loss:0.2370
Epoch [3/10],Step [400/600],Loss:0.4291
Epoch [3/10],Step [500/600],Loss:0.4832
Epoch [3/10],Step [600/600],Loss:0.2791
Epoch [4/10],Step [100/600],Loss:0.2697
Epoch [4/10],Step [200/600],Loss:0.2798
Epoch [4/10],Step [300/600],Loss:0.2582
Epoch [4/10],Step [400/600],Loss:0.3099
Epoch [4/10],Step [500/600],Loss:0.4735
Epoch [4/10],Step [600/600],Loss:0.2538
Epoch [5/10],Step [100/600],Loss:0.2579


In [9]:
#testing loop
with torch.no_grad():
    correct=0
    total=0
    acc=0
    best_acc=0
    for images,labels in test_dataloader:
        images=images.reshape(-1,sequence_len,input_len)
        outputs=model(images)
        _,predicted=torch.max(outputs,1)
        total+=labels.size(0)
        correct+=(predicted==labels).sum().item()
        acc=100*correct/total
        if acc>best_acc:
            best_acc=acc
            torch.save(model.state_dict(), 'best_lstm.pth')
    print(f"Accuracy of the model on the test images:{best_acc}%")

Accuracy of the model on the test images:92.0%
