In [1]:
import torch
from torch import nn
import numpy as np
from core.models.nts_net import NTSModel
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb
import settings

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load dataset

In [2]:
from torchvision.datasets import FGVCAircraft
from torch.utils.data import Subset, random_split
from torchvision.transforms import Compose, Normalize, Resize, ToTensor
from PIL import Image

"""
Transforms
"""
img_mean, img_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

transform = Compose([
  Resize((settings.IMAGE_HEIGHT, settings.IMAGE_WIDTH), Image.BILINEAR),
  ToTensor(),
  Normalize(mean=img_mean, std=img_std),
])

"""
Load dataset
"""
dataset = FGVCAircraft(root="data", split="train", download=True, transform=transform)

"""
Split data
"""

# Create a smaller subset
num_samples = len(dataset)
subset_size = settings.N_SAMPLES
rand_idxs = np.random.choice(range(num_samples), subset_size)
subset = Subset(dataset, rand_idxs)


# Create train-val split
val_split = int(subset_size*settings.TEST_SIZE)
with torch.random.fork_rng(devices=[device]):
  torch.manual_seed(settings.SEED)
  train_data, val_data = random_split(subset, [subset_size - val_split, val_split])

print("Train data size:", len(train_data))
print("Val data size:", len(val_data))

Train data size: 670
Val data size: 330


## Configure sweep

In [3]:
CAT_NUM = 4 # NOTE: should be 4

BATCH_SIZE = 8
NUM_EPOCHS = 8

# Random grid search
NUM_TRIALS = 30 # number of random searches to run

In [8]:
from torch.optim import SGD, Adagrad, Adam, RMSprop

optimizers = {
    "sgd": SGD,
    "adagrad": Adagrad,
    "rmsprop": RMSprop,
    "adam": Adam,
}

sweep_configuration = {
    "method": "bayes",
    "name": "sweep",
    "metric": {"goal": "maximize", "name": "val_acc"},
    "parameters": {
        "lr": {'max': 0.1, 'min': 0.0001},
        "momentum": {"max": 0.95, "min": 0.1},
        "proposal_num": {"max": 8, "min": 4},
        "weight_decay": {'max': 1e-2, 'min': 1e-6},
        "optimizer": {"values": list(optimizers.keys())},
    }
    
}

sweep_id = wandb.sweep(sweep=sweep_configuration, project=settings.WANDB_PROJECT_NAME)

Create sweep with ID: 0riwkaen
Sweep URL: https://wandb.ai/info251-project/fgvca_aircraft/sweeps/0riwkaen


In [9]:
from core.loss import list_loss, ranking_loss

