In [98]:
import torch
import torch.nn as nn
import pandas as pd
import re
import string
import numpy as np
from collections import Counter
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn.functional as F

## Data preprocessing

In [9]:
transform = transforms.Compose([
    transforms.ToTensor(),  
    transforms.Lambda(lambda x: x.view(-1))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

print("Training dataset size:", len(train_dataset))
print("Test dataset size:", len(test_dataset))

Training dataset size: 60000
Test dataset size: 10000


## Dataset class

In [127]:
class MNISTDataset(Dataset):
    
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return self.data.__len__()
    
    def __getitem__(self, idx):
        X, _ = self.data.__getitem__(idx)
        X = torch.where(X > 0, torch.tensor(1), torch.tensor(0))
        return F.one_hot(X[:-1], num_classes=2), X[1:]

## CharRNN class

In [138]:
class CharRNN(nn.Module):
    
    def __init__(self, vocab_size, hidden_dim):
        """
        param vocab_size (V): number of vocab
        param embedding_dim (E): number of embedding dimension
        param hidden_dim (H): number of hidden dimension
        
        length (L): length of sentence
        batch (B): batch size
        """
        super(CharRNN, self).__init__()
        self.rnn = nn.LSTM(2, hidden_dim, num_layers=2, batch_first=True, dropout=0.2)
        self.linear = torch.nn.Linear(in_features=hidden_dim, out_features=vocab_size)
        
    def forward(self, X):
        """
        Embedding: [B, L, 2]
        RNN: [B, L, H]
        Out: [B, L, V]
        """
        X, _ = self.rnn(X)
        out = self.linear(X)
        return out
    
    def generate(self, max_length=28**2):
        """
        sentence: stores a collection of characters
        character: store the index (integer) of current character
        """
        sentence = [0]
        character = 0
        hidden = None
        with torch.no_grad():
            while len(sentence) < max_length:
                X = torch.tensor([sentence], dtype=torch.long)
                X = F.one_hot(X, num_classes=2).float()
                X, hidden = self.rnn(X, hidden)
                X = self.linear(X)
                character = torch.argmax(X[0, 0, :]).item()
                sentence.append(character)
        return torch.tensor(sentence)

## Training loop

In [134]:
def train(model, train_loader, n_epochs, lr=1e-3, device="cpu"):
    model.to(device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss(reduction='mean')
    for epoch in range(n_epochs):
        average_loss = 0
        n = 0
        for X, y in train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            y_pred = model(X.float())
            loss = criterion(y_pred.view(-1, 2), y.view(-1))
            loss.backward()
            optimizer.step()
            average_loss += loss.item()
            n += 1
        print(f"Epoch {epoch + 1} average loss: {average_loss / n}")

## Training

In [135]:
dataset = MNISTDataset(test_dataset)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [139]:
model = CharRNN(2, 32)

In [132]:
train(model, dataloader, 20)

torch.Size([64, 783, 2])
torch.Size([64, 783, 2])
torch.Size([64, 783, 2])
torch.Size([64, 783, 2])
torch.Size([64, 783, 2])
torch.Size([64, 783, 2])
torch.Size([64, 783, 2])
torch.Size([64, 783, 2])
torch.Size([64, 783, 2])
torch.Size([64, 783, 2])
torch.Size([64, 783, 2])
torch.Size([64, 783, 2])
torch.Size([64, 783, 2])
torch.Size([64, 783, 2])


KeyboardInterrupt: 

## Testing

In [140]:
img = model.generate()

In [141]:
img

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,