In [None]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Optional, Tuple

In [None]:
print(f"MPS 장치를 지원하도록 build가 되었는가? {torch.backends.mps.is_built()}")
print(f"MPS 장치가 사용 가능한가? {torch.backends.mps.is_available()}") 
device = torch.device("mps")

In [None]:
class LSTMCell(nn.Module):
    def __init__(self, input_size : int, hidden_size : int):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.hidden_lin = nn.Linear(hidden_size, 4 * hidden_size)
        self.input_lin = nn.Linear(input_size, 4 * hidden_size, bias=False)

    def forward(self, x, h_in, c_in):
        X = self.input_lin(x) + self.hidden_lin(h_in) # 입력과 은닉 상태를 선형 변환 후 더함
        i, f, g, o = X.chunk(4, dim=-1)

        i = torch.sigmoid(i)
        f = torch.sigmoid(f)
        g = torch.tanh(g)
        o = torch.sigmoid(o)

        c_next = c_in * f + i * g
        h_next = torch.tanh(c_next) * o

        return h_next, c_next

In [None]:
class LSTM(nn.Module):
    def __init__(self, input_size : int, hidden_size : int, n_layers : int, n_classes: int):
        super(LSTM, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.n_classes = n_classes
        self.cells = nn.ModuleList([LSTMCell(input_size=input_size, hidden_size=hidden_size)] + [LSTMCell(input_size=hidden_size, hidden_size=hidden_size) for _ in range(n_layers - 1)])
        self.linear = nn.Linear(self.hidden_size, self.n_classes)

    def forward(self, x : torch.Tensor, state : Optional[Tuple[torch.Tensor, torch.Tensor]]):
        batch_size, seq_len = x.shape[:2] # x : [batch_size, seq_len, input_size]

        if state is None:
            h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
            c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
        else:
            h, c = state
            h, c = list(torch.unbind(h)), list(torch.unbind(c))

        out = [] # 각 layer들의 hidden state의 마지막 값을 담는 리스트
        for t in range(seq_len):
            inp = x[:, t, :] 
            for layer in range(self.n_layers):
                h[layer], c[layer] = self.cells[layer](inp, h[layer], c[layer])
                inp = h[layer] # 이전 hidden state가 다음 cell의 입력으로 들어감
            out.append(h[-1])

        # concat : 행렬을 좌우로 연결
        # stack : 행렬을 위 아래로 쌓음
        #out = torch.stack(out)
        out = self.linear(h[-1]).view([-1, self.n_classes])
        h = torch.stack(h)
        c = torch.stack(c)
        return out, (h, c)

In [None]:
batch_size = 100
input_size = 32 * 3  # 이미지의 각 픽셀을 입력으로 사용
seq_len = 32
hidden_size = 128
n_layers = 2
n_epoch = 100
lr = 0.001
n_classes = 10

train_dataset = datasets.CIFAR10("./cifar10", download=True, train=True, transform=transforms.ToTensor())
test_dataset = datasets.CIFAR10("./cifar10", download=True, train=False, transform=transforms.ToTensor())

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
model = LSTM(input_size, hidden_size, n_layers, len(train_dataset.classes))
loss_func = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr)

In [None]:
def train(model, seq_len, input_len, loss_func, opt, n_epoch, train_loader):
    model.train()
    for epoch in range(n_epoch):
        for batch, (image, label) in enumerate(train_loader):
            image = image.reshape(-1, seq_len, input_len)
            pred = model(image, None)
            loss = loss_func(pred[0], label)

            opt.zero_grad()
            loss.backward()
            opt.step()

            if (batch + 1) % 100 == 0: print(f"Epoch: {epoch}; Batch: {batch + 1} / {len(train_loader)}; Loss: {loss};")

In [None]:
train(model, seq_len, input_size, loss_func, opt, n_epoch, train_loader)