In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from models import ACMerge_vgg_trainer, ACMerge_resnet_trainer
import numpy as np
import json
import os
from torch.utils.data import DataLoader,Dataset
from PIL import Image
from matplotlib import pyplot as plot
import argparse
import utils
from torchvision.transforms import v2
import torch.optim as optim
import torch.nn as nn
import torchvision.models as models

In [3]:
#os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
parser = argparse.ArgumentParser(description='trainer')
parser.add_argument('--lr', type=float, default=0.008, help='learning rate') 
parser.add_argument('--data_dir', default='archive', help='data directory')
parser.add_argument('--batch_size', type=int, default=128,help='batch size')
parser.add_argument('--epochs', type=int, default=50, help='total epochs to run')
parser.add_argument('--verbose', type=int, default=1, help='verbose')
parser.add_argument('--log_freq', type=int, default=50, help='loss print freq')
parser.add_argument('--eval_freq', type=int, default=1, help='eval freq')
parser.add_argument('--device', default='cuda', help='cuda')
trainer_args = parser.parse_args("")

In [4]:
dataset_root="COCOSearch18"
with open(os.path.join(dataset_root,
               'coco_search18_fixations_TP_train_split1.json'#'coco_search18_fixations_TP_train.json'
               )) as json_file:
    human_scanpaths_train = json.load(json_file)
    
with open(os.path.join(dataset_root,
               'coco_search18_fixations_TP_validation_split1.json'#'coco_search18_fixations_TP_validation.json'
               )) as json_file:
    human_scanpaths_valid = json.load(json_file)

### Given a concatenated neuron state maps, to figure out the correlation between these maps and the target

In [5]:
train_dataset=utils.COCOSearch18(json=human_scanpaths_train,root='COCOSearch18/images')
validation_dataset=utils.COCOSearch18(json=human_scanpaths_valid,root='COCOSearch18/images')
training_loader=DataLoader(train_dataset,batch_size=trainer_args.batch_size,shuffle=True,num_workers=8,drop_last=True)
validation_loader=DataLoader(validation_dataset,batch_size=200,shuffle=True,num_workers=8,drop_last=True)

In [7]:
fovea_size=224
save_path='./models/acmerge_resnet_actor.pth'
def tensor_to_PIL(tensor):
    unloader = v2.ToPILImage()
    image = tensor.cpu().clone()
    image = image.squeeze(0)
    image = unloader(image)
    return image
    
acmerge=ACMerge_resnet_trainer(weights=models.ResNet18_Weights.IMAGENET1K_V1).to(trainer_args.device)#models.VGG16_Weights.IMAGENET1K_V1
#critic=Critic().to(trainer_args.device)
optimizer=optim.Adam([acmerge.actor_read_out.weight],lr=trainer_args.lr)
loss_fn=nn.CrossEntropyLoss()
acc_temp=0

for epoch in range(trainer_args.epochs):
    for idx, (img, target_id, fixations, correct, bbox) in enumerate(training_loader):
        
        center_x=torch.round(bbox[:,0]+bbox[:,2]/2).to(torch.int)
        center_y=torch.round(bbox[:,1]+bbox[:,3]/2).to(torch.int)
        fix=torch.stack((center_x,center_y),dim=0).T
        observation=[]
        for i in range(trainer_args.batch_size):
            height=fovea_size #fovea size
            width=fovea_size
            top=int(fix[i,1]-height/2)
            left=int(fix[i,0]-width/2)
            #print(top,left)
            observation.append(v2.functional.crop(img[i],top,left,height,width)) #self-padding
        sample=tensor_to_PIL(observation[0])
        #plot.imshow(sample)
        #plot.show()
        observations=torch.stack(observation).to(trainer_args.device)
        target_id=target_id.to(trainer_args.device)
        optimizer.zero_grad()
        output,_=acmerge(observations)
        #print(output.shape,target_id.shape)
        loss=loss_fn(output,target_id)+0.1*torch.norm(acmerge.actor_read_out.weight)
        #print(loss.item())
        loss.backward()
        optimizer.step()
    
        with torch.no_grad():
            if idx%trainer_args.log_freq==0:

                for valid_img, valid_target_id, _,_,valid_bbox in validation_loader:
                    center_x=torch.round(valid_bbox[:,0]+valid_bbox[:,2]/2).to(torch.int)
                    center_y=torch.round(valid_bbox[:,1]+valid_bbox[:,3]/2).to(torch.int)
                    fix=torch.stack((center_x,center_y),dim=0).T
                    valid_observation=[]
                    for i in range(200):
                        height=fovea_size #fovea size
                        width=fovea_size
                        top=int(fix[i,1]-height/2)
                        left=int(fix[i,0]-width/2)
                        #print(top,left)
                        valid_observation.append(v2.functional.crop(valid_img[i],top,left,height,width)) #self-padding
                    valid_observations=torch.stack(valid_observation).to(trainer_args.device)
                    valid_target_id=valid_target_id.to(trainer_args.device)
                    valid_output,_=acmerge(valid_observations)
                    
                    acc=torch.sum(torch.argmax(valid_output,dim=1)==valid_target_id)/200
                    if acc>=acc_temp and epoch>=1:
                        torch.save(acmerge.actor_read_out.weight,save_path)
                        acc_temp=acc
                    
                    
                    print(f'Epoch:{epoch},Step:{idx},acc:{acc:.3f}')
                    break




Epoch:0,Step:0,acc:0.080
Epoch:0,Step:50,acc:0.825
Epoch:0,Step:100,acc:0.820
Epoch:0,Step:150,acc:0.785
Epoch:1,Step:0,acc:0.775
Epoch:1,Step:50,acc:0.755
Epoch:1,Step:100,acc:0.785
Epoch:1,Step:150,acc:0.830
Epoch:2,Step:0,acc:0.760
Epoch:2,Step:50,acc:0.780
Epoch:2,Step:100,acc:0.860
Epoch:2,Step:150,acc:0.840
Epoch:3,Step:0,acc:0.815
Epoch:3,Step:50,acc:0.780
Epoch:3,Step:100,acc:0.795
Epoch:3,Step:150,acc:0.810
Epoch:4,Step:0,acc:0.810
Epoch:4,Step:50,acc:0.755
Epoch:4,Step:100,acc:0.875
Epoch:4,Step:150,acc:0.790
Epoch:5,Step:0,acc:0.795


KeyboardInterrupt: 

In [None]:
print(torch.max(acmerge.actor_read_out_fake.weight.squeeze()[8]))