In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from mae_components import *
import yaml
from PIL import Image
from mae_dataset import get_miniImageNetDataLoader
import torch.optim as optim
import torch
from tqdm import tqdm
import os
import torch.nn as nn
import sys

os.environ["CUDA_VISIBLE_DEVICES"] = '3'

In [3]:
def read_yaml_config(file_path):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

config = read_yaml_config('config.yaml')

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
mae_encoder=MaskedViTEncoder(config, 320, 16, embed_dim=768, device=device).to(device)

In [6]:
mae_encoder.load_state_dict(torch.load("./mae_log/encoder_param_320.pth",map_location='cuda:0'), strict=False)
mae_encoder.train()

MaskedViTEncoder(
  (cnn): PrefixCNN(
    (model): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (relu): ReLU(inplace=True)
  )
  (blocks): ModuleList(
    (0-23): 24 x Block(
      (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=False)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
        (act2): GELU(approximate='none')
      )
    )
  )
  (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (he

In [7]:
optimizer = optim.Adam(mae_encoder.head.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()

mask_ratio=0.0

dataloader, memo = get_miniImageNetDataLoader(batch_size=24, img_size=320, shuffle=True)

Data Preparation Done
Data Loaded.


In [8]:
log_iter_freq = 10
checkpoint = 20
num_epoch = 10

log = []

for epoch in range(num_epoch):
    for n_iter, (img, target) in enumerate(dataloader):
        img=img.to(device)
        target=target.to(device)
        pred = mae_encoder.forward(img)

        loss = loss_fn(pred, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():   
            if n_iter % log_iter_freq == 0:
                print(f"Epoch:{epoch} {n_iter}/{len(dataloader)} Loss:{loss.detach().item():.3f}")
                log.append(loss.detach().item())
            '''
            if n_iter % checkpoint == 0  and n_iter != 0:
                torch.save(mae_encoder.state_dict(), "./mae_log/encoder_param_448.pth")
                torch.save(mae_decoder.state_dict(), "./mae_log/decoder_param_448.pth")
                torch.save(log, "./mae_log/loss.pt")
            '''

Epoch:0 0/2500 Loss:4.907
Epoch:0 10/2500 Loss:4.751
Epoch:0 20/2500 Loss:4.686
Epoch:0 30/2500 Loss:4.606
Epoch:0 40/2500 Loss:4.560
Epoch:0 50/2500 Loss:4.504
Epoch:0 60/2500 Loss:4.518
Epoch:0 70/2500 Loss:4.446
Epoch:0 80/2500 Loss:4.421
Epoch:0 90/2500 Loss:4.352
Epoch:0 100/2500 Loss:4.374
Epoch:0 110/2500 Loss:4.412
Epoch:0 120/2500 Loss:4.350
Epoch:0 130/2500 Loss:4.327
Epoch:0 140/2500 Loss:4.371
Epoch:0 150/2500 Loss:4.283
Epoch:0 160/2500 Loss:4.352
Epoch:0 170/2500 Loss:4.115
Epoch:0 180/2500 Loss:4.119
Epoch:0 190/2500 Loss:4.257
Epoch:0 200/2500 Loss:4.025
Epoch:0 210/2500 Loss:3.927
Epoch:0 220/2500 Loss:3.894
Epoch:0 230/2500 Loss:3.931
Epoch:0 240/2500 Loss:4.366
Epoch:0 250/2500 Loss:3.850
Epoch:0 260/2500 Loss:4.068
Epoch:0 270/2500 Loss:3.906
Epoch:0 280/2500 Loss:3.945
Epoch:0 290/2500 Loss:3.985
Epoch:0 300/2500 Loss:4.255
Epoch:0 310/2500 Loss:3.972
Epoch:0 320/2500 Loss:3.796
Epoch:0 330/2500 Loss:3.936
Epoch:0 340/2500 Loss:3.867
Epoch:0 350/2500 Loss:3.960
Epo