In [1]:
import torchvision.models as models
import torch
import numpy as np
import pandas as pd

In [2]:
import pickle
with open("dataset.pkl", "rb") as f:
    dataset = pickle.load(f)

In [3]:
dataset["train"] = dataset["train"] / 255
dataset["train"] = dataset["train"].permute(0,3,1,2)
dataset["train"].shape

torch.Size([100000, 3, 64, 64])

In [4]:
data_path = "../../data/tiny-imagenet-200"

In [5]:
classes = pd.read_csv(f"{data_path}/words.txt", sep="\t", names=["id", "entity"])

In [6]:
id_to_label = {}
for _, row in classes.iterrows():
    id_to_label[row['id']] = row['entity']

In [7]:
## one hot encode labels and build dataset loader

In [8]:
labels = np.array(dataset['labels'])

In [9]:
label_to_idx = {}
for i, label in enumerate(np.unique(labels)):
    label_to_idx[label] = i

In [10]:
for i in range(len(labels)):
    labels[i] = label_to_idx[labels[i]]

In [11]:
labels = labels.astype(int)
labels = torch.from_numpy(labels)
labels_onehot = torch.nn.functional.one_hot(labels,).float()
raw_dataset = []
for i in range(len(labels_onehot)):
    raw_dataset.append((dataset["train"][i], labels_onehot[i]))

In [152]:
def compute_accs(preds, labels):
    k = 10
    _, ind = torch.topk(preds, k)
    
    accs = {}
    for _k in [1, 5, 10]:
        acc = ind[:,:_k].eq(labels).any(1).sum() / len(labels)
        accs[f"top_{_k}_acc"] = acc
    return accs

In [167]:
import os.path as osp
from collections import defaultdict
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from tqdm.notebook import tqdm as tqdm_nb
def evaluate(model, val_dl: DataLoader, verbose=True, nb=False):
    model.eval()
    _tqdm = tqdm_nb if nb else tqdm
    if verbose:
        pbar = _tqdm(enumerate(val_dl), total=len(val_dl), position=0, leave=True)
        pbar.set_description("Evaluation Progress")
    with torch.no_grad():
        batch_size = val_dl.batch_size
        avg_loss = defaultdict(int)
        num_train_imgs = len(val_dl.dataset)
        c = 0
        for batch_idx, data in enumerate(val_dl):
            c += 1

            imgs = data[0]
            labels = data[1]
            preds = model(imgs)
            preds = torch.nn.functional.softmax(preds, dim=1)

            loss = {
                "loss": loss_fnc(preds, labels),
                # "acc": 0
            }
            for k, v in compute_accs(preds, labels.argmax(1).view(-1,1)).items():
                loss[k] = v
            for k, v in loss.items():
                avg_loss[k] += v.item()
            if verbose:
                pbar.update()
        for k in avg_loss.keys():
            avg_loss[k] /= c
    model.train()
    if verbose:
        pbar.close()
    return avg_loss
