In [35]:
import torchvision.models as models
import torch
import numpy as np
import pandas as pd
device=torch.device("cuda")

In [21]:
import pickle
with open("dataset.pkl", "rb") as f:
    dataset = pickle.load(f)
dataset["train"] = dataset["train"] / 255
dataset["train"] = dataset["train"].permute(0,3,1,2)
dataset["train"].shape
data_path = "../../tiny-imagenet-200"
classes = pd.read_csv(f"{data_path}/words.txt", sep="\t", names=["id", "entity"])
id_to_label = {}
for _, row in classes.iterrows():
    id_to_label[row['id']] = row['entity']
labels = np.array(dataset['labels'])
label_to_idx = {}
for i, label in enumerate(np.unique(labels)):
    label_to_idx[label] = i
for i in range(len(labels)):
    labels[i] = label_to_idx[labels[i]]
labels = labels.astype(int)
labels = torch.from_numpy(labels)
# labels_onehot = torch.nn.functional.one_hot(labels,)
raw_dataset = []
for i in range(len(labels_onehot)):
    raw_dataset.append((dataset["train"][i], labels[i]))

In [69]:
def compute_accs(preds, labels):
    k = 10
    _, ind = torch.topk(preds, k)
    
    accs = {}
    for _k in [1, 5, 10]:
        acc = ind[:,:_k].eq(labels).any(1).sum() / len(labels)
        accs[f"top_{_k}_acc"] = acc
    return accs

In [80]:
import os.path as osp
from collections import defaultdict
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from tqdm.notebook import tqdm as tqdm_nb
def evaluate(model, val_dl: DataLoader, verbose=True, nb=False):
    loss_fnc = torch.nn.CrossEntropyLoss()
    model.eval()
    _tqdm = tqdm_nb if nb else tqdm
    if verbose:
        pbar = _tqdm(enumerate(val_dl), total=len(val_dl), position=0, leave=True)
        pbar.set_description("Evaluation Progress")
    with torch.no_grad():
        batch_size = val_dl.batch_size
        avg_loss = defaultdict(int)
        num_train_imgs = len(val_dl.dataset)
        c = 0
        for batch_idx, data in enumerate(val_dl):
            c += 1

            imgs = data[0].to(device)
            labels = data[1].to(device)
            preds = model(imgs)
            preds = torch.nn.functional.softmax(preds, dim=1)

            loss = {
                "loss": loss_fnc(preds, labels),
                # "acc": 0
            }
            for k, v in compute_accs(preds, labels.view(-1,1)).items():
                loss[k] = v
            for k, v in loss.items():
                avg_loss[k] += v.item()
            if verbose:
                pbar.update()
        for k in avg_loss.keys():
            avg_loss[k] /= c
    model.train()
    if verbose:
        pbar.close()
    return avg_loss
def train(
    model,
    optim: torch.optim.Optimizer,
    epochs: int,
    train_dl: DataLoader,
    val_dl: DataLoader,
    start_epoch=0,
    save_freq=10,
    save_best=True,
    save_dir="./",
    prev_best_loss=np.inf,
    verbose=True,
    nb=False,
    train_cb=None,
):
    """
    train the vae model
    """
    loss_fnc = torch.nn.CrossEntropyLoss()
    _tqdm = tqdm_nb if nb else tqdm
    model.train()
    optimizer_idx = 0
    num_train_imgs = len(train_dl.dataset)
    batch_size = train_dl.batch_size
    if verbose:
        epoch_pbar = _tqdm(range(epochs), position=0, leave=True)
        epoch_pbar.set_description("Progress")
        pbar = _tqdm(enumerate(train_dl), total=len(train_dl), position=0, leave=True)
        pbar.set_description("Current Epoch Progress")
        epoch_pbar.update(start_epoch)
    for epoch in range(start_epoch, start_epoch + epochs):
        avg_loss = defaultdict(int)
        c = 0
        if verbose:
            epoch_pbar.update()
            pbar.reset()
        for batch_idx, data in enumerate(train_dl):

            c += 1
            
            imgs = data[0].to(device)
            labels = data[1].to(device)
            # imgs (B, C=1, W, H)
            optim.zero_grad()
            preds = model(imgs)
            preds = torch.nn.functional.softmax(preds, dim=1)
            loss = {
                "loss": loss_fnc(preds, labels),
                # "acc": 
            }
            with torch.no_grad():
                for k, v in compute_accs(preds, labels.view(-1,1)).items():
                    loss[k] = v
                for k, v in loss.items():
                    avg_loss[k] += v.item()
            optimizer_idx += 1
            loss["loss"].backward()
            optim.step()
            if verbose:
                pbar.update()
        for k in avg_loss.keys():
            avg_loss[k] /= c
        eval_loss = evaluate(model, val_dl, nb=nb, verbose=False)
        if train_cb:
            train_cb(epoch=epoch, loss=avg_loss, eval_loss=eval_loss, model=model)
        save = False
        if eval_loss["loss"] < prev_best_loss:
            prev_best_loss = eval_loss["loss"]
            if save_best:
                save = True
        if epoch % save_freq == 0 or epoch == epochs - 1:
            save = True
        if save:
            torch.save(
                dict(
                    model_state_dict=model.state_dict(),
                    optim_State_dict=optim.state_dict(),
                    epoch=epoch,
                    loss=avg_loss,
                    eval_loss=eval_loss,
                    prev_best_loss=prev_best_loss,
                ),
                osp.join(save_dir, f"ckpt_{epoch}.pt"),
            )

