In [2]:
import torchvision.models as models
import torch
import numpy as np
import pandas as pd
device=torch.device("cuda")

In [2]:
import pickle
with open("dataset.pkl", "rb") as f:
    dataset = pickle.load(f)
dataset["train"] = dataset["train"] / 255
dataset["train"] = dataset["train"].permute(0,3,1,2)
dataset["train"].shape
data_path = "../../tiny-imagenet-200"
classes = pd.read_csv(f"{data_path}/words.txt", sep="\t", names=["id", "entity"])
id_to_label = {}
for _, row in classes.iterrows():
    id_to_label[row['id']] = row['entity']
labels = np.array(dataset['labels'])
label_to_idx = {}
for i, label in enumerate(np.unique(labels)):
    label_to_idx[label] = i
for i in range(len(labels)):
    labels[i] = label_to_idx[labels[i]]
labels = labels.astype(int)
labels = torch.from_numpy(labels)
raw_dataset = []
for i in range(len(labels)):
    raw_dataset.append((dataset["train"][i], labels[i]))

In [3]:
def compute_accs(preds, labels):
    k = 10
    _, ind = torch.topk(preds, k)
    
    accs = {}
    for _k in [1, 5, 10]:
        acc = ind[:,:_k].eq(labels).any(1).sum() / len(labels)
        accs[f"top_{_k}_acc"] = acc
    return accs

In [138]:
import os.path as osp
from collections import defaultdict
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from tqdm.notebook import tqdm as tqdm_nb
def evaluate(model, val_dl: DataLoader, verbose=True, nb=False):
    loss_fnc = torch.nn.CrossEntropyLoss()
    model.eval()
    _tqdm = tqdm_nb if nb else tqdm
    if verbose:
        pbar = _tqdm(enumerate(val_dl), total=len(val_dl), position=0, leave=True)
        pbar.set_description("Evaluation Progress")
    with torch.no_grad():
        batch_size = val_dl.batch_size
        avg_loss = defaultdict(int)
        num_train_imgs = len(val_dl.dataset)
        c = 0
        for batch_idx, data in enumerate(val_dl):
            c += 1

            imgs = data[0].to(device)
            labels = data[1].to(device)
            preds = model(imgs)
#             preds = torch.nn.functional.softmax(preds, dim=1)

            loss = {
                "loss": loss_fnc(preds, labels),
                # "acc": 0
            }
            for k, v in compute_accs(preds, labels.view(-1,1)).items():
                loss[k] = v
            for k, v in loss.items():
                avg_loss[k] += v.item()
            if verbose:
                pbar.update()
        for k in avg_loss.keys():
            avg_loss[k] /= c
    model.train()
    if verbose:
        pbar.close()
    return avg_loss
def train(
    model,
    optim: torch.optim.Optimizer,
    epochs: int,
    train_dl: DataLoader,
    val_dl: DataLoader,
    start_epoch=0,
    save_freq=10,
    save_best=True,
    save_dir="./",
    prev_best_loss=np.inf,
    verbose=True,
    nb=False,
    train_cb=None,
):
    """
    train the vae model
    """
    loss_fnc = torch.nn.CrossEntropyLoss()
    _tqdm = tqdm_nb if nb else tqdm
    model.train()
    optimizer_idx = 0
    num_train_imgs = len(train_dl.dataset)
    batch_size = train_dl.batch_size
    if verbose:
        epoch_pbar = _tqdm(range(epochs), position=0, leave=True)
        epoch_pbar.set_description("Progress")
        pbar = _tqdm(enumerate(train_dl), total=len(train_dl), position=0, leave=True)
        pbar.set_description("Current Epoch Progress")
        epoch_pbar.update(start_epoch)
    for epoch in range(start_epoch, start_epoch + epochs):
        avg_loss = defaultdict(int)
        c = 0
        if verbose:
            epoch_pbar.update()
            pbar.reset()
        for batch_idx, data in enumerate(train_dl):

            c += 1
            
            imgs = data[0].to(device)
            labels = data[1].to(device)
            # imgs (B, C=1, W, H)
            optim.zero_grad()
            preds = model(imgs)