def train(
    model,
    optim: torch.optim.Optimizer,
    epochs: int,
    train_dl: DataLoader,
    val_dl: DataLoader,
    start_epoch=0,
    save_freq=10,
    save_best=True,
    save_dir="./",
    prev_best_loss=np.inf,
    verbose=True,
    nb=False,
    train_cb=None,
):
    """
    train the vae model
    """
    loss_fnc = torch.nn.CrossEntropyLoss()
    _tqdm = tqdm_nb if nb else tqdm
    model.train()
    optimizer_idx = 0
    num_train_imgs = len(train_dl.dataset)
    batch_size = train_dl.batch_size
    if verbose:
        epoch_pbar = _tqdm(range(epochs), position=0, leave=True)
        epoch_pbar.set_description("Progress")
        pbar = _tqdm(enumerate(train_dl), total=len(train_dl), position=0, leave=True)
        pbar.set_description("Current Epoch Progress")
        epoch_pbar.update(start_epoch)
    for epoch in range(start_epoch, start_epoch + epochs):
        avg_loss = defaultdict(int)
        c = 0
        if verbose:
            epoch_pbar.update()
            pbar.reset()
        for batch_idx, data in enumerate(train_dl):

            c += 1
            
            imgs = data[0]
            labels = data[1]
            # imgs (B, C=1, W, H)
            optim.zero_grad()
            preds = model(imgs)
            preds = torch.nn.functional.softmax(preds, dim=1)
            loss = {
                "loss": loss_fnc(preds, labels),
                # "acc": 
            }
            for k, v in compute_accs(preds, labels.argmax(1).view(-1,1)).items():
                loss[k] = v
            for k, v in loss.items():
                avg_loss[k] += v.item()
            optimizer_idx += 1
            loss["loss"].backward()
            optim.step()
            if verbose:
                pbar.update()
        for k in avg_loss.keys():
            avg_loss[k] /= c
        eval_loss = evaluate(model, val_dl, nb=nb, verbose=False)
        if train_cb:
            train_cb(epoch=epoch, loss=avg_loss, eval_loss=eval_loss, model=model)
        save = False
        if eval_loss["loss"] < prev_best_loss:
            prev_best_loss = eval_loss["loss"]
            if save_best:
                save = True
        if epoch % save_freq == 0 or epoch == epochs - 1:
            save = True
        if save:
            torch.save(
                dict(
                    model_state_dict=model.state_dict(),
                    optim_State_dict=optim.state_dict(),
                    epoch=epoch,
                    loss=avg_loss,
                    eval_loss=eval_loss,
                    prev_best_loss=prev_best_loss,
                ),
                osp.join(save_dir, f"ckpt_{epoch}.pt"),
            )

In [168]:
from sklearn.utils import shuffle
raw_dataset = shuffle(raw_dataset, random_state=3407)
train_raw_dataset = raw_dataset[:60000]
val_raw_dataset = raw_dataset[60000:60000+20000]
test_raw_dataset = raw_dataset[60000+20000:]

In [169]:
len(train_raw_dataset)

60000

In [172]:
torch.manual_seed(3407)
batch_size=16
train_dl = torch.utils.data.DataLoader(train_raw_dataset[:2000], shuffle=True, batch_size=batch_size)
val_dl = torch.utils.data.DataLoader(val_raw_dataset[:2000], shuffle=True, batch_size=batch_size)
mobilenet_v2 = models.mobilenet_v2(pretrained=False, num_classes=200)
optim = torch.optim.Adam(mobilenet_v2.parameters(), lr=1e-3)

In [173]:
def train_cb(epoch, loss, eval_loss,model):
    print(f"Epoch={epoch}, loss={loss}, eval_loss={eval_loss}")
train(model=mobilenet_v2, optim=optim, epochs=10, train_dl=train_dl, val_dl=train_dl, train_cb=train_cb)

Current Epoch Progress:   0%|                                      | 0/125 [00:00<?, ?it/s]

Epoch=0, loss=defaultdict(<class 'int'>, {'loss': 5.29936605834961, 'top_1_acc': 0.004, 'top_5_acc': 0.0205, 'top_10_acc': 0.0475}), eval_loss=defaultdict(<class 'int'>, {'loss': 5.295231540679931, 'top_1_acc': 0.012, 'top_5_acc': 0.034, 'top_10_acc': 0.0745})


Current Epoch Progress:   2%|▍                             | 2/125 [00:00<00:09, 13.63it/s]

Epoch=1, loss=defaultdict(<class 'int'>, {'loss': 5.2977480163574215, 'top_1_acc': 0.0055, 'top_5_acc': 0.0285, 'top_10_acc': 0.063}), eval_loss=defaultdict(<class 'int'>, {'loss': 5.294805473327637, 'top_1_acc': 0.013, 'top_5_acc': 0.0485, 'top_10_acc': 0.0825})


Current Epoch Progress:   2%|▍                             | 2/125 [00:00<00:09, 13.62it/s]