In [81]:
from sklearn.utils import shuffle
shuffled_raw_dataset = shuffle(raw_dataset, random_state=3407)
train_raw_dataset = shuffled_raw_dataset[:60000]
val_raw_dataset = shuffled_raw_dataset[60000:60000+20000]
test_raw_dataset = shuffled_raw_dataset[60000+20000:]

In [93]:
!mkdir mobilenet_ckpts

In [91]:
torch.manual_seed(3407)
batch_size=256
train_dl = torch.utils.data.DataLoader(train_raw_dataset, shuffle=True, batch_size=batch_size)
val_dl = torch.utils.data.DataLoader(val_raw_dataset, shuffle=True, batch_size=batch_size)
mobilenet_v2 = models.mobilenet_v2(pretrained=False, num_classes=200).to(device)
optim = torch.optim.Adam(mobilenet_v2.parameters(), lr=1e-3)

In [92]:
def train_cb(epoch, loss, eval_loss,model):
    print(f"Epoch={epoch}, loss={loss}, eval_loss={eval_loss}")
train(model=mobilenet_v2, optim=optim, epochs=10, train_dl=train_dl, val_dl=train_dl, train_cb=train_cb, save_dir="mobilenet_ckpts")

Current Epoch Progress:   0%|          | 0/8 [00:00<?, ?it/s]        

Epoch=0, loss=defaultdict(<class 'int'>, {'loss': 5.298348546028137, 'top_1_acc': 0.0054837740608491, 'top_5_acc': 0.023737980984151363, 'top_10_acc': 0.04627403896301985}), eval_loss=defaultdict(<class 'int'>, {'loss': 5.2978644371032715, 'top_1_acc': 0.0080378606216982, 'top_5_acc': 0.037297175731509924, 'top_10_acc': 0.07019982021301985})


Current Epoch Progress:   0%|          | 0/8 [00:00<?, ?it/s]        

Epoch=1, loss=defaultdict(<class 'int'>, {'loss': 5.293198823928833, 'top_1_acc': 0.01911808899603784, 'top_5_acc': 0.0555889424867928, 'top_10_acc': 0.09592848550528288}), eval_loss=defaultdict(<class 'int'>, {'loss': 5.297120749950409, 'top_1_acc': 0.0078125, 'top_5_acc': 0.04139122646301985, 'top_10_acc': 0.07726111821830273})


Current Epoch Progress: 100%|██████████| 8/8 [00:00<00:00, 15.76it/s]

Epoch=2, loss=defaultdict(<class 'int'>, {'loss': 5.281932175159454, 'top_1_acc': 0.022798978490754962, 'top_5_acc': 0.06666917074471712, 'top_10_acc': 0.10471754800528288}), eval_loss=defaultdict(<class 'int'>, {'loss': 5.296472132205963, 'top_1_acc': 0.00837590149603784, 'top_5_acc': 0.03872445924207568, 'top_10_acc': 0.07707331795245409})


Current Epoch Progress:   0%|          | 0/8 [00:00<?, ?it/s]        

Epoch=3, loss=defaultdict(<class 'int'>, {'loss': 5.277269899845123, 'top_1_acc': 0.024752103490754962, 'top_5_acc': 0.06693209148943424, 'top_10_acc': 0.11151592619717121}), eval_loss=defaultdict(<class 'int'>, {'loss': 5.296350419521332, 'top_1_acc': 0.0080378606216982, 'top_5_acc': 0.03910006023943424, 'top_10_acc': 0.07590895425528288})