#             preds = torch.nn.functional.softmax(preds, dim=1)
            loss = {
                "loss": loss_fnc(preds, labels),
                # "acc": 
            }
            with torch.no_grad():
                for k, v in compute_accs(preds, labels.view(-1,1)).items():
                    loss[k] = v
                for k, v in loss.items():
                    avg_loss[k] += v.item()
            optimizer_idx += 1
            loss["loss"].backward()
            optim.step()
            if verbose:
                pbar.update()
        for k in avg_loss.keys():
            avg_loss[k] /= c
        eval_loss = evaluate(model, val_dl, nb=nb, verbose=False)
        if train_cb:
            train_cb(epoch=epoch, loss=avg_loss, eval_loss=eval_loss, model=model)
        save = False
        if eval_loss["loss"] < prev_best_loss:
            prev_best_loss = eval_loss["loss"]
            if save_best:
                save = True
        if epoch % save_freq == 0 or epoch == epochs - 1:
            save = True
        torch.save(
            dict(
                epoch=epoch,
                loss=avg_loss,
                eval_loss=eval_loss,
                prev_best_loss=prev_best_loss,
            ),
            osp.join(save_dir, f"history/ckpt_{epoch}.pt"),
        )
        if save:
            torch.save(
                dict(
                    model_state_dict=model.state_dict(),
                    optim_State_dict=optim.state_dict(),
                    epoch=epoch,
                    loss=avg_loss,
                    eval_loss=eval_loss,
                    prev_best_loss=prev_best_loss,
                ),
                osp.join(save_dir, f"weights/ckpt_{epoch}.pt"),
            )

In [139]:
from sklearn.utils import shuffle
shuffled_raw_dataset = shuffle(raw_dataset, random_state=3407)
train_raw_dataset = shuffled_raw_dataset[:60000]
val_raw_dataset = shuffled_raw_dataset[60000:60000+20000]
test_raw_dataset = shuffled_raw_dataset[60000+20000:]

In [1]:
# CHANGE THIS
!mkdir resnet18_ckpts
!mkdir resnet18_ckpts/history
!mkdir resnet18_ckpts/weights

In [143]:
torch.manual_seed(340)
batch_size=256
train_dl = torch.utils.data.DataLoader(train_raw_dataset, shuffle=True, batch_size=batch_size)
val_dl = torch.utils.data.DataLoader(val_raw_dataset, shuffle=True, batch_size=batch_size)
# CHOOSE YOUR MODEL AND SAVE DIRECTORY
model = models.resnet18(pretrained=False, num_classes=200).to(device)
save_dir = "resnet18_ckpts"
optim = torch.optim.Adam(model.parameters(), lr=1e-3)

In [144]:
def train_cb(epoch, loss, eval_loss,model):
    print(f"Epoch={epoch}, loss={loss}, eval_loss={eval_loss}")
train(model=model, optim=optim, epochs=10, train_dl=train_dl, val_dl=train_dl, train_cb=train_cb, save_dir=save_dir)

Current Epoch Progress: 100%|█████████▉| 234/235 [00:31<00:00,  7.48it/s]

Epoch=0, loss=defaultdict(<class 'int'>, {'loss': 4.54694629121334, 'top_1_acc': 0.07383089541120733, 'top_5_acc': 0.21560837765957447, 'top_10_acc': 0.31831781914893614}), eval_loss=defaultdict(<class 'int'>, {'loss': 4.722347981879052, 'top_1_acc': 0.06657247340425532, 'top_5_acc': 0.20848847519844135, 'top_10_acc': 0.31912677313419097})


Current Epoch Progress:  23%|██▎       | 54/235 [00:07<00:24,  7.39it/s] 

KeyboardInterrupt: 