Test if the ShiftFormer can do more things.

In [2]:
%load_ext autoreload
%autoreload 2

In [6]:
from mae_components_no_cls import *
import yaml
from PIL import Image
from mae_dataset import get_miniImageNetDataLoader
import torch.optim as optim
import torch
from tqdm import tqdm
import os
import torch.nn as nn
import sys
import json
import utils
from torch.utils.data import DataLoader,Dataset

os.environ["CUDA_VISIBLE_DEVICES"] = '2, 3'

In [7]:
def read_yaml_config(file_path):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

config = read_yaml_config('./mae_log/shift_test/config_r.yaml')

In [10]:
dataset_root="COCOSearch18"
with open(os.path.join(dataset_root,
               'coco_search18_fixations_TP_train_split1.json'#'coco_search18_fixations_TP_train.json'
               )) as json_file:
    human_scanpaths_train = json.load(json_file)
    
with open(os.path.join(dataset_root,
               'coco_search18_fixations_TP_validation_split1.json'#'coco_search18_fixations_TP_validation.json'
               )) as json_file:
    human_scanpaths_valid = json.load(json_file)

'''
max=0
for dict in human_scanpaths_train:
    if dict['length']>=max:
        max=dict['length']
print(max)
'''

dataset=utils.COCOSearch18(json=human_scanpaths_train,root='COCOSearch18/images')

In [11]:
training_loader=DataLoader(dataset,batch_size=16, shuffle=True,num_workers=8)

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
mae_encoder=MaskedViTEncoder(config, 224, 16, embed_dim=512, device=device).to(device)
shift_block=ShiftTransformer(config, img_size=224, patch_size=16, embed_dim=512, device=device).to(device)
mae_decoder=MaskedViTDecoder(config, 224, 16, encoder_embed_dim=512, decoder_embed_dim=256, device=device, masked_decoder_loss=True).to(device)

In [6]:
mae_encoder.load_state_dict(torch.load("./mae_log/no_cls/encoder_param.pth"),strict=False)
shift_block.load_state_dict(torch.load("./mae_log/shift_test/shift_param_r_1.pth"),strict=False)
mae_decoder.load_state_dict(torch.load("./mae_log/no_cls/decoder_param.pth"),strict=False)
mae_encoder.eval()
shift_block.train()
mae_decoder.eval()

MaskedViTDecoder(
  (encoder_to_decoder): Linear(in_features=512, out_features=256, bias=False)
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=256, out_features=768, bias=False)
        (attn_drop): Dropout(p=0.1, inplace=False)
        (proj): Linear(in_features=256, out_features=256, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=256, out_features=1024, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1024, out_features=256, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
        (act2): GELU(approximate='none')
      )
    )
    (1): Block(
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (attn): Attention(
        (qkv): L

In [7]:
#param_dict=[{'params':mae_encoder.parameters()},{'params':mae_decoder.parameters()}]
param_dict=[{'params':shift_block.parameters()}]
optimizer = optim.Adam(param_dict, lr=0.00005) #0.0001 for normal training, 0.00005 for finetuning

loss_fn=nn.CrossEntropyLoss()

dataloader, memo = get_miniImageNetDataLoader(batch_size=4, img_size=224, shuffle=True)


Data Preparation Done
Data Loaded.


In [8]:
mask_ratio = 0.0

log_iter_freq = 50
imagine_freq = 2
checkpoint = 50
num_epoch = 50

log = []

for epoch in range(num_epoch):
    
    #visual

    for idx, (img, target_id, fixations, correct, bbox) in enumerate(training_loader):
    for n_iter, (img, target) in enumerate(dataloader):
        img=img.to(device)
        encoded, batch_mask = mae_encoder.forward_encoder(img, mask_ratio)
        
        #randomize action
        xy = shift_block.generate_shift_cat(encoded)
        
        shifted = shift_block.forward_encoder(encoded, xy) #no cls_token
        reconstructed = mae_decoder.forward_decoder(shifted[:, :, 1:], batch_mask, vis=True) #assume shift complete, deprive assistant dim

        loss_rcs = mae_decoder.forward_shift_loss(imgs=img, pred=reconstructed, xy=xy)
        #loss_rcs = mae_decoder.forward_loss(imgs=img, pred=reconstructed, mask=batch_mask)

        #target=target.to(device)
        #target_pred = mae_encoder.forward(img)
        #loss_cls = loss_fn(target_pred, target)

        loss = loss_rcs #5*loss_cls + loss_rcs
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        with torch.no_grad():   
            if n_iter % log_iter_freq == 0:
                print(f"Epoch:{epoch} {n_iter}/{len(dataloader)} Loss:{loss.detach().item():.3f}")
                log.append(loss.detach().item())
                #print("ACC:",torch.sum(torch.argmax(target, dim=1)==torch.argmax(target_pred, dim=1))/target.shape[0])
            if n_iter % checkpoint == 0  and n_iter != 0:
                '''torch.save(mae_encoder.state_dict(), "./mae_log/shift_test/encoder_param_2.pth")'''
                #torch.save(mae_decoder.state_dict(), "./mae_log/shift_test/decoder_param.pth")
                torch.save(shift_block.state_dict(), "./mae_log/shift_test/shift_param_r_1.pth")
                torch.save(log, "./mae_log/shift_test/loss_r_1.pt")
            


ValueError: 

In [None]:
import numpy as np
np.random.seed(114)
x=np.expand_dims(np.random.randint(0,16,5),(0,1)).repeat(3,1).transpose(2,0,1)
y=np.expand_dims(np.random.randint(0,16,5),(0,1)).repeat(3,1).transpose(2,0,1)
xy=np.concatenate([x,y], axis=1)
print(xy.shape)


In [None]:
np.random.seed(114)
#batch 20*20 512
shift_token = np.zeros((5,196,512))
shift_token[0,2,:]=1

x = np.random.randint(0,14,5)
y = np.random.randint(0,14,5)
xy=np.stack([x,y],axis=1)
print(xy.shape)
xy_dot=x*14 + y
xy_dot=np.tile(xy_dot,(1,196,1)).transpose()

print(xy_dot.shape)
print(xy)
print(shift_token[:,xy_dot].shape)