Epoch=2, loss=defaultdict(<class 'int'>, {'loss': 5.298720592498779, 'top_1_acc': 0.006, 'top_5_acc': 0.0405, 'top_10_acc': 0.073}), eval_loss=defaultdict(<class 'int'>, {'loss': 5.300784381866455, 'top_1_acc': 0.005, 'top_5_acc': 0.038, 'top_10_acc': 0.0785})


Current Epoch Progress:   2%|▍                             | 2/125 [00:00<00:09, 12.38it/s]

Epoch=3, loss=defaultdict(<class 'int'>, {'loss': 5.296588264465332, 'top_1_acc': 0.011, 'top_5_acc': 0.042, 'top_10_acc': 0.0755}), eval_loss=defaultdict(<class 'int'>, {'loss': 5.297063140869141, 'top_1_acc': 0.009, 'top_5_acc': 0.0385, 'top_10_acc': 0.0795})


Current Epoch Progress:   2%|▍                             | 2/125 [00:00<00:09, 13.35it/s]

Epoch=4, loss=defaultdict(<class 'int'>, {'loss': 5.294265602111817, 'top_1_acc': 0.0125, 'top_5_acc': 0.0395, 'top_10_acc': 0.075}), eval_loss=defaultdict(<class 'int'>, {'loss': 5.29630114364624, 'top_1_acc': 0.01, 'top_5_acc': 0.045, 'top_10_acc': 0.0775})


Current Epoch Progress:   2%|▍                             | 2/125 [00:00<00:09, 13.58it/s]

Epoch=5, loss=defaultdict(<class 'int'>, {'loss': 5.294894966125488, 'top_1_acc': 0.012, 'top_5_acc': 0.0385, 'top_10_acc': 0.0715}), eval_loss=defaultdict(<class 'int'>, {'loss': 5.300703346252441, 'top_1_acc': 0.006, 'top_5_acc': 0.038, 'top_10_acc': 0.073})


Current Epoch Progress:   2%|▍                             | 2/125 [00:00<00:09, 13.15it/s]

Epoch=6, loss=defaultdict(<class 'int'>, {'loss': 5.293234920501709, 'top_1_acc': 0.0135, 'top_5_acc': 0.0425, 'top_10_acc': 0.074}), eval_loss=defaultdict(<class 'int'>, {'loss': 5.296468486785889, 'top_1_acc': 0.009, 'top_5_acc': 0.04, 'top_10_acc': 0.072})


Current Epoch Progress:   2%|▍                             | 2/125 [00:00<00:09, 13.42it/s]

Epoch=7, loss=defaultdict(<class 'int'>, {'loss': 5.300375061035156, 'top_1_acc': 0.0065, 'top_5_acc': 0.036, 'top_10_acc': 0.0665}), eval_loss=defaultdict(<class 'int'>, {'loss': 5.298682540893554, 'top_1_acc': 0.0075, 'top_5_acc': 0.035, 'top_10_acc': 0.066})


Current Epoch Progress:  70%|████████████████████▍        | 88/125 [00:06<00:03, 12.13it/s]

KeyboardInterrupt: 

In [165]:
for p in train_dl:
    a = mobilenet_v2(p[0])
    a = torch.nn.functional.softmax(a, dim=1)
    # k=5
    # ind = np.argpartition(a, -k, 1)[-k:]
    # # print(a.argmax(1), p[1].argmax(1))
    # print(ind.shape)
    k=5
    v, ind = torch.topk(a, k)
    print(ind[0])
    # print(p[1].argmax(1))
    print(p[1][0].argmax())
    print(ind.shape, p[1].argmax(1).shape)
    print(ind.eq(p[1].argmax(1).view(-1, 1)).any(1))
    
    # print(ind, a[0][ind])
    # print(a[0][ind[:,0]])
    break

tensor([148,   8,   2,   3,   4])
tensor(186)
torch.Size([16, 5]) torch.Size([16])
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False])


  a = torch.nn.functional.softmax(a)
