In [1]:
%cd ..

/home/cml0/rl-aux2


In [2]:
import subprocess
from tabnanny import verbose

import torch
from stable_baselines3 import PPO
from torch import nn
import matplotlib.pyplot as plt
import torchvision.transforms as T
from torchvision.datasets import CIFAR100
from datasets.cifar100 import CIFAR100, CoarseLabelCIFAR100
from datasets.transforms import cifar_trans_train, cifar_trans_test
from environment.learn_weight_aux_task import AuxTaskEnv
from environment.weight_training.weight_training_environment import WeightTuningEnv
from networks.ppo.ppo import get_ppo_agent
from networks.primary.vgg import VGG16
from networks.weight_training.ppo import get_weight_training_ppo_agent
from train.train_auxilary_agent import train_auxilary_agent
from utils.analysis.network_details import print_aux_weights
from utils.log import log_print, change_log_location
from utils.path_name import create_path_name, save_all_parameters

LOAD_MODEL_PATH = "/home/cml0/rl-aux/trained_models/PPO_VGG_learn_weights_False_train_ratio_1_aux_weight_1_obs_dim_256_CIFAR100-20v2/best_model_auxiliary"
BATCH_SIZE = 100
AUX_DIMENSION = 100
PRIMARY_DIMENSION = 20
OBSERVATION_FEATURE_DIMENSION = 256
TOTAL_EPOCH = 200
PRIMARY_LEARNING_RATE = 0.01
PPO_LEARNING_RATE = 1e-4
SCHEDULER_STEP_SIZE = 50
SCHEDULER_GAMMA = 0.5
AUX_WEIGHT = 0
LEARN_WEIGHTS = True
TRAIN_RATIO = 1

SAVE_PATH="./"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

cifar100_train_set = CIFAR100(root='dataset', train=True, transform=cifar_trans_test, download=True)
cifar100_test_set = CIFAR100(root='dataset', train=False, transform=cifar_trans_test, download=True)

course_cifar_train_set = CoarseLabelCIFAR100(cifar100_train_set)
course_cifar_test_set = CoarseLabelCIFAR100(cifar100_test_set)

cifar100_train_loader = torch.utils.data.DataLoader(
    dataset=course_cifar_train_set,
    batch_size=BATCH_SIZE,
    shuffle=True)

cifar100_test_loader = torch.utils.data.DataLoader(
    dataset=course_cifar_test_set,
    batch_size=BATCH_SIZE,
    shuffle=True)

primary_model = VGG16(
    primary_task_output=PRIMARY_DIMENSION,
    auxiliary_task_output=AUX_DIMENSION
).to(device)
criterion = nn.CrossEntropyLoss()

optimizer_callback = lambda x: torch.optim.SGD(x.parameters(), lr=PRIMARY_LEARNING_RATE)
scheduler_callback = lambda x: torch.optim.lr_scheduler.StepLR(x, step_size=SCHEDULER_STEP_SIZE, gamma=SCHEDULER_GAMMA)
# ---------

task_env = AuxTaskEnv(
    train_dataset=course_cifar_train_set,
    device=device,
    model=primary_model,
    criterion=criterion,
    optimizer_func=optimizer_callback,
    scheduler_func=scheduler_callback,
    batch_size=BATCH_SIZE,
    pri_dim=PRIMARY_DIMENSION,
    aux_dim=AUX_DIMENSION,
    aux_weight=AUX_WEIGHT,
    save_path=SAVE_PATH,
    learn_weights=LEARN_WEIGHTS,
    verbose=True,
)

auxilary_task_agent = get_ppo_agent(env=task_env,
                                    feature_dim=OBSERVATION_FEATURE_DIMENSION,
                                    auxiliary_dim=AUX_DIMENSION,
                                    learning_rate=PPO_LEARNING_RATE,
                                    device=device,
                                    ent_coef=0.01,
                                    n_steps=79,
                                    n_epochs=10,
                                    batch_size=BATCH_SIZE,
                                    weight_bins=21,
                                    )

auxilary_task_agent.set_parameters(LOAD_MODEL_PATH, device=device)



ModuleNotFoundError: No module named 'datasets'

In [None]:
print(cifar100_train_set[0])

In [None]:
def get_i_image(i):
    img_norm, label_idx = cifar100_test_set[32]
    mean = torch.tensor([0.5, 0.5, 0.5]).view(3, 1, 1)
    std  = torch.tensor([0.2, 0.2, 0.2]).view(3, 1, 1)
    img = (img_norm * std) + mean
    img = img.clamp(0, 1)                        
    
    plt.imshow(img.permute(1, 2, 0).cpu().numpy())
    plt.axis("off")
    plt.show()

In [None]:
for i in range(1000):
    img_norm, label_idx = cifar100_test_set[i]
    obs = {"image":img_norm}
    #print(i)
    print(auxilary_task_agent.predict(obs, deterministic=True)
          )