In [81]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from tqdm import tqdm 
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

In [82]:
# (batch_size, seq_len, feature)
data = torch.randn(size=(10, 5, 32))
rnn = nn.RNN(32, 128, num_layers=1, bidirectional=False, batch_first=True)
output, h_n = rnn(data)
print(output.shape)
print(h_n.shape)

torch.Size([10, 5, 128])
torch.Size([1, 10, 128])


RNNCell

In [83]:
rnncell = nn.RNNCell(32, 128)
h_t = torch.zeros((10,128)) # (batch_size, hidden_size)
for i in range(data.shape[1]):
    x_t = data[:,i,:]
    h_t = rnncell(x_t, h_t)
    print(h_t.shape)    
    break

torch.Size([10, 128])


In [84]:
print(output.shape) # (batch_size, seq_len, hidden_size)
print(h_n.shape) # (1, batch_size, hidden_size)

torch.Size([10, 5, 128])
torch.Size([1, 10, 128])


In [111]:
config = {
    "batch_size":16,
    "epoch":10,
    "lr":1e-3,
    "device":"cuda" if torch.cuda.is_available() else "cpu"
}


In [86]:
ds_train = MNIST("../data", download=True, train=True, transform=ToTensor())
dl_train = DataLoader(ds_train, config["batch_size"], shuffle=True)

In [87]:
#  查看图片 shape
next(iter(dl_train))[0].shape # (batch_size, C, H, W)

torch.Size([16, 1, 28, 28])

In [88]:
class MinistClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_labels):
        super().__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_labels)

    def forward(self, X):
        output, h_t = self.rnn(X)
        return self.fc(h_t[0])

In [89]:
class MinistClassifierAdv(nn.Module):
    def __init__(self, input_size, hidden_size, num_labels):
        super().__init__()
        self.rnn = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_labels)

    def forward(self, X):
        output, (h_t, c_t) = self.rnn(X)
        return self.fc(h_t[0])

In [90]:
# data.shape    # torch.Size([10, 5, 32])
# model = MinistClassifier(32, 64, 10)
# model(data).shape     # torch.Size([10, 10])

In [92]:
model_rnn = MinistClassifier(28, 100, 10)
optimizer = torch.optim.Adam(model_rnn.parameters(), lr=config["lr"])
loss_fn = nn.CrossEntropyLoss()
model_rnn = model_rnn.to(config["device"])
loss_fn = loss_fn.to(config["device"])
for epoch in range(config["epoch"]):
    process_bar = tqdm(dl_train)
    model_rnn.train()
    total_loss = 0
    for i, (img, label) in enumerate(process_bar, start=1):
        img, label = img.squeeze().to(config["device"]), label.to(config["device"])
        y_hat = model_rnn(img)
        loss = loss_fn(y_hat, label)
        total_loss += loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        process_bar.set_description(f"epoch: {epoch + 1}, loss: {loss.item():.4f}")



epoch: 1, loss: 0.6060: 100%|██████████| 3750/3750 [00:14<00:00, 263.82it/s]
epoch: 2, loss: 0.3385: 100%|██████████| 3750/3750 [00:13<00:00, 272.53it/s]


In [110]:
model_lstm = MinistClassifierAdv(28, 100, 10)
optimizer = torch.optim.Adam(model_lstm.parameters(), lr=config["lr"])
loss_fn = nn.CrossEntropyLoss()
model_lstm = model_lstm.to(config["device"])
loss_fn = loss_fn.to(config["device"])
for epoch in range(config["epoch"]):
    process_bar = tqdm(dl_train)
    model_lstm.train()
    total_loss = 0
    for i, (img, label) in enumerate(process_bar, start=1):
        img, label = img.squeeze().to(config["device"]), label.to(config["device"])
        y_hat = model_lstm(img)
        loss = loss_fn(y_hat, label)
        total_loss += loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        process_bar.set_description(f"epoch: {epoch + 1}, loss: {loss.item():.4f}")



epoch: 1, loss: 0.0810: 100%|██████████| 3750/3750 [00:19<00:00, 193.88it/s]
epoch: 2, loss: 0.2359: 100%|██████████| 3750/3750 [00:18<00:00, 197.52it/s]


In [112]:
class MinistClassifierAdvBidireciton(nn.Module):
    def __init__(self, input_size, hidden_size, num_labels):
        super().__init__()
        self.rnn = nn.LSTM(input_size, hidden_size, batch_first=True, bidirectional=True, num_layers=2)
        self.fc = nn.Linear(hidden_size * (2 if self.rnn.bidirectional == True else 1)
                            , num_labels)

    def forward(self, X):
        output, (h_t, c_t) = self.rnn(X)
        return self.fc(output[:,-1,:])
    
model_bidirection_lstm = MinistClassifierAdvBidireciton(28, 100, 10)
optimizer = torch.optim.Adam(model_bidirection_lstm.parameters(), lr=config["lr"])
loss_fn = nn.CrossEntropyLoss()
model_bidirection_lstm = model_bidirection_lstm.to(config["device"])
loss_fn = loss_fn.to(config["device"])
for epoch in range(config["epoch"]):
    process_bar = tqdm(dl_train)
    model_bidirection_lstm.train()
    total_loss = 0
    for i, (img, label) in enumerate(process_bar, start=1):
        img, label = img.squeeze().to(config["device"]), label.to(config["device"])
        y_hat = model_bidirection_lstm(img)
        loss = loss_fn(y_hat, label)
        total_loss += loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        process_bar.set_description(f"epoch: {epoch + 1}, loss: {loss.item():.4f}")



epoch: 1, loss: 0.0502: 100%|██████████| 3750/3750 [00:22<00:00, 169.47it/s]
epoch: 2, loss: 0.0018: 100%|██████████| 3750/3750 [00:21<00:00, 173.98it/s]
epoch: 3, loss: 0.2013: 100%|██████████| 3750/3750 [00:21<00:00, 170.92it/s]
epoch: 4, loss: 0.0253: 100%|██████████| 3750/3750 [00:22<00:00, 166.33it/s]
epoch: 5, loss: 0.0010: 100%|██████████| 3750/3750 [00:21<00:00, 173.05it/s]
epoch: 6, loss: 0.2020: 100%|██████████| 3750/3750 [00:21<00:00, 170.76it/s]
epoch: 7, loss: 0.0003: 100%|██████████| 3750/3750 [00:21<00:00, 170.72it/s]
epoch: 8, loss: 0.0034: 100%|██████████| 3750/3750 [00:21<00:00, 173.22it/s]
epoch: 9, loss: 0.0330: 100%|██████████| 3750/3750 [00:21<00:00, 174.15it/s]
epoch: 10, loss: 0.0005: 100%|██████████| 3750/3750 [00:21<00:00, 171.40it/s]


In [113]:
torch.save(model_bidirection_lstm.state_dict(), "../source/model_dict/lstm_mini_cls.pth")