In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


In [12]:
class PatchEmbed(nn.Module):
    """ Image to Patch Embedding """
    def __init__(self, img_size=(28, 28), patch_size=(16, 16), in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x, **kwargs):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

In [13]:
def get_sinusoid_encoding_table(n_position, d_hid): 
  
    def get_position_angle_vec(position): 
        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 

    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 
    return torch.FloatTensor(sinusoid_table).unsqueeze(0) 

In [14]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., attn_head_dim=None):
        super().__init__()
        self.num_heads = num_heads
        head_dim = attn_head_dim if attn_head_dim is not None else dim // num_heads
        all_head_dim = head_dim * self.num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
        self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) if qkv_bias else None
        self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) if qkv_bias else None
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(all_head_dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        print(x.shape)
        B, N, C = x.shape
        qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) if self.q_bias is not None else None
        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1)).softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
        x = self.proj_drop(self.proj(x))
        return x

In [15]:

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_head_dim=None):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
            attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
        self.norm2 = norm_layer(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)), act_layer(), nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(attn_drop)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

In [16]:
class Encoder(nn.Module):
    def __init__(self, img_size=28, patch_size=16, in_chans=1, embed_dim=768, norm_layer=nn.LayerNorm, num_classes=10, **block_kwargs):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models

        # Patch embedding
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        # Positional encoding
        self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)

        # Transformer blocks
        self.blocks = nn.ModuleList([Block(768, 8, **block_kwargs) for i in range(2)])  # various arguments are not shown here for brevity purposes
        self.norm =  norm_layer(embed_dim)
        
        # Classifier (for fine-tuning only)
        self.fc_norm = norm_layer(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x, mask):
        print('encoder')
        x = self.patch_embed(x)
        x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
        print('mask: ', mask.shape) 
        B, _, C = x.shape
        mask = mask.reshape(B, 1, -1) 
        print('encoder x shape: ', x.shape)
        print('mask: ', mask.shape)
        if mask is not None:  # for pretraining only
            mask = mask.unsqueeze(-1)  # Add a channel dimension to match the shape of x
            x = x * (~mask) 
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        if self.num_classes > 0:  # for fine-tuning only
            x = self.fc_norm(x.mean(1))  # average pooling
            x = self.head(x)
        return x

In [17]:
class Decoder(nn.Module):
    def __init__(self, patch_size=16, embed_dim=768, norm_layer=nn.LayerNorm, num_classes=10, **block_kwargs):
        super().__init__()
        self.num_classes = num_classes
        #assert num_classes == 3 * patch_size ** 2
        self.num_features = self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.blocks = nn.ModuleList([Block(768, 8, **block_kwargs) for i in range(2)])  # various arguments are not shown here for brevity purposes
        self.norm =  norm_layer(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x, return_token_num):
        print('decoder')
        for blk in self.blocks:
            x = blk(x)
        if return_token_num > 0:
            x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels
        else:
            x = self.head(self.norm(x))
        return x

In [18]:
class MAE(nn.Module):
    def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer, num_classes, encoder_embed_dim, decoder_embed_dim, **block_kwargs):  # various arguments are not shown here for brevity purposes
        super().__init__()
        self.encoder = Encoder(img_size, patch_size, in_chans, embed_dim, norm_layer, num_classes=0, **block_kwargs)
        self.decoder = Decoder(patch_size, embed_dim, norm_layer, num_classes, **block_kwargs)
        self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=False)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        self.pos_embed = get_sinusoid_encoding_table(self.encoder.patch_embed.num_patches, decoder_embed_dim)
    
    def forward(self, x, mask):
        print('Mae')
        x_vis = self.encoder(x, mask)
        x_vis = self.encoder_to_decoder(x_vis)
        B, N, C = x_vis.shape
        mask = mask.reshape(B, 1, -1)
        expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
        pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
        pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C)
        x_full = torch.cat([x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1)
        x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16]
        return x

In [19]:
def train(model, train_loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        # Create a mask with the same shape as the input data
        mask = data > 0.75
        #mask = torch.rand(size=(64, 1, 768)) < 0.75  # Example random masking
        print('mask:', mask.shape)
        print('data: ', data.shape)
        optimizer.zero_grad()
        output = model(data, mask)  # Pass the mask to the forward method
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if batch_idx % 100 == 99:  # Print every 100 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (batch_idx + 1, batch_idx + 1, running_loss / 100))
            running_loss = 0.0

In [20]:
from torchvision import datasets, transforms
# Define transformations to be applied to the dataset
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert PIL Image to tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize the image tensors
])

# Download and load the dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Optionally, create DataLoader to handle batching and shuffling
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

img_size = (28,28)
patch_size = (16,16)
in_chans = 1
embed_dim = 768
norm_layer = nn.LayerNorm
num_classes = 10
encoder_embed_dim = 768
decoder_embed_dim = 768
model = MAE(img_size, patch_size, in_chans, embed_dim, norm_layer, num_classes, encoder_embed_dim, decoder_embed_dim)  # You need to provide necessary arguments for instantiation

# Define your optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Train your model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Now you can iterate over train_loader and test_loader to get batches of data
for batch_idx, (data, target) in enumerate(train_loader):
    train(model, train_loader, optimizer, criterion, device)

for batch_idx, (data, target) in enumerate(test_loader):
    # Testing loop
    pass



mask: torch.Size([64, 1, 28, 28])
data:  torch.Size([64, 1, 28, 28])
Mae
encoder
mask:  torch.Size([64, 1, 28, 28])
encoder x shape:  torch.Size([64, 1, 768])
mask:  torch.Size([64, 1, 784])
torch.Size([64, 64, 784, 768])


ValueError: too many values to unpack (expected 3)

In [None]:
num_epochs = 2
for epoch in range(num_epochs):
    train(model, train_loader, optimizer, criterion, device)

In [None]:
50176/64