def run_sweep():
    run = wandb.init()
    
    proposal_num = wandb.config.proposal_num
    optimizer = wandb.config.optimizer

    optim_params = {"lr": wandb.config.lr, "weight_decay": wandb.config.weight_decay}
    if optimizer == ["sgd", "rmsprop"]:
        optim_params["momentum"] = wandb.config.momentum
    elif optimizer == "adam":
        optim_params["betas"] = (wandb.config.momentum, 0.999)

    """
    Initialize dataloaders
    """
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

    """
    Initialize model
    """
    model = NTSModel(top_n=proposal_num, cat_num=CAT_NUM, n_classes=len(dataset.classes), image_height=settings.IMAGE_HEIGHT, image_width=settings.IMAGE_WIDTH).to(device)
    criterion = torch.nn.CrossEntropyLoss()

    """
    Setup optimizers
    """
    resnet_parameters = list(model.resnet.parameters())
    navigator_parameters = list(model.navigator.parameters())
    concat_parameters = list(model.concat_net.parameters())
    partcls_parameters = list(model.partcls_net.parameters())

    resnet_optimizer = optimizers[optimizer](resnet_parameters, **optim_params)
    navigator_optimizer = optimizers[optimizer](navigator_parameters, **optim_params)
    concat_optimizer = optimizers[optimizer](concat_parameters, **optim_params)
    partcls_optimizer = optimizers[optimizer](partcls_parameters, **optim_params)


    model = nn.DataParallel(model)

    history = {
        "train_loss": [],
        "val_loss": [],
        "train_accuracy": [],
        "val_accuracy": [],
    }

    for epoch in range(NUM_EPOCHS):

        epoch_loss = 0
        epoch_accuracy = 0
        epoch_val_loss = 0
        epoch_val_accuracy = 0
        with tqdm(total=len(train_loader)) as pbar:
            for i, (inputs, labels) in enumerate(train_loader):
                inputs, labels = inputs.to(device), labels.to(device)
                batch_size = inputs.size(0)

                resnet_optimizer.zero_grad()
                navigator_optimizer.zero_grad()
                concat_optimizer.zero_grad()
                partcls_optimizer.zero_grad()

                resnet_logits, concat_logits, part_logits, top_n_idxs, top_n_proba = model(inputs)
                
                # Losses
                resnet_loss = criterion(resnet_logits, labels)
                navigator_loss = list_loss(part_logits.view(batch_size * proposal_num, -1),
                                        labels.unsqueeze(1).repeat(1, proposal_num).view(-1)).view(batch_size, proposal_num)
                concat_loss = criterion(concat_logits, labels)
                rank_loss = ranking_loss(top_n_proba, navigator_loss, proposal_num=proposal_num)
                partcls_loss = criterion(part_logits.view(batch_size * proposal_num, -1),
                                    labels.unsqueeze(1).repeat(1, proposal_num).view(-1))
                
                loss = resnet_loss + concat_loss + rank_loss + partcls_loss
                loss.backward()

                resnet_optimizer.step()
                navigator_optimizer.step()
                concat_optimizer.step()
                partcls_optimizer.step()

                accuracy = (concat_logits.argmax(dim=1) == labels).float().mean()
                
                epoch_loss += concat_loss.item()
                epoch_accuracy += accuracy.item()

                pbar.set_postfix_str("Train loss: {:.4f}, Train accuracy: {:.4f}".format(epoch_loss / (i+1), epoch_accuracy / (i+1)))
                pbar.update(1)


        with tqdm(total=(len(val_loader))) as pbar:
            with torch.no_grad():
                for i, (inputs, labels) in enumerate(val_loader):
                    inputs, labels = inputs.to(device), labels.to(device)
                    batch_size = inputs.size(0)

                    _, concat_logits, _, _, _ = model(inputs)

                    concat_loss = criterion(concat_logits, labels)
                    

                    accuracy = (concat_logits.argmax(dim=1) == labels).float().mean()

                    epoch_val_loss += concat_loss.item()
                    epoch_val_accuracy += accuracy.item()


                    pbar.set_postfix_str("Val loss: {:.4f}, Val accuracy: {:.4f}".format(epoch_val_loss / (i+1), epoch_val_accuracy / (i+1)))
                    pbar.update(1)

        epoch_loss = epoch_loss/len(train_loader)
        epoch_val_loss = epoch_val_loss/len(val_loader)

        epoch_accuracy = epoch_accuracy/len(train_loader)
        epoch_val_accuracy = epoch_val_accuracy/len(val_loader)

        history["train_loss"].append(epoch_loss)
        history["val_loss"].append(epoch_val_loss)    

        history["train_accuracy"].append(epoch_accuracy)
        history["val_accuracy"].append(epoch_val_accuracy) 

        print(f"Epoch {epoch+1} - Loss: {epoch_loss:.4f} - Accuracy: {epoch_accuracy:.4f} - Val Loss: {epoch_val_loss:.4f} - Val Accuracy: {epoch_val_accuracy:.4f}")
    
        # Log to wandb
        wandb.log({
            'epoch': epoch+1, 
            'train_acc': epoch_accuracy,
            'train_loss': epoch_loss, 
            'val_acc': epoch_val_accuracy, 
            'val_loss': epoch_val_loss
        })

In [10]:
wandb.agent(sweep_id, function=run_sweep, count=NUM_TRIALS)

[34m[1mwandb[0m: Agent Starting Run: qohgnji5 with config:
[34m[1mwandb[0m: 	lr: 0.08201755247479438
[34m[1mwandb[0m: 	momentum: 0.17691607455637787
[34m[1mwandb[0m: 	optimizer: adagrad
[34m[1mwandb[0m: 	proposal_num: 6
[34m[1mwandb[0m: 	weight_decay: 0.0013594197983668064
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


100%|██████████| 112/112 [03:27<00:00,  1.85s/it, Train loss: 7.5114, Train accuracy: 0.0104]
100%|██████████| 55/55 [00:46<00:00,  1.18it/s, Val loss: 4.9264, Val accuracy: 0.0061]

Epoch 1 - Loss: 7.5114 - Accuracy: 0.0104 - Val Loss: 4.9264 - Val Accuracy: 0.0061





VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁
train_acc,▁
train_loss,▁
val_acc,▁
val_loss,▁

0,1
epoch,1.0
train_acc,0.01042
train_loss,7.51138
val_acc,0.00606
val_loss,4.92636


[34m[1mwandb[0m: Agent Starting Run: y7fn8n75 with config:
[34m[1mwandb[0m: 	lr: 0.03494514580033677
[34m[1mwandb[0m: 	momentum: 0.6676218140611027
[34m[1mwandb[0m: 	optimizer: adam
[34m[1mwandb[0m: 	proposal_num: 6
[34m[1mwandb[0m: 	weight_decay: 0.0005050581614010104
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


100%|██████████| 112/112 [03:29<00:00,  1.87s/it, Train loss: 6.0732, Train accuracy: 0.0074]
100%|██████████| 55/55 [00:47<00:00,  1.17it/s, Val loss: 5.4719, Val accuracy: 0.0030]

Epoch 1 - Loss: 6.0732 - Accuracy: 0.0074 - Val Loss: 5.4719 - Val Accuracy: 0.0030





0,1
epoch,▁
train_acc,▁
train_loss,▁
val_acc,▁
val_loss,▁

0,1
epoch,1.0
train_acc,0.00744
train_loss,6.07319
val_acc,0.00303
val_loss,5.47191
