In [3]:
import torch
from torch import nn
import numpy as np
from core.models.nts_net import NTSModel
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb
import settings

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load dataset

In [6]:
from torchvision.datasets import FGVCAircraft
from torch.utils.data import Subset, random_split
from torchvision.transforms import Compose, Normalize, Resize, ToTensor
from PIL import Image

"""
Transforms
"""
img_mean, img_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

transform = Compose([
  Resize((settings.IMAGE_HEIGHT, settings.IMAGE_WIDTH), Image.BILINEAR),
  ToTensor(),
  Normalize(mean=img_mean, std=img_std),
])

"""
Load dataset
"""
dataset = FGVCAircraft(root="data", split="train", download=True, transform=transform)

"""
Split data
"""

# Create a smaller subset
num_samples = len(dataset)
subset_size = settings.N_SAMPLES
rand_idxs = np.random.choice(range(num_samples), subset_size)
subset = Subset(dataset, rand_idxs)


# Create train-val split
val_split = int(subset_size*settings.TEST_SIZE)
with torch.random.fork_rng(devices=[device]):
  torch.manual_seed(settings.SEED)
  train_data, val_data = random_split(subset, [subset_size - val_split, val_split])

print("Train data size:", len(train_data))
print("Val data size:", len(val_data))

Train data size: 670
Val data size: 330


### Hyperparameters

In [3]:
CAT_NUM = 4 # NOTE: should be 4

BATCH_SIZE = 8
NUM_EPOCHS = 10

# Random grid search
NUM_TRIALS = 10 # number of random searches to run

In [4]:
from pprint import pprint
from torch.optim import SGD, RMSprop, Adagrad, Adadelta, Adam, Adamax, NAdam
from core.loss import list_loss, ranking_loss


optimizers = {
    "sgd": SGD,
    "rmsprop": RMSprop,
    "adagrad": Adagrad,
    "adadelta": Adadelta,
    "adam": Adam,
    "adamax": Adamax,
    "nadam": NAdam,
}

search_grid = {
    "lr": [0.001, 0.01, 0.1],
    "momentum": [0.9, 0.95],
    "proposal_num": [4, 6, 8],
    "weight_decay": [1e-2, 1e-4, 1e-6],
    "optimizer": list(optimizers.keys()),
}

results = []

