In [2]:
!git clone https://github.com/Stanpie3/importance_sampling
!mv importance_sampling/* .
!rm -r importance_sampling

Cloning into 'importance_sampling'...
remote: Enumerating objects: 130, done.[K
remote: Counting objects: 100% (130/130), done.[K
remote: Compressing objects: 100% (94/94), done.[K
remote: Total 130 (delta 75), reused 69 (delta 32), pack-reused 0[K
Receiving objects: 100% (130/130), 810.39 KiB | 5.83 MiB/s, done.
Resolving deltas: 100% (75/75), done.


In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from common_utils import Accumulator
from torch_importance_sampling_tr import VarReductionCondition, get_g

In [4]:
def create_batches(p, N, k, b):
    all_points = np.arange(N)
    batches = []
    for j in range(k):
        batch = np.random.choice(all_points, size=b, replace=False, p=p)
        batches.append(batch)
        tmp = p[batch].sum()
        p[batch] = 0
        if j != k-1:
            p[p>0] += tmp / p[p>0].shape[0]
    return batches

def get_batches(loss, se, b):
    sorted_arg_loss = (-loss).argsort()
    N = loss.shape[0]
    p = np.zeros(N)
    for i in range(N):
        p[sorted_arg_loss[i]] = 1.0 / np.exp(np.log(se) / N)**(i+1)
    p /= p.sum()
    print(f'Minimal probability {p.min()}, maximal probability {p.max()}')
    return create_batches(p, N, N//b, b)

In [6]:
def train_batch_is(model,
                x_batch,
                y_batch,
                loss_fn,
                optimizer,
                accumulator):

    flag = False
    model.train()
    optimizer.zero_grad()

    batch_size = x_batch.shape[0]
    output = model(x_batch)
    loss = loss_fn(output, y_batch)
    loss = loss.mean()

    loss.backward()

    optimizer.step()

    n = len(output)
    with torch.no_grad():
        batch_loss = loss.mean().cpu().item()
        batch_acc_sum = (output.argmax(dim=1) == y_batch).sum().cpu().item()/n

    accumulator.average(
        train_loss = ( batch_loss, n),
        train_acc = ( batch_acc_sum, n),
        train_uniform_cnt = flag)

In [16]:
def train_full(model, train_dataloader, loss_fn, optimizer, n_epochs, eval = None, callback=None):
    epochs = tqdm(range(n_epochs), desc='Epochs', leave=True)
    X = torch.tensor(train_dataloader.dataset.data, dtype=torch.float32).transpose(1, -1)
    y = torch.tensor(train_dataloader.dataset.targets).long()
    batch_size = int(train_dataloader.batch_size)
    # n_batches = len(train_dataloader)

    if callback :
        callback.setMeta(
            large_batch = batch_size,
            n_epochs = n_epochs)
    n_batches = y.shape[0] // batch_size
    losses = torch.zeros(n_batches * batch_size)
    se_0, se_end = 10.0**2, 1.0
    se = se_0
    for i_epoch in epochs:
        accum = Accumulator()
        batch_indices = get_batches(np.abs(losses), se, batch_size)
        print(len(batch_indices), len(batch_indices[0]))
        for batch_idx in range(len(batch_indices)):
            train_batch_is( model,
                            X[batch_indices[batch_idx]].to(model.device),
                            y[batch_indices[batch_idx]].to(model.device),
                            loss_fn,
                            optimizer,
                            accum)
        model.eval()
        test_loss = 0
        correct = 0

        with torch.no_grad():
            losses = torch.zeros(n_batches * batch_size)
            for batch_idx in range(len(batch_indices)):
                output = model(X[batch_indices[batch_idx]].to(model.device))
                losses[batch_indices[batch_idx]] = loss_fn(output, y[batch_indices[batch_idx]].to(model.device)).cpu()
        se = se_0 * np.exp(np.log(se_end/se_0)/n_epochs) ** i_epoch
        print(losses)
        if callback :
            val_scores = eval(model) if eval else {}
            cb_dict = callback(**accum.getAll(), **val_scores)
            epochs.set_postfix(cb_dict)

In [17]:
def evaluate(model, dataloader, loss_fn):
    model.eval()
    logits = []
    targets = []
    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            output = model(X_batch.to(model.device)).cpu()
            logits.append(output)
            targets.append(y_batch)
    logits = torch.cat(logits)
    targets = torch.cat(targets)
    loss = loss_fn(logits, targets).mean().item()
    acc = (logits.argmax(dim=1) == targets).sum().item() / len(targets)
    return loss, acc

In [18]:
from data_loaders import train_val_dataloader, test_dataloader

train_dataloader, val_dataloader = train_val_dataloader(batch_size=64)
test_loader = test_dataloader(batch_size=64)


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [19]:
from models import ResNet50
from common_utils import UnCallBack

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = ResNet50()
model.to(device)
model.device = device

loss_fn = nn.CrossEntropyLoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3 )

callback = UnCallBack( info_list = ['train_loss', 'train_acc', 'train_w_loss', 'val_loss', 'val_acc', 'train_uniform_cnt'])

def eval_callback(model):
    loss, acc = evaluate(model, val_dataloader, loss_fn)
    return {"val_loss": loss, "val_acc": acc}

train_full(model, train_dataloader, loss_fn, optimizer, n_epochs=40, eval=eval_callback, callback=callback)
callback.save("callback")

cuda


Epochs:   0%|          | 0/40 [00:00<?, ?it/s]

torch.Size([50000, 3, 32, 32]) torch.Size([50000])
704
Minimal probability 9.306780869542679e-07, maximal probability 9.305923448407336e-05
781 64
torch.Size([49984])
tensor([0.2480, 0.3404, 0.5616,  ..., 0.5887, 0.6187, 0.6015])
Minimal probability 9.306780869542681e-07, maximal probability 9.305923448407337e-05
781 64
torch.Size([49984])
tensor([0.0385, 0.3548, 0.1676,  ..., 2.1717, 0.5843, 0.5424])
Minimal probability 1.0193872709318126e-06, maximal probability 9.084482523682283e-05
781 64
torch.Size([49984])
tensor([2.5582e-02, 3.5899e-04, 5.8519e-03,  ..., 9.2551e-01, 2.2986e-01,
        1.3537e+00])
Minimal probability 1.1159877789555003e-06, maximal probability 8.863830172651895e-05
781 64
torch.Size([49984])
tensor([3.3414e-02, 1.2802e-04, 1.2731e-02,  ..., 4.1161e-02, 1.1736e-02,
        3.1298e+00])
Minimal probability 1.2211056110374244e-06, maximal probability 8.644028995457406e-05
781 64
torch.Size([49984])
tensor([3.2140e-02, 3.4212e-05, 4.2000e-03,  ..., 4.7191e-01, 4.29

KeyboardInterrupt: 