In [1]:
%load_ext autoreload
%autoreload 2

In [10]:
import time
import torch
import numpy as np
import json
import os
from torch.utils.data import DataLoader,Dataset
from torchvision.datasets import ImageFolder
from environment import StaticImgEnv
from PIL import Image
from matplotlib import pyplot as plot
import argparse
from ppo import PPO
from agent_rssm import Agent
import utils
import yaml
from torchvision.transforms import functional

torch.set_printoptions(threshold=10000)

In [3]:
def read_yaml_config(file_path):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

config = read_yaml_config('./config_shift.yaml')

os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"

parser = argparse.ArgumentParser(description='trainer')
parser.add_argument('--lr', type=float, default=0.05, help='learning rate') 
parser.add_argument('--data_dir', default='archive', help='data directory')
parser.add_argument('--batch_size', type=int, default=16,help='batch size')
parser.add_argument('--ppo_rollout_batch_size', type=int, default=16,help='batch size')
parser.add_argument('--epochs', type=int, default=10, help='total epochs to run')
parser.add_argument('--verbose', type=int, default=1, help='verbose')
parser.add_argument('--loss_freq', type=int, default=50, help='loss print freq')
parser.add_argument('--eval_freq', type=int, default=1, help='eval freq')
parser.add_argument('--device', default='cuda', help='cuda')
parser.add_argument('--critic_param_path', default='./models/acmerge_resnet_value.pth', help='pretrained critic')
parser.add_argument('--actor_param_path', default='./models/acmerge_resnet_actor.pth', help='pretrained actor')
parser.add_argument('--disable_critic', default=False, help='no Critic')
trainer_args = parser.parse_args("")

parser = argparse.ArgumentParser(description='env')
parser.add_argument('--radius', type=int, default=112, help='fovea radius') 
parser.add_argument('--action_range', type=int, default=224, help='action range') 
parser.add_argument('--max_steps', type=int, default=15, help='max steps: -1 for unlimited') 
parser.add_argument('--grid_size', type=int, default=7, help='action space grid size')
parser.add_argument('--plot_freq', type=int, default=10, help='plot trajectory freq')
parser.add_argument('--target_num', type=int, default=18, help='how many targets to find')
env_args = parser.parse_args("")

print(config)