for t in range(NUM_TRIALS):
    print("Running trial {} of {} trials".format(t+1, NUM_TRIALS))
    hr = {
        "lr": np.random.choice(search_grid["lr"]),
        "momentum": np.random.choice(search_grid["momentum"]),
        "proposal_num": np.random.choice(search_grid["proposal_num"]),
        "weight_decay": np.random.choice(search_grid["weight_decay"]),
        "optimizer": np.random.choice(search_grid["optimizer"]),
    }

    ################################# SETUP TRIAL #################################

    optim_params = {"lr": hr["lr"], "weight_decay": hr["weight_decay"]}

    # Add momentum
    if hr["optimizer"] in ['rmsprop', 'sgd']:
        optim_params["momentum"] = hr["momentum"]
    else:
        del hr["momentum"]

    """
    Configure wandb
    """
    wandb.init(
        project=settings.WANDB_PROJECT_NAME,
        config={
            "architecture": "NTS-net",
            "dataset": "FGVCAircraft",
            "epochs": NUM_EPOCHS,
            "batch_size": BATCH_SIZE,
            **hr,
        }
    )

    """
    Initialize dataloaders
    """
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

    """
    Initialize model
    """
    model = NTSModel(top_n=hr["proposal_num"], cat_num=CAT_NUM, n_classes=len(dataset.classes), image_height=settings.IMAGE_HEIGHT, image_width=settings.IMAGE_WIDTH).to(device)
    criterion = torch.nn.CrossEntropyLoss()

    """
    Setup optimizers
    """
    resnet_parameters = list(model.resnet.parameters())
    navigator_parameters = list(model.navigator.parameters())
    concat_parameters = list(model.concat_net.parameters())
    partcls_parameters = list(model.partcls_net.parameters())

    resnet_optimizer = torch.optim.SGD(resnet_parameters, **optim_params)
    navigator_optimizer = torch.optim.SGD(navigator_parameters, **optim_params)
    concat_optimizer = torch.optim.SGD(concat_parameters, **optim_params)
    partcls_optimizer = torch.optim.SGD(partcls_parameters, **optim_params)


    model = nn.DataParallel(model)

    history = {
        "train_loss": [],
        "val_loss": [],
        "train_accuracy": [],
        "val_accuracy": [],
    }

    for epoch in range(NUM_EPOCHS):

        epoch_loss = 0
        epoch_accuracy = 0
        epoch_val_loss = 0
        epoch_val_accuracy = 0
        with tqdm(total=len(train_loader)) as pbar:
            for i, (inputs, labels) in enumerate(train_loader):
                inputs, labels = inputs.to(device), labels.to(device)
                batch_size = inputs.size(0)

                resnet_optimizer.zero_grad()
                navigator_optimizer.zero_grad()
                concat_optimizer.zero_grad()
                partcls_optimizer.zero_grad()

                resnet_logits, concat_logits, part_logits, top_n_idxs, top_n_proba = model(inputs)
                
                # Losses
                resnet_loss = criterion(resnet_logits, labels)
                navigator_loss = list_loss(part_logits.view(batch_size * hr["proposal_num"], -1),
                                        labels.unsqueeze(1).repeat(1, hr["proposal_num"]).view(-1)).view(batch_size, hr["proposal_num"])
                concat_loss = criterion(concat_logits, labels)
                rank_loss = ranking_loss(top_n_proba, navigator_loss, proposal_num=hr["proposal_num"])
                partcls_loss = criterion(part_logits.view(batch_size * hr["proposal_num"], -1),
                                    labels.unsqueeze(1).repeat(1, hr["proposal_num"]).view(-1))
                
                loss = resnet_loss + concat_loss + rank_loss + partcls_loss
                loss.backward()

                resnet_optimizer.step()
                navigator_optimizer.step()
                concat_optimizer.step()
                partcls_optimizer.step()

                accuracy = (concat_logits.argmax(dim=1) == labels).float().mean()
                
                epoch_loss += concat_loss.item()
                epoch_accuracy += accuracy.item()

                pbar.set_postfix_str("Train loss: {:.4f}, Train accuracy: {:.4f}".format(epoch_loss / (i+1), epoch_accuracy / (i+1)))
                pbar.update(1)


        with tqdm(total=(len(val_loader))) as pbar:
            with torch.no_grad():
                for i, (inputs, labels) in enumerate(val_loader):
                    inputs, labels = inputs.to(device), labels.to(device)
                    batch_size = inputs.size(0)

                    _, concat_logits, _, _, _ = model(inputs)

                    concat_loss = criterion(concat_logits, labels)
                    

                    accuracy = (concat_logits.argmax(dim=1) == labels).float().mean()

                    epoch_val_loss += concat_loss.item()
                    epoch_val_accuracy += accuracy.item()


                    pbar.set_postfix_str("Val loss: {:.4f}, Val accuracy: {:.4f}".format(epoch_val_loss / (i+1), epoch_val_accuracy / (i+1)))
                    pbar.update(1)

        epoch_loss = epoch_loss/len(train_loader)
        epoch_val_loss = epoch_val_loss/len(val_loader)

        epoch_accuracy = epoch_accuracy/len(train_loader)
        epoch_val_accuracy = epoch_val_accuracy/len(val_loader)

        history["train_loss"].append(epoch_loss)
        history["val_loss"].append(epoch_val_loss)    

        history["train_accuracy"].append(epoch_accuracy)
        history["val_accuracy"].append(epoch_val_accuracy) 

        print(f"Epoch {epoch+1} - Loss: {epoch_loss:.4f} - Accuracy: {epoch_accuracy:.4f} - Val Loss: {epoch_val_loss:.4f} - Val Accuracy: {epoch_val_accuracy:.4f}")
    
        # Log to wandb
        wandb.log({"train accuracy": epoch_accuracy, "train loss": epoch_loss, "val accuracy": epoch_val_accuracy, "val loss": epoch_val_loss})



    # Append trial result
    results.append({"trial": t+1, "accuracy": history["val_accuracy"][-1], "params": hr})
    wandb.finish()
    print("Finished trial {}".format(t+1))
    

results = sorted(results, key=lambda x: x["accuracy"], reverse=True)
best_result = results[0]


print("\n", "="*25, "BEST PARAMETERS", "="*25)
pprint(best_result, sort_dicts=False, indent=2)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Running trial 1 of 2 trials


[34m[1mwandb[0m: Currently logged in as: [33msimekri[0m ([33minfo251-project[0m). Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 63/63 [00:30<00:00,  2.09it/s, Train loss: 4.6092, Train accuracy: 0.0060]
100%|██████████| 63/63 [00:13<00:00,  4.63it/s, Val loss: 4.6184, Val accuracy: 0.0079]

Epoch 1 - Loss: 4.6092 - Accuracy: 0.0060 - Val Loss: 4.6184 - Val Accuracy: 0.0079





0,1
train accuracy,▁
train loss,▁
val accuracy,▁
val loss,▁

0,1
train accuracy,0.00595
train loss,4.60925
val accuracy,0.00794
val loss,4.61838


Finished trial 1
Running trial 2 of 2 trials


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

100%|██████████| 63/63 [00:22<00:00,  2.80it/s, Train loss: 4.6391, Train accuracy: 0.0060]
100%|██████████| 63/63 [00:12<00:00,  5.16it/s, Val loss: 4.6041, Val accuracy: 0.0218]

Epoch 1 - Loss: 4.6391 - Accuracy: 0.0060 - Val Loss: 4.6041 - Val Accuracy: 0.0218





VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train accuracy,▁
train loss,▁
val accuracy,▁
val loss,▁

0,1
train accuracy,0.00595
train loss,4.63912
val accuracy,0.02183
val loss,4.60407


Finished trial 2

{ 'trial': 2,
  'accuracy': 0.021825396825396824,
  'params': { 'lr': 0.1,
              'proposal_num': 4,
              'weight_decay': 0.0001,
              'optimizer': 'nadam'}}
