In [296]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim

In [297]:
device = torch.device('mps')
# device = ('mps' if torch.mps.is_available() else "cpu")

In [298]:
train_dataset = datasets.MNIST(root = 'datasets/', train = True, transform = transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset = train_dataset, batch_size=batch_size, shuffle=True)

In [299]:
test_dataset = datasets.MNIST(root = 'datasets/', train = False, transform = transforms.ToTensor(), download=True)
test_loader = DataLoader(dataset = test_dataset, batch_size=batch_size, shuffle=True)

In [300]:
for data, target in train_loader:
    data = data.to(device = device)
    target = target.to(device = device)

In [301]:
print(data.shape)
data = data.reshape(data.shape[0],-1)
print(data.shape)

torch.Size([32, 1, 28, 28])
torch.Size([32, 784])


In [302]:
print(target.shape)
print(target.size())
print(target)

torch.Size([32])
torch.Size([32])
tensor([1, 6, 1, 7, 4, 5, 2, 0, 9, 9, 3, 8, 5, 3, 1, 2, 5, 1, 2, 6, 4, 0, 8, 5,
        0, 9, 4, 4, 2, 5, 3, 4], device='mps:0')


## Create RNN

In [303]:
import torch
import torch.nn as nn

class Recurrent(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes,sequence_length):
        super(Recurrent, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.sequence_length = sequence_length

        # Define the RNN layer
        self.rnn = nn.RNN(
            input_size=self.input_size, 
            hidden_size=self.hidden_size, 
            num_layers=self.num_layers, 
            batch_first=True
        )

        # Define the fully connected layer
        self.fc = nn.Linear(hidden_size, num_classes)  # Output layer based on number of classes

    def forward(self, x):
        # Initialize hidden state with zeros (num_layers, batch_size, hidden_size)
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

        # Pass the input through the RNN (batch_size, seq_length, input_size)
        out, _ = self.rnn(x, h0)

        # Get the last hidden state output for classification (batch_size, hidden_size)
        out = out[:, -1, :]

        # Pass through the fully connected layer (batch_size, num_classes)
        out = self.fc(out)

        return out


# Instantiate the model
input_size = 28 
sequence_length = 28# Example: number of features per time step
hidden_size = 256  # Example: hidden size of the RNN
num_layers = 2  # Example: number of RNN layers
num_classes = 10  # Example: number of output classes
device = torch.device("mps")
batch_size = 64
lr = 0.001





In [304]:
# Instantiate the model
model = Recurrent(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, num_classes=num_classes, sequence_length=sequence_length).to(device)

In [305]:

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [306]:
model.parameters

<bound method Module.parameters of Recurrent(
  (rnn): RNN(28, 256, num_layers=2, batch_first=True)
  (fc): Linear(in_features=256, out_features=10, bias=True)
)>

In [307]:
from tqdm import tqdm
from torchmetrics import Accuracy

In [308]:
acc = Accuracy(task='multiclass', num_classes=10).to(device=device)

In [309]:
# Example Training Loop
from tqdm import tqdm

# Assuming train_loader is defined
epochs = 5
for epoch in range(epochs):
    loop = tqdm(enumerate(train_loader), total=len(train_loader))
    for idx, (features, targets) in loop:
        features = features.squeeze(1).to(device)  # Move data to the device
        targets = targets.to(device)  # Move targets to the device

        # Forward pass
        scores = model(features)
        loss = criterion(scores, targets)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Calculate accuracy (assuming acc is your accuracy function)
        acc_value = acc(scores, targets)

        # Update progress bar
        loop.set_description(f"Epoch [{epoch}/{epochs}]")
        loop.set_postfix(loss=loss.item(), accuracy=acc_value.item())

Epoch [0/5]: 100%|██████████| 938/938 [00:38<00:00, 24.53it/s, accuracy=0.875, loss=0.261]
Epoch [1/5]: 100%|██████████| 938/938 [00:37<00:00, 24.84it/s, accuracy=1, loss=0.0672]    
Epoch [2/5]: 100%|██████████| 938/938 [00:37<00:00, 24.72it/s, accuracy=1, loss=0.0136]    
Epoch [3/5]: 100%|██████████| 938/938 [00:38<00:00, 24.28it/s, accuracy=0.938, loss=0.178] 
Epoch [4/5]: 100%|██████████| 938/938 [00:37<00:00, 24.93it/s, accuracy=1, loss=0.0427]    


In [310]:
def check_accuracy(loader,model):
    num_correct = 0
    num_samples = 0
    model.eval()
    
    with torch.no_grad():
        for x,y in loader:
            x = x.to(device=device).squeeze(1)
            y = y.to(device=device)
            
            scores = model(x)
            _,predictions = scores.max(1)
            
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
            
        print(f"Got {num_correct} / {num_samples} with accuracy \
            {float(num_correct)/float(num_samples)*100:.2f}")
        model.train()
    

In [311]:
check_accuracy(train_loader,model)

Got 57715 / 60000 with accuracy             96.19
