In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from torchmetrics.functional import accuracy
from common_utils import Accumulator


def topIndices(vec, x, largest):
    sorted_idx = torch.argsort(vec, descending=largest)
    top_x_idxs, other_idxs = (sorted_idx[:x], sorted_idx[x:])
    return top_x_idxs, other_idxs

def irreducibleLoss(data=None, target=None, global_index=None, 
                              irreducible_loss_generator=None, target_device=None):
    
    if type(irreducible_loss_generator) is torch.Tensor:
        if (target_device is not None and irreducible_loss_generator.device != target_device ):
            irreducible_loss_generator = irreducible_loss_generator.to(device=target_device)
        irreducible_loss = irreducible_loss_generator[global_index]
    else:
        irreducible_loss = F.cross_entropy(irreducible_loss_generator(data), target, reduction="none")
    return irreducible_loss

class ReducibleLossSelection:
    bald = False

    def __call__(self, selected_batch_size, data=None, target=None, global_index=None, large_model=None,
        irreducible_loss_generator=None):

        with torch.no_grad():

            model_loss = F.cross_entropy( large_model(data), target, reduction="none")

            irreducible_loss = irreducibleLoss(data, target, global_index, irreducible_loss_generator, model_loss.device)

            reducible_loss = model_loss - irreducible_loss

            top_x_idxs, _ = topIndices( reducible_loss, selected_batch_size, largest=True )
            selected_irreducible_loss = irreducible_loss[top_x_idxs]

        return top_x_idxs, selected_irreducible_loss
    

def train_batch_rho_loss(self, batch, batch_idx, accumulator : Accumulator ):


    global_index, data, target = batch
    batch_size = len(data)
    selected_batch_size = max(1, int(batch_size * self.hparams.percent_train))

    if self.hparams.selection_train_mode:
        self.large_model.train()
    else:
        self.large_model.eval() # switch to eval mode to compute selection
    ### Selection Methods
    selected_indices, metrics_to_log, irreducible_loss = self.selection_method.__call__(
        selected_batch_size=selected_batch_size,
        data=data,
        target=target,
        global_index=global_index,
        large_model=self.large_model,
        irreducible_loss_generator=self.irreducible_loss_generator,
    )  # irreducible_loss will be None if the selection_method does not involve
    # irreducible_loss computation (e.g. uniform, CE loss selection)

    self.large_model.train()  # switch to eval mode to compute selection


    data, target = data[selected_indices], target[selected_indices]


    # Note from Jan about the following if statement: the irreducible losses
    # are already computed in the selection_method. I did not change
    # anything in this if-statement, though, because I don't know if we even
    # use it.
    if self.hparams.update_irreducible:
        opt_large_model, opt_irreducible_model = self.optimizers()

        opt_irreducible_model.zero_grad()
        logits = self.irreducible_loss_generator(data)
        irreducible_loss = self.loss(logits, target)
        self.manual_backward(irreducible_loss.mean())
        opt_irreducible_model.step()

        # logging
        preds = torch.argmax(F.log_softmax(logits, dim=1), dim=1)
        irreducible_acc = accuracy(preds, target)

    else:
        opt_large_model = self.optimizers()


    opt_large_model.zero_grad()
    logits = self.large_model(data)
    loss = self.loss(logits, target)
    self.manual_backward(loss.mean())
    opt_large_model.step()

    # training metrics
    preds = torch.argmax(F.log_softmax(logits, dim=1), dim=1)
    acc = accuracy(preds, target)

    

    self.log("train_loss", loss.mean(), on_step=True, on_epoch=True, logger=True)
    self.log("train_acc", acc, on_step=True, on_epoch=True, logger=True)

    accumulator.average( 
        train_loss = ( batch_loss, n) ,
        train_acc = ( batch_acc_sum, n) ,
        train_uniform_cnt = flag)
    
    accumulator.store( 
        max_p_i = max_p_i ,
        num_unique_points = num_unique_points)
    


    detailed_only_keys = metrics_to_log["detailed_only_keys"]
    metrics_to_log["step"] = (self.global_step)  # add step to the logging, also might help us concretely cross-corelate exact point in time.

        # batch statistics summary logging, depending on the metric that we ended up using.


In [None]:
def evaluate(model, dataloader, loss_fn):
    model.eval()
    logits = []
    targets = []
    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            output = model(X_batch.to(model.device)).cpu()
            logits.append(output)
            targets.append(y_batch)
    logits = torch.cat(logits)
    targets = torch.cat(targets)
    loss = loss_fn(logits, targets).mean().item()
    acc = (logits.argmax(dim=1) == targets).sum().item() / len(targets)
    return loss, acc



In [None]:
from data_loaders import train_val_dataloader, test_dataloader

train_dataloader, val_dataloader = train_val_dataloader(batch_size=120)#128*3
test_loader = test_dataloader(batch_size=120)


In [None]:
%load_ext autoreload

%aimport common_utils
%autoreload 1

In [None]:
%autoreload 1
from models import ResNet50
from common_utils import UnCallBack

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = ResNet50()
model.to(device)
model.device = device

loss_fn = nn.CrossEntropyLoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3 )

callback = UnCallBack( info_list = ['train_loss', 'train_acc', 'train_w_loss', 'val_loss', 'val_acc', 'train_uniform_cnt'])

def eval_callback(model):
    loss, acc =evaluate(model, val_dataloader, loss_fn)
    return {"val_loss": loss, "val_acc": acc}

train_full(model, train_dataloader, loss_fn, optimizer, n_epochs=40, eval=eval_callback, callback=callback, presample=3, tau_th = None)
callback.save("callback")

In [None]:
#from common_utils import UnCallBack
#callback = UnCallBack.load("callback.pickle")
#callback

In [None]:

loss, acc = evaluate(model, test_loader, loss_fn)
print(f'ResNet50, test loss: {loss}')
print(f'ResNet50, test accuracy: {acc}')

In [None]:
plt.figure(figsize=(15, 10))
plt.title('ResNet50, train and validation loss')
plt.xlabel('Number of epoch')
plt.ylabel('Loss')

epochs = np.arange(len(callback.train_loss)) + 1
plt.plot(epochs, callback.train_loss , label='Train')
plt.plot(epochs, callback.train_w_loss , label='Train weighted')
plt.plot(epochs, callback.val_loss , label='Validation')
plt.legend()
plt.grid(True)


In [None]:
print(f'Best loss on train: {np.min(callback.train_loss)}, on {np.argmin(callback.train_loss) + 1} epoch')
print(f'Best weighted loss on train: {np.min(callback.train_w_loss)}, on {np.argmin(callback.train_w_loss) + 1} epoch')
print(f'Best loss on validation: {np.min(callback.val_loss)}, on {np.argmin(callback.val_loss) + 1} epoch')

In [None]:
plt.figure(figsize=(15, 10))
plt.title('ResNet50, train and validation accuracy')
plt.xlabel('Number of epoch')
plt.ylabel('Accuracy')
epochs = np.arange(len(callback.train_acc)) + 1
plt.plot(epochs, callback.train_acc, label='Train')
plt.plot(epochs, callback.val_acc , label='Validation')
plt.legend()
plt.grid(True)

In [None]:
print(f'Best accuracy on train: {np.max(callback.train_acc)}, on {np.argmax(callback.train_acc) + 1} epoch')
print(f'Best accuracy on validation: {np.max(callback.val_acc)}, on {np.argmax(callback.val_acc) + 1} epoch')