In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from mae_components_no_cls import *
import yaml
from PIL import Image
from mae_dataset import get_miniImageNetDataLoader
import torch.optim as optim
import torch
from tqdm import tqdm
import os
import torch.nn as nn
import sys

os.environ["CUDA_VISIBLE_DEVICES"] = '2, 3'

In [3]:
def read_yaml_config(file_path):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

config = read_yaml_config('./mae_log/no_cls/config.yaml')

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
mae_encoder=MaskedViTEncoder(config, 224, 16, embed_dim=512, device=device).to(device)
mae_decoder=MaskedViTDecoder(config, 224, 16, encoder_embed_dim=512, decoder_embed_dim=256, device=device, masked_decoder_loss=False).to(device)

In [6]:
#mae_encoder.load_state_dict(torch.load("./mae_log/224/encoder_param_2.pth"),strict=False)
#mae_decoder.load_state_dict(torch.load("./mae_log/224/decoder_param_2.pth"),strict=False)
#mae_encoder.train()
#mae_decoder.train()

In [7]:
param_dict=[{'params':mae_encoder.parameters()},{'params':mae_decoder.parameters()}]
optimizer = optim.Adam(param_dict, lr=0.0001)

loss_fn=nn.CrossEntropyLoss()

mask_ratio=0.7

dataloader, memo = get_miniImageNetDataLoader(batch_size=128, img_size=224, shuffle=True)


Data Preparation Done
Data Loaded.


In [8]:
enable_imagine = False

log_iter_freq = 50
imagine_freq = 2
checkpoint = 50
num_epoch = 50

log = []

for epoch in range(num_epoch):
    if epoch%imagine_freq==0 and epoch!=0 and enable_imagine:
        print("============Now REM Sleeping==============")
        #imagine
        for n_iter, (img, target) in enumerate(dataloader):
            img=img.to(device)
            encoded, _ = mae_encoder.forward_encoder(img, mask_ratio=0.0)
            #pass no encoded_embedding to decoder but only the cls_token
            fake_mask = np.zeros((encoded.shape[0], encoded.shape[1]-1), dtype=bool)
            reconstructed = mae_decoder.forward_decoder(encoded, fake_mask)
            loss_rcs = mae_decoder.forward_loss(imgs=img, pred=reconstructed, mask=fake_mask)
    
            '''
            target=target.to(device)
            target_pred = mae_encoder.forward(img)
            loss_cls = loss_fn(target_pred, target)
            '''
    
            loss =  loss_rcs #+ 5*loss_cls
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            with torch.no_grad():   
                if n_iter % log_iter_freq == 0:
                    print(f"Epoch:{epoch} {n_iter}/{len(dataloader)} Loss:{loss.detach().item():.3f}")
                    log.append(loss.detach().item())
                    #print("ACC:",torch.sum(torch.argmax(target, dim=1)==torch.argmax(target_pred, dim=1))/target.shape[0])
                if n_iter % checkpoint == 0  and n_iter != 0:
                    torch.save(mae_encoder.state_dict(), "./mae_log/224/encoder_param.pth")
                    torch.save(mae_decoder.state_dict(), "./mae_log/224/decoder_param.pth")
                    torch.save(log, "./mae_log/224/loss.pt")
        
    else:
        print("=================Day Time================")
        #visual
        for n_iter, (img, target) in enumerate(dataloader):
            img=img.to(device)
            encoded, batch_mask = mae_encoder.forward_encoder(img, mask_ratio)
            reconstructed = mae_decoder.forward_decoder(encoded, batch_mask)
            loss_rcs = mae_decoder.forward_loss(imgs=img, pred=reconstructed, mask=batch_mask)
            
            #target=target.to(device)
            #target_pred = mae_encoder.forward(img)
            #loss_cls = loss_fn(target_pred, target)
    
            loss = loss_rcs #5*loss_cls + loss_rcs
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            with torch.no_grad():   
                if n_iter % log_iter_freq == 0:
                    print(f"Epoch:{epoch} {n_iter}/{len(dataloader)} Loss:{loss.detach().item():.3f}")
                    log.append(loss.detach().item())
                    #print("ACC:",torch.sum(torch.argmax(target, dim=1)==torch.argmax(target_pred, dim=1))/target.shape[0])
                if n_iter % checkpoint == 0  and n_iter != 0:
                    torch.save(mae_encoder.state_dict(), "./mae_log/no_cls/encoder_param.pth")
                    torch.save(mae_decoder.state_dict(), "./mae_log/no_cls/decoder_param.pth")
                    torch.save(log, "./mae_log/no_cls/loss.pt")


Epoch:0 0/469 Loss:0.589
Epoch:0 50/469 Loss:0.082
Epoch:0 100/469 Loss:0.079
Epoch:0 150/469 Loss:0.056
Epoch:0 200/469 Loss:0.050
Epoch:0 250/469 Loss:0.047
Epoch:0 300/469 Loss:0.048
Epoch:0 350/469 Loss:0.045
Epoch:0 400/469 Loss:0.039
Epoch:0 450/469 Loss:0.038
Epoch:1 0/469 Loss:0.034
Epoch:1 50/469 Loss:0.033
Epoch:1 100/469 Loss:0.032
Epoch:1 150/469 Loss:0.032
Epoch:1 200/469 Loss:0.030
Epoch:1 250/469 Loss:0.031
Epoch:1 300/469 Loss:0.027
Epoch:1 350/469 Loss:0.027
Epoch:1 400/469 Loss:0.027
Epoch:1 450/469 Loss:0.029
Epoch:2 0/469 Loss:0.025
Epoch:2 50/469 Loss:0.028
Epoch:2 100/469 Loss:0.027
Epoch:2 150/469 Loss:0.028
Epoch:2 200/469 Loss:0.025
Epoch:2 250/469 Loss:0.026
Epoch:2 300/469 Loss:0.023
Epoch:2 350/469 Loss:0.023
Epoch:2 400/469 Loss:0.023
Epoch:2 450/469 Loss:0.024
Epoch:3 0/469 Loss:0.024
Epoch:3 50/469 Loss:0.023
Epoch:3 100/469 Loss:0.023
Epoch:3 150/469 Loss:0.024
Epoch:3 200/469 Loss:0.022
Epoch:3 250/469 Loss:0.023
Epoch:3 300/469 Loss:0.021
Epoch:3 350/4

In [9]:
print(len(torch.load("./mae_log/1st_trial/cls_loss.pt")))

3750