Current Epoch Progress: 100%|██████████| 8/8 [00:00<00:00, 14.00it/s]

Epoch=4, loss=defaultdict(<class 'int'>, {'loss': 5.27487188577652, 'top_1_acc': 0.030461238231509924, 'top_5_acc': 0.072265625, 'top_10_acc': 0.11331881023943424}), eval_loss=defaultdict(<class 'int'>, {'loss': 5.296162366867065, 'top_1_acc': 0.0090144231216982, 'top_5_acc': 0.0389873799867928, 'top_10_acc': 0.07635967619717121})


Current Epoch Progress:   0%|          | 0/8 [00:00<?, ?it/s]        

Epoch=5, loss=defaultdict(<class 'int'>, {'loss': 5.271430015563965, 'top_1_acc': 0.03278996399603784, 'top_5_acc': 0.07579627446830273, 'top_10_acc': 0.11992938723415136}), eval_loss=defaultdict(<class 'int'>, {'loss': 5.296140015125275, 'top_1_acc': 0.0089017428108491, 'top_5_acc': 0.0389873799867928, 'top_10_acc': 0.07489483244717121})


Current Epoch Progress:   0%|          | 0/8 [00:00<?, ?it/s]        

Epoch=6, loss=defaultdict(<class 'int'>, {'loss': 5.272705435752869, 'top_1_acc': 0.03203876223415136, 'top_5_acc': 0.07681039720773697, 'top_10_acc': 0.12327223643660545}), eval_loss=defaultdict(<class 'int'>, {'loss': 5.295967757701874, 'top_1_acc': 0.0089017428108491, 'top_5_acc': 0.0394756612367928, 'top_10_acc': 0.07786207925528288})


Current Epoch Progress:   0%|          | 0/8 [00:00<?, ?it/s]        

Epoch=7, loss=defaultdict(<class 'int'>, {'loss': 5.274469792842865, 'top_1_acc': 0.03425480774603784, 'top_5_acc': 0.07850060146301985, 'top_10_acc': 0.12041766848415136}), eval_loss=defaultdict(<class 'int'>, {'loss': 5.294496297836304, 'top_1_acc': 0.011080228490754962, 'top_5_acc': 0.0414663462433964, 'top_10_acc': 0.07448167074471712})


Current Epoch Progress: 100%|██████████| 8/8 [00:00<00:00, 15.70it/s]

Epoch=8, loss=defaultdict(<class 'int'>, {'loss': 5.2656965255737305, 'top_1_acc': 0.04345703125, 'top_5_acc': 0.0847355779260397, 'top_10_acc': 0.11977914720773697}), eval_loss=defaultdict(<class 'int'>, {'loss': 5.2867830991744995, 'top_1_acc': 0.017390324734151363, 'top_5_acc': 0.060396634973585606, 'top_10_acc': 0.09795673098415136})


Current Epoch Progress: 100%|██████████| 8/8 [00:00<00:00, 13.57it/s]

Epoch=9, loss=defaultdict(<class 'int'>, {'loss': 5.259801983833313, 'top_1_acc': 0.0463115987367928, 'top_5_acc': 0.08680138271301985, 'top_10_acc': 0.12105619069188833}), eval_loss=defaultdict(<class 'int'>, {'loss': 5.261766076087952, 'top_1_acc': 0.043644831981509924, 'top_5_acc': 0.0852989787235856, 'top_10_acc': 0.12075570970773697})


Current Epoch Progress: 100%|██████████| 8/8 [00:11<00:00, 13.57it/s]

In [165]:
for p in train_dl:
    a = mobilenet_v2(p[0])
    a = torch.nn.functional.softmax(a, dim=1)
    # k=5
    # ind = np.argpartition(a, -k, 1)[-k:]
    # # print(a.argmax(1), p[1].argmax(1))
    # print(ind.shape)
    k=5
    v, ind = torch.topk(a, k)
    print(ind[0])
    # print(p[1].argmax(1))
    print(p[1][0].argmax())
    print(ind.shape, p[1].argmax(1).shape)
    print(ind.eq(p[1].argmax(1).view(-1, 1)).any(1))
    
    # print(ind, a[0][ind])
    # print(a[0][ind[:,0]])
    break

tensor([148,   8,   2,   3,   4])
tensor(186)
torch.Size([16, 5]) torch.Size([16])
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False])


  a = torch.nn.functional.softmax(a)