{'Environment': {'radius': 112, 'vision_radius': 500, 'action_range': 224, 'max_steps': 15, 'grid_size': 7, 'plot_freq': 10, 'target_num': 18}, 'PrefixResnet': {'depth': 6, 'kernel_size': '([3,3],[3,3],[3,3])', 'in_channel': 3, 'requires_grad': False}, 'PrefixCNN': {'img_size': 224, 'patch_size': 16, 'in_channel': 3, 'out_channel': 512, 'requires_grad': True, 'pretrained_model_path': ''}, 'ViTEncoder': {'embed_dim': 512, 'depth': 24, 'num_heads': 16, 'mlp_ratio': 4.0, 'qkv_bias': False, 'qk_scale': False, 'drop_rate': 0.2, 'attn_drop_rate': 0.1, 'drop_path_rate': 0.1, 'num_classes': 100, 'weight_path': './mae_log/no_cls/encoder_param.pth'}, 'ViTDecoder': {'embed_dim': 256, 'depth': 8, 'num_heads': 16, 'mlp_ratio': 4.0, 'qkv_bias': False, 'qk_scale': False, 'drop_rate': 0.0, 'attn_drop_rate': 0.1, 'drop_path_rate': 0.1}, 'ShiftTransformer': {'embed_dim': 513, 'depth': 2, 'num_heads': 16, 'mlp_ratio': 4.0, 'qkv_bias': False, 'qk_scale': False, 'drop_rate': 0.0, 'attn_drop_rate': 0.0, 'dr

In [4]:
dataset_root="COCOSearch18"
with open(os.path.join(dataset_root,
               'coco_search18_fixations_TP_train_split1.json'#'coco_search18_fixations_TP_train.json'
               )) as json_file:
    human_scanpaths_train = json.load(json_file)
    
with open(os.path.join(dataset_root,
               'coco_search18_fixations_TP_validation_split1.json'#'coco_search18_fixations_TP_validation.json'
               )) as json_file:
    human_scanpaths_valid = json.load(json_file)

'''
max=0
for dict in human_scanpaths_train:
    if dict['length']>=max:
        max=dict['length']
print(max)
'''

dataset=utils.COCOSearch18(json=human_scanpaths_train,root='COCOSearch18/images',test=True)
training_loader=DataLoader(dataset,batch_size=trainer_args.batch_size,shuffle=True,num_workers=8)
env=StaticImgEnv(env_args=config['Environment'], show_plot = True)
agent=Agent(device=trainer_args.device, config=config,\
            grid_size=config['Environment']['grid_size'],disable_critic=trainer_args.disable_critic,recog_threshold=0.5)

#ppo=PPO(agent,lr=0.01,betas=[0.9,0.999],clip_param=0.2,num_epoch=1,batch_size=trainer_args.ppo_rollout_batch_size,\
#        value_coef=0.5,entropy_coef=0.8,drop_failed=False,vgg_backbone_fixed=True) #lr=0.01 clip=0.2 value_coef=1

#for i in agent.named_children():
#     print(i)

#print(agent.acmerge.value_read_out.weight[0])

<All keys matched successfully>
<All keys matched successfully>


In [5]:
print(agent.mae_encoder.requires_grad)

False


In [9]:
def tensor_to_PIL(tensor):
    img=torch.clip(img,0.0,1.0)
    test_img=functional.to_pil_image(img)

    return test_img
    
start=time.time()
for epoch in range(trainer_args.epochs):
    for idx, (img, target_id, fixations, correct, bbox) in enumerate(training_loader):
        env.set_data(img,target_id, bbox) #bbox->target bbox [x,y,w,h]
        with torch.no_grad():
            trajs_all=utils.collect_trajs(env,agent,max_traj_length=env_args.max_steps)
        utils.process_trajs(trajs_all,gamma=0.9,epistemic_coef=2)
        '''Be cautious the trajectories collected->[traj_length,batch_size,...]'''
        #for item in trajs_all:
        #    print(item,trajs_all[item].shape)
        rollouts=utils.RolloutStorage(trajs_all)
    
        loss=ppo.update(rollouts)
    
        #for i, sample in enumerate(data_generator):

torch.cuda.empty_cache()
end=time.time()
print(end-start)

torch.Size([16, 196, 512]) torch.Size([16, 196, 512])


ValueError: 

In [None]:
from PIL import Image, ImageDraw
def tensor_to_PIL(tensor):
    unloader = v2.ToPILImage()

    image = tensor.cpu().clone()
    image = image.squeeze(0)
    image = unloader(image)
    return image
bg=torch.zeros((1,3,224,224))
from torchvision.transforms import v2
#top: int, left: int, height: int, width: int
img=v2.functional.crop(bg,0,2,100,100)
img=tensor_to_PIL(img)
#img.show()
draw = ImageDraw.Draw(img)
fixations=[[1,2],[20,20]]
fixations=[tuple(item) for item in fixations]
print(fixations)
draw.point(fixations,fill='red')
img.show()

In [None]:
from torch.distributions import MultivariateNormal,Categorical
import torch.nn.functional as F
act_batch=torch.randint(-100,100,(64,2))
#m = MultivariateNormal(act_batch, torch.eye(2))
#m.sample()
logits=F.softmax(abs(torch.randn(5,225)+1),dim=-1)
print(logits)
mvn = Categorical(logits)
act_batch=mvn.sample()
print(act_batch.shape)

In [None]:
import torch.nn as nn
a=torch.rand((14,64,18))
b=torch.rand((14,64,18))
entropy=nn.BCELoss(reduction='none')
#print(a,b)
c=torch.concatenate((a[:32],b[32:]),dim=0)
print(entropy(b,c).mean(-1).shape)

d=torch.tensor([0,1.0,0])
e=torch.tensor([1.0,0,1.0])
#print(entropy(d,e).mean(-1))
# tensor(0.9964)

In [None]:
a=torch.zeros((16,18,100))
b=torch.randint(low=0,high=17,size=(16,))
print(b)
a[torch.arange(a.shape[0]),b].shape