In [5]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

import numpy as np
from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary
from collections import OrderedDict
from typing import Optional

from utils.vit_utils import Image_Embedding as IE# 이전 장의 image embedding
from utils.vit_utils import MultiHeadAttention as MHA # 이전 장의 Multi-Head Attention

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [6]:
class ResidualConnection(nn.Module):
    def __init__(self, layer):
        super().__init__()
        self.layer = layer
    
    def forward(self, x, **kwargs):
        temp_x = x
        x = self.layer(x, **kwargs)
        return x + temp_x

In [7]:
class FeedForward(nn.Module):
    def __init__(self, 
                 embedding_size: int,
                 expansion: int = 4, 
                 dropout: float = 0.):
        super(FeedForward, self).__init__()

        self.ff_layer = nn.Sequential(
            nn.Linear(embedding_size, expansion * embedding_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(expansion * embedding_size, embedding_size),
        )
    def forward(self, x):
        return self.ff_layer(x)

In [8]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, 
                 embedding_size: int = 768,
                 dropout: float = 0.,
                 forward_expansion: int = 4,
                 forward_dropout: float = 0,
                 **kwargs):
        super(TransformerEncoderBlock, self).__init__()
        self.norm_mha = nn.Sequential(
            ResidualConnection(
                nn.Sequential(
                    nn.LayerNorm(embedding_size),
                    MHA(embedding_size, **kwargs),
                    nn.Dropout(dropout)
                    )
                )
            )
        self.norm_ff = nn.Sequential(
            ResidualConnection(
                nn.Sequential(
                    nn.LayerNorm(embedding_size),
                    FeedForward(embedding_size, forward_expansion, forward_dropout),
                    nn.Dropout(dropout)
                )
            )
        )

        def forward(x):
            x = self.norm_mha(x)
            return self.norm_ff(x)

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, depth: int = 12, **kwargs):
        super(TransformerEncoder, self).__init__()
        self.multi_encoder_layer = nn.Sequential(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])
