In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from mae_components import *
import yaml
from PIL import Image
from mae_dataset import get_miniImageNetDataLoader
import torch.optim as optim
import torch
from tqdm import tqdm
import os
import torch.nn as nn
import sys

os.environ["CUDA_VISIBLE_DEVICES"] = '1, 2, 3'

In [3]:
def read_yaml_config(file_path):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

config = read_yaml_config('config.yaml')

In [4]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [5]:
mae_encoder=MaskedViTEncoder(config, 320, 16, embed_dim=768, device=device).to(device)
mae_decoder=MaskedViTDecoder(config, 320, 16, encoder_embed_dim=768, decoder_embed_dim=512, device=device, masked_decoder_loss=True).to(device)

In [6]:
mae_encoder.load_state_dict(torch.load("./mae_log/encoder_param_w_cls.pth"),strict=False)
mae_decoder.load_state_dict(torch.load("./mae_log/decoder_param_w_cls.pth"),strict=False)
mae_encoder.train()
mae_decoder.train()

MaskedViTDecoder(
  (encoder_to_decoder): Linear(in_features=768, out_features=512, bias=False)
  (blocks): ModuleList(
    (0-7): 8 x Block(
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=512, out_features=1536, bias=False)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=512, out_features=512, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=512, out_features=2048, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=2048, out_features=512, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
        (act2): GELU(approximate='none')
      )
    )
  )
  (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (decoder_pred): Linear(in_features=512, out_featu

In [7]:
param_dict=[{'params':mae_encoder.parameters()},{'params':mae_decoder.parameters()}]
optimizer = optim.Adam(param_dict, lr=0.0001)

loss_fn=nn.CrossEntropyLoss()

dataloader, memo = get_miniImageNetDataLoader(batch_size=16, img_size=320, shuffle=True)


Data Preparation Done
Data Loaded.


In [8]:
log_iter_freq = 10
checkpoint = 20
num_epoch = 10

log = []

for epoch in range(num_epoch):
    for n_iter, (img, target) in enumerate(dataloader):
        img=img.to(device)
        encoded, _ = mae_encoder.forward_encoder(img, mask_ratio=0.0)
        #pass no encoded_embedding to decoder but only the cls_token
        fake_mask = np.zeros((encoded.shape[0], encoded.shape[1]-1), dtype=bool)
        reconstructed = mae_decoder.forward_decoder(encoded, fake_mask)
        loss_rcs = mae_decoder.forward_loss(imgs=img, pred=reconstructed, mask=fake_mask)

        '''
        target=target.to(device)
        target_pred = mae_encoder.forward(img)
        loss_cls = loss_fn(target_pred, target)
        '''

        loss =  loss_rcs #+ 5*loss_cls
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        with torch.no_grad():   
            if n_iter % log_iter_freq == 0:
                print(f"Epoch:{epoch} {n_iter}/{len(dataloader)} Loss:{loss.detach().item():.3f}")
                log.append(loss.detach().item())
                #print("ACC:",torch.sum(torch.argmax(target, dim=1)==torch.argmax(target_pred, dim=1))/target.shape[0])
            if n_iter % checkpoint == 0  and n_iter != 0:
                torch.save(mae_encoder.state_dict(), "./mae_log/encoder_param_imagine.pth")
                torch.save(mae_decoder.state_dict(), "./mae_log/decoder_param_imagine.pth")
                torch.save(log, "./mae_log/imagine_loss.pt")

Epoch:0 0/3750 Loss:0.143
Epoch:0 10/3750 Loss:0.067
Epoch:0 20/3750 Loss:0.054
Epoch:0 30/3750 Loss:0.074
Epoch:0 40/3750 Loss:0.062
Epoch:0 50/3750 Loss:0.055
Epoch:0 60/3750 Loss:0.054
Epoch:0 70/3750 Loss:0.055
Epoch:0 80/3750 Loss:0.048
Epoch:0 90/3750 Loss:0.048
Epoch:0 100/3750 Loss:0.056
Epoch:0 110/3750 Loss:0.057
Epoch:0 120/3750 Loss:0.040
Epoch:0 130/3750 Loss:0.051
Epoch:0 140/3750 Loss:0.047
Epoch:0 150/3750 Loss:0.045
Epoch:0 160/3750 Loss:0.046
Epoch:0 170/3750 Loss:0.045
Epoch:0 180/3750 Loss:0.055
Epoch:0 190/3750 Loss:0.052
Epoch:0 200/3750 Loss:0.049
Epoch:0 210/3750 Loss:0.068
Epoch:0 220/3750 Loss:0.045
Epoch:0 230/3750 Loss:0.034
Epoch:0 240/3750 Loss:0.045
Epoch:0 250/3750 Loss:0.042
Epoch:0 260/3750 Loss:0.045
Epoch:0 270/3750 Loss:0.043
Epoch:0 280/3750 Loss:0.048
Epoch:0 290/3750 Loss:0.049
Epoch:0 300/3750 Loss:0.044
Epoch:0 310/3750 Loss:0.048
Epoch:0 320/3750 Loss:0.040
Epoch:0 330/3750 Loss:0.034
Epoch:0 340/3750 Loss:0.049
Epoch:0 350/3750 Loss:0.046
Epo